cersei_provider/
stream.rs1use cersei_types::*;
4use std::collections::HashMap;
5
6pub struct StreamAccumulator {
8 content_blocks: Vec<ContentBlock>,
9 partial_text: HashMap<usize, String>,
10 partial_json: HashMap<usize, String>,
11 partial_thinking: HashMap<usize, String>,
12 block_types: HashMap<usize, String>,
13 tool_use_ids: HashMap<usize, String>,
14 tool_use_names: HashMap<usize, String>,
15 stop_reason: Option<StopReason>,
16 usage: Usage,
17 model: Option<String>,
18 message_id: Option<String>,
19}
20
21impl StreamAccumulator {
22 pub fn new() -> Self {
23 Self {
24 content_blocks: Vec::new(),
25 partial_text: HashMap::new(),
26 partial_json: HashMap::new(),
27 partial_thinking: HashMap::new(),
28 block_types: HashMap::new(),
29 tool_use_ids: HashMap::new(),
30 tool_use_names: HashMap::new(),
31 stop_reason: None,
32 usage: Usage::default(),
33 model: None,
34 message_id: None,
35 }
36 }
37
38 pub fn process_event(&mut self, event: StreamEvent) {
39 match event {
40 StreamEvent::MessageStart { id, model } => {
41 self.message_id = Some(id);
42 self.model = Some(model);
43 }
44 StreamEvent::ContentBlockStart {
45 index,
46 block_type,
47 id,
48 name,
49 } => {
50 self.block_types.insert(index, block_type);
51 if let Some(id) = id {
52 self.tool_use_ids.insert(index, id);
53 }
54 if let Some(name) = name {
55 self.tool_use_names.insert(index, name);
56 }
57 }
58 StreamEvent::TextDelta { index, text } => {
59 self.partial_text.entry(index).or_default().push_str(&text);
60 }
61 StreamEvent::InputJsonDelta {
62 index,
63 partial_json,
64 } => {
65 self.partial_json
66 .entry(index)
67 .or_default()
68 .push_str(&partial_json);
69 }
70 StreamEvent::ThinkingDelta { index, thinking } => {
71 self.partial_thinking
72 .entry(index)
73 .or_default()
74 .push_str(&thinking);
75 }
76 StreamEvent::ContentBlockStop { index } => {
77 let block_type = self.block_types.get(&index).cloned().unwrap_or_default();
78 let block = match block_type.as_str() {
79 "text" => ContentBlock::Text {
80 text: self.partial_text.remove(&index).unwrap_or_default(),
81 },
82 "tool_use" => {
83 let json_str = self.partial_json.remove(&index).unwrap_or_default();
84 let input =
85 serde_json::from_str(&json_str).unwrap_or(serde_json::Value::Null);
86 ContentBlock::ToolUse {
87 id: self.tool_use_ids.remove(&index).unwrap_or_default(),
88 name: self.tool_use_names.remove(&index).unwrap_or_default(),
89 input,
90 }
91 }
92 "thinking" => ContentBlock::Thinking {
93 thinking: self.partial_thinking.remove(&index).unwrap_or_default(),
94 signature: String::new(),
95 },
96 _ => ContentBlock::Text {
97 text: self.partial_text.remove(&index).unwrap_or_default(),
98 },
99 };
100 while self.content_blocks.len() <= index {
102 self.content_blocks.push(ContentBlock::Text {
103 text: String::new(),
104 });
105 }
106 self.content_blocks[index] = block;
107 }
108 StreamEvent::MessageDelta { stop_reason, usage } => {
109 if let Some(sr) = stop_reason {
110 self.stop_reason = Some(sr);
111 }
112 if let Some(u) = usage {
113 self.usage.merge(&u);
114 }
115 }
116 StreamEvent::MessageStop => {}
117 StreamEvent::Ping => {}
118 StreamEvent::Error { .. } => {}
119 }
120 }
121
122 pub fn into_response(self) -> Result<super::CompletionResponse> {
123 let message = Message {
124 role: Role::Assistant,
125 content: if self.content_blocks.is_empty() {
126 MessageContent::Text(String::new())
127 } else {
128 MessageContent::Blocks(self.content_blocks)
129 },
130 id: self.message_id,
131 metadata: Some(MessageMetadata {
132 model: self.model,
133 usage: Some(self.usage.clone()),
134 stop_reason: self.stop_reason.clone(),
135 provider_data: serde_json::Value::Null,
136 }),
137 };
138
139 Ok(super::CompletionResponse {
140 message,
141 usage: self.usage,
142 stop_reason: self.stop_reason.unwrap_or(StopReason::EndTurn),
143 })
144 }
145
146 pub fn current_text(&self) -> String {
148 self.partial_text.values().cloned().collect()
149 }
150}
151
152impl Default for StreamAccumulator {
153 fn default() -> Self {
154 Self::new()
155 }
156}