Skip to main content

oxi_agent/agent_loop/
streaming.rs

1/// Streaming implementation for agent loop.
2///
3/// pi-mono pattern: the provider accumulates content into a single `output`
4/// message. Each event carries a snapshot (`partial`) of this message.
5/// Done carries the complete accumulated message.
6///
7/// TTSR integration: when a [`TtsrEngine`](super::ttsr::TtsrEngine) is
8/// provided, every [`ProviderEvent::TextDelta`] is checked against
9/// registered rules. A match aborts the stream and returns
10/// [`StreamOutcome::RuleInterrupt`].
11use futures::StreamExt;
12use oxi_ai::{
13    ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
14};
15use std::collections::HashSet;
16
17use super::helpers::sanitize_orphaned_tool_results;
18use super::stream_outcome::StreamOutcome;
19use super::ttsr::{MatchSource, TtsrEngine, TtsrMatchContext};
20
21pub(crate) async fn stream_assistant_response(
22    loop_ref: &super::AgentLoop,
23    messages: &mut Vec<Message>,
24    emit: &super::EmitFn,
25    ttsr: Option<&TtsrEngine>,
26) -> StreamOutcome {
27    let model = match loop_ref.resolve_model() {
28        Ok(m) => m,
29        Err(_) => {
30            return StreamOutcome::Error {
31                message: oxi_ai::AssistantMessage::new(
32                    oxi_ai::Api::OpenAiCompletions,
33                    "agent",
34                    &loop_ref.config.model_id,
35                ),
36                detail: "Failed to resolve model".to_string(),
37            };
38        }
39    };
40
41    // Proactively sanitize orphaned tool results to prevent provider
42    // errors like "Messages with role 'tool' must be a response to a
43    // preceding message with 'tool_calls'".
44    let removed = sanitize_orphaned_tool_results(messages);
45    if removed > 0 {
46        tracing::warn!(
47            session_id = ?loop_ref.session_id,
48            removed,
49            "Sanitized orphaned tool results before streaming"
50        );
51    }
52
53    let mut context = Context::new();
54
55    if let Some(ref system_prompt) = loop_ref.config.system_prompt {
56        context.set_system_prompt(system_prompt.clone());
57    }
58
59    for msg in messages.iter() {
60        context.add_message(msg.clone());
61    }
62
63    let tool_defs = loop_ref.tools.definitions();
64    if !tool_defs.is_empty() {
65        let mut oxi_tools = Vec::with_capacity(tool_defs.len());
66        for def in &tool_defs {
67            let schema = serde_json::to_value(&def.input_schema)
68                .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
69            oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
70        }
71        context.set_tools(oxi_tools);
72    }
73
74    let stream_options = StreamOptions {
75        temperature: Some(loop_ref.config.temperature as f64),
76        max_tokens: Some(loop_ref.config.max_tokens as usize),
77        api_key: loop_ref.config.api_key.clone(),
78        provider_options: loop_ref.config.provider_options.clone(),
79        ..Default::default()
80    };
81
82    let stream = match super::retry::stream_with_retry(
83        loop_ref,
84        &model,
85        &context,
86        Some(stream_options),
87        emit,
88    )
89    .await
90    {
91        Ok(s) => s,
92        Err(e) => {
93            return StreamOutcome::Error {
94                message: oxi_ai::AssistantMessage::new(
95                    oxi_ai::Api::OpenAiCompletions,
96                    "agent",
97                    &loop_ref.config.model_id,
98                ),
99                detail: e.to_string(),
100            };
101        }
102    };
103
104    let mut added_partial = false;
105    let mut event_count = 0u32;
106
107    let mut rx = stream;
108    let stream_idle_timeout = std::time::Duration::from_secs(30);
109    let cancel_check_interval = std::time::Duration::from_millis(500);
110    let mut last_event_at = std::time::Instant::now();
111
112    loop {
113        let next_event = tokio::select! {
114            event = rx.next() => event,
115            _ = tokio::time::sleep(cancel_check_interval) => {
116                if loop_ref.is_cancelled() {
117                    tracing::info!(
118                        "Stream cancelled (detected in periodic check)"
119                    );
120                    if added_partial {
121                        let last_idx = messages.len() - 1;
122                        if let Message::Assistant(ref mut m) = messages[last_idx] {
123                            m.stop_reason = StopReason::Aborted;
124                        }
125                        let last_msg = messages.last().expect("non-empty").clone();
126                        emit(super::AgentEvent::MessageEnd {
127                            message: last_msg.clone(),
128                        });
129                        if let Message::Assistant(m) = &last_msg {
130                            return StreamOutcome::Cancelled(m.clone());
131                        }
132                    }
133                    return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
134                        oxi_ai::Api::OpenAiCompletions,
135                        "agent",
136                        &loop_ref.config.model_id,
137                    ));
138                }
139
140                if last_event_at.elapsed() >= stream_idle_timeout {
141                    tracing::warn!(
142                        "Stream idle timeout ({:?}) reached after {} events",
143                        stream_idle_timeout, event_count
144                    );
145                    let mut err_asst = oxi_ai::AssistantMessage::new(
146                        oxi_ai::Api::OpenAiCompletions,
147                        "agent",
148                        &loop_ref.config.model_id,
149                    );
150                    err_asst.stop_reason = StopReason::Error;
151                    err_asst.error_message = Some(format!(
152                        "Stream timed out after {:?} of inactivity",
153                        stream_idle_timeout
154                    ));
155                    if added_partial {
156                        let last_idx = messages.len() - 1;
157                        if let Message::Assistant(ref mut m) = messages[last_idx] {
158                            m.stop_reason = StopReason::Error;
159                        }
160                    }
161                    emit(super::AgentEvent::MessageEnd {
162                        message: Message::Assistant(err_asst.clone()),
163                    });
164                    emit(super::AgentEvent::Error {
165                        message: format!(
166                            "Stream timed out after {:?} of inactivity",
167                            stream_idle_timeout
168                        ),
169                        session_id: loop_ref.session_id.clone(),
170                    });
171                    return StreamOutcome::Error { message: err_asst, detail: format!("Stream timed out after {:?} of inactivity", stream_idle_timeout) };
172                }
173
174                continue;
175            }
176        };
177
178        let event = match next_event {
179            Some(e) => e,
180            None => break,
181        };
182
183        last_event_at = std::time::Instant::now();
184
185        if loop_ref.is_cancelled() {
186            tracing::info!("Stream cancelled after {} events", event_count);
187            if added_partial {
188                let last_idx = messages.len() - 1;
189                if let Message::Assistant(ref mut m) = messages[last_idx] {
190                    m.stop_reason = StopReason::Aborted;
191                }
192                let last_msg = messages.last().expect("non-empty").clone();
193                emit(super::AgentEvent::MessageEnd {
194                    message: last_msg.clone(),
195                });
196                if let Message::Assistant(m) = &last_msg {
197                    return StreamOutcome::Cancelled(m.clone());
198                }
199            }
200            return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
201                oxi_ai::Api::OpenAiCompletions,
202                "agent",
203                &loop_ref.config.model_id,
204            ));
205        }
206
207        event_count += 1;
208        match event {
209            ProviderEvent::Start { partial } => {
210                tracing::info!("Stream event #{}: Start", event_count);
211                messages.push(Message::Assistant((*partial).clone()));
212                added_partial = true;
213                emit(super::AgentEvent::MessageStart {
214                    message: messages.last().expect("non-empty after push").clone(),
215                });
216            }
217
218            ProviderEvent::FallbackStart {
219                from_model,
220                to_model,
221                ..
222            } => {
223                tracing::info!(
224                    "Stream event #{}: Fallback from {} to {}",
225                    event_count,
226                    from_model,
227                    to_model
228                );
229                emit(super::AgentEvent::Fallback {
230                    from_model,
231                    to_model,
232                });
233            }
234
235            ProviderEvent::FallbackExhausted {
236                models_tried,
237                final_error,
238            } => {
239                tracing::warn!(
240                    "Stream event #{}: All fallback models exhausted. Tried: {:?}, error: {}",
241                    event_count,
242                    models_tried,
243                    final_error
244                );
245                if let Some(last_model) = models_tried.last() {
246                    emit(super::AgentEvent::Fallback {
247                        from_model: last_model.clone(),
248                        to_model: "none".to_string(),
249                    });
250                }
251            }
252
253            ProviderEvent::TextDelta { delta, partial, .. } => {
254                if added_partial {
255                    let last_idx = messages.len() - 1;
256                    if let Message::Assistant(ref mut m) = messages[last_idx] {
257                        *m = (*partial).clone();
258                    }
259                }
260                let last_msg = messages.last().expect("non-empty").clone();
261                let delta_clone = delta.clone();
262                emit(super::AgentEvent::MessageUpdate {
263                    message: last_msg,
264                    delta: Some(delta),
265                });
266
267                // ── TTSR check ──
268                if let Some(engine) = ttsr {
269                    let ctx = TtsrMatchContext {
270                        source: MatchSource::Text,
271                        file_paths: vec![],
272                        tool_name: None,
273                    };
274                    let violations = engine.check_delta(&delta_clone, &ctx);
275                    if !violations.is_empty() {
276                        let mut partial_msg = messages
277                            .last()
278                            .and_then(|m| match m {
279                                Message::Assistant(a) => Some(a.clone()),
280                                _ => None,
281                            })
282                            .unwrap_or_else(|| {
283                                oxi_ai::AssistantMessage::new(
284                                    oxi_ai::Api::OpenAiCompletions,
285                                    "agent",
286                                    &loop_ref.config.model_id,
287                                )
288                            });
289                        partial_msg.stop_reason = StopReason::Aborted;
290                        return StreamOutcome::RuleInterrupt {
291                            partial: partial_msg,
292                            rule: violations.into_iter().next().expect("non-empty"),
293                        };
294                    }
295                }
296            }
297
298            ProviderEvent::ThinkingStart { partial, .. } if added_partial => {
299                let last_idx = messages.len() - 1;
300                if let Message::Assistant(ref mut m) = messages[last_idx] {
301                    *m = (*partial).clone();
302                }
303            }
304
305            ProviderEvent::ThinkingDelta { delta, partial, .. } => {
306                if added_partial {
307                    let last_idx = messages.len() - 1;
308                    if let Message::Assistant(ref mut m) = messages[last_idx] {
309                        *m = (*partial).clone();
310                    }
311                }
312                let last_msg = messages.last().expect("non-empty").clone();
313                emit(super::AgentEvent::MessageUpdate {
314                    message: last_msg,
315                    delta: Some(delta),
316                });
317            }
318
319            ProviderEvent::ToolCallStart { partial, .. } if added_partial => {
320                let last_idx = messages.len() - 1;
321                if let Message::Assistant(ref mut m) = messages[last_idx] {
322                    *m = (*partial).clone();
323                }
324            }
325
326            ProviderEvent::ToolCallDelta { partial, .. } if added_partial => {
327                let last_idx = messages.len() - 1;
328                if let Message::Assistant(ref mut m) = messages[last_idx] {
329                    *m = (*partial).clone();
330                }
331            }
332
333            ProviderEvent::ToolCallEnd { tool_call, .. } if added_partial => {
334                let last_idx = messages.len() - 1;
335                if let Message::Assistant(ref mut m) = messages[last_idx] {
336                    m.content.push(ContentBlock::ToolCall(tool_call));
337                }
338                let last_msg = messages.last().expect("non-empty").clone();
339                emit(super::AgentEvent::MessageUpdate {
340                    message: last_msg,
341                    delta: None,
342                });
343            }
344
345            ProviderEvent::Done { message, .. } => {
346                loop_ref.circuit_breaker.record_success();
347
348                let (input, output) = (message.usage.input, message.usage.output);
349                if input > 0 || output > 0 {
350                    loop_ref.state.update(|s| {
351                        s.record_usage(input, output);
352                    });
353                    emit(super::AgentEvent::Usage {
354                        input_tokens: input,
355                        output_tokens: output,
356                    });
357                }
358
359                tracing::info!(
360                    "Stream event #{}: Done (stop_reason={:?})",
361                    event_count,
362                    message.stop_reason
363                );
364
365                if added_partial {
366                    let last_idx = messages.len() - 1;
367                    if let Message::Assistant(ref mut m) = messages[last_idx] {
368                        let mut seen_ids: HashSet<String> = message
369                            .content
370                            .iter()
371                            .filter_map(|b| match b {
372                                ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
373                                _ => None,
374                            })
375                            .collect();
376
377                        let extra_tool_calls: Vec<ContentBlock> = m
378                            .content
379                            .iter()
380                            .filter(|b| match b {
381                                ContentBlock::ToolCall(tc) => seen_ids.insert(tc.id.clone()),
382                                _ => false,
383                            })
384                            .cloned()
385                            .collect();
386
387                        let tc_count = extra_tool_calls.len();
388                        *m = message.clone();
389                        m.content.extend(extra_tool_calls);
390
391                        tracing::info!(
392                            "Done: merged {} extra tool_calls, final has {} content blocks, stop_reason={:?}",
393                            tc_count,
394                            m.content.len(),
395                            m.stop_reason
396                        );
397                    }
398                } else {
399                    messages.push(Message::Assistant(message.clone()));
400                }
401                let last_msg = messages.last().expect("non-empty").clone();
402                emit(super::AgentEvent::MessageEnd {
403                    message: last_msg.clone(),
404                });
405                if let Message::Assistant(m) = &last_msg {
406                    return StreamOutcome::Complete(m.clone());
407                } else {
408                    return StreamOutcome::Complete(message);
409                }
410            }
411
412            ProviderEvent::Error { mut error, .. } => {
413                loop_ref.circuit_breaker.record_failure();
414
415                tracing::info!("Stream event #{}: Error", event_count);
416                let raw_msg = error.text_content();
417                let friendly = if raw_msg.is_empty() {
418                    "Unknown provider error".to_string()
419                } else {
420                    raw_msg
421                };
422                tracing::error!(
423                    session_id = ?loop_ref.session_id,
424                    "Provider stream error: {}", friendly
425                );
426
427                error.stop_reason = StopReason::Error;
428
429                if added_partial {
430                    let last_idx = messages.len() - 1;
431                    if let Message::Assistant(ref mut m) = messages[last_idx] {
432                        *m = error.clone();
433                    }
434                } else {
435                    messages.push(Message::Assistant(error.clone()));
436                }
437
438                emit(super::AgentEvent::MessageEnd {
439                    message: Message::Assistant(error.clone()),
440                });
441                emit(super::AgentEvent::Error {
442                    message: format!("⚠ {}", friendly),
443                    session_id: loop_ref.session_id.clone(),
444                });
445
446                return StreamOutcome::Error {
447                    message: error,
448                    detail: format!("⚠ {}", friendly),
449                };
450            }
451
452            _ => {}
453        }
454    }
455
456    tracing::info!("Stream ended after {} events", event_count);
457
458    let final_message = match messages.last().and_then(|m| match m {
459        Message::Assistant(a) => Some(a.clone()),
460        _ => None,
461    }) {
462        Some(m) => m,
463        None => {
464            return StreamOutcome::Error {
465                message: oxi_ai::AssistantMessage::new(
466                    oxi_ai::Api::OpenAiCompletions,
467                    "agent",
468                    &loop_ref.config.model_id,
469                ),
470                detail: "No final assistant message in stream".to_string(),
471            };
472        }
473    };
474
475    if !added_partial {
476        tracing::warn!("Stream ended without Start event, emitting synthetic MessageStart");
477        emit(super::AgentEvent::MessageStart {
478            message: Message::Assistant(final_message.clone()),
479        });
480    }
481
482    emit(super::AgentEvent::MessageEnd {
483        message: Message::Assistant(final_message.clone()),
484    });
485    StreamOutcome::Complete(final_message)
486}