llm_stack_core/tool/loop_channel.rs
1//! Channel-based tool loop with backpressure support.
2
3use std::sync::Arc;
4
5use futures::StreamExt;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8
9use crate::error::LlmError;
10use crate::provider::{ChatParams, DynProvider};
11use crate::stream::StreamEvent;
12
13use super::LoopDepth;
14use super::ToolRegistry;
15use super::config::{ToolLoopConfig, ToolLoopResult};
16use super::loop_stream::tool_loop_stream;
17
18/// Channel-based tool loop with bounded buffer for backpressure.
19///
20/// Unlike [`tool_loop_stream`], this function spawns an internal task and
21/// sends events through a bounded channel. This provides natural backpressure:
22/// if the consumer is slow, the producer blocks when the buffer is full,
23/// preventing unbounded memory growth.
24///
25/// Returns a tuple of:
26/// - `Receiver<Result<StreamEvent, LlmError>>` - events from the stream
27/// - `JoinHandle<ToolLoopResult>` - the final result (join to get it)
28///
29/// # Backpressure
30///
31/// The `buffer_size` parameter controls how many events can be buffered before
32/// the producer blocks. Choose based on your use case:
33/// - Small (4-16): Tight backpressure, minimal memory
34/// - Medium (32-64): Balance between latency and memory
35/// - Large (128+): More latency tolerance, higher memory
36///
37/// # Consumer Drop
38///
39/// If the receiver is dropped before the stream completes, the internal task
40/// will detect this (send returns error) and terminate gracefully. The join
41/// handle will still return a `ToolLoopResult`, though it may indicate partial
42/// completion.
43///
44/// # Example
45///
46/// ```rust,no_run
47/// use std::sync::Arc;
48/// use llm_stack_core::{ChatParams, ChatMessage, ToolLoopConfig, ToolRegistry, StreamEvent};
49/// use llm_stack_core::tool::tool_loop_channel;
50///
51/// async fn example(
52/// provider: Arc<dyn llm_stack_core::DynProvider>,
53/// registry: Arc<ToolRegistry<()>>,
54/// ) -> Result<(), Box<dyn std::error::Error>> {
55/// let params = ChatParams {
56/// messages: vec![ChatMessage::user("Hello")],
57/// ..Default::default()
58/// };
59///
60/// let (mut rx, handle) = tool_loop_channel(
61/// provider,
62/// registry,
63/// params,
64/// ToolLoopConfig::default(),
65/// Arc::new(()),
66/// 32, // buffer size
67/// );
68///
69/// while let Some(event) = rx.recv().await {
70/// match event? {
71/// StreamEvent::TextDelta(text) => print!("{text}"),
72/// StreamEvent::Done { .. } => break,
73/// _ => {}
74/// }
75/// }
76///
77/// let result = handle.await?;
78/// println!("\nCompleted in {} iterations", result.iterations);
79/// Ok(())
80/// }
81/// ```
82pub fn tool_loop_channel<Ctx: LoopDepth + Send + Sync + 'static>(
83 provider: Arc<dyn DynProvider>,
84 registry: Arc<ToolRegistry<Ctx>>,
85 params: ChatParams,
86 config: ToolLoopConfig,
87 ctx: Arc<Ctx>,
88 buffer_size: usize,
89) -> (
90 mpsc::Receiver<Result<StreamEvent, LlmError>>,
91 JoinHandle<ToolLoopResult>,
92) {
93 let (tx, rx) = mpsc::channel(buffer_size);
94
95 let handle = tokio::spawn(async move {
96 let mut stream = tool_loop_stream(provider, registry, params, config, ctx);
97
98 while let Some(event) = stream.next().await {
99 let is_done = matches!(&event, Ok(StreamEvent::Done { .. }));
100
101 // Try to send the event
102 if tx.send(event).await.is_err() {
103 // Consumer dropped - break out of loop
104 break;
105 }
106
107 if is_done {
108 // Stream is done, we can exit
109 break;
110 }
111 }
112
113 // Return a minimal result indicating completion.
114 // Note: The actual iteration/usage tracking would need deeper
115 // integration with tool_loop_stream's internal state. For now,
116 // we return a placeholder result. The events sent through the
117 // channel contain the actual usage data via StreamEvent::Usage.
118 ToolLoopResult {
119 response: crate::chat::ChatResponse::empty(),
120 iterations: 0,
121 total_usage: crate::usage::Usage::default(),
122 termination_reason: super::config::TerminationReason::Complete,
123 }
124 });
125
126 (rx, handle)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::chat::{ChatMessage, StopReason};
133 use crate::test_helpers::mock_for;
134
135 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
136 async fn test_tool_loop_channel_basic() {
137 let mock = Arc::new(mock_for("test", "test-model"));
138
139 // Queue a stream response (tool_loop_stream uses streaming)
140 mock.queue_stream(vec![
141 StreamEvent::TextDelta("Hello ".into()),
142 StreamEvent::TextDelta("from channel!".into()),
143 StreamEvent::Done {
144 stop_reason: StopReason::EndTurn,
145 },
146 ]);
147
148 let registry: ToolRegistry<()> = ToolRegistry::new();
149 let registry = Arc::new(registry);
150
151 let params = ChatParams {
152 messages: vec![ChatMessage::user("Hello")],
153 ..Default::default()
154 };
155
156 let (mut rx, handle) = tool_loop_channel(
157 mock,
158 registry,
159 params,
160 ToolLoopConfig::default(),
161 Arc::new(()),
162 16,
163 );
164
165 let mut events = Vec::new();
166 while let Some(event) = rx.recv().await {
167 events.push(event);
168 }
169
170 // Should have received events
171 assert!(!events.is_empty());
172
173 // Join handle should complete
174 let result = handle.await.unwrap();
175 assert!(matches!(
176 result.termination_reason,
177 super::super::config::TerminationReason::Complete
178 ));
179 }
180
181 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
182 async fn test_tool_loop_channel_consumer_drop() {
183 let mock = Arc::new(mock_for("test", "test-model"));
184
185 // Queue a stream that completes
186 mock.queue_stream(vec![
187 StreamEvent::TextDelta("Hello".into()),
188 StreamEvent::Done {
189 stop_reason: StopReason::EndTurn,
190 },
191 ]);
192
193 let registry: ToolRegistry<()> = ToolRegistry::new();
194 let registry = Arc::new(registry);
195
196 let params = ChatParams {
197 messages: vec![ChatMessage::user("Hello")],
198 ..Default::default()
199 };
200
201 let (rx, handle) = tool_loop_channel(
202 mock,
203 registry,
204 params,
205 ToolLoopConfig::default(),
206 Arc::new(()),
207 2, // Small buffer
208 );
209
210 // Drop the receiver immediately
211 drop(rx);
212
213 // Handle should still complete (gracefully)
214 let _result = handle.await.unwrap();
215 // Result will be a default/empty one since we dropped early - just check it completes
216 }
217
218 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
219 async fn test_tool_loop_channel_backpressure() {
220 let mock = Arc::new(mock_for("test", "test-model"));
221
222 // Queue a stream response
223 mock.queue_stream(vec![
224 StreamEvent::TextDelta("Response".into()),
225 StreamEvent::Done {
226 stop_reason: StopReason::EndTurn,
227 },
228 ]);
229
230 let registry: ToolRegistry<()> = ToolRegistry::new();
231 let registry = Arc::new(registry);
232
233 let params = ChatParams {
234 messages: vec![ChatMessage::user("Hello")],
235 ..Default::default()
236 };
237
238 // Very small buffer to test backpressure behavior
239 let (mut rx, handle) = tool_loop_channel(
240 mock,
241 registry,
242 params,
243 ToolLoopConfig::default(),
244 Arc::new(()),
245 1, // Minimal buffer
246 );
247
248 // Consume all events
249 while let Some(_event) = rx.recv().await {}
250
251 let result = handle.await.unwrap();
252 assert!(matches!(
253 result.termination_reason,
254 super::super::config::TerminationReason::Complete
255 ));
256 }
257}