1use futures::StreamExt;
12use oxi_ai::{
13 ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
14};
15use std::collections::HashSet;
16
17use super::helpers::sanitize_orphaned_tool_results;
18use super::stream_outcome::StreamOutcome;
19use super::ttsr::{MatchSource, TtsrEngine, TtsrMatchContext};
20
21pub(crate) async fn stream_assistant_response(
22 loop_ref: &super::AgentLoop,
23 messages: &mut Vec<Message>,
24 emit: &super::EmitFn,
25 ttsr: Option<&TtsrEngine>,
26) -> StreamOutcome {
27 let model = match loop_ref.resolve_model() {
28 Ok(m) => m,
29 Err(_) => {
30 return StreamOutcome::Error {
31 message: oxi_ai::AssistantMessage::new(
32 oxi_ai::Api::OpenAiCompletions,
33 "agent",
34 &loop_ref.config.model_id,
35 ),
36 detail: "Failed to resolve model".to_string(),
37 };
38 }
39 };
40
41 let removed = sanitize_orphaned_tool_results(messages);
45 if removed > 0 {
46 tracing::warn!(
47 session_id = ?loop_ref.session_id,
48 removed,
49 "Sanitized orphaned tool results before streaming"
50 );
51 }
52
53 let mut context = Context::new();
54
55 if let Some(ref system_prompt) = loop_ref.config.system_prompt {
56 context.set_system_prompt(system_prompt.clone());
57 }
58
59 for msg in messages.iter() {
60 context.add_message(msg.clone());
61 }
62
63 let tool_defs = loop_ref.tools.definitions();
64 if !tool_defs.is_empty() {
65 let mut oxi_tools = Vec::with_capacity(tool_defs.len());
66 for def in &tool_defs {
67 let schema = serde_json::to_value(&def.input_schema)
68 .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
69 oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
70 }
71 context.set_tools(oxi_tools);
72 }
73
74 let stream_options = StreamOptions {
75 temperature: Some(loop_ref.config.temperature as f64),
76 max_tokens: Some(loop_ref.config.max_tokens as usize),
77 api_key: loop_ref.config.api_key.clone(),
78 provider_options: loop_ref.config.provider_options.clone(),
79 ..Default::default()
80 };
81
82 let stream = match super::retry::stream_with_retry(
83 loop_ref,
84 &model,
85 &context,
86 Some(stream_options),
87 emit,
88 )
89 .await
90 {
91 Ok(s) => s,
92 Err(e) => {
93 return StreamOutcome::Error {
94 message: oxi_ai::AssistantMessage::new(
95 oxi_ai::Api::OpenAiCompletions,
96 "agent",
97 &loop_ref.config.model_id,
98 ),
99 detail: e.to_string(),
100 };
101 }
102 };
103
104 let mut added_partial = false;
105 let mut event_count = 0u32;
106
107 let mut rx = stream;
108 let stream_idle_timeout = std::time::Duration::from_secs(30);
109 let cancel_check_interval = std::time::Duration::from_millis(500);
110 let mut last_event_at = std::time::Instant::now();
111
112 loop {
113 let next_event = tokio::select! {
114 event = rx.next() => event,
115 _ = tokio::time::sleep(cancel_check_interval) => {
116 if loop_ref.is_cancelled() {
117 tracing::info!(
118 "Stream cancelled (detected in periodic check)"
119 );
120 if added_partial {
121 let last_idx = messages.len() - 1;
122 if let Message::Assistant(ref mut m) = messages[last_idx] {
123 m.stop_reason = StopReason::Aborted;
124 }
125 let last_msg = messages.last().expect("non-empty").clone();
126 emit(super::AgentEvent::MessageEnd {
127 message: last_msg.clone(),
128 });
129 if let Message::Assistant(m) = &last_msg {
130 return StreamOutcome::Cancelled(m.clone());
131 }
132 }
133 return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
134 oxi_ai::Api::OpenAiCompletions,
135 "agent",
136 &loop_ref.config.model_id,
137 ));
138 }
139
140 if last_event_at.elapsed() >= stream_idle_timeout {
141 tracing::warn!(
142 "Stream idle timeout ({:?}) reached after {} events",
143 stream_idle_timeout, event_count
144 );
145 let mut err_asst = oxi_ai::AssistantMessage::new(
146 oxi_ai::Api::OpenAiCompletions,
147 "agent",
148 &loop_ref.config.model_id,
149 );
150 err_asst.stop_reason = StopReason::Error;
151 err_asst.error_message = Some(format!(
152 "Stream timed out after {:?} of inactivity",
153 stream_idle_timeout
154 ));
155 if added_partial {
156 let last_idx = messages.len() - 1;
157 if let Message::Assistant(ref mut m) = messages[last_idx] {
158 m.stop_reason = StopReason::Error;
159 }
160 }
161 emit(super::AgentEvent::MessageEnd {
162 message: Message::Assistant(err_asst.clone()),
163 });
164 emit(super::AgentEvent::Error {
165 message: format!(
166 "Stream timed out after {:?} of inactivity",
167 stream_idle_timeout
168 ),
169 session_id: loop_ref.session_id.clone(),
170 });
171 return StreamOutcome::Error { message: err_asst, detail: format!("Stream timed out after {:?} of inactivity", stream_idle_timeout) };
172 }
173
174 continue;
175 }
176 };
177
178 let event = match next_event {
179 Some(e) => e,
180 None => break,
181 };
182
183 last_event_at = std::time::Instant::now();
184
185 if loop_ref.is_cancelled() {
186 tracing::info!("Stream cancelled after {} events", event_count);
187 if added_partial {
188 let last_idx = messages.len() - 1;
189 if let Message::Assistant(ref mut m) = messages[last_idx] {
190 m.stop_reason = StopReason::Aborted;
191 }
192 let last_msg = messages.last().expect("non-empty").clone();
193 emit(super::AgentEvent::MessageEnd {
194 message: last_msg.clone(),
195 });
196 if let Message::Assistant(m) = &last_msg {
197 return StreamOutcome::Cancelled(m.clone());
198 }
199 }
200 return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
201 oxi_ai::Api::OpenAiCompletions,
202 "agent",
203 &loop_ref.config.model_id,
204 ));
205 }
206
207 event_count += 1;
208 match event {
209 ProviderEvent::Start { partial } => {
210 tracing::info!("Stream event #{}: Start", event_count);
211 messages.push(Message::Assistant((*partial).clone()));
212 added_partial = true;
213 emit(super::AgentEvent::MessageStart {
214 message: messages.last().expect("non-empty after push").clone(),
215 });
216 }
217
218 ProviderEvent::FallbackStart {
219 from_model,
220 to_model,
221 ..
222 } => {
223 tracing::info!(
224 "Stream event #{}: Fallback from {} to {}",
225 event_count,
226 from_model,
227 to_model
228 );
229 emit(super::AgentEvent::Fallback {
230 from_model,
231 to_model,
232 });
233 }
234
235 ProviderEvent::FallbackExhausted {
236 models_tried,
237 final_error,
238 } => {
239 tracing::warn!(
240 "Stream event #{}: All fallback models exhausted. Tried: {:?}, error: {}",
241 event_count,
242 models_tried,
243 final_error
244 );
245 if let Some(last_model) = models_tried.last() {
246 emit(super::AgentEvent::Fallback {
247 from_model: last_model.clone(),
248 to_model: "none".to_string(),
249 });
250 }
251 }
252
253 ProviderEvent::TextDelta { delta, partial, .. } => {
254 if added_partial {
255 let last_idx = messages.len() - 1;
256 if let Message::Assistant(ref mut m) = messages[last_idx] {
257 *m = (*partial).clone();
258 }
259 }
260 let last_msg = messages.last().expect("non-empty").clone();
261 let delta_clone = delta.clone();
262 emit(super::AgentEvent::MessageUpdate {
263 message: last_msg,
264 delta: Some(delta),
265 });
266
267 if let Some(engine) = ttsr {
269 let ctx = TtsrMatchContext {
270 source: MatchSource::Text,
271 file_paths: vec![],
272 tool_name: None,
273 };
274 let violations = engine.check_delta(&delta_clone, &ctx);
275 if !violations.is_empty() {
276 let mut partial_msg = messages
277 .last()
278 .and_then(|m| match m {
279 Message::Assistant(a) => Some(a.clone()),
280 _ => None,
281 })
282 .unwrap_or_else(|| {
283 oxi_ai::AssistantMessage::new(
284 oxi_ai::Api::OpenAiCompletions,
285 "agent",
286 &loop_ref.config.model_id,
287 )
288 });
289 partial_msg.stop_reason = StopReason::Aborted;
290 return StreamOutcome::RuleInterrupt {
291 partial: partial_msg,
292 rule: violations.into_iter().next().expect("non-empty"),
293 };
294 }
295 }
296 }
297
298 ProviderEvent::ThinkingStart { partial, .. } if added_partial => {
299 let last_idx = messages.len() - 1;
300 if let Message::Assistant(ref mut m) = messages[last_idx] {
301 *m = (*partial).clone();
302 }
303 emit(super::AgentEvent::Thinking);
304 }
305
306 ProviderEvent::ThinkingDelta { delta, partial, .. } => {
307 if added_partial {
308 let last_idx = messages.len() - 1;
309 if let Message::Assistant(ref mut m) = messages[last_idx] {
310 *m = (*partial).clone();
311 }
312 }
313 let last_msg = messages.last().expect("non-empty").clone();
314 emit(super::AgentEvent::ThinkingDelta {
315 text: delta.clone(),
316 });
317 emit(super::AgentEvent::MessageUpdate {
318 message: last_msg,
319 delta: Some(delta),
320 });
321 }
322
323 ProviderEvent::ToolCallStart { partial, .. } if added_partial => {
324 let last_idx = messages.len() - 1;
325 if let Message::Assistant(ref mut m) = messages[last_idx] {
326 *m = (*partial).clone();
327 }
328 }
329
330 ProviderEvent::ToolCallDelta { partial, .. } if added_partial => {
331 let last_idx = messages.len() - 1;
332 if let Message::Assistant(ref mut m) = messages[last_idx] {
333 *m = (*partial).clone();
334 }
335 }
336
337 ProviderEvent::ToolCallEnd { tool_call, .. } if added_partial => {
338 let last_idx = messages.len() - 1;
339 if let Message::Assistant(ref mut m) = messages[last_idx] {
340 m.content.push(ContentBlock::ToolCall(tool_call));
341 }
342 let last_msg = messages.last().expect("non-empty").clone();
343 emit(super::AgentEvent::MessageUpdate {
344 message: last_msg,
345 delta: None,
346 });
347 }
348
349 ProviderEvent::Done { message, .. } => {
350 loop_ref.circuit_breaker.record_success();
351
352 let (input, output) = (message.usage.input, message.usage.output);
353 if input > 0 || output > 0 {
354 let prompt_len = messages.len().saturating_sub(1);
379 let estimate_at_report = estimate_tokens_from_messages(&messages[..prompt_len]);
380 loop_ref.state.update(|s| {
381 s.record_usage(input, output);
382 s.record_provider_turn(input, estimate_at_report);
383 });
384 emit(super::AgentEvent::Usage {
385 input_tokens: input,
386 output_tokens: output,
387 });
388 }
389
390 tracing::info!(
391 "Stream event #{}: Done (stop_reason={:?})",
392 event_count,
393 message.stop_reason
394 );
395
396 if added_partial {
397 let last_idx = messages.len() - 1;
398 if let Message::Assistant(ref mut m) = messages[last_idx] {
399 let mut seen_ids: HashSet<String> = message
400 .content
401 .iter()
402 .filter_map(|b| match b {
403 ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
404 _ => None,
405 })
406 .collect();
407
408 let extra_tool_calls: Vec<ContentBlock> = m
409 .content
410 .iter()
411 .filter(|b| match b {
412 ContentBlock::ToolCall(tc) => seen_ids.insert(tc.id.clone()),
413 _ => false,
414 })
415 .cloned()
416 .collect();
417
418 let tc_count = extra_tool_calls.len();
419 *m = message.clone();
420 m.content.extend(extra_tool_calls);
421
422 tracing::info!(
423 "Done: merged {} extra tool_calls, final has {} content blocks, stop_reason={:?}",
424 tc_count,
425 m.content.len(),
426 m.stop_reason
427 );
428 }
429 } else {
430 messages.push(Message::Assistant(message.clone()));
431 }
432 let last_msg = messages.last().expect("non-empty").clone();
433 emit(super::AgentEvent::MessageEnd {
434 message: last_msg.clone(),
435 });
436 if let Message::Assistant(m) = &last_msg {
437 return StreamOutcome::Complete(m.clone());
438 } else {
439 return StreamOutcome::Complete(message);
440 }
441 }
442
443 ProviderEvent::Error { mut error, .. } => {
444 loop_ref.circuit_breaker.record_failure();
445
446 tracing::info!("Stream event #{}: Error", event_count);
447 let raw_msg = error.text_content();
448 let friendly = if raw_msg.is_empty() {
449 "Unknown provider error".to_string()
450 } else {
451 raw_msg
452 };
453 tracing::error!(
454 session_id = ?loop_ref.session_id,
455 "Provider stream error: {}", friendly
456 );
457
458 error.stop_reason = StopReason::Error;
459
460 if added_partial {
461 let last_idx = messages.len() - 1;
462 if let Message::Assistant(ref mut m) = messages[last_idx] {
463 *m = error.clone();
464 }
465 } else {
466 messages.push(Message::Assistant(error.clone()));
467 }
468
469 emit(super::AgentEvent::MessageEnd {
470 message: Message::Assistant(error.clone()),
471 });
472 emit(super::AgentEvent::Error {
473 message: format!("⚠ {}", friendly),
474 session_id: loop_ref.session_id.clone(),
475 });
476
477 return StreamOutcome::Error {
478 message: error,
479 detail: format!("⚠ {}", friendly),
480 };
481 }
482
483 _ => {}
484 }
485 }
486
487 tracing::info!("Stream ended after {} events", event_count);
488
489 let final_message = match messages.last().and_then(|m| match m {
490 Message::Assistant(a) => Some(a.clone()),
491 _ => None,
492 }) {
493 Some(m) => m,
494 None => {
495 return StreamOutcome::Error {
496 message: oxi_ai::AssistantMessage::new(
497 oxi_ai::Api::OpenAiCompletions,
498 "agent",
499 &loop_ref.config.model_id,
500 ),
501 detail: "No final assistant message in stream".to_string(),
502 };
503 }
504 };
505
506 if !added_partial {
507 tracing::warn!("Stream ended without Start event, emitting synthetic MessageStart");
508 emit(super::AgentEvent::MessageStart {
509 message: Message::Assistant(final_message.clone()),
510 });
511 }
512
513 emit(super::AgentEvent::MessageEnd {
514 message: Message::Assistant(final_message.clone()),
515 });
516 StreamOutcome::Complete(final_message)
517}
518
519fn estimate_tokens_from_messages(messages: &[Message]) -> usize {
534 let json = serde_json::to_string(messages).unwrap_or_default();
535 json.len() / 4
536}