1use futures::StreamExt;
12use oxi_ai::{
13 ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
14};
15use std::collections::HashSet;
16
17use super::helpers::sanitize_orphaned_tool_results;
18use super::stream_outcome::StreamOutcome;
19use super::ttsr::{MatchSource, TtsrEngine, TtsrMatchContext};
20
21pub(crate) async fn stream_assistant_response(
22 loop_ref: &super::AgentLoop,
23 messages: &mut Vec<Message>,
24 emit: &super::EmitFn,
25 ttsr: Option<&TtsrEngine>,
26) -> StreamOutcome {
27 let model = match loop_ref.resolve_model() {
28 Ok(m) => m,
29 Err(_) => {
30 return StreamOutcome::Error {
31 message: oxi_ai::AssistantMessage::new(
32 oxi_ai::Api::OpenAiCompletions,
33 "agent",
34 &loop_ref.config.model_id,
35 ),
36 detail: "Failed to resolve model".to_string(),
37 };
38 }
39 };
40
41 let removed = sanitize_orphaned_tool_results(messages);
45 if removed > 0 {
46 tracing::warn!(
47 session_id = ?loop_ref.session_id,
48 removed,
49 "Sanitized orphaned tool results before streaming"
50 );
51 }
52
53 let mut context = Context::new();
54
55 if let Some(ref system_prompt) = loop_ref.config.system_prompt {
56 context.set_system_prompt(system_prompt.clone());
57 }
58
59 for msg in messages.iter() {
60 context.add_message(msg.clone());
61 }
62
63 let tool_defs = loop_ref.tools.definitions();
64 if !tool_defs.is_empty() {
65 let mut oxi_tools = Vec::with_capacity(tool_defs.len());
66 for def in &tool_defs {
67 let schema = serde_json::to_value(&def.input_schema)
68 .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
69 oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
70 }
71 context.set_tools(oxi_tools);
72 }
73
74 let stream_options = StreamOptions {
75 temperature: Some(loop_ref.config.temperature as f64),
76 max_tokens: Some(loop_ref.config.max_tokens as usize),
77 api_key: loop_ref.config.api_key.clone(),
78 provider_options: loop_ref.config.provider_options.clone(),
79 ..Default::default()
80 };
81
82 let stream = match super::retry::stream_with_retry(
83 loop_ref,
84 &model,
85 &context,
86 Some(stream_options),
87 emit,
88 )
89 .await
90 {
91 Ok(s) => s,
92 Err(e) => {
93 return StreamOutcome::Error {
94 message: oxi_ai::AssistantMessage::new(
95 oxi_ai::Api::OpenAiCompletions,
96 "agent",
97 &loop_ref.config.model_id,
98 ),
99 detail: e.to_string(),
100 };
101 }
102 };
103
104 let mut added_partial = false;
105 let mut event_count = 0u32;
106
107 let mut rx = stream;
108 let stream_idle_timeout = std::time::Duration::from_secs(30);
109 let cancel_check_interval = std::time::Duration::from_millis(500);
110 let mut last_event_at = std::time::Instant::now();
111
112 loop {
113 let next_event = tokio::select! {
114 event = rx.next() => event,
115 _ = tokio::time::sleep(cancel_check_interval) => {
116 if loop_ref.is_cancelled() {
117 tracing::info!(
118 "Stream cancelled (detected in periodic check)"
119 );
120 if added_partial {
121 let last_idx = messages.len() - 1;
122 if let Message::Assistant(ref mut m) = messages[last_idx] {
123 m.stop_reason = StopReason::Aborted;
124 }
125 let last_msg = messages.last().expect("non-empty").clone();
126 emit(super::AgentEvent::MessageEnd {
127 message: last_msg.clone(),
128 });
129 if let Message::Assistant(m) = &last_msg {
130 return StreamOutcome::Cancelled(m.clone());
131 }
132 }
133 return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
134 oxi_ai::Api::OpenAiCompletions,
135 "agent",
136 &loop_ref.config.model_id,
137 ));
138 }
139
140 if last_event_at.elapsed() >= stream_idle_timeout {
141 tracing::warn!(
142 "Stream idle timeout ({:?}) reached after {} events",
143 stream_idle_timeout, event_count
144 );
145 let mut err_asst = oxi_ai::AssistantMessage::new(
146 oxi_ai::Api::OpenAiCompletions,
147 "agent",
148 &loop_ref.config.model_id,
149 );
150 err_asst.stop_reason = StopReason::Error;
151 err_asst.error_message = Some(format!(
152 "Stream timed out after {:?} of inactivity",
153 stream_idle_timeout
154 ));
155 if added_partial {
156 let last_idx = messages.len() - 1;
157 if let Message::Assistant(ref mut m) = messages[last_idx] {
158 m.stop_reason = StopReason::Error;
159 }
160 }
161 emit(super::AgentEvent::MessageEnd {
162 message: Message::Assistant(err_asst.clone()),
163 });
164 emit(super::AgentEvent::Error {
165 message: format!(
166 "Stream timed out after {:?} of inactivity",
167 stream_idle_timeout
168 ),
169 session_id: loop_ref.session_id.clone(),
170 });
171 return StreamOutcome::Error { message: err_asst, detail: format!("Stream timed out after {:?} of inactivity", stream_idle_timeout) };
172 }
173
174 continue;
175 }
176 };
177
178 let event = match next_event {
179 Some(e) => e,
180 None => break,
181 };
182
183 last_event_at = std::time::Instant::now();
184
185 if loop_ref.is_cancelled() {
186 tracing::info!("Stream cancelled after {} events", event_count);
187 if added_partial {
188 let last_idx = messages.len() - 1;
189 if let Message::Assistant(ref mut m) = messages[last_idx] {
190 m.stop_reason = StopReason::Aborted;
191 }
192 let last_msg = messages.last().expect("non-empty").clone();
193 emit(super::AgentEvent::MessageEnd {
194 message: last_msg.clone(),
195 });
196 if let Message::Assistant(m) = &last_msg {
197 return StreamOutcome::Cancelled(m.clone());
198 }
199 }
200 return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
201 oxi_ai::Api::OpenAiCompletions,
202 "agent",
203 &loop_ref.config.model_id,
204 ));
205 }
206
207 event_count += 1;
208 match event {
209 ProviderEvent::Start { partial } => {
210 tracing::info!("Stream event #{}: Start", event_count);
211 messages.push(Message::Assistant((*partial).clone()));
212 added_partial = true;
213 emit(super::AgentEvent::MessageStart {
214 message: messages.last().expect("non-empty after push").clone(),
215 });
216 }
217
218 ProviderEvent::FallbackStart {
219 from_model,
220 to_model,
221 ..
222 } => {
223 tracing::info!(
224 "Stream event #{}: Fallback from {} to {}",
225 event_count,
226 from_model,
227 to_model
228 );
229 emit(super::AgentEvent::Fallback {
230 from_model,
231 to_model,
232 });
233 }
234
235 ProviderEvent::FallbackExhausted {
236 models_tried,
237 final_error,
238 } => {
239 tracing::warn!(
240 "Stream event #{}: All fallback models exhausted. Tried: {:?}, error: {}",
241 event_count,
242 models_tried,
243 final_error
244 );
245 if let Some(last_model) = models_tried.last() {
246 emit(super::AgentEvent::Fallback {
247 from_model: last_model.clone(),
248 to_model: "none".to_string(),
249 });
250 }
251 }
252
253 ProviderEvent::TextDelta { delta, partial, .. } => {
254 if added_partial {
255 let last_idx = messages.len() - 1;
256 if let Message::Assistant(ref mut m) = messages[last_idx] {
257 *m = (*partial).clone();
258 }
259 }
260 let last_msg = messages.last().expect("non-empty").clone();
261 let delta_clone = delta.clone();
262 emit(super::AgentEvent::MessageUpdate {
263 message: last_msg,
264 delta: Some(delta),
265 });
266
267 if let Some(engine) = ttsr {
269 let ctx = TtsrMatchContext {
270 source: MatchSource::Text,
271 file_paths: vec![],
272 tool_name: None,
273 };
274 let violations = engine.check_delta(&delta_clone, &ctx);
275 if !violations.is_empty() {
276 let mut partial_msg = messages
277 .last()
278 .and_then(|m| match m {
279 Message::Assistant(a) => Some(a.clone()),
280 _ => None,
281 })
282 .unwrap_or_else(|| {
283 oxi_ai::AssistantMessage::new(
284 oxi_ai::Api::OpenAiCompletions,
285 "agent",
286 &loop_ref.config.model_id,
287 )
288 });
289 partial_msg.stop_reason = StopReason::Aborted;
290 return StreamOutcome::RuleInterrupt {
291 partial: partial_msg,
292 rule: violations.into_iter().next().expect("non-empty"),
293 };
294 }
295 }
296 }
297
298 ProviderEvent::ThinkingStart { partial, .. } if added_partial => {
299 let last_idx = messages.len() - 1;
300 if let Message::Assistant(ref mut m) = messages[last_idx] {
301 *m = (*partial).clone();
302 }
303 }
304
305 ProviderEvent::ThinkingDelta { delta, partial, .. } => {
306 if added_partial {
307 let last_idx = messages.len() - 1;
308 if let Message::Assistant(ref mut m) = messages[last_idx] {
309 *m = (*partial).clone();
310 }
311 }
312 let last_msg = messages.last().expect("non-empty").clone();
313 emit(super::AgentEvent::MessageUpdate {
314 message: last_msg,
315 delta: Some(delta),
316 });
317 }
318
319 ProviderEvent::ToolCallStart { partial, .. } if added_partial => {
320 let last_idx = messages.len() - 1;
321 if let Message::Assistant(ref mut m) = messages[last_idx] {
322 *m = (*partial).clone();
323 }
324 }
325
326 ProviderEvent::ToolCallDelta { partial, .. } if added_partial => {
327 let last_idx = messages.len() - 1;
328 if let Message::Assistant(ref mut m) = messages[last_idx] {
329 *m = (*partial).clone();
330 }
331 }
332
333 ProviderEvent::ToolCallEnd { tool_call, .. } if added_partial => {
334 let last_idx = messages.len() - 1;
335 if let Message::Assistant(ref mut m) = messages[last_idx] {
336 m.content.push(ContentBlock::ToolCall(tool_call));
337 }
338 let last_msg = messages.last().expect("non-empty").clone();
339 emit(super::AgentEvent::MessageUpdate {
340 message: last_msg,
341 delta: None,
342 });
343 }
344
345 ProviderEvent::Done { message, .. } => {
346 loop_ref.circuit_breaker.record_success();
347
348 let (input, output) = (message.usage.input, message.usage.output);
349 if input > 0 || output > 0 {
350 loop_ref.state.update(|s| {
351 s.record_usage(input, output);
352 });
353 emit(super::AgentEvent::Usage {
354 input_tokens: input,
355 output_tokens: output,
356 });
357 }
358
359 tracing::info!(
360 "Stream event #{}: Done (stop_reason={:?})",
361 event_count,
362 message.stop_reason
363 );
364
365 if added_partial {
366 let last_idx = messages.len() - 1;
367 if let Message::Assistant(ref mut m) = messages[last_idx] {
368 let mut seen_ids: HashSet<String> = message
369 .content
370 .iter()
371 .filter_map(|b| match b {
372 ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
373 _ => None,
374 })
375 .collect();
376
377 let extra_tool_calls: Vec<ContentBlock> = m
378 .content
379 .iter()
380 .filter(|b| match b {
381 ContentBlock::ToolCall(tc) => seen_ids.insert(tc.id.clone()),
382 _ => false,
383 })
384 .cloned()
385 .collect();
386
387 let tc_count = extra_tool_calls.len();
388 *m = message.clone();
389 m.content.extend(extra_tool_calls);
390
391 tracing::info!(
392 "Done: merged {} extra tool_calls, final has {} content blocks, stop_reason={:?}",
393 tc_count,
394 m.content.len(),
395 m.stop_reason
396 );
397 }
398 } else {
399 messages.push(Message::Assistant(message.clone()));
400 }
401 let last_msg = messages.last().expect("non-empty").clone();
402 emit(super::AgentEvent::MessageEnd {
403 message: last_msg.clone(),
404 });
405 if let Message::Assistant(m) = &last_msg {
406 return StreamOutcome::Complete(m.clone());
407 } else {
408 return StreamOutcome::Complete(message);
409 }
410 }
411
412 ProviderEvent::Error { mut error, .. } => {
413 loop_ref.circuit_breaker.record_failure();
414
415 tracing::info!("Stream event #{}: Error", event_count);
416 let raw_msg = error.text_content();
417 let friendly = if raw_msg.is_empty() {
418 "Unknown provider error".to_string()
419 } else {
420 raw_msg
421 };
422 tracing::error!(
423 session_id = ?loop_ref.session_id,
424 "Provider stream error: {}", friendly
425 );
426
427 error.stop_reason = StopReason::Error;
428
429 if added_partial {
430 let last_idx = messages.len() - 1;
431 if let Message::Assistant(ref mut m) = messages[last_idx] {
432 *m = error.clone();
433 }
434 } else {
435 messages.push(Message::Assistant(error.clone()));
436 }
437
438 emit(super::AgentEvent::MessageEnd {
439 message: Message::Assistant(error.clone()),
440 });
441 emit(super::AgentEvent::Error {
442 message: format!("⚠ {}", friendly),
443 session_id: loop_ref.session_id.clone(),
444 });
445
446 return StreamOutcome::Error {
447 message: error,
448 detail: format!("⚠ {}", friendly),
449 };
450 }
451
452 _ => {}
453 }
454 }
455
456 tracing::info!("Stream ended after {} events", event_count);
457
458 let final_message = match messages.last().and_then(|m| match m {
459 Message::Assistant(a) => Some(a.clone()),
460 _ => None,
461 }) {
462 Some(m) => m,
463 None => {
464 return StreamOutcome::Error {
465 message: oxi_ai::AssistantMessage::new(
466 oxi_ai::Api::OpenAiCompletions,
467 "agent",
468 &loop_ref.config.model_id,
469 ),
470 detail: "No final assistant message in stream".to_string(),
471 };
472 }
473 };
474
475 if !added_partial {
476 tracing::warn!("Stream ended without Start event, emitting synthetic MessageStart");
477 emit(super::AgentEvent::MessageStart {
478 message: Message::Assistant(final_message.clone()),
479 });
480 }
481
482 emit(super::AgentEvent::MessageEnd {
483 message: Message::Assistant(final_message.clone()),
484 });
485 StreamOutcome::Complete(final_message)
486}