Skip to main content

cersei_provider/
stream.rs

1//! Stream accumulator: collects SSE stream events into a complete response.
2
3use cersei_types::*;
4use std::collections::HashMap;
5
6/// Accumulates streaming events into content blocks.
7pub struct StreamAccumulator {
8    content_blocks: Vec<ContentBlock>,
9    partial_text: HashMap<usize, String>,
10    partial_json: HashMap<usize, String>,
11    partial_thinking: HashMap<usize, String>,
12    block_types: HashMap<usize, String>,
13    tool_use_ids: HashMap<usize, String>,
14    tool_use_names: HashMap<usize, String>,
15    stop_reason: Option<StopReason>,
16    usage: Usage,
17    model: Option<String>,
18    message_id: Option<String>,
19}
20
21impl StreamAccumulator {
22    pub fn new() -> Self {
23        Self {
24            content_blocks: Vec::new(),
25            partial_text: HashMap::new(),
26            partial_json: HashMap::new(),
27            partial_thinking: HashMap::new(),
28            block_types: HashMap::new(),
29            tool_use_ids: HashMap::new(),
30            tool_use_names: HashMap::new(),
31            stop_reason: None,
32            usage: Usage::default(),
33            model: None,
34            message_id: None,
35        }
36    }
37
38    pub fn process_event(&mut self, event: StreamEvent) {
39        match event {
40            StreamEvent::MessageStart { id, model } => {
41                self.message_id = Some(id);
42                self.model = Some(model);
43            }
44            StreamEvent::ContentBlockStart {
45                index,
46                block_type,
47                id,
48                name,
49            } => {
50                self.block_types.insert(index, block_type);
51                if let Some(id) = id {
52                    self.tool_use_ids.insert(index, id);
53                }
54                if let Some(name) = name {
55                    self.tool_use_names.insert(index, name);
56                }
57            }
58            StreamEvent::TextDelta { index, text } => {
59                self.partial_text.entry(index).or_default().push_str(&text);
60            }
61            StreamEvent::InputJsonDelta {
62                index,
63                partial_json,
64            } => {
65                self.partial_json
66                    .entry(index)
67                    .or_default()
68                    .push_str(&partial_json);
69            }
70            StreamEvent::ThinkingDelta { index, thinking } => {
71                self.partial_thinking
72                    .entry(index)
73                    .or_default()
74                    .push_str(&thinking);
75            }
76            StreamEvent::ContentBlockStop { index } => {
77                let block_type = self.block_types.get(&index).cloned().unwrap_or_default();
78                let block = match block_type.as_str() {
79                    "text" => ContentBlock::Text {
80                        text: self.partial_text.remove(&index).unwrap_or_default(),
81                    },
82                    "tool_use" => {
83                        let json_str = self.partial_json.remove(&index).unwrap_or_default();
84                        let input =
85                            serde_json::from_str(&json_str).unwrap_or(serde_json::Value::Null);
86                        ContentBlock::ToolUse {
87                            id: self.tool_use_ids.remove(&index).unwrap_or_default(),
88                            name: self.tool_use_names.remove(&index).unwrap_or_default(),
89                            input,
90                        }
91                    }
92                    "thinking" => ContentBlock::Thinking {
93                        thinking: self.partial_thinking.remove(&index).unwrap_or_default(),
94                        signature: String::new(),
95                    },
96                    _ => ContentBlock::Text {
97                        text: self.partial_text.remove(&index).unwrap_or_default(),
98                    },
99                };
100                // Ensure we have enough slots
101                while self.content_blocks.len() <= index {
102                    self.content_blocks.push(ContentBlock::Text {
103                        text: String::new(),
104                    });
105                }
106                self.content_blocks[index] = block;
107            }
108            StreamEvent::MessageDelta { stop_reason, usage } => {
109                if let Some(sr) = stop_reason {
110                    self.stop_reason = Some(sr);
111                }
112                if let Some(u) = usage {
113                    self.usage.merge(&u);
114                }
115            }
116            StreamEvent::MessageStop => {}
117            StreamEvent::Ping => {}
118            StreamEvent::Error { .. } => {}
119        }
120    }
121
122    pub fn into_response(self) -> Result<super::CompletionResponse> {
123        let message = Message {
124            role: Role::Assistant,
125            content: if self.content_blocks.is_empty() {
126                MessageContent::Text(String::new())
127            } else {
128                MessageContent::Blocks(self.content_blocks)
129            },
130            id: self.message_id,
131            metadata: Some(MessageMetadata {
132                model: self.model,
133                usage: Some(self.usage.clone()),
134                stop_reason: self.stop_reason.clone(),
135                provider_data: serde_json::Value::Null,
136            }),
137        };
138
139        Ok(super::CompletionResponse {
140            message,
141            usage: self.usage,
142            stop_reason: self.stop_reason.unwrap_or(StopReason::EndTurn),
143        })
144    }
145
146    /// Get accumulated text so far (for streaming display).
147    pub fn current_text(&self) -> String {
148        self.partial_text.values().cloned().collect()
149    }
150}
151
152impl Default for StreamAccumulator {
153    fn default() -> Self {
154        Self::new()
155    }
156}