Skip to main content

rig_core/agent/run/
streamed.rs

1//! Streamed-turn assembly for [`AgentRun`](super::AgentRun).
2//!
3//! A streamed model turn arrives as incremental [`StreamedAssistantContent`]
4//! items. [`StreamedTurnAssembler`] is the sans-IO accumulator that turns that
5//! item stream into the same canonical complete turn the non-streaming path
6//! feeds the machine — while telling the driver what to forward to its
7//! consumer and surfacing invalid tool calls the moment they appear, so a
8//! driver can stop paying for a doomed provider stream early.
9//!
10//! The protocol, paired with the streamed entry points on
11//! [`AgentRun`](super::AgentRun):
12//!
13//! 1. On [`AgentRunStep::CallModel`](super::AgentRunStep::CallModel), open a
14//!    provider stream and create one assembler per turn with the tool names
15//!    advertised for that turn.
16//! 2. Feed every stream item to [`StreamedTurnAssembler::ingest`] and act on
17//!    the returned [`StreamedTurnEvent`]s: forward items to the consumer, and
18//!    on [`StreamedTurnEvent::InvalidToolCall`] consult
19//!    [`AgentRun::resolve_streamed_invalid_tool_call`](super::AgentRun::resolve_streamed_invalid_tool_call) —
20//!    [`StreamedResolution::Repaired`] continues the same stream via
21//!    [`StreamedTurnAssembler::resolve_pending_invalid`];
22//!    [`StreamedResolution::TurnAbandoned`] means drain the provider stream
23//!    for usage and re-enter
24//!    [`AgentRun::next_step`](super::AgentRun::next_step).
25//! 3. When the provider stream ends, call [`StreamedTurnAssembler::finish`]
26//!    and feed the result to
27//!    [`AgentRun::streamed_turn`](super::AgentRun::streamed_turn); the run
28//!    then proceeds exactly like a non-streamed one
29//!    ([`CallTools`](super::AgentRunStep::CallTools) /
30//!    [`Done`](super::AgentRunStep::Done)).
31//!
32//! [`crate::streaming::StreamingPrompt::stream_prompt`] drives this protocol
33//! internally; hand-driven runs can use it to stream any
34//! [`AgentRun`](super::AgentRun).
35
36use std::collections::{BTreeSet, HashMap};
37
38use serde::{Deserialize, Serialize};
39
40use crate::{
41    OneOrMany,
42    agent::prompt_request::{TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER, tool_result_user_content},
43    completion::{CompletionError, GetTokenUsage, Message, Usage},
44    json_utils,
45    message::{AssistantContent, Reasoning, ToolCall, ToolFunction, ToolResult},
46    streaming::{StreamedAssistantContent, ToolCallDeltaContent},
47};
48
49/// Merge an incoming reasoning block into the accumulated reasoning,
50/// extending an existing block when provider-assigned IDs match.
51pub(crate) fn merge_reasoning_blocks(
52    accumulated_reasoning: &mut Vec<Reasoning>,
53    incoming: &Reasoning,
54) {
55    let ids_match = |existing: &Reasoning| {
56        matches!(
57            (&existing.id, &incoming.id),
58            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
59        )
60    };
61
62    if let Some(existing) = accumulated_reasoning
63        .iter_mut()
64        .rev()
65        .find(|existing| ids_match(existing))
66    {
67        existing.content.extend(incoming.content.clone());
68    } else {
69        accumulated_reasoning.push(incoming.clone());
70    }
71}
72
73/// Assemble assistant content in canonical replay order: reasoning blocks,
74/// then text, then trailing items (tool calls, images).
75pub(crate) fn ordered_streaming_assistant_content(
76    reasoning_items: impl IntoIterator<Item = Reasoning>,
77    text_items: impl IntoIterator<Item = AssistantContent>,
78    trailing_items: impl IntoIterator<Item = AssistantContent>,
79) -> Option<OneOrMany<AssistantContent>> {
80    let mut content_items = reasoning_items
81        .into_iter()
82        .map(AssistantContent::Reasoning)
83        .collect::<Vec<_>>();
84    content_items.extend(text_items);
85    content_items.extend(trailing_items);
86
87    OneOrMany::from_iter_optional(content_items)
88}
89
90pub(crate) fn assistant_text_items_from_choice(
91    choice: &OneOrMany<AssistantContent>,
92) -> Vec<AssistantContent> {
93    choice
94        .iter()
95        .filter_map(|content| match content {
96            AssistantContent::Text(text) => (!text.text.is_empty()
97                || text.additional_params.is_some())
98            .then(|| AssistantContent::Text(text.clone())),
99            _ => None,
100        })
101        .collect()
102}
103
104/// One invalid tool call surfaced mid-stream, awaiting a resolution from
105/// [`AgentRun::resolve_streamed_invalid_tool_call`](super::AgentRun::resolve_streamed_invalid_tool_call).
106#[derive(Debug, Clone, Serialize, Deserialize)]
107#[non_exhaustive]
108pub struct StreamedInvalidToolCall {
109    /// The rejected tool call. For a name delta this is a diagnostic call
110    /// assembled from the streamed name and any buffered argument deltas.
111    pub tool_call: ToolCall,
112    /// Rig-generated identifier correlating this call's stream items.
113    pub internal_call_id: String,
114    /// Raw argument payload for diagnostics, when available.
115    pub args: Option<String>,
116    /// Executable Rig tools advertised to the provider for this turn.
117    pub executable_tool_names: BTreeSet<String>,
118    /// Tools allowed by the active tool choice for this turn.
119    pub allowed_tool_names: BTreeSet<String>,
120}
121
122/// Snapshot of a streamed turn at the moment an invalid tool call appeared.
123/// Used by the machine to build diagnostics and rollback messages from
124/// exactly what the model has produced so far.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126#[non_exhaustive]
127pub struct PartialStreamedTurn {
128    /// Provider-assigned assistant message ID, when already known.
129    pub message_id: Option<String>,
130    /// Aggregated assistant text, when any text was streamed this turn.
131    pub text: Option<String>,
132    /// Accumulated reasoning, with any pending unsigned delta text assembled
133    /// into a block.
134    pub reasoning: Vec<Reasoning>,
135    /// Tool calls already validated (or repaired) this turn.
136    pub pending_tool_calls: Vec<ToolCall>,
137}
138
139impl PartialStreamedTurn {
140    /// The assistant message representing this partial turn, in canonical
141    /// order, including `current_tool_call` when provided. `None` when the
142    /// turn has produced no representable content.
143    pub(crate) fn assistant_message(&self, current_tool_call: Option<ToolCall>) -> Option<Message> {
144        let text_items = match &self.text {
145            Some(text) if !text.is_empty() => vec![AssistantContent::text(text.clone())],
146            _ => Vec::new(),
147        };
148        let mut tool_items = self
149            .pending_tool_calls
150            .iter()
151            .cloned()
152            .map(AssistantContent::ToolCall)
153            .collect::<Vec<_>>();
154        if let Some(tool_call) = current_tool_call {
155            tool_items.push(AssistantContent::ToolCall(tool_call));
156        }
157
158        let content = ordered_streaming_assistant_content(
159            self.reasoning.iter().cloned(),
160            text_items,
161            tool_items,
162        )?;
163        Some(Message::Assistant {
164            id: self.message_id.clone(),
165            content,
166        })
167    }
168
169    /// Rollback messages for a retried or skipped streamed turn: the partial
170    /// assistant turn plus a user message carrying `feedback` for the invalid
171    /// call and a synthetic "not executed" result for each validated peer.
172    pub(crate) fn rollback_messages(
173        &self,
174        invalid_tool_call: ToolCall,
175        feedback: String,
176    ) -> Option<(Message, Message)> {
177        let assistant_message = self.assistant_message(Some(invalid_tool_call.clone()))?;
178
179        let mut retry_results = self
180            .pending_tool_calls
181            .iter()
182            .map(|tool_call| {
183                tool_result_user_content(
184                    tool_call.id.clone(),
185                    tool_call.call_id.clone(),
186                    TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
187                )
188            })
189            .collect::<Vec<_>>();
190        retry_results.push(tool_result_user_content(
191            invalid_tool_call.id,
192            invalid_tool_call.call_id,
193            feedback,
194        ));
195
196        let user_message = Message::User {
197            content: OneOrMany::from_iter_optional(retry_results)?,
198        };
199
200        Some((assistant_message, user_message))
201    }
202}
203
204/// The assembled streamed turn, fed to
205/// [`AgentRun::streamed_turn`](super::AgentRun::streamed_turn).
206#[derive(Debug, Clone, Serialize, Deserialize)]
207#[non_exhaustive]
208pub struct StreamedTurn {
209    /// Provider-assigned assistant message ID, when available.
210    pub message_id: Option<String>,
211    /// The assistant content to record in history: canonical
212    /// (reasoning → text → tool calls) when the turn produced reasoning or
213    /// tool calls, otherwise the provider's aggregated choice as-is.
214    pub choice: OneOrMany<AssistantContent>,
215    /// Executable Rig tools advertised to the provider for this turn.
216    pub executable_tool_names: BTreeSet<String>,
217    /// Tools allowed by the active tool choice for this turn.
218    pub allowed_tool_names: BTreeSet<String>,
219    /// `(tool_call_id, internal_call_id)` pairs for this turn's tool calls,
220    /// in emission order. Carried into the run state so a resumed process
221    /// keeps the IDs consumers already saw in tool-call deltas.
222    #[serde(default)]
223    pub internal_call_ids: Vec<(String, String)>,
224}
225
226/// What the machine decided about a mid-stream invalid tool call.
227///
228/// Deliberately exhaustive: a driver must handle every resolution, so adding
229/// a variant is a breaking change by design.
230#[derive(Debug)]
231pub enum StreamedResolution {
232    /// The tool name was repaired. Apply it via
233    /// [`StreamedTurnAssembler::resolve_pending_invalid`] and keep consuming
234    /// the provider stream.
235    Repaired {
236        /// The validated replacement tool name.
237        tool_name: String,
238    },
239    /// The turn was rolled back (retry) or the call skipped; corrective
240    /// messages are already in the history. Drain the provider stream for
241    /// usage, record the completion call, then call
242    /// [`AgentRun::next_step`](super::AgentRun::next_step).
243    TurnAbandoned {
244        /// For a skipped call, the synthetic tool result to surface to the
245        /// consumer stream.
246        skipped_tool_result: Option<ToolResult>,
247    },
248}
249
250/// What a driver must do with one ingested stream item.
251///
252/// Deliberately exhaustive: a driver must handle every event, so adding a
253/// variant is a breaking change by design.
254#[derive(Debug, Clone)]
255pub enum StreamedTurnEvent {
256    /// Forward the ingested item to the consumer as-is (text, reasoning, or
257    /// reasoning deltas, after accumulation).
258    EmitIngested,
259    /// Forward this tool-call delta. Argument deltas buffered while the tool
260    /// name awaited validation are replayed through this event.
261    EmitToolCallDelta {
262        /// Provider-supplied tool call ID.
263        id: String,
264        /// Rig-generated identifier correlating this call's stream items.
265        internal_call_id: String,
266        /// The (possibly repaired) name or argument delta.
267        content: ToolCallDeltaContent,
268    },
269    /// The model emitted an unknown or disallowed tool call. Resolve it via
270    /// [`AgentRun::resolve_streamed_invalid_tool_call`](super::AgentRun::resolve_streamed_invalid_tool_call),
271    /// then apply the outcome with
272    /// [`StreamedTurnAssembler::resolve_pending_invalid`].
273    InvalidToolCall(Box<StreamedInvalidToolCall>),
274    /// The provider reported the end of this completion call. Record it (see
275    /// [`AgentRun::record_streamed_completion_call`](super::AgentRun::record_streamed_completion_call));
276    /// when `emit_final` is set, the turn streamed text and the driver should
277    /// run its stream-finish hook and forward the final item.
278    Completed {
279        /// Provider-reported usage for this call. Zero-valued usage means the
280        /// provider reported no usage metrics.
281        usage: Usage,
282        /// Whether the ingested final item should be forwarded to the
283        /// consumer (set when the turn streamed text).
284        emit_final: bool,
285    },
286}
287
288#[derive(Default)]
289struct ToolCallDeltaState {
290    name_validated: bool,
291    buffered_arguments: Vec<String>,
292}
293
294enum PendingInvalid {
295    /// A complete tool call with a disallowed name.
296    FullCall {
297        tool_call: Box<ToolCall>,
298        internal_call_id: String,
299    },
300    /// A streamed tool-name delta with a disallowed name.
301    NameDelta {
302        id: String,
303        internal_call_id: String,
304    },
305}
306
307/// Sans-IO accumulator that assembles one streamed model turn. See the
308/// [module docs](self) for the driving protocol.
309pub struct StreamedTurnAssembler {
310    executable_tool_names: BTreeSet<String>,
311    allowed_tool_names: BTreeSet<String>,
312    text: String,
313    saw_text: bool,
314    accumulated_reasoning: Vec<Reasoning>,
315    pending_reasoning_delta_text: String,
316    pending_reasoning_delta_id: Option<String>,
317    pending_tool_calls: Vec<(ToolCall, String)>,
318    delta_states: HashMap<(String, String), ToolCallDeltaState>,
319    pending_invalid: Option<PendingInvalid>,
320}
321
322impl StreamedTurnAssembler {
323    /// Create an assembler for one streamed turn with the tool names
324    /// advertised to the provider for that turn.
325    pub fn new(
326        executable_tool_names: BTreeSet<String>,
327        allowed_tool_names: BTreeSet<String>,
328    ) -> Self {
329        Self {
330            executable_tool_names,
331            allowed_tool_names,
332            text: String::new(),
333            saw_text: false,
334            accumulated_reasoning: Vec::new(),
335            pending_reasoning_delta_text: String::new(),
336            pending_reasoning_delta_id: None,
337            pending_tool_calls: Vec::new(),
338            delta_states: HashMap::new(),
339            pending_invalid: None,
340        }
341    }
342
343    /// Aggregated assistant text streamed so far this turn (empty until the
344    /// first text delta).
345    pub fn aggregated_text(&self) -> &str {
346        &self.text
347    }
348
349    /// Ingest one provider stream item and return what the driver must do.
350    ///
351    /// # Errors
352    /// Returns an error when the provider stream is inconsistent (argument
353    /// deltas finishing without a validated tool name) or when an invalid
354    /// tool call is still awaiting resolution.
355    pub fn ingest<R>(
356        &mut self,
357        item: &StreamedAssistantContent<R>,
358    ) -> Result<Vec<StreamedTurnEvent>, CompletionError>
359    where
360        R: Clone + Unpin + GetTokenUsage,
361    {
362        if self.pending_invalid.is_some() {
363            return Err(CompletionError::ResponseError(
364                "streamed turn ingested while an invalid tool call awaits resolution".to_string(),
365            ));
366        }
367
368        match item {
369            StreamedAssistantContent::Text(text) => {
370                if !self.saw_text {
371                    self.text.clear();
372                    self.saw_text = true;
373                }
374                self.text.push_str(&text.text);
375                Ok(vec![StreamedTurnEvent::EmitIngested])
376            }
377            StreamedAssistantContent::Reasoning(reasoning) => {
378                merge_reasoning_blocks(&mut self.accumulated_reasoning, reasoning);
379                Ok(vec![StreamedTurnEvent::EmitIngested])
380            }
381            StreamedAssistantContent::ReasoningDelta { reasoning, id } => {
382                // Deltas lack signatures/encrypted content that full blocks
383                // carry; mixing them into accumulated reasoning causes
384                // providers like Anthropic to reject with "signature required",
385                // so they are kept aside until the turn ends.
386                self.pending_reasoning_delta_text.push_str(reasoning);
387                if self.pending_reasoning_delta_id.is_none() {
388                    self.pending_reasoning_delta_id = id.clone();
389                }
390                Ok(vec![StreamedTurnEvent::EmitIngested])
391            }
392            StreamedAssistantContent::ToolCall {
393                tool_call,
394                internal_call_id,
395            } => {
396                if !self.allowed_tool_names.contains(&tool_call.function.name) {
397                    let invalid = StreamedInvalidToolCall {
398                        tool_call: tool_call.clone(),
399                        internal_call_id: internal_call_id.clone(),
400                        args: Some(json_utils::value_to_json_string(
401                            &tool_call.function.arguments,
402                        )),
403                        executable_tool_names: self.executable_tool_names.clone(),
404                        allowed_tool_names: self.allowed_tool_names.clone(),
405                    };
406                    self.pending_invalid = Some(PendingInvalid::FullCall {
407                        tool_call: Box::new(tool_call.clone()),
408                        internal_call_id: internal_call_id.clone(),
409                    });
410                    return Ok(vec![StreamedTurnEvent::InvalidToolCall(Box::new(invalid))]);
411                }
412
413                self.pending_tool_calls
414                    .push((tool_call.clone(), internal_call_id.clone()));
415                Ok(Vec::new())
416            }
417            StreamedAssistantContent::ToolCallDelta {
418                id,
419                internal_call_id,
420                content,
421            } => {
422                let key = (id.clone(), internal_call_id.clone());
423                match content {
424                    ToolCallDeltaContent::Name(name) => {
425                        if !self.allowed_tool_names.contains(name) {
426                            let buffered_args = self
427                                .delta_states
428                                .get(&key)
429                                .map(|state| state.buffered_arguments.join(""))
430                                .unwrap_or_default();
431                            let invalid = StreamedInvalidToolCall {
432                                tool_call: self.name_delta_diagnostic_tool_call(
433                                    id,
434                                    name,
435                                    &buffered_args,
436                                ),
437                                internal_call_id: internal_call_id.clone(),
438                                args: Some(buffered_args),
439                                executable_tool_names: self.executable_tool_names.clone(),
440                                allowed_tool_names: self.allowed_tool_names.clone(),
441                            };
442                            self.pending_invalid = Some(PendingInvalid::NameDelta {
443                                id: id.clone(),
444                                internal_call_id: internal_call_id.clone(),
445                            });
446                            return Ok(vec![StreamedTurnEvent::InvalidToolCall(Box::new(invalid))]);
447                        }
448
449                        Ok(self.validate_delta_name(&key, name.clone()))
450                    }
451                    ToolCallDeltaContent::Delta(arguments) => {
452                        let state = self.delta_states.entry(key.clone()).or_default();
453                        if state.name_validated {
454                            Ok(vec![StreamedTurnEvent::EmitToolCallDelta {
455                                id: id.clone(),
456                                internal_call_id: internal_call_id.clone(),
457                                content: ToolCallDeltaContent::Delta(arguments.clone()),
458                            }])
459                        } else {
460                            state.buffered_arguments.push(arguments.clone());
461                            Ok(Vec::new())
462                        }
463                    }
464                }
465            }
466            StreamedAssistantContent::Final(final_response) => {
467                if let Some(err) = self.pending_delta_error() {
468                    return Err(err);
469                }
470
471                let usage = final_response.token_usage();
472                let emit_final = self.saw_text;
473                self.saw_text = false;
474                Ok(vec![StreamedTurnEvent::Completed { usage, emit_final }])
475            }
476        }
477    }
478
479    /// Apply the machine's resolution for the invalid tool call surfaced by
480    /// the last [`StreamedTurnEvent::InvalidToolCall`]. For a repaired name
481    /// this returns the deltas to forward (the repaired name plus any
482    /// buffered argument deltas).
483    pub fn resolve_pending_invalid(
484        &mut self,
485        resolution: &StreamedResolution,
486    ) -> Vec<StreamedTurnEvent> {
487        let Some(pending) = self.pending_invalid.take() else {
488            return Vec::new();
489        };
490
491        match (resolution, pending) {
492            (
493                StreamedResolution::Repaired { tool_name },
494                PendingInvalid::FullCall {
495                    mut tool_call,
496                    internal_call_id,
497                },
498            ) => {
499                tool_call.function.name = tool_name.clone();
500                self.pending_tool_calls.push((*tool_call, internal_call_id));
501                Vec::new()
502            }
503            (
504                StreamedResolution::Repaired { tool_name },
505                PendingInvalid::NameDelta {
506                    id,
507                    internal_call_id,
508                },
509            ) => {
510                let key = (id, internal_call_id);
511                self.validate_delta_name(&key, tool_name.clone())
512            }
513            (
514                StreamedResolution::TurnAbandoned { .. },
515                PendingInvalid::NameDelta {
516                    id,
517                    internal_call_id,
518                },
519            ) => {
520                // The abandoned call's buffered state must not trip the
521                // pending-delta consistency check while usage is drained.
522                self.delta_states.remove(&(id, internal_call_id));
523                Vec::new()
524            }
525            (StreamedResolution::TurnAbandoned { .. }, PendingInvalid::FullCall { .. }) => {
526                Vec::new()
527            }
528        }
529    }
530
531    /// Error when argument deltas were buffered for a tool call whose name
532    /// never validated — a provider-stream consistency violation.
533    pub fn pending_delta_error(&self) -> Option<CompletionError> {
534        self.delta_states
535            .iter()
536            .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
537            .map(|((id, internal_call_id), state)| {
538                CompletionError::ResponseError(format!(
539                    "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
540                    state.buffered_arguments.len()
541                ))
542            })
543    }
544
545    /// Snapshot of the turn so far, for diagnostics and rollback messages.
546    pub fn partial_turn(&self, message_id: Option<String>) -> PartialStreamedTurn {
547        let mut reasoning = self.accumulated_reasoning.clone();
548        if reasoning.is_empty() && !self.pending_reasoning_delta_text.is_empty() {
549            let mut assembled = Reasoning::new(&self.pending_reasoning_delta_text);
550            if let Some(id) = self.pending_reasoning_delta_id.clone() {
551                assembled = assembled.with_id(id);
552            }
553            reasoning.push(assembled);
554        }
555
556        PartialStreamedTurn {
557            message_id,
558            text: self.saw_text.then(|| self.text.clone()),
559            reasoning,
560            pending_tool_calls: self
561                .pending_tool_calls
562                .iter()
563                .map(|(tool_call, _)| tool_call.clone())
564                .collect(),
565        }
566    }
567
568    /// Assemble the completed turn. `final_choice` is the provider's
569    /// aggregated choice for the turn
570    /// ([`crate::streaming::StreamingCompletionResponse::choice`]).
571    pub fn finish(
572        mut self,
573        message_id: Option<String>,
574        final_choice: &OneOrMany<AssistantContent>,
575    ) -> StreamedTurn {
576        let internal_call_ids: Vec<(String, String)> = self
577            .pending_tool_calls
578            .iter()
579            .map(|(tool_call, internal_call_id)| (tool_call.id.clone(), internal_call_id.clone()))
580            .collect();
581        // Providers like Gemini emit thinking as incremental deltas without
582        // signatures; assemble them into a single block so reasoning survives
583        // into the next turn's chat history.
584        if self.accumulated_reasoning.is_empty() && !self.pending_reasoning_delta_text.is_empty() {
585            let mut assembled = Reasoning::new(&self.pending_reasoning_delta_text);
586            if let Some(id) = self.pending_reasoning_delta_id.take() {
587                assembled = assembled.with_id(id);
588            }
589            self.accumulated_reasoning.push(assembled);
590        }
591
592        // Canonical replay order when the turn produced reasoning or tool
593        // calls; otherwise the provider's aggregated choice is recorded as-is.
594        let choice =
595            if !self.pending_tool_calls.is_empty() || !self.accumulated_reasoning.is_empty() {
596                let text_items = assistant_text_items_from_choice(final_choice);
597                let tool_items = self
598                    .pending_tool_calls
599                    .iter()
600                    .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone()))
601                    .collect::<Vec<_>>();
602                ordered_streaming_assistant_content(
603                    self.accumulated_reasoning.drain(..),
604                    text_items,
605                    tool_items,
606                )
607                .unwrap_or_else(|| final_choice.clone())
608            } else {
609                final_choice.clone()
610            };
611
612        StreamedTurn {
613            message_id,
614            choice,
615            executable_tool_names: self.executable_tool_names,
616            allowed_tool_names: self.allowed_tool_names,
617            internal_call_ids,
618        }
619    }
620
621    fn name_delta_diagnostic_tool_call(
622        &self,
623        id: &str,
624        name: &str,
625        buffered_args: &str,
626    ) -> ToolCall {
627        let diagnostic_args = if buffered_args.trim().is_empty() {
628            serde_json::Value::Null
629        } else {
630            serde_json::from_str(buffered_args).unwrap_or(serde_json::Value::Null)
631        };
632        ToolCall::new(
633            id.to_string(),
634            ToolFunction::new(name.to_string(), diagnostic_args),
635        )
636    }
637
638    fn validate_delta_name(
639        &mut self,
640        key: &(String, String),
641        name: String,
642    ) -> Vec<StreamedTurnEvent> {
643        let state = self.delta_states.entry(key.clone()).or_default();
644        state.name_validated = true;
645        let buffered_arguments = std::mem::take(&mut state.buffered_arguments);
646
647        let mut events = vec![StreamedTurnEvent::EmitToolCallDelta {
648            id: key.0.clone(),
649            internal_call_id: key.1.clone(),
650            content: ToolCallDeltaContent::Name(name),
651        }];
652        events.extend(buffered_arguments.into_iter().map(|arguments| {
653            StreamedTurnEvent::EmitToolCallDelta {
654                id: key.0.clone(),
655                internal_call_id: key.1.clone(),
656                content: ToolCallDeltaContent::Delta(arguments),
657            }
658        }));
659        events
660    }
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666    use crate::agent::prompt_request::hooks::InvalidToolCallHookAction;
667    use crate::agent::run::{AgentRun, AgentRunStep};
668    use crate::completion::PromptError;
669    use crate::message::{Text, ToolResultContent, UserContent};
670    use crate::test_utils::MockResponse;
671    use serde_json::json;
672
673    fn tool_names(names: &[&str]) -> BTreeSet<String> {
674        names.iter().map(|name| (*name).to_string()).collect()
675    }
676
677    fn assembler() -> StreamedTurnAssembler {
678        StreamedTurnAssembler::new(tool_names(&["add"]), tool_names(&["add"]))
679    }
680
681    fn text_item(text: &str) -> StreamedAssistantContent<MockResponse> {
682        StreamedAssistantContent::Text(Text::new(text.to_string()))
683    }
684
685    fn tool_call(id: &str, name: &str) -> ToolCall {
686        ToolCall::new(
687            id.to_string(),
688            ToolFunction::new(name.to_string(), json!({"x": 1})),
689        )
690    }
691
692    fn tool_call_item(id: &str, name: &str) -> StreamedAssistantContent<MockResponse> {
693        StreamedAssistantContent::ToolCall {
694            tool_call: tool_call(id, name),
695            internal_call_id: format!("internal_{id}"),
696        }
697    }
698
699    fn final_item() -> StreamedAssistantContent<MockResponse> {
700        StreamedAssistantContent::Final(MockResponse::with_usage(Usage::new()))
701    }
702
703    fn name_delta(id: &str, name: &str) -> StreamedAssistantContent<MockResponse> {
704        StreamedAssistantContent::ToolCallDelta {
705            id: id.to_string(),
706            internal_call_id: format!("internal_{id}"),
707            content: ToolCallDeltaContent::Name(name.to_string()),
708        }
709    }
710
711    fn args_delta(id: &str, arguments: &str) -> StreamedAssistantContent<MockResponse> {
712        StreamedAssistantContent::ToolCallDelta {
713            id: id.to_string(),
714            internal_call_id: format!("internal_{id}"),
715            content: ToolCallDeltaContent::Delta(arguments.to_string()),
716        }
717    }
718
719    fn expect_invalid(events: Vec<StreamedTurnEvent>) -> StreamedInvalidToolCall {
720        match events.into_iter().next() {
721            Some(StreamedTurnEvent::InvalidToolCall(invalid)) => *invalid,
722            other => panic!("expected InvalidToolCall, got {other:?}"),
723        }
724    }
725
726    #[test]
727    fn text_accumulates_and_emits() {
728        let mut asm = assembler();
729        let events = asm
730            .ingest(&text_item("hel"))
731            .expect("ingest should succeed");
732        assert!(matches!(
733            events.as_slice(),
734            [StreamedTurnEvent::EmitIngested]
735        ));
736        asm.ingest(&text_item("lo")).expect("ingest should succeed");
737        assert_eq!(asm.aggregated_text(), "hello");
738    }
739
740    #[test]
741    fn argument_deltas_buffer_until_name_validates() {
742        let mut asm = assembler();
743
744        let events = asm
745            .ingest(&args_delta("tc_1", "{\"x\""))
746            .expect("ingest should succeed");
747        assert!(events.is_empty(), "arguments must buffer before the name");
748
749        let events = asm
750            .ingest(&name_delta("tc_1", "add"))
751            .expect("ingest should succeed");
752        let contents: Vec<_> = events
753            .iter()
754            .map(|event| match event {
755                StreamedTurnEvent::EmitToolCallDelta { content, .. } => content.clone(),
756                other => panic!("expected EmitToolCallDelta, got {other:?}"),
757            })
758            .collect();
759        assert_eq!(
760            contents,
761            vec![
762                ToolCallDeltaContent::Name("add".to_string()),
763                ToolCallDeltaContent::Delta("{\"x\"".to_string()),
764            ]
765        );
766
767        // Subsequent argument deltas now pass straight through.
768        let events = asm
769            .ingest(&args_delta("tc_1", ":1}"))
770            .expect("ingest should succeed");
771        assert_eq!(events.len(), 1);
772    }
773
774    #[test]
775    fn buffered_arguments_without_validated_name_error_at_final() {
776        let mut asm = assembler();
777        asm.ingest(&args_delta("tc_1", "{\"x\":1}"))
778            .expect("ingest should succeed");
779
780        assert!(asm.pending_delta_error().is_some());
781        assert!(asm.ingest(&final_item()).is_err());
782    }
783
784    #[test]
785    fn finish_orders_reasoning_text_then_tool_calls() {
786        let mut asm = assembler();
787        asm.ingest(&StreamedAssistantContent::<MockResponse>::ReasoningDelta {
788            id: Some("rs_1".to_string()),
789            reasoning: "think".to_string(),
790        })
791        .expect("ingest should succeed");
792        asm.ingest(&tool_call_item("tc_1", "add"))
793            .expect("ingest should succeed");
794
795        // Provider aggregation order differs deliberately.
796        let final_choice = OneOrMany::many(vec![
797            AssistantContent::text("answer"),
798            AssistantContent::ToolCall(tool_call("tc_1", "add")),
799        ])
800        .expect("two items");
801
802        let turn = asm.finish(Some("msg_1".to_string()), &final_choice);
803        let kinds: Vec<&'static str> = turn
804            .choice
805            .iter()
806            .map(|item| match item {
807                AssistantContent::Reasoning(_) => "reasoning",
808                AssistantContent::Text(_) => "text",
809                AssistantContent::ToolCall(_) => "tool_call",
810                _ => "other",
811            })
812            .collect();
813        assert_eq!(kinds, vec!["reasoning", "text", "tool_call"]);
814    }
815
816    #[test]
817    fn finish_passes_raw_choice_through_for_plain_text_turns() {
818        let mut asm = assembler();
819        asm.ingest(&text_item("hi")).expect("ingest should succeed");
820
821        let final_choice = OneOrMany::one(AssistantContent::text("hi"));
822        let turn = asm.finish(None, &final_choice);
823        assert_eq!(
824            serde_json::to_value(&turn.choice).expect("serialize"),
825            serde_json::to_value(&final_choice).expect("serialize"),
826        );
827    }
828
829    #[test]
830    fn streamed_run_completes_a_tool_roundtrip() {
831        let mut run = AgentRun::new("add things").max_turns(2);
832
833        // Turn 1: the model streams one tool call.
834        let AgentRunStep::CallModel { .. } = run.next_step().expect("next_step") else {
835            panic!("expected CallModel");
836        };
837        let mut asm = assembler();
838        assert!(
839            asm.ingest(&tool_call_item("tc_1", "add"))
840                .expect("ingest should succeed")
841                .is_empty()
842        );
843        let usage = Usage {
844            input_tokens: 5,
845            output_tokens: 7,
846            total_tokens: 12,
847            ..Usage::new()
848        };
849        run.record_streamed_completion_call(usage)
850            .expect("record should succeed");
851        let final_choice = OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "add")));
852        run.streamed_turn(asm.finish(Some("msg_1".to_string()), &final_choice))
853            .expect("streamed_turn should succeed");
854
855        let AgentRunStep::CallTools { calls } = run.next_step().expect("next_step") else {
856            panic!("expected CallTools");
857        };
858        assert_eq!(calls.len(), 1);
859        assert_eq!(calls[0].internal_call_id.as_deref(), Some("internal_tc_1"));
860        run.tool_results(vec![UserContent::tool_result(
861            "tc_1".to_string(),
862            ToolResultContent::from_tool_output("2".to_string()),
863        )])
864        .expect("tool_results should succeed");
865
866        // Turn 2: plain text finishes the run.
867        let AgentRunStep::CallModel { .. } = run.next_step().expect("next_step") else {
868            panic!("expected CallModel");
869        };
870        let asm = assembler();
871        run.record_streamed_completion_call(Usage::new())
872            .expect("record should succeed");
873        let final_choice = OneOrMany::one(AssistantContent::text("done"));
874        run.streamed_turn(asm.finish(None, &final_choice))
875            .expect("streamed_turn should succeed");
876
877        let AgentRunStep::Done(response) = run.next_step().expect("next_step") else {
878            panic!("expected Done");
879        };
880        assert_eq!(response.output, "done");
881        assert_eq!(response.usage, usage);
882        assert_eq!(response.completion_calls.len(), 2);
883        assert_eq!(response.completion_calls[0].usage, usage);
884        assert_eq!(response.completion_calls[1].usage, Usage::new());
885        // prompt, assistant tool call, tool result, final assistant text
886        assert_eq!(
887            response
888                .messages
889                .expect("messages should be recorded")
890                .len(),
891            4
892        );
893    }
894
895    #[test]
896    fn streamed_invalid_tool_call_retry_rolls_back_with_partial_turn() {
897        let mut run = AgentRun::new("use the tool")
898            .max_turns(2)
899            .max_invalid_tool_call_retries(1);
900        run.next_step().expect("next_step");
901
902        let mut asm = assembler();
903        asm.ingest(&text_item("thinking ")).expect("ingest");
904        let invalid = expect_invalid(
905            asm.ingest(&tool_call_item("tc_1", "default_api"))
906                .expect("ingest should succeed"),
907        );
908        let partial = asm.partial_turn(Some("msg_1".to_string()));
909        assert_eq!(partial.text.as_deref(), Some("thinking "));
910
911        let context = run.streamed_invalid_tool_call_context(&partial, &invalid);
912        assert!(context.is_streaming);
913        assert_eq!(context.tool_name, "default_api");
914        assert_eq!(context.internal_call_id.as_deref(), Some("internal_tc_1"));
915
916        let resolution = run
917            .resolve_streamed_invalid_tool_call(
918                &partial,
919                &invalid,
920                InvalidToolCallHookAction::retry("use add instead"),
921            )
922            .expect("retry should be accepted");
923        assert!(matches!(
924            resolution,
925            StreamedResolution::TurnAbandoned {
926                skipped_tool_result: None
927            }
928        ));
929        asm.resolve_pending_invalid(&resolution);
930
931        // Usage from the drained stream is recorded after the rollback.
932        run.record_streamed_completion_call(Usage::new())
933            .expect("record after rollback should succeed");
934
935        // The rollback appended the partial assistant turn and feedback.
936        assert_eq!(run.messages().len(), 3);
937        let AgentRunStep::CallModel { turn, .. } = run.next_step().expect("next_step") else {
938            panic!("expected CallModel retry");
939        };
940        assert_eq!(turn, 2);
941    }
942
943    #[test]
944    fn streamed_invalid_tool_call_skip_returns_synthetic_result() {
945        let mut run = AgentRun::new("use the tool").max_turns(2);
946        run.next_step().expect("next_step");
947
948        let mut asm = assembler();
949        let invalid = expect_invalid(
950            asm.ingest(&tool_call_item("tc_1", "default_api"))
951                .expect("ingest should succeed"),
952        );
953        let partial = asm.partial_turn(None);
954
955        let resolution = run
956            .resolve_streamed_invalid_tool_call(
957                &partial,
958                &invalid,
959                InvalidToolCallHookAction::skip("not available"),
960            )
961            .expect("skip should be accepted");
962        let StreamedResolution::TurnAbandoned {
963            skipped_tool_result: Some(tool_result),
964        } = &resolution
965        else {
966            panic!("expected skipped tool result");
967        };
968        assert_eq!(tool_result.id, "tc_1");
969    }
970
971    #[test]
972    fn streamed_invalid_name_delta_repair_replays_buffered_arguments() {
973        let mut run = AgentRun::new("use the tool").max_turns(2);
974        run.next_step().expect("next_step");
975
976        let mut asm = assembler();
977        asm.ingest(&args_delta("tc_1", "{\"x\":1}"))
978            .expect("ingest should succeed");
979        let invalid = expect_invalid(
980            asm.ingest(&name_delta("tc_1", "default_api"))
981                .expect("ingest should succeed"),
982        );
983        assert_eq!(invalid.args.as_deref(), Some("{\"x\":1}"));
984
985        let partial = asm.partial_turn(None);
986        let resolution = run
987            .resolve_streamed_invalid_tool_call(
988                &partial,
989                &invalid,
990                InvalidToolCallHookAction::repair("add"),
991            )
992            .expect("repair should be accepted");
993        assert!(matches!(
994            resolution,
995            StreamedResolution::Repaired { ref tool_name } if tool_name == "add"
996        ));
997
998        let events = asm.resolve_pending_invalid(&resolution);
999        let contents: Vec<_> = events
1000            .iter()
1001            .map(|event| match event {
1002                StreamedTurnEvent::EmitToolCallDelta { content, .. } => content.clone(),
1003                other => panic!("expected EmitToolCallDelta, got {other:?}"),
1004            })
1005            .collect();
1006        assert_eq!(
1007            contents,
1008            vec![
1009                ToolCallDeltaContent::Name("add".to_string()),
1010                ToolCallDeltaContent::Delta("{\"x\":1}".to_string()),
1011            ]
1012        );
1013    }
1014
1015    #[test]
1016    fn streamed_turn_rejects_unknown_tool_calls_fail_fast() {
1017        let mut run = AgentRun::new("use the tool");
1018        run.next_step().expect("next_step");
1019
1020        let turn = StreamedTurn {
1021            message_id: None,
1022            choice: OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "unknown"))),
1023            executable_tool_names: tool_names(&["add"]),
1024            allowed_tool_names: tool_names(&["add"]),
1025            internal_call_ids: Vec::new(),
1026        };
1027        let err = run
1028            .streamed_turn(turn)
1029            .expect_err("unknown tool should fail fast");
1030        assert!(matches!(
1031            err,
1032            PromptError::UnknownToolCall { tool_name, .. } if tool_name == "unknown"
1033        ));
1034    }
1035
1036    #[test]
1037    fn streamed_completion_call_record_requires_a_model_call() {
1038        // A fresh run has emitted no CallModel: recording must be rejected
1039        // even though the machine is in its initial PreparingRequest state.
1040        let mut run = AgentRun::new("hello");
1041        let err = run
1042            .record_streamed_completion_call(Usage::new())
1043            .expect_err("recording before any model call must be rejected");
1044        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1045
1046        // The run stays drivable.
1047        run.next_step().expect("next_step should still succeed");
1048        run.record_streamed_completion_call(Usage::new())
1049            .expect("recording during a pending model call succeeds");
1050    }
1051
1052    #[test]
1053    fn duplicate_tool_call_ids_keep_distinct_internal_ids_through_the_run() {
1054        let mut run = AgentRun::new("do both").max_turns(2);
1055        run.next_step().expect("next_step");
1056
1057        let mut asm = assembler();
1058        asm.ingest(&StreamedAssistantContent::<MockResponse>::ToolCall {
1059            tool_call: tool_call("tc_1", "add"),
1060            internal_call_id: "internal_a".to_string(),
1061        })
1062        .expect("ingest should succeed");
1063        asm.ingest(&StreamedAssistantContent::<MockResponse>::ToolCall {
1064            tool_call: tool_call("tc_1", "add"),
1065            internal_call_id: "internal_b".to_string(),
1066        })
1067        .expect("ingest should succeed");
1068        run.record_streamed_completion_call(Usage::new())
1069            .expect("record should succeed");
1070
1071        let final_choice = OneOrMany::many(vec![
1072            AssistantContent::ToolCall(tool_call("tc_1", "add")),
1073            AssistantContent::ToolCall(tool_call("tc_1", "add")),
1074        ])
1075        .expect("two items");
1076        run.streamed_turn(asm.finish(None, &final_choice))
1077            .expect("streamed_turn should succeed");
1078
1079        // The internal IDs survive in the run state itself: a serde round
1080        // trip must keep both calls distinguishable.
1081        let serialized = serde_json::to_string(&run).expect("serialize");
1082        let mut restored: AgentRun = serde_json::from_str(&serialized).expect("deserialize");
1083        let AgentRunStep::CallTools { calls } = restored.next_step().expect("next_step") else {
1084            panic!("expected CallTools");
1085        };
1086        assert_eq!(calls.len(), 2);
1087        assert_eq!(calls[0].internal_call_id.as_deref(), Some("internal_a"));
1088        assert_eq!(calls[1].internal_call_id.as_deref(), Some("internal_b"));
1089    }
1090
1091    #[test]
1092    fn streamed_turn_records_the_completion_call_when_the_driver_did_not() {
1093        let mut run = AgentRun::new("hello");
1094        run.next_step().expect("next_step");
1095
1096        let asm = assembler();
1097        let final_choice = OneOrMany::one(AssistantContent::text("done"));
1098        run.streamed_turn(asm.finish(None, &final_choice))
1099            .expect("streamed_turn should succeed");
1100
1101        // Exactly one CompletionCall per model call, even without an explicit
1102        // record; usage is simply unreported.
1103        assert_eq!(run.completion_calls().len(), 1);
1104        assert_eq!(run.completion_calls()[0].usage, Usage::new());
1105    }
1106
1107    #[test]
1108    fn streamed_completion_call_is_recorded_once_per_turn() {
1109        let mut run = AgentRun::new("hello");
1110        run.next_step().expect("next_step");
1111
1112        run.record_streamed_completion_call(Usage::new())
1113            .expect("first record succeeds");
1114        let err = run
1115            .record_streamed_completion_call(Usage::new())
1116            .expect_err("second record for the same turn must be rejected");
1117        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1118        assert_eq!(run.completion_calls().len(), 1);
1119    }
1120
1121    #[test]
1122    fn streamed_run_serde_round_trips_while_tools_pend() {
1123        let mut run = AgentRun::new("add things").max_turns(2);
1124        run.next_step().expect("next_step");
1125
1126        let mut asm = assembler();
1127        asm.ingest(&tool_call_item("tc_1", "add"))
1128            .expect("ingest should succeed");
1129        run.record_streamed_completion_call(Usage::new())
1130            .expect("record should succeed");
1131        let final_choice = OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "add")));
1132        run.streamed_turn(asm.finish(None, &final_choice))
1133            .expect("streamed_turn should succeed");
1134        run.next_step().expect("CallTools step");
1135
1136        let serialized = serde_json::to_string(&run).expect("serialize mid-run");
1137        let mut restored: AgentRun =
1138            serde_json::from_str(&serialized).expect("deserialize mid-run");
1139        restored
1140            .tool_results(vec![UserContent::tool_result(
1141                "tc_1".to_string(),
1142                ToolResultContent::from_tool_output("2".to_string()),
1143            )])
1144            .expect("tool_results should succeed");
1145        assert!(matches!(
1146            restored.next_step().expect("next turn"),
1147            AgentRunStep::CallModel { turn: 2, .. }
1148        ));
1149    }
1150}