Skip to main content

oxi_agent/agent_loop/
streaming.rs

1/// Streaming implementation for agent loop.
2///
3/// pi-mono pattern: the provider accumulates content into a single `output`
4/// message. Each event carries a snapshot (`partial`) of this message.
5/// Done carries the complete accumulated message.
6///
7/// TTSR integration: when a [`TtsrEngine`](super::ttsr::TtsrEngine) is
8/// provided, every [`ProviderEvent::TextDelta`] is checked against
9/// registered rules. A match aborts the stream and returns
10/// [`StreamOutcome::RuleInterrupt`].
11use futures::StreamExt;
12use oxi_ai::{
13    ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
14};
15use std::collections::HashSet;
16
17use super::helpers::sanitize_orphaned_tool_results;
18use super::stream_outcome::StreamOutcome;
19use super::ttsr::{MatchSource, TtsrEngine, TtsrMatchContext};
20
21pub(crate) async fn stream_assistant_response(
22    loop_ref: &super::AgentLoop,
23    messages: &mut Vec<Message>,
24    emit: &super::EmitFn,
25    ttsr: Option<&TtsrEngine>,
26) -> StreamOutcome {
27    let model = match loop_ref.resolve_model() {
28        Ok(m) => m,
29        Err(_) => {
30            return StreamOutcome::Error {
31                message: oxi_ai::AssistantMessage::new(
32                    oxi_ai::Api::OpenAiCompletions,
33                    "agent",
34                    &loop_ref.config.model_id,
35                ),
36                detail: "Failed to resolve model".to_string(),
37            };
38        }
39    };
40
41    // Proactively sanitize orphaned tool results to prevent provider
42    // errors like "Messages with role 'tool' must be a response to a
43    // preceding message with 'tool_calls'".
44    let removed = sanitize_orphaned_tool_results(messages);
45    if removed > 0 {
46        tracing::warn!(
47            session_id = ?loop_ref.session_id,
48            removed,
49            "Sanitized orphaned tool results before streaming"
50        );
51    }
52
53    let mut context = Context::new();
54
55    if let Some(ref system_prompt) = loop_ref.config.system_prompt {
56        context.set_system_prompt(system_prompt.clone());
57    }
58
59    for msg in messages.iter() {
60        context.add_message(msg.clone());
61    }
62
63    let tool_defs = loop_ref.tools.definitions();
64    if !tool_defs.is_empty() {
65        let mut oxi_tools = Vec::with_capacity(tool_defs.len());
66        for def in &tool_defs {
67            let schema = serde_json::to_value(&def.input_schema)
68                .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
69            oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
70        }
71        context.set_tools(oxi_tools);
72    }
73
74    let stream_options = StreamOptions {
75        temperature: Some(loop_ref.config.temperature as f64),
76        max_tokens: Some(loop_ref.config.max_tokens as usize),
77        api_key: loop_ref.config.api_key.clone(),
78        provider_options: loop_ref.config.provider_options.clone(),
79        ..Default::default()
80    };
81
82    let stream = match super::retry::stream_with_retry(
83        loop_ref,
84        &model,
85        &context,
86        Some(stream_options),
87        emit,
88    )
89    .await
90    {
91        Ok(s) => s,
92        Err(e) => {
93            return StreamOutcome::Error {
94                message: oxi_ai::AssistantMessage::new(
95                    oxi_ai::Api::OpenAiCompletions,
96                    "agent",
97                    &loop_ref.config.model_id,
98                ),
99                detail: e.to_string(),
100            };
101        }
102    };
103
104    let mut added_partial = false;
105    let mut event_count = 0u32;
106
107    let mut rx = stream;
108    let stream_idle_timeout = std::time::Duration::from_secs(30);
109    let cancel_check_interval = std::time::Duration::from_millis(500);
110    let mut last_event_at = std::time::Instant::now();
111
112    loop {
113        let next_event = tokio::select! {
114            event = rx.next() => event,
115            _ = tokio::time::sleep(cancel_check_interval) => {
116                if loop_ref.is_cancelled() {
117                    tracing::info!(
118                        "Stream cancelled (detected in periodic check)"
119                    );
120                    if added_partial {
121                        let last_idx = messages.len() - 1;
122                        if let Message::Assistant(ref mut m) = messages[last_idx] {
123                            m.stop_reason = StopReason::Aborted;
124                        }
125                        let last_msg = messages.last().expect("non-empty").clone();
126                        emit(super::AgentEvent::MessageEnd {
127                            message: last_msg.clone(),
128                        });
129                        if let Message::Assistant(m) = &last_msg {
130                            return StreamOutcome::Cancelled(m.clone());
131                        }
132                    }
133                    return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
134                        oxi_ai::Api::OpenAiCompletions,
135                        "agent",
136                        &loop_ref.config.model_id,
137                    ));
138                }
139
140                if last_event_at.elapsed() >= stream_idle_timeout {
141                    tracing::warn!(
142                        "Stream idle timeout ({:?}) reached after {} events",
143                        stream_idle_timeout, event_count
144                    );
145                    let mut err_asst = oxi_ai::AssistantMessage::new(
146                        oxi_ai::Api::OpenAiCompletions,
147                        "agent",
148                        &loop_ref.config.model_id,
149                    );
150                    err_asst.stop_reason = StopReason::Error;
151                    err_asst.error_message = Some(format!(
152                        "Stream timed out after {:?} of inactivity",
153                        stream_idle_timeout
154                    ));
155                    if added_partial {
156                        let last_idx = messages.len() - 1;
157                        if let Message::Assistant(ref mut m) = messages[last_idx] {
158                            m.stop_reason = StopReason::Error;
159                        }
160                    }
161                    emit(super::AgentEvent::MessageEnd {
162                        message: Message::Assistant(err_asst.clone()),
163                    });
164                    emit(super::AgentEvent::Error {
165                        message: format!(
166                            "Stream timed out after {:?} of inactivity",
167                            stream_idle_timeout
168                        ),
169                        session_id: loop_ref.session_id.clone(),
170                    });
171                    return StreamOutcome::Error { message: err_asst, detail: format!("Stream timed out after {:?} of inactivity", stream_idle_timeout) };
172                }
173
174                continue;
175            }
176        };
177
178        let event = match next_event {
179            Some(e) => e,
180            None => break,
181        };
182
183        last_event_at = std::time::Instant::now();
184
185        if loop_ref.is_cancelled() {
186            tracing::info!("Stream cancelled after {} events", event_count);
187            if added_partial {
188                let last_idx = messages.len() - 1;
189                if let Message::Assistant(ref mut m) = messages[last_idx] {
190                    m.stop_reason = StopReason::Aborted;
191                }
192                let last_msg = messages.last().expect("non-empty").clone();
193                emit(super::AgentEvent::MessageEnd {
194                    message: last_msg.clone(),
195                });
196                if let Message::Assistant(m) = &last_msg {
197                    return StreamOutcome::Cancelled(m.clone());
198                }
199            }
200            return StreamOutcome::Cancelled(oxi_ai::AssistantMessage::new(
201                oxi_ai::Api::OpenAiCompletions,
202                "agent",
203                &loop_ref.config.model_id,
204            ));
205        }
206
207        event_count += 1;
208        match event {
209            ProviderEvent::Start { partial } => {
210                tracing::info!("Stream event #{}: Start", event_count);
211                messages.push(Message::Assistant((*partial).clone()));
212                added_partial = true;
213                emit(super::AgentEvent::MessageStart {
214                    message: messages.last().expect("non-empty after push").clone(),
215                });
216            }
217
218            ProviderEvent::FallbackStart {
219                from_model,
220                to_model,
221                ..
222            } => {
223                tracing::info!(
224                    "Stream event #{}: Fallback from {} to {}",
225                    event_count,
226                    from_model,
227                    to_model
228                );
229                emit(super::AgentEvent::Fallback {
230                    from_model,
231                    to_model,
232                });
233            }
234
235            ProviderEvent::FallbackExhausted {
236                models_tried,
237                final_error,
238            } => {
239                tracing::warn!(
240                    "Stream event #{}: All fallback models exhausted. Tried: {:?}, error: {}",
241                    event_count,
242                    models_tried,
243                    final_error
244                );
245                if let Some(last_model) = models_tried.last() {
246                    emit(super::AgentEvent::Fallback {
247                        from_model: last_model.clone(),
248                        to_model: "none".to_string(),
249                    });
250                }
251            }
252
253            ProviderEvent::TextDelta { delta, partial, .. } => {
254                if added_partial {
255                    let last_idx = messages.len() - 1;
256                    if let Message::Assistant(ref mut m) = messages[last_idx] {
257                        *m = (*partial).clone();
258                    }
259                }
260                let last_msg = messages.last().expect("non-empty").clone();
261                let delta_clone = delta.clone();
262                emit(super::AgentEvent::MessageUpdate {
263                    message: last_msg,
264                    delta: Some(delta),
265                });
266
267                // ── TTSR check ──
268                if let Some(engine) = ttsr {
269                    let ctx = TtsrMatchContext {
270                        source: MatchSource::Text,
271                        file_paths: vec![],
272                        tool_name: None,
273                    };
274                    let violations = engine.check_delta(&delta_clone, &ctx);
275                    if !violations.is_empty() {
276                        let mut partial_msg = messages
277                            .last()
278                            .and_then(|m| match m {
279                                Message::Assistant(a) => Some(a.clone()),
280                                _ => None,
281                            })
282                            .unwrap_or_else(|| {
283                                oxi_ai::AssistantMessage::new(
284                                    oxi_ai::Api::OpenAiCompletions,
285                                    "agent",
286                                    &loop_ref.config.model_id,
287                                )
288                            });
289                        partial_msg.stop_reason = StopReason::Aborted;
290                        return StreamOutcome::RuleInterrupt {
291                            partial: partial_msg,
292                            rule: violations.into_iter().next().expect("non-empty"),
293                        };
294                    }
295                }
296            }
297
298            ProviderEvent::ThinkingStart { partial, .. } if added_partial => {
299                let last_idx = messages.len() - 1;
300                if let Message::Assistant(ref mut m) = messages[last_idx] {
301                    *m = (*partial).clone();
302                }
303                emit(super::AgentEvent::Thinking);
304            }
305
306            ProviderEvent::ThinkingDelta { delta, partial, .. } => {
307                if added_partial {
308                    let last_idx = messages.len() - 1;
309                    if let Message::Assistant(ref mut m) = messages[last_idx] {
310                        *m = (*partial).clone();
311                    }
312                }
313                let last_msg = messages.last().expect("non-empty").clone();
314                emit(super::AgentEvent::ThinkingDelta {
315                    text: delta.clone(),
316                });
317                emit(super::AgentEvent::MessageUpdate {
318                    message: last_msg,
319                    delta: Some(delta),
320                });
321            }
322
323            ProviderEvent::ToolCallStart { partial, .. } if added_partial => {
324                let last_idx = messages.len() - 1;
325                if let Message::Assistant(ref mut m) = messages[last_idx] {
326                    *m = (*partial).clone();
327                }
328            }
329
330            ProviderEvent::ToolCallDelta { partial, .. } if added_partial => {
331                let last_idx = messages.len() - 1;
332                if let Message::Assistant(ref mut m) = messages[last_idx] {
333                    *m = (*partial).clone();
334                }
335            }
336
337            ProviderEvent::ToolCallEnd { tool_call, .. } if added_partial => {
338                let last_idx = messages.len() - 1;
339                if let Message::Assistant(ref mut m) = messages[last_idx] {
340                    m.content.push(ContentBlock::ToolCall(tool_call));
341                }
342                let last_msg = messages.last().expect("non-empty").clone();
343                emit(super::AgentEvent::MessageUpdate {
344                    message: last_msg,
345                    delta: None,
346                });
347            }
348
349            ProviderEvent::Done { message, .. } => {
350                loop_ref.circuit_breaker.record_success();
351
352                let (input, output) = (message.usage.input, message.usage.output);
353                if input > 0 || output > 0 {
354                    // Snapshot the heuristic estimate of what was *just
355                    // sent* so we can compare it to the provider's
356                    // reported input_tokens on the same snapshot. This
357                    // is the drift metric referenced by issue #28:
358                    // `bytes/4` can undercount by 3-4× on token-dense
359                    // content (base64, JSON, CJK), and the legacy
360                    // compaction path used that heuristic directly.
361                    //
362                    // The slice we estimate over is the *prompt* the
363                    // provider tokenized, NOT the prompt + the
364                    // assistant turn we just streamed. At
365                    // `ProviderEvent::Done`, `messages` ends with the
366                    // just-completed assistant message (pushed on
367                    // `Start` at the start of the stream, or — in the
368                    // no-partial-Start path — on `Done` after this
369                    // branch; we record before that push). We slice
370                    // off the trailing assistant message so the
371                    // heuristic matches what `usage.input` actually
372                    // covers; otherwise the drift metric would
373                    // *understate* #28's bytes/4 underestimate.
374                    //
375                    // The compaction decision itself is unaffected —
376                    // it reads `Real(last_input_tokens)` =
377                    // `usage.input`, which is correct.
378                    let prompt_len = messages.len().saturating_sub(1);
379                    let estimate_at_report = estimate_tokens_from_messages(&messages[..prompt_len]);
380                    loop_ref.state.update(|s| {
381                        s.record_usage(input, output);
382                        s.record_provider_turn(input, estimate_at_report);
383                    });
384                    emit(super::AgentEvent::Usage {
385                        input_tokens: input,
386                        output_tokens: output,
387                    });
388                }
389
390                tracing::info!(
391                    "Stream event #{}: Done (stop_reason={:?})",
392                    event_count,
393                    message.stop_reason
394                );
395
396                if added_partial {
397                    let last_idx = messages.len() - 1;
398                    if let Message::Assistant(ref mut m) = messages[last_idx] {
399                        let mut seen_ids: HashSet<String> = message
400                            .content
401                            .iter()
402                            .filter_map(|b| match b {
403                                ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
404                                _ => None,
405                            })
406                            .collect();
407
408                        let extra_tool_calls: Vec<ContentBlock> = m
409                            .content
410                            .iter()
411                            .filter(|b| match b {
412                                ContentBlock::ToolCall(tc) => seen_ids.insert(tc.id.clone()),
413                                _ => false,
414                            })
415                            .cloned()
416                            .collect();
417
418                        let tc_count = extra_tool_calls.len();
419                        *m = message.clone();
420                        m.content.extend(extra_tool_calls);
421
422                        tracing::info!(
423                            "Done: merged {} extra tool_calls, final has {} content blocks, stop_reason={:?}",
424                            tc_count,
425                            m.content.len(),
426                            m.stop_reason
427                        );
428                    }
429                } else {
430                    messages.push(Message::Assistant(message.clone()));
431                }
432                let last_msg = messages.last().expect("non-empty").clone();
433                emit(super::AgentEvent::MessageEnd {
434                    message: last_msg.clone(),
435                });
436                if let Message::Assistant(m) = &last_msg {
437                    return StreamOutcome::Complete(m.clone());
438                } else {
439                    return StreamOutcome::Complete(message);
440                }
441            }
442
443            ProviderEvent::Error { mut error, .. } => {
444                loop_ref.circuit_breaker.record_failure();
445
446                tracing::info!("Stream event #{}: Error", event_count);
447                let raw_msg = error.text_content();
448                let friendly = if raw_msg.is_empty() {
449                    "Unknown provider error".to_string()
450                } else {
451                    raw_msg
452                };
453                tracing::error!(
454                    session_id = ?loop_ref.session_id,
455                    "Provider stream error: {}", friendly
456                );
457
458                error.stop_reason = StopReason::Error;
459
460                if added_partial {
461                    let last_idx = messages.len() - 1;
462                    if let Message::Assistant(ref mut m) = messages[last_idx] {
463                        *m = error.clone();
464                    }
465                } else {
466                    messages.push(Message::Assistant(error.clone()));
467                }
468
469                emit(super::AgentEvent::MessageEnd {
470                    message: Message::Assistant(error.clone()),
471                });
472                emit(super::AgentEvent::Error {
473                    message: format!("⚠ {}", friendly),
474                    session_id: loop_ref.session_id.clone(),
475                });
476
477                return StreamOutcome::Error {
478                    message: error,
479                    detail: format!("⚠ {}", friendly),
480                };
481            }
482
483            _ => {}
484        }
485    }
486
487    tracing::info!("Stream ended after {} events", event_count);
488
489    let final_message = match messages.last().and_then(|m| match m {
490        Message::Assistant(a) => Some(a.clone()),
491        _ => None,
492    }) {
493        Some(m) => m,
494        None => {
495            return StreamOutcome::Error {
496                message: oxi_ai::AssistantMessage::new(
497                    oxi_ai::Api::OpenAiCompletions,
498                    "agent",
499                    &loop_ref.config.model_id,
500                ),
501                detail: "No final assistant message in stream".to_string(),
502            };
503        }
504    };
505
506    if !added_partial {
507        tracing::warn!("Stream ended without Start event, emitting synthetic MessageStart");
508        emit(super::AgentEvent::MessageStart {
509            message: Message::Assistant(final_message.clone()),
510        });
511    }
512
513    emit(super::AgentEvent::MessageEnd {
514        message: Message::Assistant(final_message.clone()),
515    });
516    StreamOutcome::Complete(final_message)
517}
518
519/// Heuristic token estimate for a messages slice, mirroring
520/// `AgentState::estimate_tokens` (serialized JSON length / 4).
521///
522/// Used in [`stream_assistant_response`] at the moment the provider
523/// reports `usage.input_tokens` to record the divergence between
524/// the legacy heuristic and the ground-truth provider count (see
525/// issue #28 gap 2). The result is cached on
526/// `AgentState::last_estimate_at_report` / `last_estimate_divergence`
527/// so the operator can see how badly `bytes/4` is undercounting on
528/// token-dense content.
529///
530/// Kept local (not a method on `AgentState`) so we can call it with
531/// a borrowed slice of the loop's working `messages` buffer without
532/// cloning the whole history.
533fn estimate_tokens_from_messages(messages: &[Message]) -> usize {
534    let json = serde_json::to_string(messages).unwrap_or_default();
535    json.len() / 4
536}