Skip to main content

braid_http/client/
headers.rs

1//! Braid-specific HTTP header handling.
2
3use crate::error::{BraidError, Result};
4use crate::protocol;
5use crate::types::Version;
6use http::header::{HeaderMap, HeaderValue};
7
8/// Braid-specific HTTP headers for requests and responses.
9#[derive(Clone, Debug, Default)]
10pub struct BraidHeaders {
11    /// Version identifier(s) from `Version` header
12    pub version: Option<Vec<Version>>,
13    /// Parent version(s) from `Parents` header
14    pub parents: Option<Vec<Version>>,
15    /// Current version(s) from `Current-Version` header
16    pub current_version: Option<Vec<Version>>,
17    /// Subscribe header indicating subscription mode
18    pub subscribe: bool,
19    /// Peer identifier from `Peer` header
20    pub peer: Option<String>,
21    /// Heartbeat interval from `Heartbeats` header
22    pub heartbeat: Option<String>,
23    /// Merge type from `Merge-Type` header
24    pub merge_type: Option<String>,
25    /// Number of patches from `Patches` header
26    pub patches_count: Option<usize>,
27    /// Content range from `Content-Range` header
28    pub content_range: Option<String>,
29    /// Retry-After header for backoff guidance
30    pub retry_after: Option<String>,
31    /// Additional non-Braid headers
32    pub extra: std::collections::BTreeMap<String, String>,
33}
34
35impl BraidHeaders {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn with_version(mut self, version: Version) -> Self {
41        let mut versions = self.version.unwrap_or_default();
42        versions.push(version);
43        self.version = Some(versions);
44        self
45    }
46
47    pub fn with_versions(mut self, versions: Vec<Version>) -> Self {
48        self.version = Some(versions);
49        self
50    }
51
52    pub fn with_parent(mut self, parent: Version) -> Self {
53        let mut parents = self.parents.unwrap_or_default();
54        parents.push(parent);
55        self.parents = Some(parents);
56        self
57    }
58
59    pub fn with_parents(mut self, parents: Vec<Version>) -> Self {
60        self.parents = Some(parents);
61        self
62    }
63
64    pub fn with_current_version(mut self, version: Version) -> Self {
65        let mut versions = self.current_version.unwrap_or_default();
66        versions.push(version);
67        self.current_version = Some(versions);
68        self
69    }
70
71    pub fn with_current_versions(mut self, versions: Vec<Version>) -> Self {
72        self.current_version = Some(versions);
73        self
74    }
75
76    pub fn with_subscribe(mut self) -> Self {
77        self.subscribe = true;
78        self
79    }
80
81    pub fn with_merge_type(mut self, merge_type: impl Into<String>) -> Self {
82        self.merge_type = Some(merge_type.into());
83        self
84    }
85
86    pub fn with_content_range(mut self, content_range: impl Into<String>) -> Self {
87        self.content_range = Some(content_range.into());
88        self
89    }
90
91    pub fn with_heartbeat(mut self, interval: String) -> Self {
92        self.heartbeat = Some(interval);
93        self
94    }
95
96    pub fn with_peer(mut self, peer: String) -> Self {
97        self.peer = Some(peer);
98        self
99    }
100
101    /// Convert to HTTP HeaderMap.
102    pub fn to_header_map(&self) -> Result<HeaderMap> {
103        let mut headers = HeaderMap::new();
104
105        if let Some(ref versions) = self.version {
106            let version_str = protocol::format_version_header(versions);
107            headers.insert(
108                "Version",
109                HeaderValue::from_str(&version_str)
110                    .map_err(|e| BraidError::Config(e.to_string()))?,
111            );
112        }
113
114        if let Some(ref parents) = self.parents {
115            let parents_str = protocol::format_version_header(parents);
116            headers.insert(
117                "Parents",
118                HeaderValue::from_str(&parents_str)
119                    .map_err(|e| BraidError::Config(e.to_string()))?,
120            );
121        }
122
123        if let Some(ref current_versions) = self.current_version {
124            let current_version_str = protocol::format_version_header(current_versions);
125            headers.insert(
126                "Current-Version",
127                HeaderValue::from_str(&current_version_str)
128                    .map_err(|e| BraidError::Config(e.to_string()))?,
129            );
130        }
131
132        if self.subscribe {
133            headers.insert("Subscribe", HeaderValue::from_static("true"));
134        }
135
136        if let Some(ref peer) = self.peer {
137            headers.insert(
138                "Peer",
139                HeaderValue::from_str(peer).map_err(|e| BraidError::Config(e.to_string()))?,
140            );
141        }
142
143        if let Some(ref heartbeat) = self.heartbeat {
144            headers.insert(
145                "Heartbeats",
146                HeaderValue::from_str(heartbeat).map_err(|e| BraidError::Config(e.to_string()))?,
147            );
148        }
149
150        if let Some(ref merge_type) = self.merge_type {
151            headers.insert(
152                "Merge-Type",
153                HeaderValue::from_str(merge_type).map_err(|e| BraidError::Config(e.to_string()))?,
154            );
155        }
156
157        if let Some(count) = self.patches_count {
158            headers.insert(
159                "Patches",
160                HeaderValue::from_str(&count.to_string())
161                    .map_err(|e| BraidError::Config(e.to_string()))?,
162            );
163        }
164
165        if let Some(ref content_range) = self.content_range {
166            headers.insert(
167                "Content-Range",
168                HeaderValue::from_str(content_range)
169                    .map_err(|e| BraidError::Config(e.to_string()))?,
170            );
171        }
172
173        Ok(headers)
174    }
175
176    /// Parse from HTTP HeaderMap.
177    pub fn from_header_map(headers: &HeaderMap) -> Result<Self> {
178        let mut braid_headers = BraidHeaders::new();
179        for (name, value) in headers.iter() {
180            let name_lower = name.as_str().to_lowercase();
181            let value_str = value
182                .to_str()
183                .map_err(|_| BraidError::HeaderParse("Invalid header value".to_string()))?;
184
185            match name_lower.as_str() {
186                "version" => {
187                    braid_headers.version = Some(protocol::parse_version_header(value_str)?);
188                }
189                "parents" => {
190                    braid_headers.parents = Some(protocol::parse_version_header(value_str)?);
191                }
192                "current-version" => {
193                    braid_headers.current_version =
194                        Some(protocol::parse_version_header(value_str)?);
195                }
196                "subscribe" => {
197                    braid_headers.subscribe = value_str.to_lowercase() == "true";
198                }
199                "peer" => {
200                    braid_headers.peer = Some(value_str.to_string());
201                }
202                "heartbeats" => {
203                    braid_headers.heartbeat = Some(value_str.to_string());
204                }
205                "merge-type" => {
206                    braid_headers.merge_type = Some(value_str.to_string());
207                }
208                "patches" => {
209                    braid_headers.patches_count = value_str.parse().ok();
210                }
211                "content-range" => {
212                    braid_headers.content_range = Some(value_str.to_string());
213                }
214                "retry-after" => {
215                    braid_headers.retry_after = Some(value_str.to_string());
216                }
217                _ => {
218                    braid_headers
219                        .extra
220                        .insert(name_lower, value_str.to_string());
221                }
222            }
223        }
224
225        Ok(braid_headers)
226    }
227}
228
229/// Utility for parsing Braid protocol headers.
230pub struct HeaderParser;
231
232impl HeaderParser {
233    pub fn parse_version(value: &str) -> Result<Vec<Version>> {
234        protocol::parse_version_header(value)
235    }
236
237    pub fn parse_content_range(value: &str) -> Result<(String, String)> {
238        protocol::parse_content_range(value)
239    }
240
241    pub fn format_version(versions: &[Version]) -> String {
242        protocol::format_version_header(versions)
243    }
244
245    pub fn format_content_range(unit: &str, range: &str) -> String {
246        protocol::format_content_range(unit, range)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_braid_headers_to_map() {
256        let headers = BraidHeaders::new()
257            .with_version(Version::String("v1".to_string()))
258            .with_subscribe();
259
260        let map = headers.to_header_map().unwrap();
261        assert!(map.contains_key("Version"));
262        assert!(map.contains_key("Subscribe"));
263    }
264
265    #[test]
266    fn test_parse_version_header() {
267        let result = protocol::parse_version_header("\"v1\", \"v2\", \"v3\"").unwrap();
268        assert_eq!(result.len(), 3);
269    }
270
271    #[test]
272    fn test_parse_content_range() {
273        let (unit, range) = protocol::parse_content_range("json .field").unwrap();
274        assert_eq!(unit, "json");
275        assert_eq!(range, ".field");
276    }
277}