llm_stack/tool/
loop_stream.rs1use std::collections::VecDeque;
14use std::sync::Arc;
15
16use futures::StreamExt;
17
18use crate::chat::{ChatResponse, ContentBlock, StopReason, ToolCall};
19use crate::error::LlmError;
20use crate::provider::{ChatParams, DynProvider};
21use crate::stream::{ChatStream, StreamEvent};
22use crate::usage::Usage;
23
24use super::LoopDepth;
25use super::ToolRegistry;
26use super::config::{LoopEvent, LoopStream, ToolLoopConfig};
27use super::loop_core::{IterationOutcome, LoopCore, StartOutcome};
28
29#[allow(clippy::needless_pass_by_value)] pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
47 provider: Arc<dyn DynProvider>,
48 registry: Arc<ToolRegistry<Ctx>>,
49 params: ChatParams,
50 config: ToolLoopConfig,
51 ctx: Arc<Ctx>,
52) -> LoopStream {
53 let core = LoopCore::new(params, config, &*ctx);
54
55 let state = UnfoldState {
56 core,
57 provider,
58 registry,
59 phase: StreamPhase::StartIteration,
60 current_text: String::new(),
61 current_tool_calls: Vec::new(),
62 current_usage: Usage::default(),
63 pending_events: VecDeque::new(),
64 };
65
66 let stream = futures::stream::unfold(state, |mut state| async move {
67 loop {
68 if let Some(event) = state.pending_events.pop_front() {
70 return Some((event, state));
71 }
72
73 match std::mem::replace(&mut state.phase, StreamPhase::Done) {
74 StreamPhase::Done => return None,
75
76 StreamPhase::StartIteration => {
77 match state.core.start_iteration(&*state.provider).await {
78 StartOutcome::Stream(s) => {
79 state.current_text.clear();
80 state.current_tool_calls.clear();
81 state.current_usage = Usage::default();
82 state.load_core_events();
84 state.phase = StreamPhase::Streaming(s);
85 }
86 StartOutcome::Terminal(outcome) => {
87 state.load_core_events();
89 if let Some(event) = outcome_to_error(*outcome) {
90 state.phase = StreamPhase::Done;
91 state.pending_events.push_back(event);
93 }
94 }
96 }
97 }
98
99 StreamPhase::Streaming(mut stream) => match stream.next().await {
100 Some(Ok(event)) => {
101 if let StreamEvent::TextDelta(ref t) = event {
103 state.current_text.push_str(t);
104 }
105 if let StreamEvent::ToolCallComplete { ref call, .. } = event {
106 state.current_tool_calls.push(call.clone());
107 }
108 if let StreamEvent::Usage(ref u) = event {
109 state.current_usage += u;
110 }
111
112 let is_done = matches!(&event, StreamEvent::Done { .. });
113 let loop_event = translate_stream_event(event);
114
115 if is_done {
116 state.phase = StreamPhase::ExecutingTools;
118 } else {
119 state.phase = StreamPhase::Streaming(stream);
120 }
121
122 if let Some(le) = loop_event {
124 return Some((Ok(le), state));
125 }
126 }
128 Some(Err(e)) => {
129 state.phase = StreamPhase::Done;
130 return Some((Err(e), state));
131 }
132 None => {
133 return None;
135 }
136 },
137
138 StreamPhase::ExecutingTools => {
139 let response = build_response(
140 &state.current_text,
141 &state.current_tool_calls,
142 std::mem::take(&mut state.current_usage),
143 );
144 let outcome = state.core.finish_iteration(response, &state.registry).await;
145
146 state.load_core_events();
148
149 match outcome {
150 IterationOutcome::ToolsExecuted { .. } => {
151 state.phase = StreamPhase::StartIteration;
152 }
153 IterationOutcome::Completed(_) => {
154 state.phase = StreamPhase::Done;
156 }
157 IterationOutcome::Error(data) => {
158 state.phase = StreamPhase::Done;
159 state.pending_events.push_back(Err(data.error));
160 }
161 }
162 }
164 }
165 }
166 });
167
168 Box::pin(stream)
169}
170
171enum StreamPhase {
173 StartIteration,
174 Streaming(ChatStream),
175 ExecutingTools,
176 Done,
177}
178
179struct UnfoldState<Ctx: LoopDepth + Send + Sync + 'static> {
181 core: LoopCore<Ctx>,
182 provider: Arc<dyn DynProvider>,
183 registry: Arc<ToolRegistry<Ctx>>,
184 phase: StreamPhase,
185 current_text: String,
186 current_tool_calls: Vec<ToolCall>,
187 current_usage: Usage,
188 pending_events: VecDeque<Result<LoopEvent, LlmError>>,
190}
191
192impl<Ctx: LoopDepth + Send + Sync + 'static> UnfoldState<Ctx> {
193 fn load_core_events(&mut self) {
195 for event in self.core.drain_events() {
196 self.pending_events.push_back(Ok(event));
197 }
198 }
199}
200
201fn translate_stream_event(event: StreamEvent) -> Option<LoopEvent> {
206 match event {
207 StreamEvent::TextDelta(t) => Some(LoopEvent::TextDelta(t)),
208 StreamEvent::ReasoningDelta(t) => Some(LoopEvent::ReasoningDelta(t)),
209 StreamEvent::ToolCallStart { index, id, name } => {
210 Some(LoopEvent::ToolCallStart { index, id, name })
211 }
212 StreamEvent::ToolCallDelta { index, json_chunk } => {
213 Some(LoopEvent::ToolCallDelta { index, json_chunk })
214 }
215 StreamEvent::ToolCallComplete { index, call } => {
216 Some(LoopEvent::ToolCallComplete { index, call })
217 }
218 StreamEvent::Usage(u) => Some(LoopEvent::Usage(u)),
219 StreamEvent::Done { .. } => None, }
221}
222
223fn build_response(text: &str, tool_calls: &[ToolCall], usage: Usage) -> ChatResponse {
225 let mut content = Vec::new();
226 if !text.is_empty() {
227 content.push(ContentBlock::Text(text.to_owned()));
228 }
229 for call in tool_calls {
230 content.push(ContentBlock::ToolCall(call.clone()));
231 }
232
233 let stop_reason = if tool_calls.is_empty() {
234 StopReason::EndTurn
235 } else {
236 StopReason::ToolUse
237 };
238
239 ChatResponse {
240 content,
241 usage,
242 stop_reason,
243 model: String::new(),
244 metadata: std::collections::HashMap::new(),
245 }
246}
247
248fn outcome_to_error(outcome: IterationOutcome) -> Option<Result<LoopEvent, LlmError>> {
254 match outcome {
255 IterationOutcome::Error(data) => Some(Err(data.error)),
256 IterationOutcome::Completed(_) | IterationOutcome::ToolsExecuted { .. } => None,
258 }
259}