use std::sync::Arc;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::error::LlmError;
use crate::provider::{ChatParams, DynProvider};
use crate::stream::StreamEvent;
use super::LoopDepth;
use super::ToolRegistry;
use super::config::{ToolLoopConfig, ToolLoopResult};
use super::loop_stream::tool_loop_stream;
pub fn tool_loop_channel<Ctx: LoopDepth + Send + Sync + 'static>(
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
params: ChatParams,
config: ToolLoopConfig,
ctx: Arc<Ctx>,
buffer_size: usize,
) -> (
mpsc::Receiver<Result<StreamEvent, LlmError>>,
JoinHandle<ToolLoopResult>,
) {
let (tx, rx) = mpsc::channel(buffer_size);
let handle = tokio::spawn(async move {
let mut stream = tool_loop_stream(provider, registry, params, config, ctx);
while let Some(event) = stream.next().await {
let is_done = matches!(&event, Ok(StreamEvent::Done { .. }));
if tx.send(event).await.is_err() {
break;
}
if is_done {
break;
}
}
ToolLoopResult {
response: crate::chat::ChatResponse::empty(),
iterations: 0,
total_usage: crate::usage::Usage::default(),
termination_reason: super::config::TerminationReason::Complete,
}
});
(rx, handle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::{ChatMessage, StopReason};
use crate::test_helpers::mock_for;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_tool_loop_channel_basic() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
StreamEvent::TextDelta("Hello ".into()),
StreamEvent::TextDelta("from channel!".into()),
StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let registry: ToolRegistry<()> = ToolRegistry::new();
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Hello")],
..Default::default()
};
let (mut rx, handle) = tool_loop_channel(
mock,
registry,
params,
ToolLoopConfig::default(),
Arc::new(()),
16,
);
let mut events = Vec::new();
while let Some(event) = rx.recv().await {
events.push(event);
}
assert!(!events.is_empty());
let result = handle.await.unwrap();
assert!(matches!(
result.termination_reason,
super::super::config::TerminationReason::Complete
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_tool_loop_channel_consumer_drop() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
StreamEvent::TextDelta("Hello".into()),
StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let registry: ToolRegistry<()> = ToolRegistry::new();
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Hello")],
..Default::default()
};
let (rx, handle) = tool_loop_channel(
mock,
registry,
params,
ToolLoopConfig::default(),
Arc::new(()),
2, );
drop(rx);
let _result = handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_tool_loop_channel_backpressure() {
let mock = Arc::new(mock_for("test", "test-model"));
mock.queue_stream(vec![
StreamEvent::TextDelta("Response".into()),
StreamEvent::Done {
stop_reason: StopReason::EndTurn,
},
]);
let registry: ToolRegistry<()> = ToolRegistry::new();
let registry = Arc::new(registry);
let params = ChatParams {
messages: vec![ChatMessage::user("Hello")],
..Default::default()
};
let (mut rx, handle) = tool_loop_channel(
mock,
registry,
params,
ToolLoopConfig::default(),
Arc::new(()),
1, );
while let Some(_event) = rx.recv().await {}
let result = handle.await.unwrap();
assert!(matches!(
result.termination_reason,
super::super::config::TerminationReason::Complete
));
}
}