Skip to main content

llm_stack/tool/
loop_channel.rs

1//! Channel-based tool loop with backpressure support.
2
3use std::sync::Arc;
4
5use futures::StreamExt;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8
9use crate::error::LlmError;
10use crate::provider::{ChatParams, DynProvider};
11use crate::stream::StreamEvent;
12
13use super::LoopDepth;
14use super::ToolRegistry;
15use super::config::{ToolLoopConfig, ToolLoopResult};
16use super::loop_stream::tool_loop_stream;
17
18/// Channel-based tool loop with bounded buffer for backpressure.
19///
20/// Unlike [`tool_loop_stream`], this function spawns an internal task and
21/// sends events through a bounded channel. This provides natural backpressure:
22/// if the consumer is slow, the producer blocks when the buffer is full,
23/// preventing unbounded memory growth.
24///
25/// Returns a tuple of:
26/// - `Receiver<Result<StreamEvent, LlmError>>` - events from the stream
27/// - `JoinHandle<ToolLoopResult>` - the final result (join to get it)
28///
29/// # Backpressure
30///
31/// The `buffer_size` parameter controls how many events can be buffered before
32/// the producer blocks. Choose based on your use case:
33/// - Small (4-16): Tight backpressure, minimal memory
34/// - Medium (32-64): Balance between latency and memory
35/// - Large (128+): More latency tolerance, higher memory
36///
37/// # Consumer Drop
38///
39/// If the receiver is dropped before the stream completes, the internal task
40/// will detect this (send returns error) and terminate gracefully. The join
41/// handle will still return a `ToolLoopResult`, though it may indicate partial
42/// completion.
43///
44/// # Example
45///
46/// ```rust,no_run
47/// use std::sync::Arc;
48/// use llm_stack::{ChatParams, ChatMessage, ToolLoopConfig, ToolRegistry, StreamEvent};
49/// use llm_stack::tool::tool_loop_channel;
50///
51/// async fn example(
52///     provider: Arc<dyn llm_stack::DynProvider>,
53///     registry: Arc<ToolRegistry<()>>,
54/// ) -> Result<(), Box<dyn std::error::Error>> {
55///     let params = ChatParams {
56///         messages: vec![ChatMessage::user("Hello")],
57///         ..Default::default()
58///     };
59///
60///     let (mut rx, handle) = tool_loop_channel(
61///         provider,
62///         registry,
63///         params,
64///         ToolLoopConfig::default(),
65///         Arc::new(()),
66///         32, // buffer size
67///     );
68///
69///     while let Some(event) = rx.recv().await {
70///         match event? {
71///             StreamEvent::TextDelta(text) => print!("{text}"),
72///             StreamEvent::Done { .. } => break,
73///             _ => {}
74///         }
75///     }
76///
77///     let result = handle.await?;
78///     println!("\nCompleted in {} iterations", result.iterations);
79///     Ok(())
80/// }
81/// ```
82pub fn tool_loop_channel<Ctx: LoopDepth + Send + Sync + 'static>(
83    provider: Arc<dyn DynProvider>,
84    registry: Arc<ToolRegistry<Ctx>>,
85    params: ChatParams,
86    config: ToolLoopConfig,
87    ctx: Arc<Ctx>,
88    buffer_size: usize,
89) -> (
90    mpsc::Receiver<Result<StreamEvent, LlmError>>,
91    JoinHandle<ToolLoopResult>,
92) {
93    let (tx, rx) = mpsc::channel(buffer_size);
94
95    let handle = tokio::spawn(async move {
96        let mut stream = tool_loop_stream(provider, registry, params, config, ctx);
97
98        while let Some(event) = stream.next().await {
99            let is_done = matches!(&event, Ok(StreamEvent::Done { .. }));
100
101            // Try to send the event
102            if tx.send(event).await.is_err() {
103                // Consumer dropped - break out of loop
104                break;
105            }
106
107            if is_done {
108                // Stream is done, we can exit
109                break;
110            }
111        }
112
113        // Return a minimal result indicating completion.
114        // Note: The actual iteration/usage tracking would need deeper
115        // integration with tool_loop_stream's internal state. For now,
116        // we return a placeholder result. The events sent through the
117        // channel contain the actual usage data via StreamEvent::Usage.
118        ToolLoopResult {
119            response: crate::chat::ChatResponse::empty(),
120            iterations: 0,
121            total_usage: crate::usage::Usage::default(),
122            termination_reason: super::config::TerminationReason::Complete,
123        }
124    });
125
126    (rx, handle)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::chat::{ChatMessage, StopReason};
133    use crate::test_helpers::mock_for;
134
135    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
136    async fn test_tool_loop_channel_basic() {
137        let mock = Arc::new(mock_for("test", "test-model"));
138
139        // Queue a stream response (tool_loop_stream uses streaming)
140        mock.queue_stream(vec![
141            StreamEvent::TextDelta("Hello ".into()),
142            StreamEvent::TextDelta("from channel!".into()),
143            StreamEvent::Done {
144                stop_reason: StopReason::EndTurn,
145            },
146        ]);
147
148        let registry: ToolRegistry<()> = ToolRegistry::new();
149        let registry = Arc::new(registry);
150
151        let params = ChatParams {
152            messages: vec![ChatMessage::user("Hello")],
153            ..Default::default()
154        };
155
156        let (mut rx, handle) = tool_loop_channel(
157            mock,
158            registry,
159            params,
160            ToolLoopConfig::default(),
161            Arc::new(()),
162            16,
163        );
164
165        let mut events = Vec::new();
166        while let Some(event) = rx.recv().await {
167            events.push(event);
168        }
169
170        // Should have received events
171        assert!(!events.is_empty());
172
173        // Join handle should complete
174        let result = handle.await.unwrap();
175        assert!(matches!(
176            result.termination_reason,
177            super::super::config::TerminationReason::Complete
178        ));
179    }
180
181    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
182    async fn test_tool_loop_channel_consumer_drop() {
183        let mock = Arc::new(mock_for("test", "test-model"));
184
185        // Queue a stream that completes
186        mock.queue_stream(vec![
187            StreamEvent::TextDelta("Hello".into()),
188            StreamEvent::Done {
189                stop_reason: StopReason::EndTurn,
190            },
191        ]);
192
193        let registry: ToolRegistry<()> = ToolRegistry::new();
194        let registry = Arc::new(registry);
195
196        let params = ChatParams {
197            messages: vec![ChatMessage::user("Hello")],
198            ..Default::default()
199        };
200
201        let (rx, handle) = tool_loop_channel(
202            mock,
203            registry,
204            params,
205            ToolLoopConfig::default(),
206            Arc::new(()),
207            2, // Small buffer
208        );
209
210        // Drop the receiver immediately
211        drop(rx);
212
213        // Handle should still complete (gracefully)
214        let _result = handle.await.unwrap();
215        // Result will be a default/empty one since we dropped early - just check it completes
216    }
217
218    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
219    async fn test_tool_loop_channel_backpressure() {
220        let mock = Arc::new(mock_for("test", "test-model"));
221
222        // Queue a stream response
223        mock.queue_stream(vec![
224            StreamEvent::TextDelta("Response".into()),
225            StreamEvent::Done {
226                stop_reason: StopReason::EndTurn,
227            },
228        ]);
229
230        let registry: ToolRegistry<()> = ToolRegistry::new();
231        let registry = Arc::new(registry);
232
233        let params = ChatParams {
234            messages: vec![ChatMessage::user("Hello")],
235            ..Default::default()
236        };
237
238        // Very small buffer to test backpressure behavior
239        let (mut rx, handle) = tool_loop_channel(
240            mock,
241            registry,
242            params,
243            ToolLoopConfig::default(),
244            Arc::new(()),
245            1, // Minimal buffer
246        );
247
248        // Consume all events
249        while let Some(_event) = rx.recv().await {}
250
251        let result = handle.await.unwrap();
252        assert!(matches!(
253            result.termination_reason,
254            super::super::config::TerminationReason::Complete
255        ));
256    }
257}