Skip to main content

acktor/codec/
common_codec.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use bytes::{Bytes, BytesMut};
5use prost::Message as _;
6
7use acktor_ipc_proto::{
8    optional_length_delimited_field_encoded_len,
9    utils::{ProtoOption, ProtoResult, ProtoResultType},
10};
11
12use super::error::{DecodeError, EncodeError};
13use super::protobuf_helper::LENGTH_DELIMITED_TAGS;
14use super::{Decode, DecodeContext, Encode, EncodeContext};
15
16impl<T> Encode for Box<T>
17where
18    T: Encode,
19{
20    fn encoded_len(&self) -> usize {
21        self.as_ref().encoded_len()
22    }
23
24    fn encode(
25        &self,
26        buf: &mut BytesMut,
27        ctx: Option<&dyn EncodeContext>,
28    ) -> Result<(), EncodeError> {
29        self.as_ref().encode(buf, ctx)
30    }
31}
32
33impl<T> Decode for Box<T>
34where
35    T: Decode,
36{
37    fn decode(buf: Bytes, ctx: Option<&dyn DecodeContext>) -> Result<Self, DecodeError> {
38        T::decode(buf, ctx).map(Box::new)
39    }
40}
41
42impl<T> Encode for Arc<T>
43where
44    T: Encode,
45{
46    fn encoded_len(&self) -> usize {
47        self.as_ref().encoded_len()
48    }
49
50    fn encode(
51        &self,
52        buf: &mut BytesMut,
53        ctx: Option<&dyn EncodeContext>,
54    ) -> Result<(), EncodeError> {
55        self.as_ref().encode(buf, ctx)
56    }
57}
58
59impl<T> Decode for Arc<T>
60where
61    T: Decode,
62{
63    fn decode(buf: Bytes, ctx: Option<&dyn DecodeContext>) -> Result<Self, DecodeError> {
64        T::decode(buf, ctx).map(Arc::new)
65    }
66}
67
68// Encode a `Result<T, E>`.
69//
70// The `Err` variant is lossy: the error is serialized as its `Display` string and decoded back
71// via `E::from(String)` (see the `Decode` impl below). Any structured data on the error type —
72// enum discriminants, numeric codes, nested fields — is collapsed to the rendered message.
73impl<T, E> Encode for Result<T, E>
74where
75    T: Encode,
76    E: Display,
77{
78    fn encoded_len(&self) -> usize {
79        let inner_len = match self {
80            Ok(ok) => ok.encoded_len(),
81            Err(err) => err.to_string().len(),
82        };
83        // Result is a `oneof` message, so use optional_length_delimited_field_encoded_len
84        optional_length_delimited_field_encoded_len(1, inner_len)
85    }
86
87    fn encode(
88        &self,
89        buf: &mut BytesMut,
90        ctx: Option<&dyn EncodeContext>,
91    ) -> Result<(), EncodeError> {
92        match self {
93            Ok(ok) => {
94                // field 1, wire type LengthDelimited (bytes)
95                buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[1]]);
96                // use encode_varint since it bypassed the capacity check and triggers resize if
97                // buf is a BytesMut
98                prost::encoding::encode_varint(ok.encoded_len() as u64, buf);
99                ok.encode(buf, ctx)?;
100            }
101            Err(err) => {
102                // field 2, wire type LengthDelimited (string)
103                let err_str = err.to_string();
104                buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[2]]);
105                // use encode_varint since it bypassed the capacity check and triggers resize if
106                // buf is a BytesMut
107                prost::encoding::encode_varint(err_str.len() as u64, buf);
108                buf.extend_from_slice(err_str.as_bytes());
109            }
110        }
111
112        Ok(())
113    }
114
115    fn encode_to_bytes(&self, ctx: Option<&dyn EncodeContext>) -> Result<Bytes, EncodeError> {
116        match self {
117            Ok(ok) => {
118                // field 1, wire type LengthDelimited (bytes)
119                let inner_len = ok.encoded_len();
120                let total = optional_length_delimited_field_encoded_len(1, inner_len);
121                let mut buf = BytesMut::with_capacity(total);
122                buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[1]]);
123                prost::encoding::encode_varint(inner_len as u64, &mut buf);
124                ok.encode(&mut buf, ctx)?;
125
126                Ok(buf.freeze())
127            }
128            Err(err) => {
129                // field 2, wire type LengthDelimited (string)
130                let err_string = err.to_string();
131                let inner_len = err_string.len();
132                let total = optional_length_delimited_field_encoded_len(1, inner_len);
133                let mut buf = BytesMut::with_capacity(total);
134                buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[2]]);
135                prost::encoding::encode_varint(inner_len as u64, &mut buf);
136                buf.extend_from_slice(err_string.as_bytes());
137
138                Ok(buf.freeze())
139            }
140        }
141    }
142}
143
144impl<T, E> Decode for Result<T, E>
145where
146    T: Decode,
147    E: From<String>,
148{
149    #[inline]
150    fn decode(buf: Bytes, ctx: Option<&dyn DecodeContext>) -> Result<Self, DecodeError> {
151        let result = ProtoResult::decode(buf)?;
152        match result.result {
153            Some(ProtoResultType::Ok(ok)) => Ok(Ok(T::decode(ok, ctx)?)),
154            Some(ProtoResultType::Err(err)) => Ok(Err(E::from(err))),
155            _ => Err("missing field `result` in the `Result` message".into()),
156        }
157    }
158}
159
160impl<T> Encode for Option<T>
161where
162    T: Encode,
163{
164    fn encoded_len(&self) -> usize {
165        match self {
166            // bytes field: 1 byte tag + varint length + data
167            Some(some) => {
168                let inner_len = some.encoded_len();
169                // Option::Some is a `optional` field, so use
170                // optional_length_delimited_field_encoded_len
171                optional_length_delimited_field_encoded_len(1, inner_len)
172            }
173            // empty message
174            None => 0,
175        }
176    }
177
178    fn encode(
179        &self,
180        buf: &mut BytesMut,
181        ctx: Option<&dyn EncodeContext>,
182    ) -> Result<(), EncodeError> {
183        if let Some(some) = self {
184            // field 1, wire type LengthDelimited (bytes)
185            buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[1]]);
186            // use encode_varint since it bypassed the capacity check and triggers resize if
187            // buf is a BytesMut
188            prost::encoding::encode_varint(some.encoded_len() as u64, buf);
189            some.encode(buf, ctx)?;
190        }
191
192        Ok(())
193    }
194
195    fn encode_to_bytes(&self, ctx: Option<&dyn EncodeContext>) -> Result<Bytes, EncodeError> {
196        match self {
197            Some(some) => {
198                // field 1, wire type LengthDelimited (bytes)
199                let inner_len = some.encoded_len();
200                let total = optional_length_delimited_field_encoded_len(1, inner_len);
201                let mut buf = BytesMut::with_capacity(total);
202                buf.extend_from_slice(&[LENGTH_DELIMITED_TAGS[1]]);
203                prost::encoding::encode_varint(inner_len as u64, &mut buf);
204                some.encode(&mut buf, ctx)?;
205
206                Ok(buf.freeze())
207            }
208            None => Ok(Bytes::new()),
209        }
210    }
211}
212
213impl<T> Decode for Option<T>
214where
215    T: Decode,
216{
217    fn decode(buf: Bytes, ctx: Option<&dyn DecodeContext>) -> Result<Self, DecodeError> {
218        let option = ProtoOption::decode(buf)?;
219        match option.option {
220            Some(bytes) => Ok(Some(T::decode(bytes, ctx)?)),
221            None => Ok(None),
222        }
223    }
224}