use super::Agent;
use crate::agent::tool::{DefaultToolParser, Tool, ToolCallParser, ToolDefinition};
use crate::error::{AmbiError, Result};
use crate::llm::{ChatTemplate, LLMEngine, LLMEngineConfig, LLMEngineTrait};
use crate::types::message::Message;
use crate::types::AgentConfig;
use std::collections::HashMap;
use std::sync::Arc;
impl Agent {
pub async fn make(engine_cfg: LLMEngineConfig) -> Result<Self> {
let engine = tokio::task::spawn_blocking(move || LLMEngine::load(engine_cfg))
.await
.map_err(|e| {
AmbiError::EngineError(format!("Failed to spawn blocking task: {}", e))
})??;
Ok(Self::init_agent(engine))
}
pub fn with_custom_engine(custom_backend: Box<dyn LLMEngineTrait>) -> Result<Self> {
let engine = LLMEngine::from_custom(custom_backend);
Ok(Self::init_agent(engine))
}
pub(super) fn init_agent(engine: LLMEngine) -> Self {
let llm_engine = Arc::new(engine);
Self {
llm_engine,
config: AgentConfig::default(),
tools_def: Arc::new(Vec::new()),
tool_map: Arc::new(HashMap::new()),
tool_parser: Arc::new(DefaultToolParser::make()),
on_evict_handler: None,
cached_tool_prompt: String::new(),
}
}
pub fn enable_formatting(mut self, enable: bool) -> Self {
self.config.enable_formatting = enable;
self
}
pub fn with_eviction_strategy(
mut self,
keep_head: usize,
keep_tail: usize,
max_safe_tokens: usize,
) -> Self {
self.config.eviction_strategy = (keep_head, keep_tail, max_safe_tokens);
self
}
pub fn preamble(mut self, system_prompt: &str) -> Self {
self.config.system_prompt = system_prompt.to_string();
self
}
pub fn template<T: Into<ChatTemplate>>(mut self, template_source: T) -> Self {
self.config.template = template_source.into();
self
}
pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Result<Self> {
let def = tool.definition();
let mut defs = Arc::try_unwrap(self.tools_def).unwrap_or_else(|arc| (*arc).clone());
let mut map = Arc::try_unwrap(self.tool_map).unwrap_or_else(|arc| (*arc).clone());
if !defs.iter().any(|t| t.name == def.name) {
defs.push(ToolDefinition {
name: def.name.clone(),
description: def.description,
parameters: def.parameters,
timeout_secs: def.timeout_secs,
max_retries: def.max_retries,
is_idempotent: def.is_idempotent,
});
map.insert(def.name, Arc::new(tool));
}
self.tools_def = Arc::new(defs);
self.tool_map = Arc::new(map);
self.update_cached_tool_prompt();
Ok(self)
}
pub fn with_tool_parser<P: ToolCallParser + 'static>(mut self, parser: P) -> Self {
self.tool_parser = Arc::new(parser);
self.update_cached_tool_prompt();
self
}
pub fn on_evict<F>(mut self, handler: F) -> Self
where
F: Fn(Vec<Arc<Message>>) + Send + Sync + 'static,
{
self.on_evict_handler = Some(Arc::new(handler));
self
}
fn update_cached_tool_prompt(&mut self) {
if self.tools_def.is_empty() {
self.cached_tool_prompt = String::new();
} else {
let tools_json = serde_json::to_string(&*self.tools_def).unwrap_or_default();
self.cached_tool_prompt = self.tool_parser.format_instruction(&tools_json);
}
}
}