Skip to main content

cortexai_cloudflare/
streaming.rs

1//! Streaming response support for Cloudflare Workers
2
3use cortexai_llm_client::Provider;
4use serde::{Deserialize, Serialize};
5
6/// A chunk from a streaming response
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct StreamChunk {
9    /// Text content delta
10    pub content: String,
11    /// Whether this is the final chunk
12    pub done: bool,
13    /// Token usage (only on final chunk)
14    pub usage: Option<StreamUsage>,
15}
16
17/// Token usage information
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StreamUsage {
20    pub prompt_tokens: u32,
21    pub completion_tokens: u32,
22    pub total_tokens: u32,
23}
24
25/// Iterator over streaming response chunks
26pub struct StreamingResponse {
27    buffer: String,
28    provider: Provider,
29    done: bool,
30}
31
32impl StreamingResponse {
33    /// Create a new streaming response parser
34    pub fn new(provider: Provider) -> Self {
35        Self {
36            buffer: String::new(),
37            provider,
38            done: false,
39        }
40    }
41
42    /// Process incoming data and extract chunks
43    pub fn process(&mut self, data: &str) -> Vec<StreamChunk> {
44        self.buffer.push_str(data);
45        let mut chunks = Vec::new();
46
47        // Process complete SSE lines
48        while let Some(pos) = self.buffer.find("\n\n") {
49            let line = self.buffer[..pos].to_string();
50            self.buffer = self.buffer[pos + 2..].to_string();
51
52            if let Some(chunk) = self.parse_sse_line(&line) {
53                if chunk.done {
54                    self.done = true;
55                }
56                chunks.push(chunk);
57            }
58        }
59
60        chunks
61    }
62
63    /// Check if the stream is complete
64    pub fn is_done(&self) -> bool {
65        self.done
66    }
67
68    /// Parse a single SSE line
69    fn parse_sse_line(&self, line: &str) -> Option<StreamChunk> {
70        // Handle data: prefix
71        let data = line.strip_prefix("data: ")?;
72
73        // Handle [DONE] marker
74        if data.trim() == "[DONE]" {
75            return Some(StreamChunk {
76                content: String::new(),
77                done: true,
78                usage: None,
79            });
80        }
81
82        // Parse JSON based on provider
83        match self.provider {
84            Provider::OpenAI | Provider::OpenRouter => self.parse_openai_chunk(data),
85            Provider::Anthropic => self.parse_anthropic_chunk(data),
86        }
87    }
88
89    /// Parse OpenAI/OpenRouter format chunk
90    fn parse_openai_chunk(&self, data: &str) -> Option<StreamChunk> {
91        #[derive(Deserialize)]
92        struct OpenAiChunk {
93            choices: Vec<OpenAiChoice>,
94            usage: Option<OpenAiUsage>,
95        }
96
97        #[derive(Deserialize)]
98        struct OpenAiChoice {
99            delta: OpenAiDelta,
100            finish_reason: Option<String>,
101        }
102
103        #[derive(Deserialize)]
104        struct OpenAiDelta {
105            content: Option<String>,
106        }
107
108        #[derive(Deserialize)]
109        struct OpenAiUsage {
110            prompt_tokens: u32,
111            completion_tokens: u32,
112            total_tokens: u32,
113        }
114
115        let chunk: OpenAiChunk = serde_json::from_str(data).ok()?;
116        let choice = chunk.choices.first()?;
117
118        let done = choice.finish_reason.is_some();
119        let content = choice.delta.content.clone().unwrap_or_default();
120        let usage = chunk.usage.map(|u| StreamUsage {
121            prompt_tokens: u.prompt_tokens,
122            completion_tokens: u.completion_tokens,
123            total_tokens: u.total_tokens,
124        });
125
126        Some(StreamChunk {
127            content,
128            done,
129            usage,
130        })
131    }
132
133    /// Parse Anthropic format chunk
134    fn parse_anthropic_chunk(&self, data: &str) -> Option<StreamChunk> {
135        #[derive(Deserialize)]
136        struct AnthropicEvent {
137            #[serde(rename = "type")]
138            event_type: String,
139            delta: Option<AnthropicDelta>,
140            usage: Option<AnthropicUsage>,
141        }
142
143        #[derive(Deserialize)]
144        struct AnthropicDelta {
145            #[serde(rename = "type")]
146            _delta_type: Option<String>,
147            text: Option<String>,
148        }
149
150        #[derive(Deserialize)]
151        struct AnthropicUsage {
152            input_tokens: u32,
153            output_tokens: u32,
154        }
155
156        let event: AnthropicEvent = serde_json::from_str(data).ok()?;
157
158        match event.event_type.as_str() {
159            "content_block_delta" => {
160                let text = event.delta.and_then(|d| d.text).unwrap_or_default();
161                Some(StreamChunk {
162                    content: text,
163                    done: false,
164                    usage: None,
165                })
166            }
167            "message_stop" => Some(StreamChunk {
168                content: String::new(),
169                done: true,
170                usage: event.usage.map(|u| StreamUsage {
171                    prompt_tokens: u.input_tokens,
172                    completion_tokens: u.output_tokens,
173                    total_tokens: u.input_tokens + u.output_tokens,
174                }),
175            }),
176            _ => None,
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_openai_chunk_parsing() {
187        let mut stream = StreamingResponse::new(Provider::OpenAI);
188
189        let data = r#"data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}
190
191"#;
192
193        let chunks = stream.process(data);
194        assert_eq!(chunks.len(), 1);
195        assert_eq!(chunks[0].content, "Hello");
196        assert!(!chunks[0].done);
197    }
198
199    #[test]
200    fn test_done_marker() {
201        let mut stream = StreamingResponse::new(Provider::OpenAI);
202
203        let data = "data: [DONE]\n\n";
204        let chunks = stream.process(data);
205
206        assert_eq!(chunks.len(), 1);
207        assert!(chunks[0].done);
208        assert!(stream.is_done());
209    }
210
211    #[test]
212    fn test_anthropic_chunk_parsing() {
213        let mut stream = StreamingResponse::new(Provider::Anthropic);
214
215        let data = r#"data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"Hi"}}
216
217"#;
218
219        let chunks = stream.process(data);
220        assert_eq!(chunks.len(), 1);
221        assert_eq!(chunks[0].content, "Hi");
222        assert!(!chunks[0].done);
223    }
224}