use std::sync::Arc;
use crate::chat::{ChatMessage, ContentBlock, ToolCall, ToolResult};
use crate::provider::{ChatParams, DynProvider};
use crate::usage::Usage;
use super::LoopDepth;
use super::ToolRegistry;
use super::config::{LoopEvent, ToolLoopConfig, ToolLoopResult};
use super::loop_core::{CompletedData, ErrorData, IterationOutcome, LoopCore};
use super::loop_resumable::{
Completed, LoopCommand, TurnError, impl_yielded_methods, outcome_to_turn_result,
};
#[must_use = "an OwnedTurnResult must be matched — Yielded requires resume() to continue"]
pub enum OwnedTurnResult<'h, Ctx: LoopDepth + Send + Sync + 'static> {
Yielded(OwnedYielded<'h, Ctx>),
Completed(Completed),
Error(TurnError),
}
#[must_use = "must call .resume(), .continue_loop(), .inject_and_continue(), or .stop() to continue"]
pub struct OwnedYielded<'h, Ctx: LoopDepth + Send + Sync + 'static> {
handle: &'h mut OwnedToolLoopHandle<Ctx>,
pub tool_calls: Vec<ToolCall>,
pub results: Vec<ToolResult>,
pub assistant_content: Vec<ContentBlock>,
pub iteration: u32,
pub total_usage: Usage,
pub events: Vec<LoopEvent>,
}
impl_yielded_methods!(OwnedYielded<'h>);
pub struct OwnedToolLoopHandle<Ctx: LoopDepth + Send + Sync + 'static> {
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
core: LoopCore<Ctx>,
}
impl<Ctx: LoopDepth + Send + Sync + 'static> OwnedToolLoopHandle<Ctx> {
pub fn new(
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
params: ChatParams,
config: ToolLoopConfig,
ctx: &Ctx,
) -> Self {
Self {
provider,
registry,
core: LoopCore::new(params, config, ctx),
}
}
pub(crate) fn from_core(
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
core: LoopCore<Ctx>,
) -> Self {
Self {
provider,
registry,
core,
}
}
pub async fn next_turn(&mut self) -> OwnedTurnResult<'_, Ctx> {
let outcome = self
.core
.do_iteration(&*self.provider, &self.registry)
.await;
outcome_to_turn_result!(outcome, self, OwnedTurnResult, OwnedYielded)
}
pub fn resume(&mut self, command: LoopCommand) {
self.core.resume(command);
}
pub fn messages(&self) -> &[ChatMessage] {
self.core.messages()
}
pub fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
self.core.messages_mut()
}
pub fn total_usage(&self) -> &Usage {
self.core.total_usage()
}
pub fn iterations(&self) -> u32 {
self.core.iterations()
}
pub fn is_finished(&self) -> bool {
self.core.is_finished()
}
pub fn drain_events(&mut self) -> Vec<LoopEvent> {
self.core.drain_events()
}
pub fn into_result(self) -> ToolLoopResult {
self.core.into_result()
}
}
impl<Ctx: LoopDepth + Send + Sync + 'static> std::fmt::Debug for OwnedToolLoopHandle<Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedToolLoopHandle")
.field("core", &self.core)
.finish_non_exhaustive()
}
}