oxi_agent/agent_loop/
streaming.rs1use anyhow::{Error, Result};
9use futures::StreamExt;
10use oxi_ai::{
11 ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
12};
13use std::collections::HashSet;
14
15pub(crate) async fn stream_assistant_response(
16 loop_ref: &super::AgentLoop,
17 messages: &mut Vec<Message>,
18 emit: &super::EmitFn,
19) -> Result<oxi_ai::AssistantMessage> {
20 let model = loop_ref.resolve_model()?;
21
22 let mut context = Context::new();
23
24 if let Some(ref system_prompt) = loop_ref.config.system_prompt {
25 context.set_system_prompt(system_prompt.clone());
26 }
27
28 for msg in messages.iter() {
29 context.add_message(msg.clone());
30 }
31
32 let tool_defs = loop_ref.tools.definitions();
34 if !tool_defs.is_empty() {
35 let mut oxi_tools = Vec::with_capacity(tool_defs.len());
36 for def in &tool_defs {
37 let schema = serde_json::to_value(&def.input_schema)
38 .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
39 oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
40 }
41 context.set_tools(oxi_tools);
42 }
43
44 let stream_options = StreamOptions {
45 temperature: Some(loop_ref.config.temperature as f64),
46 max_tokens: Some(loop_ref.config.max_tokens as usize),
47 api_key: loop_ref.config.api_key.clone(),
48 ..Default::default()
49 };
50
51 let stream =
52 super::retry::stream_with_retry(loop_ref, &model, &context, Some(stream_options), emit)
53 .await?;
54
55 let mut added_partial = false;
59 let mut event_count = 0u32;
60
61 let mut rx = stream;
62 while let Some(event) = rx.next().await {
63 event_count += 1;
64 match event {
65 ProviderEvent::Start { partial } => {
66 tracing::info!("Stream event #{}: Start", event_count);
67 messages.push(Message::Assistant(partial));
68 added_partial = true;
69 emit(super::AgentEvent::MessageStart {
70 message: messages.last().expect("non-empty after push").clone(),
71 });
72 }
73
74 ProviderEvent::TextDelta { delta, partial, .. } => {
75 if added_partial {
78 let last_idx = messages.len() - 1;
79 if let Message::Assistant(ref mut m) = messages[last_idx] {
80 *m = partial;
81 }
82 }
83 let last_msg = messages.last().expect("non-empty").clone();
84 emit(super::AgentEvent::MessageUpdate {
85 message: last_msg,
86 delta: Some(delta),
87 });
88 }
89
90 ProviderEvent::ThinkingStart { partial, .. }
91 if added_partial => {
94 let last_idx = messages.len() - 1;
95 if let Message::Assistant(ref mut m) = messages[last_idx] {
96 *m = partial;
97 }
98 }
99
100 ProviderEvent::ThinkingDelta { delta, partial, .. } => {
101 if added_partial {
102 let last_idx = messages.len() - 1;
103 if let Message::Assistant(ref mut m) = messages[last_idx] {
104 *m = partial;
105 }
106 }
107 let last_msg = messages.last().expect("non-empty").clone();
108 emit(super::AgentEvent::MessageUpdate {
109 message: last_msg,
110 delta: Some(delta),
111 });
112 }
113
114 ProviderEvent::ToolCallStart { partial, .. }
115 if added_partial => {
116 let last_idx = messages.len() - 1;
117 if let Message::Assistant(ref mut m) = messages[last_idx] {
118 *m = partial;
119 }
120 }
121
122 ProviderEvent::ToolCallDelta { partial, .. }
123 if added_partial => {
124 let last_idx = messages.len() - 1;
125 if let Message::Assistant(ref mut m) = messages[last_idx] {
126 *m = partial;
127 }
128 }
129
130 ProviderEvent::ToolCallEnd { tool_call, .. }
131 if added_partial => {
133 let last_idx = messages.len() - 1;
134 if let Message::Assistant(ref mut m) = messages[last_idx] {
135 m.content.push(ContentBlock::ToolCall(tool_call));
136 }
137 let last_msg = messages.last().expect("non-empty").clone();
141 emit(super::AgentEvent::MessageUpdate {
142 message: last_msg,
143 delta: None,
144 });
145 }
146
147 ProviderEvent::Done { message, .. } => {
148 tracing::info!(
149 "Stream event #{}: Done (stop_reason={:?})",
150 event_count,
151 message.stop_reason
152 );
153 if added_partial {
154 let last_idx = messages.len() - 1;
155 if let Message::Assistant(ref mut m) = messages[last_idx] {
156 let mut preserved_tool_calls: Vec<ContentBlock> = m
160 .content
161 .drain(..)
162 .filter(|b| matches!(b, ContentBlock::ToolCall(_)))
163 .collect();
164
165 let mut seen: HashSet<String> = message
166 .content
167 .iter()
168 .filter_map(|b| match b {
169 ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
170 _ => None,
171 })
172 .collect();
173
174 preserved_tool_calls.retain(|b| match b {
175 ContentBlock::ToolCall(tc) => seen.insert(tc.id.clone()),
176 _ => true,
177 });
178
179 tracing::info!(
180 "Done: preserving {} tool_calls (deduped), Done message has {} content blocks",
181 preserved_tool_calls.len(),
182 message.content.len()
183 );
184
185 *m = message.clone();
186 m.content.extend(preserved_tool_calls);
187 tracing::info!(
188 "Done: final message has {} content blocks, stop_reason={:?}",
189 m.content.len(),
190 m.stop_reason
191 );
192 }
193 } else {
194 messages.push(Message::Assistant(message.clone()));
195 }
196 let last_msg = messages.last().expect("non-empty").clone();
197 emit(super::AgentEvent::MessageEnd {
198 message: last_msg.clone(),
199 });
200 if let Message::Assistant(m) = &last_msg {
202 return Ok(m.clone());
203 } else {
204 return Ok(message);
205 }
206 }
207
208 ProviderEvent::Error { mut error, .. } => {
209 tracing::info!("Stream event #{}: Error", event_count);
210 let raw_msg = error.text_content();
211 let friendly = if raw_msg.is_empty() {
212 "Unknown provider error".to_string()
213 } else {
214 raw_msg
215 };
216 tracing::error!(session_id = ?loop_ref.session_id, "Provider stream error: {}", friendly);
217
218 error.stop_reason = StopReason::Error;
219
220 if added_partial {
221 let last_idx = messages.len() - 1;
222 if let Message::Assistant(ref mut m) = messages[last_idx] {
223 *m = error.clone();
224 }
225 } else {
226 messages.push(Message::Assistant(error.clone()));
227 }
228
229 emit(super::AgentEvent::MessageEnd {
230 message: Message::Assistant(error.clone()),
231 });
232 emit(super::AgentEvent::Error {
233 message: format!("⚠ {}", friendly),
234 session_id: loop_ref.session_id.clone(),
235 });
236
237 return Ok(error);
238 }
239
240 _ => {}
241 }
242 }
243
244 tracing::info!("Stream ended after {} events", event_count);
245
246 let final_message = messages
247 .last()
248 .and_then(|m| match m {
249 Message::Assistant(a) => Some(a.clone()),
250 _ => None,
251 })
252 .ok_or_else(|| Error::msg("No assistant message in context"))?;
253
254 emit(super::AgentEvent::MessageEnd {
255 message: Message::Assistant(final_message.clone()),
256 });
257 Ok(final_message)
258}