use super::{ChatRunner, StateManager};
use crate::agent::core::{
Agent, AgentState, DynToolObj, EvictionHandler, FormatterFactory, ToolCallParserObj,
};
use crate::config::EvictionStrategy;
use crate::error::{AmbiError, Result};
use crate::llm::LLMEngine;
use crate::runtime::spawn;
use crate::types::{ChatTemplate, Message, ToolDefinition};
use crate::ContentPart;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
pub(crate) enum ExecutionMode<'a> {
Sync,
Stream {
tx_out: &'a Sender<Result<String>>,
formatter_factory: &'a FormatterFactory,
},
}
pub(crate) struct LoopConfig<'a> {
pub template: &'a ChatTemplate,
pub max_iterations: usize,
pub system_prompt: &'a str,
pub eviction_strategy: EvictionStrategy,
}
pub(crate) struct LoopTooling<'a> {
pub tools_def: &'a Arc<Vec<ToolDefinition>>,
pub cached_tool_prompt: &'a str,
pub tool_map: &'a Arc<HashMap<String, Arc<DynToolObj>>>, pub tool_parser: &'a Arc<ToolCallParserObj>, pub formatter_factory: &'a FormatterFactory,
}
pub(crate) struct RunCtx<'a> {
pub loop_config: LoopConfig<'a>,
pub loop_tooling: LoopTooling<'a>,
pub tx_out: Option<&'a Sender<Result<String>>>,
pub evict_handler: &'a Option<EvictionHandler>,
}
impl ChatRunner {
pub(crate) async fn chat_multimodal(
agent: &Agent,
state: &Arc<RwLock<AgentState>>,
parts: Vec<ContentPart>,
) -> Result<String> {
let tokens = agent.llm_engine.count_tokens(
&Message::User {
content: parts.clone(),
}
.to_string(),
)?;
let accessor = StateManager(state);
accessor.push_user_message(parts, tokens).await?;
let ctx = RunCtx {
loop_config: LoopConfig {
template: &agent.config.template,
max_iterations: agent.config.max_iterations,
system_prompt: &agent.config.system_prompt,
eviction_strategy: agent.config.eviction_strategy.clone(),
},
loop_tooling: LoopTooling {
tools_def: &agent.tools_def,
cached_tool_prompt: &agent.cached_tool_prompt,
tool_map: &agent.tool_map,
tool_parser: &agent.tool_parser,
formatter_factory: &agent.formatter_factory,
},
tx_out: None,
evict_handler: &agent.on_evict_handler,
};
Self::run_loop(&ctx, &agent.llm_engine, &accessor, ExecutionMode::Sync).await
}
pub(crate) async fn chat_multimodal_stream(
agent: &Agent,
state: &Arc<RwLock<AgentState>>,
parts: Vec<ContentPart>,
) -> Result<Pin<Box<ReceiverStream<Result<String>>>>> {
let (tx_out, rx_out) = channel::<Result<String>>(1024);
let tx_out_for_panic = tx_out.clone();
#[cfg(target_arch = "wasm32")]
let _tx_out_for_panic = &tx_out_for_panic;
let agent_clone = agent.clone();
let state_clone = Arc::clone(state);
let tokens = agent_clone.llm_engine.count_tokens(
&Message::User {
content: parts.clone(),
}
.to_string(),
)?;
let handle = spawn(async move {
let tx_out_clone = tx_out.clone();
let accessor = StateManager(&state_clone);
if let Err(e) = accessor.push_user_message(parts, tokens).await {
let _ = tx_out_clone.send(Err(e)).await;
return;
}
let ctx = RunCtx {
loop_config: LoopConfig {
template: &agent_clone.config.template,
max_iterations: agent_clone.config.max_iterations,
system_prompt: &agent_clone.config.system_prompt,
eviction_strategy: agent_clone.config.eviction_strategy,
},
loop_tooling: LoopTooling {
tools_def: &agent_clone.tools_def,
cached_tool_prompt: &agent_clone.cached_tool_prompt,
tool_map: &agent_clone.tool_map,
tool_parser: &agent_clone.tool_parser,
formatter_factory: &agent_clone.formatter_factory,
},
tx_out: Some(&tx_out_clone),
evict_handler: &agent_clone.on_evict_handler,
};
let mode = ExecutionMode::Stream {
tx_out: &tx_out_clone,
formatter_factory: &agent_clone.formatter_factory,
};
if let Err(e) = Self::run_loop(&ctx, &agent_clone.llm_engine, &accessor, mode).await {
let _ = tx_out_clone.send(Err(e)).await;
}
});
#[cfg(not(target_arch = "wasm32"))]
spawn(async move {
if let Err(join_err) = handle.await {
if join_err.is_panic() {
log::error!("CRITICAL: Pipeline streaming task panicked internally.");
let _ = tx_out_for_panic
.send(Err(AmbiError::PipelineError(
"Internal framework panic caught, avoiding process crash.".into(),
)))
.await;
}
}
});
#[cfg(target_arch = "wasm32")]
drop(handle);
Ok(Box::pin(ReceiverStream::new(rx_out)))
}
pub(crate) async fn run_loop(
ctx: &RunCtx<'_>,
engine: &LLMEngine,
accessor: &StateManager<'_>,
mode: ExecutionMode<'_>,
) -> Result<String> {
let mut final_formatted_output = if ctx.tx_out.is_none() {
String::with_capacity(2048)
} else {
String::new()
};
let mut iteration_count = 0;
let mut snapshot_len = accessor.get_snapshot_len().await?;
loop {
if iteration_count >= ctx.loop_config.max_iterations {
accessor.truncate(snapshot_len).await?;
let err = AmbiError::MaxIterationsReached(ctx.loop_config.max_iterations);
return if let Some(tx) = ctx.tx_out {
let _ = tx.send(Err(err)).await;
Ok(String::new())
} else {
Err(err)
};
}
let req_data = accessor
.get_llm_request(
ctx.loop_config.system_prompt,
ctx.loop_config.template,
ctx.loop_tooling.tools_def,
ctx.loop_tooling.cached_tool_prompt,
ctx.loop_tooling.tool_parser.get_tags(),
)
.await?;
let (full_output, has_error) = match &mode {
ExecutionMode::Sync => match engine.chat(req_data).await {
Ok(res) => (res, false),
Err(e) => {
accessor.truncate(snapshot_len).await?;
return Err(e);
}
},
ExecutionMode::Stream {
tx_out,
formatter_factory,
} => {
let (tx_llm, rx_llm) = channel::<Result<String>>(1024);
let process_future =
Self::process_llm_stream(rx_llm, tx_out, formatter_factory);
let engine_future = engine.chat_stream(req_data, tx_llm);
let ((), (full_output, stream_error)) =
tokio::join!(engine_future, process_future);
if let Some(err) = stream_error {
accessor.truncate(snapshot_len).await?;
return Err(err);
}
(full_output, false)
}
};
if has_error {
accessor.truncate(snapshot_len).await?;
break;
}
let parsed_tool_calls = ctx.loop_tooling.tool_parser.parse(&full_output);
let tool_calls_with_ids: Vec<_> = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(i, (name, args))| {
let id = format!(
"call_ambi_{}_{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis(),
i
);
(name, args, id)
})
.collect();
let asst_msg = Message::Assistant {
content: full_output.clone(),
tool_calls: tool_calls_with_ids.clone(),
};
let tokens = engine.count_tokens(&asst_msg.to_string())?;
let dynamic_system_overhead = accessor.get_system_overhead().await?;
let prompt_overhead = engine.count_tokens(ctx.loop_config.system_prompt)?
+ engine.count_tokens(ctx.loop_tooling.cached_tool_prompt)?
+ dynamic_system_overhead;
let evicted_count = accessor
.append_assistant_message_and_evict(
full_output.clone(),
tool_calls_with_ids.clone(),
tokens,
ctx.evict_handler,
&ctx.loop_config.eviction_strategy,
prompt_overhead,
)
.await?;
snapshot_len = snapshot_len.saturating_sub(evicted_count);
if ctx.tx_out.is_none() {
let mut formatter = (ctx.loop_tooling.formatter_factory)();
final_formatted_output.push_str(&formatter.push(&full_output));
final_formatted_output.push_str(&formatter.flush());
}
let tool_calls = match Self::handle_tool_calls(
accessor,
engine,
Arc::clone(ctx.loop_tooling.tool_map),
tool_calls_with_ids,
ctx.tx_out.cloned(),
)
.await
{
Ok(calls) => calls,
Err(e) => {
accessor.truncate(snapshot_len).await?;
return if let Some(tx) = ctx.tx_out {
let _ = tx.send(Err(AmbiError::ToolError(e.to_string()))).await;
Ok(String::new())
} else {
Err(AmbiError::ToolError(e.to_string()))
};
}
};
if tool_calls.is_empty() {
break;
}
if let Some(tx) = ctx.tx_out {
let mut formatted_tools = String::with_capacity(1024);
Self::process_tool_calls_output(&tool_calls, &mut formatted_tools);
let _ = tx.send(Ok(formatted_tools)).await;
} else {
Self::process_tool_calls_output(&tool_calls, &mut final_formatted_output);
}
iteration_count += 1;
}
Ok(final_formatted_output.trim().to_string())
}
}