Skip to main content

lb_rs/model/
wire.rs

1use serde::Serialize;
2use serde::de::DeserializeOwned;
3
4/// Header used by both client and server to negotiate body encoding.
5pub const WIRE_FORMAT_HEADER: &str = "X-Lockbook-Wire-Format";
6
7// https://github.com/lockbook/lockbook/issues/4768
8pub const OS_HEADER: &str = "X-Lockbook-OS";
9
10pub const CLIENT_HEADER: &str = "X-Lockbook-Client";
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum WireFormat {
14    Json,
15    Bincode,
16}
17
18impl WireFormat {
19    pub const CLIENT_DEFAULT: Self = WireFormat::Bincode;
20
21    pub fn as_str(self) -> &'static str {
22        match self {
23            WireFormat::Json => "json",
24            WireFormat::Bincode => "bincode",
25        }
26    }
27
28    pub fn from_header(value: Option<&str>) -> Self {
29        match value.map(|v| v.trim()) {
30            Some(v) if v.eq_ignore_ascii_case("bincode") => WireFormat::Bincode,
31            _ => WireFormat::Json,
32        }
33    }
34
35    pub fn serialize<T: Serialize + ?Sized>(self, value: &T) -> Result<Vec<u8>, WireError> {
36        match self {
37            WireFormat::Json => {
38                serde_json::to_vec(value).map_err(|e| WireError::Serialize(e.to_string()))
39            }
40            WireFormat::Bincode => {
41                bincode::serialize(value).map_err(|e| WireError::Serialize(e.to_string()))
42            }
43        }
44    }
45
46    pub fn deserialize<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T, WireError> {
47        match self {
48            WireFormat::Json => {
49                serde_json::from_slice(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
50            }
51            WireFormat::Bincode => {
52                bincode::deserialize(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
53            }
54        }
55    }
56}
57
58#[derive(Debug)]
59pub enum WireError {
60    Serialize(String),
61    Deserialize(String),
62}
63
64impl std::fmt::Display for WireError {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            WireError::Serialize(e) => write!(f, "serialize: {e}"),
68            WireError::Deserialize(e) => write!(f, "deserialize: {e}"),
69        }
70    }
71}
72
73impl std::error::Error for WireError {}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn from_header_handles_common_inputs() {
81        assert_eq!(WireFormat::from_header(None), WireFormat::Json);
82        assert_eq!(WireFormat::from_header(Some("")), WireFormat::Json);
83        assert_eq!(WireFormat::from_header(Some("json")), WireFormat::Json);
84        assert_eq!(WireFormat::from_header(Some("bincode")), WireFormat::Bincode);
85        assert_eq!(WireFormat::from_header(Some("BINCODE")), WireFormat::Bincode);
86        assert_eq!(WireFormat::from_header(Some(" bincode ")), WireFormat::Bincode);
87        // Unknown values fall back to JSON.
88        assert_eq!(WireFormat::from_header(Some("msgpack")), WireFormat::Json);
89    }
90
91    #[test]
92    fn roundtrip_byte_heavy_payload() {
93        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
94        struct Doc {
95            #[serde(with = "serde_bytes")]
96            bytes: Vec<u8>,
97        }
98
99        let doc = Doc { bytes: vec![0u8, 1, 2, 200, 255] };
100
101        for fmt in [WireFormat::Json, WireFormat::Bincode] {
102            let encoded = fmt.serialize(&doc).unwrap();
103            let decoded: Doc = fmt.deserialize(&encoded).unwrap();
104            assert_eq!(doc, decoded, "{fmt:?} roundtrip mismatch");
105        }
106    }
107
108    #[test]
109    fn bincode_is_more_compact_for_bytes_than_json() {
110        #[derive(serde::Serialize, serde::Deserialize)]
111        struct Doc {
112            #[serde(with = "serde_bytes")]
113            bytes: Vec<u8>,
114        }
115
116        let doc = Doc { bytes: (0u8..=255).collect() };
117        let json_len = WireFormat::Json.serialize(&doc).unwrap().len();
118        let bincode_len = WireFormat::Bincode.serialize(&doc).unwrap().len();
119        assert!(
120            bincode_len * 3 < json_len,
121            "bincode={bincode_len} json={json_len} — bincode should be much smaller"
122        );
123    }
124}