Skip to main content

mlua_extras/
ser.rs

1use serde::ser::{
2    self, Error as _, Serialize, SerializeMap, SerializeSeq, SerializeStruct, Serializer,
3};
4use std::fmt::{self, Write};
5
6#[derive(Debug)]
7pub struct Error(String);
8
9impl ser::Error for Error {
10    fn custom<T: fmt::Display>(msg: T) -> Self {
11        Error(msg.to_string())
12    }
13}
14
15impl fmt::Display for Error {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.write_str(&self.0)
18    }
19}
20
21impl std::error::Error for Error {}
22
23/// Convert any type that is Serializable into it's lua value representation
24pub fn to_lua_repr<T: Serialize>(value: &T) -> Result<String, Error> {
25    let mut out = String::new();
26    value.serialize(&mut LuaSerializer { out: &mut out })?;
27    Ok(out)
28}
29
30struct LuaSerializer<'a> {
31    out: &'a mut String,
32}
33
34impl<'a, 'b> Serializer for &'a mut LuaSerializer<'b> {
35    type Ok = ();
36    type Error = Error;
37
38    type SerializeSeq = LuaSeq<'a, 'b>;
39    type SerializeTuple = LuaSeq<'a, 'b>;
40    type SerializeTupleStruct = LuaSeq<'a, 'b>;
41    type SerializeTupleVariant = LuaSeq<'a, 'b>;
42    type SerializeMap = LuaMap<'a, 'b>;
43    type SerializeStruct = LuaMap<'a, 'b>;
44    type SerializeStructVariant = LuaMap<'a, 'b>;
45
46    fn serialize_bool(self, v: bool) -> Result<(), Error> {
47        self.out.push_str(if v { "true" } else { "false" });
48        Ok(())
49    }
50
51    fn serialize_i8(self, v: i8) -> Result<(), Error> {
52        write!(self.out, "{v}").map_err(Error::custom)
53    }
54    fn serialize_i16(self, v: i16) -> Result<(), Error> {
55        write!(self.out, "{v}").map_err(Error::custom)
56    }
57    fn serialize_i32(self, v: i32) -> Result<(), Error> {
58        write!(self.out, "{v}").map_err(Error::custom)
59    }
60    fn serialize_i64(self, v: i64) -> Result<(), Error> {
61        write!(self.out, "{v}").map_err(Error::custom)
62    }
63
64    fn serialize_u8(self, v: u8) -> Result<(), Error> {
65        write!(self.out, "{v}").map_err(Error::custom)
66    }
67    fn serialize_u16(self, v: u16) -> Result<(), Error> {
68        write!(self.out, "{v}").map_err(Error::custom)
69    }
70    fn serialize_u32(self, v: u32) -> Result<(), Error> {
71        write!(self.out, "{v}").map_err(Error::custom)
72    }
73    fn serialize_u64(self, v: u64) -> Result<(), Error> {
74        write!(self.out, "{v}").map_err(Error::custom)
75    }
76
77    fn serialize_f32(self, v: f32) -> Result<(), Error> {
78        if v.is_nan() {
79            self.out.push_str("0/0");
80        } else if v.is_infinite() {
81            if  v.is_sign_positive() {
82                self.out.push_str("math.huge");
83            } else {
84                self.out.push_str("-math.huge");
85            }
86        } else {
87            let mut buf = ryu::Buffer::new();
88            let s = buf.format_finite(v);
89
90            if s.contains('.') {
91                self.out.push_str(s);
92            } else {
93                let s = format!("{s}.0");
94                self.out.push_str(&s);
95            }
96        }
97
98        Ok(())
99    }
100
101    fn serialize_f64(self, v: f64) -> Result<(), Error> {
102        let roundedf32 = v as f32;
103        if (roundedf32 as f64) == v {
104            self.serialize_f32(roundedf32)?;
105        } else {
106            if v.is_nan() {
107                self.out.push_str("0/0");
108            } else if v.is_infinite() {
109                if  v.is_sign_positive() {
110                    self.out.push_str("math.huge");
111                } else {
112                    self.out.push_str("-math.huge");
113                }
114            } else {
115                let mut buf = ryu::Buffer::new();
116                let s = buf.format_finite(v);
117
118                if s.contains('.') {
119                    self.out.push_str(s);
120                } else {
121                    let s = format!("{s}.0");
122                    self.out.push_str(&s);
123                }
124            }
125        }
126
127        Ok(())
128    }
129
130    fn serialize_char(self, v: char) -> Result<(), Error> {
131        self.serialize_str(&v.to_string())
132    }
133
134    fn serialize_str(self, v: &str) -> Result<(), Error> {
135        self.out.push('"');
136        for ch in v.chars() {
137            match ch {
138                '\\' => self.out.push_str("\\\\"),
139                '"' => self.out.push_str("\\\""),
140                '\n' => self.out.push_str("\\n"),
141                '\r' => self.out.push_str("\\r"),
142                '\t' => self.out.push_str("\\t"),
143                c => self.out.push(c),
144            }
145        }
146        self.out.push('"');
147        Ok(())
148    }
149
150    fn serialize_bytes(self, v: &[u8]) -> Result<(), Error> {
151        let mut seq = self.serialize_seq(Some(v.len()))?;
152        for b in v {
153            seq.serialize_element(b)?;
154        }
155        seq.end()
156    }
157
158    fn serialize_none(self) -> Result<(), Error> {
159        self.out.push_str("nil");
160        Ok(())
161    }
162
163    fn serialize_some<T: ?Sized + Serialize>(self, value: &T) -> Result<(), Error> {
164        value.serialize(self)
165    }
166
167    fn serialize_unit(self) -> Result<(), Error> {
168        self.out.push_str("nil");
169        Ok(())
170    }
171
172    fn serialize_unit_struct(self, _name: &'static str) -> Result<(), Error> {
173        self.out.push_str("nil");
174        Ok(())
175    }
176
177    fn serialize_unit_variant(
178        self,
179        _name: &'static str,
180        _variant_index: u32,
181        variant: &'static str,
182    ) -> Result<(), Error> {
183        self.serialize_str(variant)
184    }
185
186    fn serialize_newtype_struct<T: ?Sized + Serialize>(
187        self,
188        _name: &'static str,
189        value: &T,
190    ) -> Result<(), Error> {
191        value.serialize(self)
192    }
193
194    fn serialize_newtype_variant<T: ?Sized + Serialize>(
195        self,
196        _name: &'static str,
197        _variant_index: u32,
198        variant: &'static str,
199        value: &T,
200    ) -> Result<(), Error> {
201        self.out.push('{');
202        write!(self.out, "{} = ", lua_ident_or_bracket(variant)).map_err(Error::custom)?;
203        value.serialize(&mut *self)?;
204        self.out.push('}');
205        Ok(())
206    }
207
208    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Error> {
209        self.out.push('{');
210        Ok(LuaSeq {
211            ser: self,
212            first: true,
213            closes_twice: false,
214        })
215    }
216
217    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Error> {
218        self.serialize_seq(None)
219    }
220
221    fn serialize_tuple_struct(
222        self,
223        _name: &'static str,
224        _len: usize,
225    ) -> Result<Self::SerializeTupleStruct, Error> {
226        self.serialize_seq(None)
227    }
228
229    fn serialize_tuple_variant(
230        self,
231        _name: &'static str,
232        _variant_index: u32,
233        variant: &'static str,
234        _len: usize,
235    ) -> Result<Self::SerializeTupleVariant, Error> {
236        self.out.push('{');
237        write!(self.out, "{} = ", lua_ident_or_bracket(variant)).map_err(Error::custom)?;
238        self.out.push('{');
239        Ok(LuaSeq {
240            ser: self,
241            first: true,
242            closes_twice: true,
243        })
244    }
245
246    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Error> {
247        self.out.push('{');
248        Ok(LuaMap {
249            ser: self,
250            first: true,
251            closes_twice: false,
252        })
253    }
254
255    fn serialize_struct(
256        self,
257        _name: &'static str,
258        _len: usize,
259    ) -> Result<Self::SerializeStruct, Error> {
260        self.serialize_map(None)
261    }
262
263    fn serialize_struct_variant(
264        self,
265        _name: &'static str,
266        _variant_index: u32,
267        variant: &'static str,
268        _len: usize,
269    ) -> Result<Self::SerializeStructVariant, Error> {
270        self.out.push('{');
271        write!(self.out, "{} = {{", lua_ident_or_bracket(variant)).map_err(Error::custom)?;
272        Ok(LuaMap {
273            ser: self,
274            first: true,
275            closes_twice: true,
276        })
277    }
278}
279
280struct LuaSeq<'a, 'b> {
281    ser: &'a mut LuaSerializer<'b>,
282    first: bool,
283    closes_twice: bool,
284}
285
286impl<'a, 'b> LuaSeq<'a, 'b> {
287    fn comma(&mut self) {
288        if !self.first {
289            self.ser.out.push_str(", ");
290        }
291        self.first = false;
292    }
293}
294
295impl<'a, 'b> SerializeSeq for LuaSeq<'a, 'b> {
296    type Ok = ();
297    type Error = Error;
298
299    fn serialize_element<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Error> {
300        self.comma();
301        value.serialize(&mut *self.ser)
302    }
303
304    fn end(self) -> Result<(), Error> {
305        self.ser.out.push('}');
306        if self.closes_twice {
307            self.ser.out.push('}');
308        }
309        Ok(())
310    }
311}
312
313impl<'a, 'b> ser::SerializeTuple for LuaSeq<'a, 'b> {
314    type Ok = ();
315    type Error = Error;
316    fn serialize_element<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Error> {
317        SerializeSeq::serialize_element(self, value)
318    }
319    fn end(self) -> Result<(), Error> {
320        SerializeSeq::end(self)
321    }
322}
323
324impl<'a, 'b> ser::SerializeTupleStruct for LuaSeq<'a, 'b> {
325    type Ok = ();
326    type Error = Error;
327    fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Error> {
328        SerializeSeq::serialize_element(self, value)
329    }
330    fn end(self) -> Result<(), Error> {
331        SerializeSeq::end(self)
332    }
333}
334
335impl<'a, 'b> ser::SerializeTupleVariant for LuaSeq<'a, 'b> {
336    type Ok = ();
337    type Error = Error;
338    fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Error> {
339        SerializeSeq::serialize_element(self, value)
340    }
341    fn end(self) -> Result<(), Error> {
342        SerializeSeq::end(self)
343    }
344}
345
346struct LuaMap<'a, 'b> {
347    ser: &'a mut LuaSerializer<'b>,
348    first: bool,
349    closes_twice: bool,
350}
351
352impl<'a, 'b> LuaMap<'a, 'b> {
353    fn comma(&mut self) {
354        if !self.first {
355            self.ser.out.push_str(", ");
356        }
357        self.first = false;
358    }
359}
360
361impl<'a, 'b> SerializeMap for LuaMap<'a, 'b> {
362    type Ok = ();
363    type Error = Error;
364
365    fn serialize_key<T: ?Sized + Serialize>(&mut self, key: &T) -> Result<(), Error> {
366        self.comma();
367        key.serialize(KeySerializer { out: self.ser.out })?;
368        self.ser.out.push_str(" = ");
369        Ok(())
370    }
371
372    fn serialize_value<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Error> {
373        value.serialize(&mut *self.ser)
374    }
375
376    fn end(self) -> Result<(), Error> {
377        self.ser.out.push('}');
378        if self.closes_twice {
379            self.ser.out.push('}');
380        }
381        Ok(())
382    }
383}
384
385impl<'a, 'b> SerializeStruct for LuaMap<'a, 'b> {
386    type Ok = ();
387    type Error = Error;
388
389    fn serialize_field<T: ?Sized + Serialize>(
390        &mut self,
391        key: &'static str,
392        value: &T,
393    ) -> Result<(), Error> {
394        self.comma();
395        self.ser.out.push_str(lua_ident_or_bracket(key));
396        self.ser.out.push_str(" = ");
397        value.serialize(&mut *self.ser)
398    }
399
400    fn end(self) -> Result<(), Error> {
401        SerializeMap::end(self)
402    }
403}
404
405impl<'a, 'b> ser::SerializeStructVariant for LuaMap<'a, 'b> {
406    type Ok = ();
407    type Error = Error;
408
409    fn serialize_field<T: ?Sized + Serialize>(
410        &mut self,
411        key: &'static str,
412        value: &T,
413    ) -> Result<(), Error> {
414        SerializeStruct::serialize_field(self, key, value)
415    }
416
417    fn end(self) -> Result<(), Error> {
418        SerializeMap::end(self)
419    }
420}
421
422struct KeySerializer<'a> {
423    out: &'a mut String,
424}
425
426impl<'a> Serializer for KeySerializer<'a> {
427    type Ok = ();
428    type Error = Error;
429
430    type SerializeSeq = ser::Impossible<(), Error>;
431    type SerializeTuple = ser::Impossible<(), Error>;
432    type SerializeTupleStruct = ser::Impossible<(), Error>;
433    type SerializeTupleVariant = ser::Impossible<(), Error>;
434    type SerializeMap = ser::Impossible<(), Error>;
435    type SerializeStruct = ser::Impossible<(), Error>;
436    type SerializeStructVariant = ser::Impossible<(), Error>;
437
438    fn serialize_str(self, v: &str) -> Result<(), Error> {
439        self.out.push_str(lua_ident_or_bracket(v));
440        Ok(())
441    }
442
443    fn serialize_bool(self, v: bool) -> Result<(), Error> {
444        write!(self.out, "[{}]", if v { "true" } else { "false" }).map_err(Error::custom)
445    }
446
447    fn serialize_i64(self, v: i64) -> Result<(), Error> {
448        write!(self.out, "[{v}]").map_err(Error::custom)
449    }
450
451    fn serialize_u64(self, v: u64) -> Result<(), Error> {
452        write!(self.out, "[{v}]").map_err(Error::custom)
453    }
454
455    fn serialize_unit(self) -> Result<(), Error> {
456        Err(Error::custom("unit cannot be a Lua table key"))
457    }
458
459    fn serialize_none(self) -> Result<(), Error> {
460        Err(Error::custom("nil cannot be a Lua table key"))
461    }
462
463    fn serialize_some<T: ?Sized + Serialize>(self, value: &T) -> Result<(), Error> {
464        value.serialize(self)
465    }
466
467    fn serialize_i8(self, v: i8) -> Result<(), Error> {
468        self.serialize_i64(v as i64)
469    }
470    fn serialize_i16(self, v: i16) -> Result<(), Error> {
471        self.serialize_i64(v as i64)
472    }
473    fn serialize_i32(self, v: i32) -> Result<(), Error> {
474        self.serialize_i64(v as i64)
475    }
476    fn serialize_u8(self, v: u8) -> Result<(), Error> {
477        self.serialize_u64(v as u64)
478    }
479    fn serialize_u16(self, v: u16) -> Result<(), Error> {
480        self.serialize_u64(v as u64)
481    }
482    fn serialize_u32(self, v: u32) -> Result<(), Error> {
483        self.serialize_u64(v as u64)
484    }
485
486    fn serialize_f32(self, _v: f32) -> Result<(), Error> {
487        Err(Error::custom("float keys not supported"))
488    }
489    fn serialize_f64(self, _v: f64) -> Result<(), Error> {
490        Err(Error::custom("float keys not supported"))
491    }
492    fn serialize_char(self, v: char) -> Result<(), Error> {
493        self.serialize_str(&v.to_string())
494    }
495    fn serialize_bytes(self, _v: &[u8]) -> Result<(), Error> {
496        Err(Error::custom("bytes keys not supported"))
497    }
498    fn serialize_unit_struct(self, _: &'static str) -> Result<(), Error> {
499        self.serialize_unit()
500    }
501    fn serialize_unit_variant(
502        self,
503        _: &'static str,
504        _: u32,
505        variant: &'static str,
506    ) -> Result<(), Error> {
507        self.serialize_str(variant)
508    }
509    fn serialize_newtype_struct<T: ?Sized + Serialize>(
510        self,
511        _: &'static str,
512        value: &T,
513    ) -> Result<(), Error> {
514        value.serialize(self)
515    }
516    fn serialize_newtype_variant<T: ?Sized + Serialize>(
517        self,
518        _: &'static str,
519        _: u32,
520        _: &'static str,
521        _: &T,
522    ) -> Result<(), Error> {
523        Err(Error::custom("complex keys not supported"))
524    }
525    fn serialize_seq(self, _: Option<usize>) -> Result<Self::SerializeSeq, Error> {
526        Err(Error::custom("complex keys not supported"))
527    }
528    fn serialize_tuple(self, _: usize) -> Result<Self::SerializeTuple, Error> {
529        Err(Error::custom("complex keys not supported"))
530    }
531    fn serialize_tuple_struct(
532        self,
533        _: &'static str,
534        _: usize,
535    ) -> Result<Self::SerializeTupleStruct, Error> {
536        Err(Error::custom("complex keys not supported"))
537    }
538    fn serialize_tuple_variant(
539        self,
540        _: &'static str,
541        _: u32,
542        _: &'static str,
543        _: usize,
544    ) -> Result<Self::SerializeTupleVariant, Error> {
545        Err(Error::custom("complex keys not supported"))
546    }
547    fn serialize_map(self, _: Option<usize>) -> Result<Self::SerializeMap, Error> {
548        Err(Error::custom("complex keys not supported"))
549    }
550    fn serialize_struct(self, _: &'static str, _: usize) -> Result<Self::SerializeStruct, Error> {
551        Err(Error::custom("complex keys not supported"))
552    }
553    fn serialize_struct_variant(
554        self,
555        _: &'static str,
556        _: u32,
557        _: &'static str,
558        _: usize,
559    ) -> Result<Self::SerializeStructVariant, Error> {
560        Err(Error::custom("complex keys not supported"))
561    }
562}
563
564fn lua_ident_or_bracket(s: &str) -> &str {
565    if is_lua_ident(s) {
566        s
567    } else {
568        // This helper returns only &str, so callers that need brackets for
569        // arbitrary strings should handle that separately if needed.
570        // For simple struct fields, keep them identifier-safe or rename them.
571        panic!("non-identifier key: {s}");
572    }
573}
574
575fn is_lua_ident(s: &str) -> bool {
576    let mut chars = s.chars();
577    match chars.next() {
578        Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
579        _ => return false,
580    }
581    chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
582}
583
584#[cfg(test)]
585mod test {
586    use std::collections::BTreeMap;
587
588    use serde::Serialize;
589
590    use super::*;
591
592    /// Expect {type} with {value} to serialize to {expected}
593    ///
594    /// `expect!(type, &value, expected)`
595    ///
596    /// ```
597    /// exptect!(u8, &3, "3");
598    /// ```
599    macro_rules! expect {
600        ( $t:ty, $value:expr, $expected:literal ) => {
601            let result = to_lua_repr::<$t>($value).unwrap();
602            assert_eq!(result, $expected);
603        };
604    }
605
606    #[test]
607    fn test_rust_literal_values() {
608        expect!(u8, &3, "3");
609        expect!(u16, &3, "3");
610        expect!(u32, &3, "3");
611        expect!(u64, &3, "3");
612        expect!(i8, &3, "3");
613        expect!(i16, &3, "3");
614        expect!(i32, &3, "3");
615        expect!(i64, &3, "3");
616        expect!(f32, &3.0, "3.0");
617        expect!(f32, &3.1, "3.1");
618        expect!(f64, &3.0, "3.0");
619        expect!(f64, &3.1, "3.1");
620
621        expect!(&str, &"test", "\"test\"");
622        expect!(String, &"test".to_string(), "\"test\"");
623
624        expect!(bool, &true, "true");
625        expect!(bool, &false, "false");
626    }
627
628    #[test]
629    fn test_rust_builtin_types() {
630        expect!(Option<bool>, &None, "nil");
631        expect!(Option<bool>, &Some(true), "true");
632        expect!(&[u8], &b"test".as_slice(), "{116, 101, 115, 116}");
633        expect!(BTreeMap<&str, bool>, &BTreeMap::from([("test", false), ("test2", true)]), "{test = false, test2 = true}");
634        expect!(
635            Vec<&str>,
636            &Vec::from(["test", "test2"]),
637            "{\"test\", \"test2\"}"
638        );
639    }
640
641    #[derive(Serialize)]
642    struct Person {
643        name: String,
644        age: usize,
645    }
646
647    #[test]
648    fn test_rust_custom_types() {
649        expect!(
650            Person,
651            &Person {
652                name: "Test".into(),
653                age: 10
654            },
655            "{name = \"Test\", age = 10}"
656        );
657    }
658}