use std::sync::Arc;
use autoagents_llm::LLMProvider;
use serde_json::Value;
use temporalio_macros::activities;
use temporalio_sdk::activities::{ActivityContext, ActivityError};
use crate::error::AgentError;
use crate::llm;
use crate::state::{LlmChatInput, LlmResponse, ToolCall, ToolResult};
use crate::tool::ToolRegistry;
#[derive(Clone)]
pub struct AgentActivities {
pub llm: Arc<dyn LLMProvider>,
pub tools: ToolRegistry,
}
impl AgentActivities {
pub fn new(llm: Arc<dyn LLMProvider>, tools: ToolRegistry) -> Self {
Self { llm, tools }
}
}
#[activities]
impl AgentActivities {
#[activity]
pub async fn llm_chat(
self: Arc<Self>,
_ctx: ActivityContext,
input: LlmChatInput,
) -> Result<LlmResponse, ActivityError> {
tracing::debug!(
messages = input.messages.len(),
tools = input.tools.len(),
"llm_chat: invoking LLM"
);
let response = llm::chat(&self.llm, &input.messages, &input.tools)
.await
.map_err(agent_err_to_activity_err)?;
Ok(response)
}
#[activity]
pub async fn execute_tool(
self: Arc<Self>,
_ctx: ActivityContext,
call: ToolCall,
) -> Result<ToolResult, ActivityError> {
let tool = self.tools.get(&call.name).ok_or_else(|| {
agent_err_to_activity_err(AgentError::ToolNotFound(call.name.clone()))
})?;
tracing::debug!(name = %call.name, id = %call.id, "execute_tool: dispatching");
match tool.execute(call.args.clone()).await {
Ok(output) => Ok(ToolResult {
call_id: call.id,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
call_id: call.id,
output: Value::Null,
error: Some(e.to_string()),
}),
}
}
}
fn agent_err_to_activity_err(e: AgentError) -> ActivityError {
ActivityError::from(anyhow::Error::from(e))
}