highflame_shield/stream.rs
1//! SSE streaming support for `POST /v1/guard/stream`.
2
3use crate::ShieldError;
4use async_stream::stream;
5use futures_util::StreamExt as _;
6use std::collections::HashMap;
7
8/// A single server-sent event from `POST /v1/guard/stream`.
9#[derive(Debug, Clone)]
10pub struct ShieldStreamEvent {
11 /// Event type: `"detection"`, `"decision"`, `"error"`, or `"done"`.
12 pub r#type: String,
13 /// Parsed JSON payload.
14 pub data: HashMap<String, serde_json::Value>,
15}
16
17/// Parse a raw SSE response body into a stream of [`ShieldStreamEvent`]s.
18///
19/// Implements the SSE wire protocol (RFC 8895) manually so the SDK stays
20/// dependency-free beyond `reqwest` and `async-stream`. Handles:
21/// - Multi-chunk delivery (lines that span chunk boundaries)
22/// - Multiple events per chunk
23/// - `event:` / `data:` field pairs
24/// - `id:`, `retry:`, and comment lines are silently ignored
25pub(crate) fn parse_sse_response(
26 response: reqwest::Response,
27) -> impl futures_util::Stream<Item = Result<ShieldStreamEvent, ShieldError>> {
28 stream! {
29 let mut byte_stream = response.bytes_stream();
30 // Accumulates bytes across chunk boundaries until a complete line is seen.
31 let mut line_buf = String::new();
32 // SSE parser state for the current event block.
33 let mut current_event: Option<String> = None;
34 let mut current_data: Option<String> = None;
35
36 while let Some(chunk) = byte_stream.next().await {
37 let chunk = match chunk {
38 Ok(c) => c,
39 Err(e) => {
40 yield Err(ShieldError::Connection(e.to_string()));
41 return;
42 }
43 };
44
45 line_buf.push_str(&String::from_utf8_lossy(&chunk));
46
47 // Drain complete lines from the buffer.
48 loop {
49 match line_buf.find('\n') {
50 None => break,
51 Some(pos) => {
52 let raw = line_buf[..pos].trim_end_matches('\r').to_string();
53 line_buf.drain(..=pos);
54
55 if raw.is_empty() {
56 // Blank line — dispatch the accumulated event.
57 if let Some(data) = current_data.take() {
58 let event_type = current_event
59 .take()
60 .unwrap_or_else(|| "detection".to_string());
61 match serde_json::from_str::<HashMap<String, serde_json::Value>>(
62 &data,
63 ) {
64 Ok(parsed) => yield Ok(ShieldStreamEvent {
65 r#type: event_type,
66 data: parsed,
67 }),
68 Err(e) => {
69 yield Err(ShieldError::Deserialisation(e));
70 return;
71 }
72 }
73 } else {
74 current_event = None;
75 }
76 } else if let Some(rest) = raw.strip_prefix("data:") {
77 current_data = Some(rest.trim().to_string());
78 } else if let Some(rest) = raw.strip_prefix("event:") {
79 current_event = Some(rest.trim().to_string());
80 }
81 // Ignore id:, retry:, and comment lines (:...).
82 }
83 }
84 }
85 }
86
87 // Flush any unterminated trailing message (server closed without blank line).
88 if let Some(data) = current_data.take() {
89 let event_type = current_event
90 .take()
91 .unwrap_or_else(|| "detection".to_string());
92 match serde_json::from_str::<HashMap<String, serde_json::Value>>(&data) {
93 Ok(parsed) => yield Ok(ShieldStreamEvent { r#type: event_type, data: parsed }),
94 Err(e) => yield Err(ShieldError::Deserialisation(e)),
95 }
96 }
97 }
98}