use super::tool_exec::ParallelToolExecutor;
use super::AgentLoopError;
use crate::contracts::runtime::behavior::NoOpBehavior;
use crate::contracts::runtime::tool_call::{Tool, ToolDescriptor};
use crate::contracts::runtime::AgentBehavior;
use crate::contracts::runtime::ToolExecutor;
use crate::contracts::RunContext;
use async_trait::async_trait;
use genai::chat::ChatOptions;
use genai::Client;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct LlmRetryPolicy {
pub max_attempts_per_model: usize,
pub initial_backoff_ms: u64,
pub max_backoff_ms: u64,
pub backoff_jitter_percent: u8,
pub max_retry_window_ms: Option<u64>,
pub retry_stream_start: bool,
pub max_stream_event_retries: usize,
pub stream_error_fallback_threshold: usize,
}
impl Default for LlmRetryPolicy {
fn default() -> Self {
Self {
max_attempts_per_model: 2,
initial_backoff_ms: 250,
max_backoff_ms: 2_000,
backoff_jitter_percent: 20,
max_retry_window_ms: Some(10_000),
retry_stream_start: true,
max_stream_event_retries: 2,
stream_error_fallback_threshold: 2,
}
}
}
pub struct StepToolInput<'a> {
pub state: &'a RunContext,
}
#[derive(Clone, Default)]
pub struct StepToolSnapshot {
pub tools: HashMap<String, Arc<dyn Tool>>,
pub descriptors: Vec<ToolDescriptor>,
}
impl StepToolSnapshot {
pub fn from_tools(tools: HashMap<String, Arc<dyn Tool>>) -> Self {
let descriptors = tools
.values()
.map(|tool| tool.descriptor().clone())
.collect();
Self { tools, descriptors }
}
}
#[async_trait]
pub trait StepToolProvider: Send + Sync {
async fn provide(&self, input: StepToolInput<'_>) -> Result<StepToolSnapshot, AgentLoopError>;
}
pub type LlmEventStream = std::pin::Pin<
Box<dyn futures::Stream<Item = Result<genai::chat::ChatStreamEvent, genai::Error>> + Send>,
>;
#[async_trait]
pub trait LlmExecutor: Send + Sync {
async fn exec_chat_response(
&self,
model: &str,
chat_req: genai::chat::ChatRequest,
options: Option<&genai::chat::ChatOptions>,
) -> genai::Result<genai::chat::ChatResponse>;
async fn exec_chat_stream_events(
&self,
model: &str,
chat_req: genai::chat::ChatRequest,
options: Option<&genai::chat::ChatOptions>,
) -> genai::Result<LlmEventStream>;
fn name(&self) -> &'static str;
}
#[derive(Clone)]
pub struct GenaiLlmExecutor {
client: Client,
}
impl GenaiLlmExecutor {
pub fn new(client: Client) -> Self {
Self { client }
}
}
impl std::fmt::Debug for GenaiLlmExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenaiLlmExecutor").finish()
}
}
#[async_trait]
impl LlmExecutor for GenaiLlmExecutor {
async fn exec_chat_response(
&self,
model: &str,
chat_req: genai::chat::ChatRequest,
options: Option<&ChatOptions>,
) -> genai::Result<genai::chat::ChatResponse> {
self.client.exec_chat(model, chat_req, options).await
}
async fn exec_chat_stream_events(
&self,
model: &str,
chat_req: genai::chat::ChatRequest,
options: Option<&ChatOptions>,
) -> genai::Result<LlmEventStream> {
let resp = self
.client
.exec_chat_stream(model, chat_req, options)
.await?;
Ok(Box::pin(resp.stream))
}
fn name(&self) -> &'static str {
"genai_client"
}
}
#[derive(Clone, Default)]
pub struct StaticStepToolProvider {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl StaticStepToolProvider {
pub fn new(tools: HashMap<String, Arc<dyn Tool>>) -> Self {
Self { tools }
}
}
#[async_trait]
impl StepToolProvider for StaticStepToolProvider {
async fn provide(&self, _input: StepToolInput<'_>) -> Result<StepToolSnapshot, AgentLoopError> {
Ok(StepToolSnapshot::from_tools(self.tools.clone()))
}
}
pub trait Agent: Send + Sync {
fn id(&self) -> &str;
fn model(&self) -> &str;
fn system_prompt(&self) -> &str;
fn max_rounds(&self) -> usize;
fn chat_options(&self) -> Option<&ChatOptions>;
fn fallback_models(&self) -> &[String];
fn llm_retry_policy(&self) -> &LlmRetryPolicy;
fn tool_executor(&self) -> Arc<dyn ToolExecutor>;
fn step_tool_provider(&self) -> Option<Arc<dyn StepToolProvider>> {
None
}
fn llm_executor(&self) -> Option<Arc<dyn LlmExecutor>> {
None
}
fn behavior(&self) -> &dyn AgentBehavior;
fn state_action_deserializer_registry(
&self,
) -> Arc<tirea_contract::runtime::state::StateActionDeserializerRegistry> {
Arc::new(tirea_contract::runtime::state::StateActionDeserializerRegistry::new())
}
}
impl std::fmt::Debug for dyn Agent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Agent")
.field("id", &self.id())
.field("model", &self.model())
.field("max_rounds", &self.max_rounds())
.field("behavior", &self.behavior().id())
.finish()
}
}
#[derive(Clone)]
pub struct BaseAgent {
pub id: String,
pub model: String,
pub system_prompt: String,
pub max_rounds: usize,
pub tool_executor: Arc<dyn ToolExecutor>,
pub chat_options: Option<ChatOptions>,
pub fallback_models: Vec<String>,
pub llm_retry_policy: LlmRetryPolicy,
pub behavior: Arc<dyn AgentBehavior>,
pub lattice_registry: Arc<tirea_state::LatticeRegistry>,
pub state_scope_registry: Arc<tirea_contract::runtime::state::StateScopeRegistry>,
pub step_tool_provider: Option<Arc<dyn StepToolProvider>>,
pub llm_executor: Option<Arc<dyn LlmExecutor>>,
pub state_action_deserializer_registry:
Arc<tirea_contract::runtime::state::StateActionDeserializerRegistry>,
}
impl Default for BaseAgent {
fn default() -> Self {
Self {
id: "default".to_string(),
model: "gpt-4o-mini".to_string(),
system_prompt: String::new(),
max_rounds: 10,
tool_executor: Arc::new(ParallelToolExecutor::streaming()),
chat_options: Some(
ChatOptions::default()
.with_capture_usage(true)
.with_capture_reasoning_content(true)
.with_capture_tool_calls(true),
),
fallback_models: Vec::new(),
llm_retry_policy: LlmRetryPolicy::default(),
behavior: Arc::new(NoOpBehavior),
lattice_registry: Arc::new(tirea_state::LatticeRegistry::new()),
state_scope_registry: Arc::new(
tirea_contract::runtime::state::StateScopeRegistry::new(),
),
step_tool_provider: None,
llm_executor: None,
state_action_deserializer_registry: Arc::new(
tirea_contract::runtime::state::StateActionDeserializerRegistry::new(),
),
}
}
}
impl std::fmt::Debug for BaseAgent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BaseAgent")
.field("id", &self.id)
.field("model", &self.model)
.field(
"system_prompt",
&format!("[{} chars]", self.system_prompt.len()),
)
.field("max_rounds", &self.max_rounds)
.field("tool_executor", &self.tool_executor.name())
.field("chat_options", &self.chat_options)
.field("fallback_models", &self.fallback_models)
.field("llm_retry_policy", &self.llm_retry_policy)
.field("behavior", &self.behavior.id())
.field(
"step_tool_provider",
&self.step_tool_provider.as_ref().map(|_| "<set>"),
)
.field(
"llm_executor",
&self
.llm_executor
.as_ref()
.map(|executor| executor.name())
.unwrap_or("genai_client(default)"),
)
.finish()
}
}
impl Agent for BaseAgent {
fn id(&self) -> &str {
&self.id
}
fn model(&self) -> &str {
&self.model
}
fn system_prompt(&self) -> &str {
&self.system_prompt
}
fn max_rounds(&self) -> usize {
self.max_rounds
}
fn chat_options(&self) -> Option<&ChatOptions> {
self.chat_options.as_ref()
}
fn fallback_models(&self) -> &[String] {
&self.fallback_models
}
fn llm_retry_policy(&self) -> &LlmRetryPolicy {
&self.llm_retry_policy
}
fn tool_executor(&self) -> Arc<dyn ToolExecutor> {
self.tool_executor.clone()
}
fn step_tool_provider(&self) -> Option<Arc<dyn StepToolProvider>> {
self.step_tool_provider.clone()
}
fn llm_executor(&self) -> Option<Arc<dyn LlmExecutor>> {
self.llm_executor.clone()
}
fn behavior(&self) -> &dyn AgentBehavior {
self.behavior.as_ref()
}
fn state_action_deserializer_registry(
&self,
) -> Arc<tirea_contract::runtime::state::StateActionDeserializerRegistry> {
self.state_action_deserializer_registry.clone()
}
}
impl BaseAgent {
tirea_contract::impl_shared_agent_builder_methods!();
#[must_use]
pub fn with_tool_executor(mut self, executor: Arc<dyn ToolExecutor>) -> Self {
self.tool_executor = executor;
self
}
#[must_use]
pub fn with_tools(self, tools: HashMap<String, Arc<dyn Tool>>) -> Self {
self.with_step_tool_provider(Arc::new(StaticStepToolProvider::new(tools)))
}
#[must_use]
pub fn with_step_tool_provider(mut self, provider: Arc<dyn StepToolProvider>) -> Self {
self.step_tool_provider = Some(provider);
self
}
#[must_use]
pub fn with_llm_executor(mut self, executor: Arc<dyn LlmExecutor>) -> Self {
self.llm_executor = Some(executor);
self
}
#[must_use]
pub fn with_behavior(mut self, behavior: Arc<dyn AgentBehavior>) -> Self {
self.behavior = behavior;
self
}
pub fn has_behavior(&self) -> bool {
self.behavior.id() != "noop"
}
}