1extern crate alloc;
10
11use alloc::vec::Vec;
12
13use crate::error::XdrSerializeError;
14use facet_core::ScalarType;
15use facet_format::{FormatSerializer, ScalarValue, SerializeError};
16use facet_reflect::Peek;
17
18pub struct XdrSerializer {
20 out: Vec<u8>,
21 stack: Vec<ContainerState>,
23}
24
25#[derive(Debug)]
26enum ContainerState {
27 Struct,
28 Seq { count: usize, count_pos: usize },
29}
30
31impl XdrSerializer {
32 pub const fn new() -> Self {
34 Self {
35 out: Vec::new(),
36 stack: Vec::new(),
37 }
38 }
39
40 pub fn finish(mut self) -> Vec<u8> {
42 while let Some(state) = self.stack.pop() {
44 if let ContainerState::Seq { count, count_pos } = state {
45 self.patch_seq_count(count_pos, count);
46 }
47 }
48 self.out
49 }
50
51 fn write_padding(&mut self, data_len: usize) {
53 let pad = (4 - (data_len % 4)) % 4;
54 for _ in 0..pad {
55 self.out.push(0);
56 }
57 }
58
59 fn write_u32(&mut self, val: u32) {
61 self.out.extend_from_slice(&val.to_be_bytes());
62 }
63
64 fn write_u64(&mut self, val: u64) {
66 self.out.extend_from_slice(&val.to_be_bytes());
67 }
68
69 fn write_i32(&mut self, val: i32) {
71 self.out.extend_from_slice(&val.to_be_bytes());
72 }
73
74 fn write_i64(&mut self, val: i64) {
76 self.out.extend_from_slice(&val.to_be_bytes());
77 }
78
79 fn write_f32(&mut self, val: f32) {
81 self.out.extend_from_slice(&val.to_be_bytes());
82 }
83
84 fn write_f64(&mut self, val: f64) {
86 self.out.extend_from_slice(&val.to_be_bytes());
87 }
88
89 fn write_bool(&mut self, val: bool) {
91 self.write_u32(if val { 1 } else { 0 });
92 }
93
94 fn write_string(&mut self, s: &str) {
96 let bytes = s.as_bytes();
97 self.write_u32(bytes.len() as u32);
98 self.out.extend_from_slice(bytes);
99 self.write_padding(bytes.len());
100 }
101
102 fn write_opaque(&mut self, bytes: &[u8]) {
104 self.write_u32(bytes.len() as u32);
105 self.out.extend_from_slice(bytes);
106 self.write_padding(bytes.len());
107 }
108
109 fn begin_seq(&mut self) -> usize {
111 let count_pos = self.out.len();
112 self.write_u32(0); count_pos
114 }
115
116 fn patch_seq_count(&mut self, count_pos: usize, count: usize) {
118 let count_bytes = (count as u32).to_be_bytes();
119 self.out[count_pos..count_pos + 4].copy_from_slice(&count_bytes);
120 }
121}
122
123impl Default for XdrSerializer {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl FormatSerializer for XdrSerializer {
130 type Error = XdrSerializeError;
131
132 fn begin_struct(&mut self) -> Result<(), Self::Error> {
133 self.stack.push(ContainerState::Struct);
135 Ok(())
136 }
137
138 fn begin_option_some(&mut self) -> Result<(), Self::Error> {
139 self.write_u32(1);
141 Ok(())
142 }
143
144 fn serialize_none(&mut self) -> Result<(), Self::Error> {
145 self.write_u32(0);
147 Ok(())
148 }
149
150 fn field_key(&mut self, _key: &str) -> Result<(), Self::Error> {
151 Ok(())
153 }
154
155 fn end_struct(&mut self) -> Result<(), Self::Error> {
156 match self.stack.pop() {
157 Some(ContainerState::Struct) => Ok(()),
158 _ => Err(XdrSerializeError::new(
159 "end_struct called without matching begin_struct",
160 )),
161 }
162 }
163
164 fn begin_seq(&mut self) -> Result<(), Self::Error> {
165 let count_pos = self.begin_seq();
166 self.stack.push(ContainerState::Seq {
167 count: 0,
168 count_pos,
169 });
170 Ok(())
171 }
172
173 fn end_seq(&mut self) -> Result<(), Self::Error> {
174 match self.stack.pop() {
175 Some(ContainerState::Seq { count, count_pos }) => {
176 self.patch_seq_count(count_pos, count);
177 Ok(())
178 }
179 _ => Err(XdrSerializeError::new(
180 "end_seq called without matching begin_seq",
181 )),
182 }
183 }
184
185 fn scalar(&mut self, scalar: ScalarValue<'_>) -> Result<(), Self::Error> {
186 if let Some(ContainerState::Seq { count, .. }) = self.stack.last_mut() {
188 *count += 1;
189 }
190
191 match scalar {
192 ScalarValue::Null | ScalarValue::Unit => {
193 self.write_u32(0);
196 }
197 ScalarValue::Bool(v) => self.write_bool(v),
198 ScalarValue::Char(c) => {
199 let mut buf = [0u8; 4];
200 self.write_string(c.encode_utf8(&mut buf));
201 }
202 ScalarValue::U64(n) => {
203 if n <= u32::MAX as u64 {
205 self.write_u32(n as u32);
206 } else {
207 self.write_u64(n);
208 }
209 }
210 ScalarValue::I64(n) => {
211 if n >= i32::MIN as i64 && n <= i32::MAX as i64 {
213 self.write_i32(n as i32);
214 } else {
215 self.write_i64(n);
216 }
217 }
218 ScalarValue::U128(_n) => {
219 return Err(XdrSerializeError::new("XDR does not support u128"));
220 }
221 ScalarValue::I128(_n) => {
222 return Err(XdrSerializeError::new("XDR does not support i128"));
223 }
224 ScalarValue::F64(n) => {
225 let as_f32 = n as f32;
228 if as_f32 as f64 == n && n.is_finite() {
229 self.write_f32(as_f32);
230 } else {
231 self.write_f64(n);
232 }
233 }
234 ScalarValue::Str(s) => self.write_string(&s),
235 ScalarValue::Bytes(bytes) => self.write_opaque(&bytes),
236 }
237 Ok(())
238 }
239
240 fn typed_scalar(
241 &mut self,
242 scalar_type: ScalarType,
243 value: Peek<'_, '_>,
244 ) -> Result<(), Self::Error> {
245 if let Some(ContainerState::Seq { count, .. }) = self.stack.last_mut() {
247 *count += 1;
248 }
249
250 match scalar_type {
251 ScalarType::Unit => {
252 }
254 ScalarType::Bool => {
255 let v = *value.get::<bool>().unwrap();
256 self.write_bool(v);
257 }
258 ScalarType::Char => {
259 let c = *value.get::<char>().unwrap();
261 self.write_u32(c as u32);
262 }
263 ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
264 if let Some(s) = value.as_str() {
265 self.write_string(s);
266 }
267 }
268 ScalarType::F32 => {
269 let v = *value.get::<f32>().unwrap();
270 self.write_f32(v);
271 }
272 ScalarType::F64 => {
273 let v = *value.get::<f64>().unwrap();
274 self.write_f64(v);
275 }
276 ScalarType::U8 => {
277 let v = *value.get::<u8>().unwrap();
278 self.write_u32(v as u32);
279 }
280 ScalarType::U16 => {
281 let v = *value.get::<u16>().unwrap();
282 self.write_u32(v as u32);
283 }
284 ScalarType::U32 => {
285 let v = *value.get::<u32>().unwrap();
286 self.write_u32(v);
287 }
288 ScalarType::U64 => {
289 let v = *value.get::<u64>().unwrap();
290 self.write_u64(v);
291 }
292 ScalarType::U128 => {
293 return Err(XdrSerializeError::new("XDR does not support u128"));
294 }
295 ScalarType::USize => {
296 let v = *value.get::<usize>().unwrap();
297 self.write_u64(v as u64);
298 }
299 ScalarType::I8 => {
300 let v = *value.get::<i8>().unwrap();
301 self.write_i32(v as i32);
302 }
303 ScalarType::I16 => {
304 let v = *value.get::<i16>().unwrap();
305 self.write_i32(v as i32);
306 }
307 ScalarType::I32 => {
308 let v = *value.get::<i32>().unwrap();
309 self.write_i32(v);
310 }
311 ScalarType::I64 => {
312 let v = *value.get::<i64>().unwrap();
313 self.write_i64(v);
314 }
315 ScalarType::I128 => {
316 return Err(XdrSerializeError::new("XDR does not support i128"));
317 }
318 ScalarType::ISize => {
319 let v = *value.get::<isize>().unwrap();
320 self.write_i64(v as i64);
321 }
322 _ => {
323 if let Some(s) = value.as_str() {
325 self.write_string(s);
326 }
327 }
328 }
329 Ok(())
330 }
331}
332
333pub fn to_vec<'facet, T>(value: &T) -> Result<Vec<u8>, SerializeError<XdrSerializeError>>
335where
336 T: facet_core::Facet<'facet>,
337{
338 let mut ser = XdrSerializer::new();
339 facet_format::serialize_root(&mut ser, facet_reflect::Peek::new(value))?;
340 Ok(ser.finish())
341}
342
343pub fn to_writer<'facet, T, W>(writer: &mut W, value: &T) -> Result<(), std::io::Error>
345where
346 T: facet_core::Facet<'facet>,
347 W: std::io::Write,
348{
349 let bytes = to_vec(value).map_err(|e| std::io::Error::other(e.to_string()))?;
350 writer.write_all(&bytes)
351}