bitflags_serde_legacy/
lib.rs

1#![no_std]
2
3//! This library is a generic implementation of `Serialize` and `Deserialize` that can be used by
4//! any flags type generated by `bitflags!`.
5//!
6//! # Usage
7//!
8//! Add `bitflags-serde-legacy` to your `Cargo.toml`:
9//!
10//! ```toml
11//! [dependencies.bitflags_serde_legacy]
12//! version = "0.1.1"
13//! ```
14//!
15//! Then, replace an existing `#[derive(Serialize, Deserialize)]` on your `bitflags!`
16//! generated types with the following manual implementations:
17//!
18//! ```
19//! use bitflags::bitflags;
20//!
21//! bitflags! {
22//!     // #[derive(Serialize, Deserialize)]
23//!     struct Flags: u32 {
24//!         const A = 0b00000001;
25//!         const B = 0b00000010;
26//!         const C = 0b00000100;
27//!         const ABC = Self::A.bits() | Self::B.bits() | Self::C.bits();
28//!     }
29//! }
30//!
31//! impl serde::Serialize for Flags {
32//!     fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
33//!         bitflags_serde_legacy::serialize(self, "Flags", serializer)
34//!     }
35//! }
36//!
37//! impl<'de> serde::Deserialize<'de> for Flags {
38//!     fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
39//!         bitflags_serde_legacy::deserialize("Flags", deserializer)
40//!     }
41//! }
42//! ```
43
44use core::fmt;
45use serde::{
46    de::{Error, MapAccess, Visitor},
47    ser::SerializeStruct,
48    Deserialize, Deserializer, Serialize, Serializer,
49};
50
51use bitflags::Flags;
52
53/// Serialize a flags type equivalently to how `#[derive(Serialize)]` on a flags type
54/// from `bitflags` `1.x` would.
55pub fn serialize<T: Flags, S: Serializer>(
56    flags: &T,
57    name: &'static str,
58    serializer: S,
59) -> Result<S::Ok, S::Error>
60where
61    T::Bits: Serialize,
62{
63    let mut serialize_struct = serializer.serialize_struct(name, 1)?;
64    serialize_struct.serialize_field("bits", &flags.bits())?;
65    serialize_struct.end()
66}
67
68/// Deserialize a flags type equivalently to how `#[derive(Deserialize)]` on a flags type
69/// from `bitflags` `1.x` would.
70pub fn deserialize<'de, T: Flags, D: Deserializer<'de>>(
71    name: &'static str,
72    deserializer: D,
73) -> Result<T, D::Error>
74where
75    T::Bits: Deserialize<'de>,
76{
77    struct BitsVisitor<T>(core::marker::PhantomData<T>);
78
79    impl<'de, T: Deserialize<'de>> Visitor<'de> for BitsVisitor<T> {
80        type Value = T;
81
82        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
83            formatter.write_str("a primitive bitflags value wrapped in a struct")
84        }
85
86        fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
87            let mut bits = None;
88
89            struct Field;
90
91            impl<'de> Deserialize<'de> for Field {
92                fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
93                    struct FieldVisitor;
94
95                    impl<'de> Visitor<'de> for FieldVisitor {
96                        type Value = Field;
97
98                        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
99                            formatter.write_str("field identifier")
100                        }
101
102                        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
103                        where
104                            E: Error,
105                        {
106                            match v {
107                                "bits" => Ok(Field),
108                                field => Err(E::unknown_field(field, &["bits"])),
109                            }
110                        }
111                    }
112
113                    deserializer.deserialize_identifier(FieldVisitor)
114                }
115            }
116
117            while map.next_key::<Field>()?.is_some() {
118                if bits.is_some() {
119                    return Err(Error::duplicate_field("bits"));
120                }
121
122                bits = Some(map.next_value()?);
123            }
124
125            bits.ok_or_else(|| Error::missing_field("bits"))
126        }
127    }
128
129    let bits = deserializer.deserialize_struct(name, &["bits"], BitsVisitor(Default::default()))?;
130
131    Ok(T::from_bits_retain(bits))
132}
133
134#[cfg(test)]
135mod tests {
136    bitflags1::bitflags! {
137        #[derive(serde_derive::Serialize, serde_derive::Deserialize)]
138        struct Flags1: u32 {
139            const A = 0b00000001;
140            const B = 0b00000010;
141            const C = 0b00000100;
142            const ABC = Self::A.bits | Self::B.bits | Self::C.bits;
143        }
144    }
145
146    bitflags::bitflags! {
147        #[derive(Debug, PartialEq, Eq)]
148        struct Flags2: u32 {
149            const A = 0b00000001;
150            const B = 0b00000010;
151            const C = 0b00000100;
152            const ABC = Self::A.bits() | Self::B.bits() | Self::C.bits();
153        }
154    }
155
156    impl serde::Serialize for Flags2 {
157        fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
158            crate::serialize(self, "Flags1", serializer)
159        }
160    }
161
162    impl<'de> serde::Deserialize<'de> for Flags2 {
163        fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
164            crate::deserialize("Flags1", deserializer)
165        }
166    }
167
168    #[test]
169    fn serde_compat() {
170        use serde_test::{assert_tokens, Configure as _, Token::*};
171
172        let expected = &[
173            Struct {
174                name: "Flags1",
175                len: 1,
176            },
177            Str("bits"),
178            U32(0b00000010),
179            StructEnd,
180        ];
181
182        assert_tokens(&(Flags1::B).readable(), expected);
183        assert_tokens(&(Flags2::B).readable(), expected);
184
185        assert_tokens(&(Flags1::B).compact(), expected);
186        assert_tokens(&(Flags2::B).compact(), expected);
187    }
188}