1use serde::Serialize;
2use serde::de::DeserializeOwned;
3
4pub const WIRE_FORMAT_HEADER: &str = "X-Lockbook-Wire-Format";
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum WireFormat {
9 Json,
10 Bincode,
11}
12
13impl WireFormat {
14 pub const CLIENT_DEFAULT: Self = WireFormat::Bincode;
15
16 pub fn as_str(self) -> &'static str {
17 match self {
18 WireFormat::Json => "json",
19 WireFormat::Bincode => "bincode",
20 }
21 }
22
23 pub fn from_header(value: Option<&str>) -> Self {
24 match value.map(|v| v.trim()) {
25 Some(v) if v.eq_ignore_ascii_case("bincode") => WireFormat::Bincode,
26 _ => WireFormat::Json,
27 }
28 }
29
30 pub fn serialize<T: Serialize + ?Sized>(self, value: &T) -> Result<Vec<u8>, WireError> {
31 match self {
32 WireFormat::Json => {
33 serde_json::to_vec(value).map_err(|e| WireError::Serialize(e.to_string()))
34 }
35 WireFormat::Bincode => {
36 bincode::serialize(value).map_err(|e| WireError::Serialize(e.to_string()))
37 }
38 }
39 }
40
41 pub fn deserialize<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T, WireError> {
42 match self {
43 WireFormat::Json => {
44 serde_json::from_slice(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
45 }
46 WireFormat::Bincode => {
47 bincode::deserialize(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
48 }
49 }
50 }
51}
52
53#[derive(Debug)]
54pub enum WireError {
55 Serialize(String),
56 Deserialize(String),
57}
58
59impl std::fmt::Display for WireError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 WireError::Serialize(e) => write!(f, "serialize: {e}"),
63 WireError::Deserialize(e) => write!(f, "deserialize: {e}"),
64 }
65 }
66}
67
68impl std::error::Error for WireError {}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn from_header_handles_common_inputs() {
76 assert_eq!(WireFormat::from_header(None), WireFormat::Json);
77 assert_eq!(WireFormat::from_header(Some("")), WireFormat::Json);
78 assert_eq!(WireFormat::from_header(Some("json")), WireFormat::Json);
79 assert_eq!(WireFormat::from_header(Some("bincode")), WireFormat::Bincode);
80 assert_eq!(WireFormat::from_header(Some("BINCODE")), WireFormat::Bincode);
81 assert_eq!(WireFormat::from_header(Some(" bincode ")), WireFormat::Bincode);
82 assert_eq!(WireFormat::from_header(Some("msgpack")), WireFormat::Json);
84 }
85
86 #[test]
87 fn roundtrip_byte_heavy_payload() {
88 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
89 struct Doc {
90 #[serde(with = "serde_bytes")]
91 bytes: Vec<u8>,
92 }
93
94 let doc = Doc { bytes: vec![0u8, 1, 2, 200, 255] };
95
96 for fmt in [WireFormat::Json, WireFormat::Bincode] {
97 let encoded = fmt.serialize(&doc).unwrap();
98 let decoded: Doc = fmt.deserialize(&encoded).unwrap();
99 assert_eq!(doc, decoded, "{fmt:?} roundtrip mismatch");
100 }
101 }
102
103 #[test]
104 fn bincode_is_more_compact_for_bytes_than_json() {
105 #[derive(serde::Serialize, serde::Deserialize)]
106 struct Doc {
107 #[serde(with = "serde_bytes")]
108 bytes: Vec<u8>,
109 }
110
111 let doc = Doc { bytes: (0u8..=255).collect() };
112 let json_len = WireFormat::Json.serialize(&doc).unwrap().len();
113 let bincode_len = WireFormat::Bincode.serialize(&doc).unwrap().len();
114 assert!(
115 bincode_len * 3 < json_len,
116 "bincode={bincode_len} json={json_len} — bincode should be much smaller"
117 );
118 }
119}