use crate::chat::{ChatMessage, ChatResponse, ContentBlock, ToolCall, ToolResult};
use crate::error::LlmError;
use crate::provider::{ChatParams, DynProvider};
use crate::usage::Usage;
use super::LoopDepth;
use super::ToolRegistry;
use super::config::{LoopEvent, TerminationReason, ToolLoopConfig, ToolLoopResult};
use super::loop_core::{CompletedData, ErrorData, IterationOutcome, LoopCore};
macro_rules! impl_yielded_methods {
($yielded:ident < $($lt:lifetime),* >) => {
impl<$($lt,)* Ctx: LoopDepth + Send + Sync + 'static> $yielded<$($lt,)* Ctx> {
pub fn resume(self, command: LoopCommand) {
self.handle.resume(command);
}
pub fn continue_loop(self) {
self.resume(LoopCommand::Continue);
}
pub fn inject_and_continue(self, messages: Vec<ChatMessage>) {
self.resume(LoopCommand::InjectMessages(messages));
}
pub fn stop(self, reason: Option<String>) {
self.resume(LoopCommand::Stop(reason));
}
pub fn assistant_text(&self) -> Option<String> {
let text: String = self
.assistant_content
.iter()
.filter_map(|block| match block {
ContentBlock::Text(t) => Some(t.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
if text.is_empty() { None } else { Some(text) }
}
pub fn messages(&self) -> &[ChatMessage] {
self.handle.messages()
}
pub fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
self.handle.messages_mut()
}
}
};
}
pub(crate) use impl_yielded_methods;
macro_rules! outcome_to_turn_result {
($outcome:expr, $handle:expr, $turn_ty:ident, $yielded_ty:ident) => {{
let events = $handle.core.drain_events();
match $outcome {
IterationOutcome::ToolsExecuted {
tool_calls,
results,
assistant_content,
iteration,
total_usage,
} => $turn_ty::Yielded($yielded_ty {
handle: $handle,
tool_calls,
results,
assistant_content,
iteration,
total_usage,
events,
}),
IterationOutcome::Completed(CompletedData {
response,
termination_reason,
iterations,
total_usage,
}) => $turn_ty::Completed(Completed {
response,
termination_reason,
iterations,
total_usage,
events,
}),
IterationOutcome::Error(ErrorData {
error,
iterations,
total_usage,
}) => $turn_ty::Error(TurnError {
error,
iterations,
total_usage,
events,
}),
}
}};
}
pub(crate) use outcome_to_turn_result;
#[must_use = "a TurnResult must be matched — Yielded requires resume() to continue"]
pub enum TurnResult<'a, 'h, Ctx: LoopDepth + Send + Sync + 'static> {
Yielded(Yielded<'a, 'h, Ctx>),
Completed(Completed),
Error(TurnError),
}
#[must_use = "must call .resume(), .continue_loop(), .inject_and_continue(), or .stop() to continue"]
pub struct Yielded<'a, 'h, Ctx: LoopDepth + Send + Sync + 'static> {
handle: &'h mut ToolLoopHandle<'a, Ctx>,
pub tool_calls: Vec<ToolCall>,
pub results: Vec<ToolResult>,
pub assistant_content: Vec<ContentBlock>,
pub iteration: u32,
pub total_usage: Usage,
pub events: Vec<LoopEvent>,
}
impl_yielded_methods!(Yielded<'a, 'h>);
pub struct Completed {
pub response: ChatResponse,
pub termination_reason: TerminationReason,
pub iterations: u32,
pub total_usage: Usage,
pub events: Vec<LoopEvent>,
}
pub struct TurnError {
pub error: LlmError,
pub iterations: u32,
pub total_usage: Usage,
pub events: Vec<LoopEvent>,
}
#[derive(Debug)]
pub enum LoopCommand {
Continue,
InjectMessages(Vec<ChatMessage>),
Stop(Option<String>),
}
pub struct ToolLoopHandle<'a, Ctx: LoopDepth + Send + Sync + 'static> {
provider: &'a dyn DynProvider,
registry: &'a ToolRegistry<Ctx>,
core: LoopCore<Ctx>,
}
impl<'a, Ctx: LoopDepth + Send + Sync + 'static> ToolLoopHandle<'a, Ctx> {
pub fn new(
provider: &'a dyn DynProvider,
registry: &'a ToolRegistry<Ctx>,
params: ChatParams,
config: ToolLoopConfig,
ctx: &Ctx,
) -> Self {
Self {
provider,
registry,
core: LoopCore::new(params, config, ctx),
}
}
pub async fn next_turn(&mut self) -> TurnResult<'a, '_, Ctx> {
let outcome = self.core.do_iteration(self.provider, self.registry).await;
outcome_to_turn_result!(outcome, self, TurnResult, Yielded)
}
pub fn resume(&mut self, command: LoopCommand) {
self.core.resume(command);
}
pub fn messages(&self) -> &[ChatMessage] {
self.core.messages()
}
pub fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
self.core.messages_mut()
}
pub fn total_usage(&self) -> &Usage {
self.core.total_usage()
}
pub fn iterations(&self) -> u32 {
self.core.iterations()
}
pub fn is_finished(&self) -> bool {
self.core.is_finished()
}
pub fn drain_events(&mut self) -> Vec<LoopEvent> {
self.core.drain_events()
}
pub fn into_result(self) -> ToolLoopResult {
self.core.into_result()
}
pub fn into_owned(
self,
provider: std::sync::Arc<dyn DynProvider>,
registry: std::sync::Arc<ToolRegistry<Ctx>>,
) -> super::OwnedToolLoopHandle<Ctx> {
super::OwnedToolLoopHandle::from_core(provider, registry, self.core)
}
}
impl<Ctx: LoopDepth + Send + Sync + 'static> std::fmt::Debug for ToolLoopHandle<'_, Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolLoopHandle")
.field("core", &self.core)
.finish_non_exhaustive()
}
}