1use serde::Serialize;
2use serde::de::DeserializeOwned;
3
4pub const WIRE_FORMAT_HEADER: &str = "X-Lockbook-Wire-Format";
6
7pub const OS_HEADER: &str = "X-Lockbook-OS";
9
10pub const CLIENT_HEADER: &str = "X-Lockbook-Client";
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum WireFormat {
14 Json,
15 Bincode,
16}
17
18impl WireFormat {
19 pub const CLIENT_DEFAULT: Self = WireFormat::Bincode;
20
21 pub fn as_str(self) -> &'static str {
22 match self {
23 WireFormat::Json => "json",
24 WireFormat::Bincode => "bincode",
25 }
26 }
27
28 pub fn from_header(value: Option<&str>) -> Self {
29 match value.map(|v| v.trim()) {
30 Some(v) if v.eq_ignore_ascii_case("bincode") => WireFormat::Bincode,
31 _ => WireFormat::Json,
32 }
33 }
34
35 pub fn serialize<T: Serialize + ?Sized>(self, value: &T) -> Result<Vec<u8>, WireError> {
36 match self {
37 WireFormat::Json => {
38 serde_json::to_vec(value).map_err(|e| WireError::Serialize(e.to_string()))
39 }
40 WireFormat::Bincode => {
41 bincode::serialize(value).map_err(|e| WireError::Serialize(e.to_string()))
42 }
43 }
44 }
45
46 pub fn deserialize<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T, WireError> {
47 match self {
48 WireFormat::Json => {
49 serde_json::from_slice(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
50 }
51 WireFormat::Bincode => {
52 bincode::deserialize(bytes).map_err(|e| WireError::Deserialize(e.to_string()))
53 }
54 }
55 }
56}
57
58#[derive(Debug)]
59pub enum WireError {
60 Serialize(String),
61 Deserialize(String),
62}
63
64impl std::fmt::Display for WireError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 WireError::Serialize(e) => write!(f, "serialize: {e}"),
68 WireError::Deserialize(e) => write!(f, "deserialize: {e}"),
69 }
70 }
71}
72
73impl std::error::Error for WireError {}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn from_header_handles_common_inputs() {
81 assert_eq!(WireFormat::from_header(None), WireFormat::Json);
82 assert_eq!(WireFormat::from_header(Some("")), WireFormat::Json);
83 assert_eq!(WireFormat::from_header(Some("json")), WireFormat::Json);
84 assert_eq!(WireFormat::from_header(Some("bincode")), WireFormat::Bincode);
85 assert_eq!(WireFormat::from_header(Some("BINCODE")), WireFormat::Bincode);
86 assert_eq!(WireFormat::from_header(Some(" bincode ")), WireFormat::Bincode);
87 assert_eq!(WireFormat::from_header(Some("msgpack")), WireFormat::Json);
89 }
90
91 #[test]
92 fn roundtrip_byte_heavy_payload() {
93 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
94 struct Doc {
95 #[serde(with = "serde_bytes")]
96 bytes: Vec<u8>,
97 }
98
99 let doc = Doc { bytes: vec![0u8, 1, 2, 200, 255] };
100
101 for fmt in [WireFormat::Json, WireFormat::Bincode] {
102 let encoded = fmt.serialize(&doc).unwrap();
103 let decoded: Doc = fmt.deserialize(&encoded).unwrap();
104 assert_eq!(doc, decoded, "{fmt:?} roundtrip mismatch");
105 }
106 }
107
108 #[test]
109 fn bincode_is_more_compact_for_bytes_than_json() {
110 #[derive(serde::Serialize, serde::Deserialize)]
111 struct Doc {
112 #[serde(with = "serde_bytes")]
113 bytes: Vec<u8>,
114 }
115
116 let doc = Doc { bytes: (0u8..=255).collect() };
117 let json_len = WireFormat::Json.serialize(&doc).unwrap().len();
118 let bincode_len = WireFormat::Bincode.serialize(&doc).unwrap().len();
119 assert!(
120 bincode_len * 3 < json_len,
121 "bincode={bincode_len} json={json_len} — bincode should be much smaller"
122 );
123 }
124}