1use crate::error::{BraidError, Result};
4use crate::types::Patch;
5use bytes::{Buf, Bytes, BytesMut};
6use once_cell::sync::Lazy;
7use regex::Regex;
8use std::collections::BTreeMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ParseState {
12 WaitingForHeaders,
13 ParsingHeaders,
14 WaitingForBody,
15 WaitingForPatchHeaders,
16 WaitingForPatchBody,
17 SkippingSeparator,
18 Complete,
19 Error,
20}
21
22#[derive(Debug)]
23pub struct MessageParser {
24 buffer: BytesMut,
25 state: ParseState,
26 headers: BTreeMap<String, String>,
27 body_buffer: BytesMut,
28 expected_body_length: usize,
29 read_body_length: usize,
30 patches: Vec<Patch>,
31 expected_patches: usize,
32 patches_read: usize,
33 patch_headers: BTreeMap<String, String>,
34 expected_patch_length: usize,
35 read_patch_length: usize,
36 is_encoding_block: bool,
37}
38
39static HTTP_STATUS_REGEX: Lazy<Regex> =
40 Lazy::new(|| Regex::new(r"^HTTP/?\d*\.?\d* (\d{3})").unwrap());
41
42static ENCODING_BLOCK_REGEX: Lazy<Regex> =
43 Lazy::new(|| Regex::new(r"(?i)Encoding:\s*(\w+)\r?\nLength:\s*(\d+)\r?\n").unwrap());
44
45impl MessageParser {
46 pub fn new() -> Self {
47 MessageParser {
48 buffer: BytesMut::with_capacity(8192),
49 state: ParseState::WaitingForHeaders,
50 headers: BTreeMap::new(),
51 body_buffer: BytesMut::new(),
52 expected_body_length: 0,
53 read_body_length: 0,
54 patches: Vec::new(),
55 expected_patches: 0,
56 patches_read: 0,
57 patch_headers: BTreeMap::new(),
58 expected_patch_length: 0,
59 read_patch_length: 0,
60 is_encoding_block: false,
61 }
62 }
63
64 pub fn new_with_state(headers: BTreeMap<String, String>, content_length: usize) -> Self {
65 let mut parser = MessageParser::new();
66 parser.headers = headers;
67 parser.expected_body_length = content_length;
68 if content_length > 0 {
69 parser.state = ParseState::WaitingForBody;
70 } else {
71 parser.state = ParseState::WaitingForBody;
74 }
75 parser
76 }
77
78 pub fn feed(&mut self, data: &[u8]) -> Result<Vec<Message>> {
79 self.buffer.extend_from_slice(data);
80 let mut messages = Vec::new();
81
82 loop {
83 match self.state {
84 ParseState::WaitingForHeaders => {
85 while !self.buffer.is_empty()
86 && (self.buffer[0] == b'\r' || self.buffer[0] == b'\n')
87 {
88 self.buffer.advance(1);
89 }
90
91 if self.buffer.is_empty() {
92 break;
93 }
94
95 if self.check_encoding_block()? {
96 self.state = ParseState::WaitingForBody;
97 continue;
98 }
99
100 if let Some(pos) = self.find_header_end() {
101 self.parse_headers(pos)?;
102 self.state = ParseState::WaitingForBody;
103 } else {
104 break;
105 }
106 }
107 ParseState::WaitingForBody => {
108 if self.expected_patches > 0 {
109 self.state = ParseState::WaitingForPatchHeaders;
110 } else if self.try_parse_body()? {
111 if let Some(msg) = self.finalize_message()? {
112 messages.push(msg);
113 }
114 self.reset();
115 self.state = ParseState::WaitingForHeaders;
116 } else {
117 break;
118 }
119 }
120 ParseState::WaitingForPatchHeaders => {
121 if let Some(pos) = self.find_header_end() {
122 self.parse_patch_headers(pos)?;
123 self.state = ParseState::WaitingForPatchBody;
124 } else {
125 break;
126 }
127 }
128 ParseState::WaitingForPatchBody => {
129 if self.try_parse_patch_body()? {
130 self.patches_read += 1;
131 if self.patches_read < self.expected_patches {
132 self.state = ParseState::SkippingSeparator;
133 } else {
134 if let Some(msg) = self.finalize_message()? {
135 messages.push(msg);
136 }
137 self.reset();
138 self.state = ParseState::WaitingForHeaders;
139 }
140 } else {
141 break;
142 }
143 }
144 ParseState::SkippingSeparator => {
145 if self.buffer.len() >= 2 {
146 if &self.buffer[..2] == b"\r\n" {
147 self.buffer.advance(2);
148 } else if self.buffer[0] == b'\n' {
149 self.buffer.advance(1);
150 }
151 self.state = ParseState::WaitingForPatchHeaders;
152 } else if self.buffer.len() == 1 && self.buffer[0] == b'\n' {
153 self.buffer.advance(1);
154 self.state = ParseState::WaitingForPatchHeaders;
155 } else {
156 break;
157 }
158 }
159 _ => break,
160 }
161 }
162 Ok(messages)
163 }
164
165 fn check_encoding_block(&mut self) -> Result<bool> {
166 if self.buffer.is_empty() || (self.buffer[0] != b'E' && self.buffer[0] != b'e') {
167 return Ok(false);
168 }
169
170 if let Some(end) = self.find_double_newline() {
171 let header_bytes = &self.buffer[..end];
172 let header_str = std::str::from_utf8(header_bytes).map_err(|e| {
173 BraidError::Protocol(format!("Invalid encoding block UTF-8: {}", e))
174 })?;
175
176 if let Some(caps) = ENCODING_BLOCK_REGEX.captures(header_str) {
177 let encoding = caps.get(1).unwrap().as_str().to_string();
178 let length: usize = caps.get(2).unwrap().as_str().parse().map_err(|_| {
179 BraidError::Protocol("Invalid length in encoding block".to_string())
180 })?;
181
182 let _ = self.buffer.split_to(end);
183 self.headers.insert("encoding".to_string(), encoding);
184 self.headers
185 .insert("length".to_string(), length.to_string());
186 self.expected_body_length = length;
187 self.is_encoding_block = true;
188 return Ok(true);
189 }
190 }
191 Ok(false)
192 }
193
194 fn find_double_newline(&self) -> Option<usize> {
195 if let Some(pos) = self.buffer.windows(4).position(|w| w == b"\r\n\r\n") {
196 return Some(pos + 4);
197 }
198 if let Some(pos) = self.buffer.windows(2).position(|w| w == b"\n\n") {
199 return Some(pos + 2);
200 }
201 None
202 }
203
204 fn find_header_end(&self) -> Option<usize> {
205 self.buffer
206 .windows(4)
207 .position(|w| w == b"\r\n\r\n")
208 .map(|p| p + 4)
209 }
210
211 fn parse_headers(&mut self, end: usize) -> Result<()> {
212 let header_bytes = self.buffer.split_to(end);
213 let mut header_str = String::from_utf8(header_bytes[..header_bytes.len() - 4].to_vec())?;
214
215 if let Some(caps) = HTTP_STATUS_REGEX.captures(&header_str) {
216 if let Some(status_match) = caps.get(1) {
217 let status = status_match.as_str();
218 if let Some(first_newline) = header_str.find('\n') {
219 let replacement = format!(":status: {}\r", status);
220 header_str = replacement + &header_str[first_newline..];
221 }
222 }
223 }
224
225 for line in header_str.lines() {
226 if let Some(colon_pos) = line.find(':') {
227 let key = line[..colon_pos].trim().to_lowercase();
228 let value = line[colon_pos + 1..].trim().to_string();
229 self.headers.insert(key, value);
230 }
231 }
232
233 if let Some(patches_str) = self.headers.get("patches") {
234 self.expected_patches = patches_str.parse().unwrap_or(0);
235 }
236
237 if let Some(len_str) = self
238 .headers
239 .get("content-length")
240 .or_else(|| self.headers.get("length"))
241 {
242 self.expected_body_length = len_str.parse().map_err(|_| {
243 BraidError::HeaderParse(format!("Invalid content-length: {}", len_str))
244 })?;
245 }
246 Ok(())
247 }
248
249 fn parse_patch_headers(&mut self, end: usize) -> Result<()> {
250 let header_bytes = self.buffer.split_to(end);
251 let header_str = String::from_utf8(header_bytes[..header_bytes.len() - 4].to_vec())?;
252
253 self.patch_headers.clear();
254 for line in header_str.lines() {
255 if let Some(colon_pos) = line.find(':') {
256 let key = line[..colon_pos].trim().to_lowercase();
257 let value = line[colon_pos + 1..].trim().to_string();
258 self.patch_headers.insert(key, value);
259 }
260 }
261
262 if let Some(len_str) = self.patch_headers.get("content-length") {
263 self.expected_patch_length = len_str.parse().map_err(|_| {
264 BraidError::HeaderParse(format!("Invalid patch content-length: {}", len_str))
265 })?;
266 } else {
267 return Err(BraidError::Protocol(
268 "Every patch MUST include Content-Length".to_string(),
269 ));
270 }
271
272 self.read_patch_length = 0;
273 Ok(())
274 }
275
276 fn try_parse_patch_body(&mut self) -> Result<bool> {
277 let remaining = self.expected_patch_length - self.read_patch_length;
278 if self.buffer.len() >= remaining {
279 let body_chunk = self.buffer.split_to(remaining);
280 let unit = self
281 .patch_headers
282 .get("content-range")
283 .and_then(|cr| cr.split_whitespace().next())
284 .unwrap_or("bytes")
285 .to_string();
286 let range = self
287 .patch_headers
288 .get("content-range")
289 .and_then(|cr| cr.split_whitespace().nth(1))
290 .unwrap_or("")
291 .to_string();
292 let patch = Patch::with_length(unit, range, body_chunk, self.expected_patch_length);
293 self.patches.push(patch);
294 self.read_patch_length += remaining;
295 Ok(true)
296 } else {
297 Ok(false)
298 }
299 }
300
301 fn try_parse_body(&mut self) -> Result<bool> {
302 if self.expected_body_length == 0 {
303 return Ok(true);
304 }
305 let remaining = self.expected_body_length - self.read_body_length;
306 if self.buffer.len() >= remaining {
307 let body_chunk = self.buffer.split_to(remaining);
308 self.body_buffer.extend_from_slice(&body_chunk);
309 self.read_body_length += body_chunk.len();
310 Ok(true)
311 } else {
312 let chunk_len = self.buffer.len();
313 self.body_buffer
314 .extend_from_slice(&self.buffer.split_to(chunk_len));
315 self.read_body_length += chunk_len;
316 Ok(false)
317 }
318 }
319
320 fn finalize_message(&mut self) -> Result<Option<Message>> {
321 let body = self.body_buffer.split().freeze();
322 let headers = std::mem::take(&mut self.headers);
323 let url = headers.get("content-location").cloned();
324 let encoding = headers.get("encoding").cloned();
325
326 Ok(Some(Message {
327 headers,
328 body,
329 patches: std::mem::take(&mut self.patches),
330 status_code: None,
331 encoding,
332 url,
333 }))
334 }
335
336 fn reset(&mut self) {
337 self.headers.clear();
338 self.body_buffer.clear();
339 self.expected_body_length = 0;
340 self.read_body_length = 0;
341 self.patches.clear();
342 self.expected_patches = 0;
343 self.patches_read = 0;
344 self.patch_headers.clear();
345 self.expected_patch_length = 0;
346 self.read_patch_length = 0;
347 self.is_encoding_block = false;
348 }
349
350 pub fn state(&self) -> ParseState {
351 self.state
352 }
353 pub fn headers(&self) -> &BTreeMap<String, String> {
354 &self.headers
355 }
356 pub fn body(&self) -> &[u8] {
357 &self.body_buffer
358 }
359}
360
361impl Default for MessageParser {
362 fn default() -> Self {
363 Self::new()
364 }
365}
366
367#[derive(Debug, Clone)]
368pub struct Message {
369 pub headers: BTreeMap<String, String>,
370 pub body: Bytes,
371 pub patches: Vec<Patch>,
372 pub status_code: Option<u16>,
373 pub encoding: Option<String>,
374 pub url: Option<String>,
375}
376
377impl Message {
378 pub fn status(&self) -> Option<u16> {
379 self.status_code
380 .or_else(|| self.headers.get(":status").and_then(|v| v.parse().ok()))
381 }
382
383 pub fn version(&self) -> Option<&str> {
384 self.headers.get("version").map(|s| s.as_str())
385 }
386 pub fn current_version(&self) -> Option<&str> {
387 self.headers.get("current-version").map(|s| s.as_str())
388 }
389 pub fn parents(&self) -> Option<&str> {
390 self.headers.get("parents").map(|s| s.as_str())
391 }
392
393 pub fn decode_body(&self) -> Result<Bytes> {
394 match self.encoding.as_deref() {
395 Some("dt") => Ok(self.body.clone()),
396 Some(enc) => Err(BraidError::Protocol(format!("Unknown encoding: {}", enc))),
397 None => Ok(self.body.clone()),
398 }
399 }
400
401 pub fn extra_headers(&self) -> BTreeMap<String, String> {
402 const KNOWN_HEADERS: &[&str] = &[
403 "version",
404 "parents",
405 "current-version",
406 "patches",
407 "content-length",
408 "content-range",
409 ":status",
410 ];
411 self.headers
412 .iter()
413 .filter(|(k, _)| !KNOWN_HEADERS.contains(&k.as_str()))
414 .map(|(k, v)| (k.clone(), v.clone()))
415 .collect()
416 }
417
418 pub fn body_text(&self) -> Option<String> {
419 std::str::from_utf8(&self.body).ok().map(|s| s.to_string())
420 }
421}
422
423pub fn parse_status_line(line: &str) -> Option<u16> {
424 let parts: Vec<&str> = line.split_whitespace().collect();
425 if parts.len() >= 2 && parts[0].to_uppercase().starts_with("HTTP") {
426 parts[1].parse().ok()
427 } else {
428 None
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_parser_creation() {
438 let parser = MessageParser::new();
439 assert_eq!(parser.state(), ParseState::WaitingForHeaders);
440 }
441
442 #[test]
443 fn test_simple_message_parsing() {
444 let mut parser = MessageParser::new();
445 let data = b"Content-Length: 5\r\n\r\nHello";
446 let messages = parser.feed(data).unwrap();
447 assert!(!messages.is_empty());
448 assert_eq!(messages[0].body, Bytes::from_static(b"Hello"));
449 }
450
451 #[test]
452 fn test_parse_status_line() {
453 assert_eq!(parse_status_line("HTTP/1.1 200 OK"), Some(200));
454 assert_eq!(parse_status_line("HTTP 209 Subscription"), Some(209));
455 assert_eq!(parse_status_line("HTTP/2 404"), Some(404));
456 }
457
458 #[test]
459 fn test_message_extra_headers() {
460 let mut headers = BTreeMap::new();
461 headers.insert("version".to_string(), "\"v1\"".to_string());
462 headers.insert("x-custom-header".to_string(), "value".to_string());
463
464 let msg = Message {
465 headers,
466 body: Bytes::new(),
467 patches: vec![],
468 status_code: None,
469 encoding: None,
470 url: None,
471 };
472
473 let extra = msg.extra_headers();
474 assert_eq!(extra.len(), 1);
475 assert!(extra.contains_key("x-custom-header"));
476 assert!(!extra.contains_key("version"));
477 }
478
479 #[test]
480 fn test_multi_patch_parsing() {
481 let mut parser = MessageParser::new();
482 let data = b"Patches: 2\r\n\r\n\
483 Content-Length: 5\r\n\
484 Content-Range: json .a\r\n\r\n\
485 hello\r\n\
486 Content-Length: 5\r\n\
487 Content-Range: json .b\r\n\r\n\
488 world\r\n";
489
490 let messages = parser.feed(data).unwrap();
491 assert!(!messages.is_empty());
492 let msg = &messages[0];
493 assert_eq!(msg.patches.len(), 2);
494 }
495}