Skip to main content

oxi_agent/agent_loop/
streaming.rs

1/// Streaming implementation for agent loop.
2///
3/// pi-mono pattern: the provider accumulates content into a single `output`
4/// message. Each event carries a snapshot (`partial`) of this message.
5/// Done carries the complete accumulated message.
6///
7/// This module simply forwards events to the agent loop emit function.
8use anyhow::{Error, Result};
9use futures::StreamExt;
10use oxi_ai::{
11    ContentBlock, Context, Message, ProviderEvent, StopReason, StreamOptions, Tool as OxTool,
12};
13use std::collections::HashSet;
14
15pub(crate) async fn stream_assistant_response(
16    loop_ref: &super::AgentLoop,
17    messages: &mut Vec<Message>,
18    emit: &super::EmitFn,
19) -> Result<oxi_ai::AssistantMessage> {
20    let model = loop_ref.resolve_model()?;
21
22    let mut context = Context::new();
23
24    if let Some(ref system_prompt) = loop_ref.config.system_prompt {
25        context.set_system_prompt(system_prompt.clone());
26    }
27
28    for msg in messages.iter() {
29        context.add_message(msg.clone());
30    }
31
32    // Cache tool definitions serialization once to avoid repeated serde work.
33    let tool_defs = loop_ref.tools.definitions();
34    if !tool_defs.is_empty() {
35        let mut oxi_tools = Vec::with_capacity(tool_defs.len());
36        for def in &tool_defs {
37            let schema = serde_json::to_value(&def.input_schema)
38                .unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));
39            oxi_tools.push(OxTool::new(&def.name, &def.description, schema));
40        }
41        context.set_tools(oxi_tools);
42    }
43
44    let stream_options = StreamOptions {
45        temperature: Some(loop_ref.config.temperature as f64),
46        max_tokens: Some(loop_ref.config.max_tokens as usize),
47        api_key: loop_ref.config.api_key.clone(),
48        ..Default::default()
49    };
50
51    let stream =
52        super::retry::stream_with_retry(loop_ref, &model, &context, Some(stream_options), emit)
53            .await?;
54
55    // pi-mono pattern: track whether we've emitted MessageStart.
56    // Start event initializes the stream. Subsequent deltas carry
57    // accumulated partial messages (content grows in-place at the provider).
58    let mut added_partial = false;
59    let mut event_count = 0u32;
60
61    let mut rx = stream;
62    while let Some(event) = rx.next().await {
63        event_count += 1;
64        match event {
65            ProviderEvent::Start { partial } => {
66                tracing::info!("Stream event #{}: Start", event_count);
67                messages.push(Message::Assistant(partial));
68                added_partial = true;
69                emit(super::AgentEvent::MessageStart {
70                    message: messages.last().expect("non-empty after push").clone(),
71                });
72            }
73
74            ProviderEvent::TextDelta { delta, partial, .. } => {
75                // Replace the last assistant message with the provider's
76                // accumulated snapshot (pi-mono: content grows in partial).
77                if added_partial {
78                    let last_idx = messages.len() - 1;
79                    if let Message::Assistant(ref mut m) = messages[last_idx] {
80                        *m = partial;
81                    }
82                }
83                let last_msg = messages.last().expect("non-empty").clone();
84                emit(super::AgentEvent::MessageUpdate {
85                    message: last_msg,
86                    delta: Some(delta),
87                });
88            }
89
90            ProviderEvent::ThinkingStart { partial, .. }
91                // ThinkingStart arrives before ThinkingDelta.
92                // Update the snapshot.
93                if added_partial => {
94                    let last_idx = messages.len() - 1;
95                    if let Message::Assistant(ref mut m) = messages[last_idx] {
96                        *m = partial;
97                    }
98                }
99
100            ProviderEvent::ThinkingDelta { delta, partial, .. } => {
101                if added_partial {
102                    let last_idx = messages.len() - 1;
103                    if let Message::Assistant(ref mut m) = messages[last_idx] {
104                        *m = partial;
105                    }
106                }
107                let last_msg = messages.last().expect("non-empty").clone();
108                emit(super::AgentEvent::MessageUpdate {
109                    message: last_msg,
110                    delta: Some(delta),
111                });
112            }
113
114            ProviderEvent::ToolCallStart { partial, .. }
115                if added_partial => {
116                    let last_idx = messages.len() - 1;
117                    if let Message::Assistant(ref mut m) = messages[last_idx] {
118                        *m = partial;
119                    }
120                }
121
122            ProviderEvent::ToolCallDelta { partial, .. }
123                if added_partial => {
124                    let last_idx = messages.len() - 1;
125                    if let Message::Assistant(ref mut m) = messages[last_idx] {
126                        *m = partial;
127                    }
128                }
129
130            ProviderEvent::ToolCallEnd { tool_call, .. }
131                // Add the tool call directly to our tracked message.
132                if added_partial => {
133                    let last_idx = messages.len() - 1;
134                    if let Message::Assistant(ref mut m) = messages[last_idx] {
135                        m.content.push(ContentBlock::ToolCall(tool_call));
136                    }
137                    // CRITICAL: emit MessageUpdate so the TUI sees the ToolCall block.
138                    // Without this, tool calls are never rendered (matching pi's behavior
139                    // where toolcall_end emits message_update).
140                    let last_msg = messages.last().expect("non-empty").clone();
141                    emit(super::AgentEvent::MessageUpdate {
142                        message: last_msg,
143                        delta: None,
144                    });
145                }
146
147            ProviderEvent::Done { message, .. } => {
148                tracing::info!(
149                    "Stream event #{}: Done (stop_reason={:?})",
150                    event_count,
151                    message.stop_reason
152                );
153                if added_partial {
154                    let last_idx = messages.len() - 1;
155                    if let Message::Assistant(ref mut m) = messages[last_idx] {
156                        // Preserve tool calls we may have injected via ToolCallEnd.
157                        // Some providers also include ToolCall blocks in the final Done message,
158                        // so dedupe by tool_call_id to avoid executing the same tool twice.
159                        let mut preserved_tool_calls: Vec<ContentBlock> = m
160                            .content
161                            .drain(..)
162                            .filter(|b| matches!(b, ContentBlock::ToolCall(_)))
163                            .collect();
164
165                        let mut seen: HashSet<String> = message
166                            .content
167                            .iter()
168                            .filter_map(|b| match b {
169                                ContentBlock::ToolCall(tc) => Some(tc.id.clone()),
170                                _ => None,
171                            })
172                            .collect();
173
174                        preserved_tool_calls.retain(|b| match b {
175                            ContentBlock::ToolCall(tc) => seen.insert(tc.id.clone()),
176                            _ => true,
177                        });
178
179                        tracing::info!(
180                            "Done: preserving {} tool_calls (deduped), Done message has {} content blocks",
181                            preserved_tool_calls.len(),
182                            message.content.len()
183                        );
184
185                        *m = message.clone();
186                        m.content.extend(preserved_tool_calls);
187                        tracing::info!(
188                            "Done: final message has {} content blocks, stop_reason={:?}",
189                            m.content.len(),
190                            m.stop_reason
191                        );
192                    }
193                } else {
194                    messages.push(Message::Assistant(message.clone()));
195                }
196                let last_msg = messages.last().expect("non-empty").clone();
197                emit(super::AgentEvent::MessageEnd {
198                    message: last_msg.clone(),
199                });
200                // Return the message we actually stored (with tool calls preserved)
201                if let Message::Assistant(m) = &last_msg {
202                    return Ok(m.clone());
203                } else {
204                    return Ok(message);
205                }
206            }
207
208            ProviderEvent::Error { mut error, .. } => {
209                tracing::info!("Stream event #{}: Error", event_count);
210                let raw_msg = error.text_content();
211                let friendly = if raw_msg.is_empty() {
212                    "Unknown provider error".to_string()
213                } else {
214                    raw_msg
215                };
216                tracing::error!(session_id = ?loop_ref.session_id, "Provider stream error: {}", friendly);
217
218                error.stop_reason = StopReason::Error;
219
220                if added_partial {
221                    let last_idx = messages.len() - 1;
222                    if let Message::Assistant(ref mut m) = messages[last_idx] {
223                        *m = error.clone();
224                    }
225                } else {
226                    messages.push(Message::Assistant(error.clone()));
227                }
228
229                emit(super::AgentEvent::MessageEnd {
230                    message: Message::Assistant(error.clone()),
231                });
232                emit(super::AgentEvent::Error {
233                    message: format!("⚠ {}", friendly),
234                    session_id: loop_ref.session_id.clone(),
235                });
236
237                return Ok(error);
238            }
239
240            _ => {}
241        }
242    }
243
244    tracing::info!("Stream ended after {} events", event_count);
245
246    let final_message = messages
247        .last()
248        .and_then(|m| match m {
249            Message::Assistant(a) => Some(a.clone()),
250            _ => None,
251        })
252        .ok_or_else(|| Error::msg("No assistant message in context"))?;
253
254    emit(super::AgentEvent::MessageEnd {
255        message: Message::Assistant(final_message.clone()),
256    });
257    Ok(final_message)
258}