use crate::agents::{
callbacks::{AgentCallback, CallbackAction},
criteria::CompletionCriteria,
error::AgentError,
memory::{Memory, MemoryTurn},
run_context::{AgentRunConfig, AgentRunContext, ResumeContext},
runner::{AgentRunOutcome, AgentRunResult, AgentRunner},
session::{SessionSnapshot, SessionState},
store::{
app_state_store::AppStateStore, persistent_memory::PersistentMemory,
session_store::SessionStore, user_state_store::UserStateStore,
},
task::Task,
tool_ext::AgentTool,
types::{AgentResponse, PyAgentResponse},
};
use async_trait::async_trait;
use potato_provider::providers::anthropic::client::AnthropicClient;
use potato_provider::providers::types::ServiceType;
use potato_provider::GeminiClient;
use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
use potato_state::block_on;
use potato_type::prompt::Prompt;
use potato_type::prompt::{MessageNum, Role};
use potato_type::Provider;
use potato_type::{
prompt::extract_system_instructions,
tools::{Tool, ToolRegistry},
};
use potato_util::create_uuid7;
use pyo3::prelude::*;
use pyo3::types::PyList;
use serde::{
de::{self, MapAccess, Visitor},
ser::SerializeStruct,
Deserializer, Serializer,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use tracing::{debug, instrument, warn};
#[derive(Debug)]
pub struct Agent {
pub id: String,
client: Arc<GenAiClient>,
pub provider: Provider,
pub system_instruction: Vec<MessageNum>,
pub tools: Arc<RwLock<ToolRegistry>>,
pub max_iterations: u32,
pub run_config: Option<AgentRunConfig>,
pub model_override: Option<String>,
pub criteria: Vec<Box<dyn CompletionCriteria>>,
pub callbacks: Vec<Arc<dyn AgentCallback>>,
pub memory: Option<Arc<tokio::sync::Mutex<Box<dyn Memory>>>>,
pub app_name: Option<String>,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub session_store: Option<Arc<dyn SessionStore>>,
pub user_state_store: Option<Arc<dyn UserStateStore>>,
pub app_state_store: Option<Arc<dyn AppStateStore>>,
}
impl Agent {
#[instrument(skip_all)]
pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
let client = match self.provider {
Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
Provider::Gemini => {
GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
}
Provider::Vertex => {
GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
}
Provider::Anthropic => {
GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
}
Provider::Google => {
GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
}
_ => {
return Err(AgentError::MissingProviderError);
} };
Ok(Self {
id: self.id.clone(),
client: Arc::new(client),
system_instruction: self.system_instruction.clone(),
provider: self.provider.clone(),
tools: self.tools.clone(),
max_iterations: self.max_iterations,
run_config: None,
model_override: None,
criteria: Vec::new(),
callbacks: Vec::new(),
memory: None,
app_name: None,
user_id: None,
session_id: None,
session_store: None,
user_state_store: None,
app_state_store: None,
})
}
pub async fn new(
provider: Provider,
system_instruction: Option<Vec<MessageNum>>,
) -> Result<Self, AgentError> {
let client = match provider {
Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
Provider::Gemini => {
GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
}
Provider::Vertex => {
GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
}
Provider::Anthropic => {
GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
}
Provider::Google => {
GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
}
_ => {
return Err(AgentError::MissingProviderError);
} };
Ok(Self {
client: Arc::new(client),
id: create_uuid7(),
system_instruction: system_instruction.unwrap_or_default(),
provider,
tools: Arc::new(RwLock::new(ToolRegistry::new())),
max_iterations: 10,
run_config: None,
model_override: None,
criteria: Vec::new(),
callbacks: Vec::new(),
memory: None,
app_name: None,
user_id: None,
session_id: None,
session_store: None,
user_state_store: None,
app_state_store: None,
})
}
pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
self.tools
.write()
.unwrap_or_else(|e| e.into_inner())
.register_tool(tool);
}
#[instrument(skip_all)]
fn append_task_with_message_dependency_context(
&self,
task: &mut Task,
context_messages: &HashMap<String, Vec<MessageNum>>,
) {
debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
if task.dependencies.is_empty() {
return;
}
let messages = task.prompt.request.messages_mut();
let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
match first_user_idx {
Some(insert_idx) => {
let mut dependency_messages = Vec::new();
for dep_id in &task.dependencies {
if let Some(messages) = context_messages.get(dep_id) {
debug!(
"Adding {} messages from dependency {}",
messages.len(),
dep_id
);
dependency_messages.extend(messages.iter().cloned());
}
}
for message in dependency_messages.into_iter() {
task.prompt
.request
.insert_message(message, Some(insert_idx))
}
debug!(
"Inserted {} dependency messages before user message at index {}",
task.dependencies.len(),
insert_idx
);
}
None => {
warn!(
"No user message found in task {}, appending dependency context to end",
task.id
);
for dep_id in &task.dependencies {
if let Some(messages) = context_messages.get(dep_id) {
for message in messages {
task.prompt.request.push_message(message.clone());
}
}
}
}
}
}
#[instrument(skip_all)]
fn bind_context(
&self,
prompt: &mut Prompt,
parameter_context: &Value,
global_context: &Option<Arc<Value>>,
) -> Result<(), AgentError> {
if !prompt.parameters.is_empty() {
for param in &prompt.parameters {
if let Some(value) = parameter_context.get(param) {
for message in prompt.request.messages_mut() {
if message.role() == Role::User.as_str() {
debug!("Binding parameter: {} with value: {}", param, value);
message.bind_mut(param, &value.to_string())?;
}
}
}
if let Some(global_value) = global_context {
if let Some(value) = global_value.get(param) {
for message in prompt.request.messages_mut() {
if message.role() == Role::User.as_str() {
debug!("Binding global parameter: {} with value: {}", param, value);
message.bind_mut(param, &value.to_string())?;
}
}
}
}
}
}
Ok(())
}
fn prepend_system_instructions(&self, prompt: &mut Prompt) -> Result<(), AgentError> {
if !self.system_instruction.is_empty() {
prompt
.request
.prepend_system_instructions(self.system_instruction.clone())
.map_err(|e| AgentError::Error(e.to_string()))?;
}
Ok(())
}
pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
debug!("Executing task: {}, count: {}", task.id, task.retry_count);
let mut prompt = task.prompt.clone();
self.prepend_system_instructions(&mut prompt)?;
let chat_response = self.client.generate_content(&prompt).await?;
Ok(AgentResponse::new(task.id.clone(), chat_response))
}
#[instrument(skip_all)]
pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
debug!("Executing prompt");
let mut prompt = prompt.clone();
self.prepend_system_instructions(&mut prompt)?;
let chat_response = self.client.generate_content(&prompt).await?;
Ok(AgentResponse::new(chat_response.id(), chat_response))
}
#[instrument(skip_all)]
pub async fn execute_task_with_context(
&self,
task: &Arc<RwLock<Task>>,
context: &Value,
) -> Result<AgentResponse, AgentError> {
let (mut prompt, task_id) = {
let task = task.read().unwrap();
(task.prompt.clone(), task.id.clone())
};
self.bind_context(&mut prompt, context, &None)?;
self.prepend_system_instructions(&mut prompt)?;
let chat_response = self.client.generate_content(&prompt).await?;
Ok(AgentResponse::new(task_id, chat_response))
}
pub async fn execute_task_with_context_message(
&self,
task: &Arc<RwLock<Task>>,
context_messages: HashMap<String, Vec<MessageNum>>,
parameter_context: Value,
global_context: Option<Arc<Value>>,
) -> Result<AgentResponse, AgentError> {
let (prompt, task_id) = {
let mut task = task.write().unwrap();
self.append_task_with_message_dependency_context(&mut task, &context_messages);
self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
self.prepend_system_instructions(&mut task.prompt)?;
(task.prompt.clone(), task.id.clone())
};
let chat_response = self.client.generate_content(&prompt).await?;
Ok(AgentResponse::new(task_id, chat_response))
}
pub fn client_provider(&self) -> &Provider {
self.client.provider()
}
fn build_input_prompt(&self, input: &str) -> Result<Prompt, AgentError> {
use potato_type::prompt::builder::to_provider_request;
use potato_type::prompt::settings::ModelSettings;
use potato_type::prompt::types::ResponseType;
let msg = {
use potato_type::traits::MessageFactory;
match self.provider {
Provider::OpenAI => {
use potato_type::openai::v1::chat::request::ChatMessage;
ChatMessage::from_text(input.to_string(), "user")
.map(MessageNum::OpenAIMessageV1)?
}
Provider::Anthropic => {
use potato_type::anthropic::v1::request::MessageParam;
MessageParam::from_text(input.to_string(), "user")
.map(MessageNum::AnthropicMessageV1)?
}
Provider::Gemini | Provider::Google | Provider::Vertex => {
use potato_type::google::v1::generate::request::GeminiContent;
GeminiContent::from_text(input.to_string(), "user")
.map(MessageNum::GeminiContentV1)?
}
_ => {
return Err(AgentError::MissingProviderError);
}
}
};
let model = self.model_override.clone().ok_or_else(|| {
AgentError::Error("model must be set explicitly via AgentBuilder::model()".into())
})?;
let settings = ModelSettings::provider_default_settings(&self.provider);
let request = to_provider_request(
vec![msg],
self.system_instruction.clone(),
model.clone(),
settings,
None,
)?;
Ok(Prompt {
request,
model,
provider: self.provider.clone(),
version: env!("CARGO_PKG_VERSION").to_string(),
parameters: Vec::new(),
media_parameters: Vec::new(),
response_type: ResponseType::Null,
})
}
fn fire_before_model(&self, ctx: &AgentRunContext, prompt: &Prompt) -> Result<(), AgentError> {
for cb in &self.callbacks {
if let CallbackAction::Abort(msg) = cb.before_model_call(ctx, prompt) {
return Err(AgentError::CallbackAbort(msg));
}
}
Ok(())
}
fn fire_after_model(
&self,
ctx: &AgentRunContext,
response: &AgentResponse,
) -> Result<Option<String>, AgentError> {
for cb in &self.callbacks {
match cb.after_model_call(ctx, response) {
CallbackAction::Abort(msg) => return Err(AgentError::CallbackAbort(msg)),
CallbackAction::OverrideResponse(text) => return Ok(Some(text)),
CallbackAction::Continue => {}
}
}
Ok(None)
}
fn fire_before_tool(
&self,
ctx: &AgentRunContext,
call: &potato_type::tools::ToolCall,
) -> Result<(), AgentError> {
for cb in &self.callbacks {
if let CallbackAction::Abort(msg) = cb.before_tool_call(ctx, call) {
return Err(AgentError::CallbackAbort(msg));
}
}
Ok(())
}
fn fire_after_tool(
&self,
ctx: &AgentRunContext,
call: &potato_type::tools::ToolCall,
result: &serde_json::Value,
) -> Result<(), AgentError> {
for cb in &self.callbacks {
if let CallbackAction::Abort(msg) = cb.after_tool_call(ctx, call, result) {
return Err(AgentError::CallbackAbort(msg));
}
}
Ok(())
}
}
#[async_trait]
impl AgentRunner for Agent {
fn id(&self) -> &str {
&self.id
}
async fn run(
&self,
input: &str,
session: &mut SessionState,
) -> Result<AgentRunOutcome, AgentError> {
let max_iter = self
.run_config
.as_ref()
.map(|c| c.max_iterations)
.unwrap_or(self.max_iterations);
let mut run_ctx = AgentRunContext::new(self.id.clone(), max_iter);
let app = self.app_name.as_deref().unwrap_or("default");
let uid = self.user_id.as_deref().unwrap_or("default");
if let Some(store) = &self.app_state_store {
if let Some(snapshot) = store.load(app).await? {
session.merge(snapshot.0);
}
}
if let Some(store) = &self.user_state_store {
if let Some(snapshot) = store.load(app, uid).await? {
session.merge(snapshot.0);
}
}
if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
if let Some(snapshot) = store.load(app, uid, sid).await? {
session.merge(snapshot.0);
}
}
let mut prompt = self.build_input_prompt(input)?;
if let Some(mem_lock) = &self.memory {
let mut mem = mem_lock.lock().await;
if let Some(pm) = mem
.as_any_mut()
.and_then(|a| a.downcast_mut::<PersistentMemory>())
{
pm.hydrate().await?;
}
}
if let Some(mem_lock) = &self.memory {
let mem = mem_lock.lock().await;
let history = mem.messages();
if !history.is_empty() {
let insert_at = prompt
.request
.messages()
.iter()
.position(|m| !m.is_system_message())
.unwrap_or(0);
for (i, msg) in history.into_iter().enumerate() {
prompt.request.insert_message(msg, Some(insert_at + i));
}
}
}
{
let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
let defs = registry.get_all_definitions();
if !defs.is_empty() {
prompt.request.add_tools(defs)?;
}
}
let mut last_user_msg: Option<MessageNum> = None;
if let Some(msg) = prompt.request.messages().last().cloned() {
last_user_msg = Some(msg);
}
loop {
if run_ctx.iteration >= max_iter {
break;
}
self.fire_before_model(&run_ctx, &prompt)?;
let chat_response = self.client.generate_content(&prompt).await?;
let agent_response = AgentResponse::new(chat_response.id(), chat_response.clone());
if let Some(override_text) = self.fire_after_model(&run_ctx, &agent_response)? {
run_ctx.push_response(override_text.clone());
return Ok(AgentRunOutcome::complete(AgentRunResult {
final_response: agent_response,
iterations: run_ctx.iteration,
completion_reason: format!("callback override: {}", override_text),
combined_text: None,
}));
}
if let Some(tool_calls) = chat_response.extract_tool_calls() {
let assistant_msgs = chat_response.to_message_num(&self.provider)?;
for msg in assistant_msgs {
prompt.request.push_message(msg);
}
for call in &tool_calls {
self.fire_before_tool(&run_ctx, call)?;
let result = {
let async_tool = {
let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
registry.get_async_tool(&call.tool_name)
};
if let Some(tool) = async_tool {
if let Some(agent_tool) =
tool.as_any().and_then(|a| a.downcast_ref::<AgentTool>())
{
agent_tool
.dispatch(call.arguments.clone(), session)
.await
.map_err(|e| {
AgentError::Error(format!(
"Tool '{}' failed: {}",
call.tool_name, e
))
})?
} else {
tool.execute(call.arguments.clone()).await.map_err(|e| {
AgentError::Error(format!(
"Tool '{}' failed: {}",
call.tool_name, e
))
})?
}
} else {
let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
registry.execute(call).map_err(|e| {
AgentError::Error(format!(
"Tool '{}' failed: {}",
call.tool_name, e
))
})?
}
};
self.fire_after_tool(&run_ctx, call, &result)?;
prompt.request.add_tool_result(call, &result)?;
}
run_ctx.increment();
continue;
}
let text = chat_response.response_text();
if text.trim().starts_with("__ask_user__:") {
let question = text.trim_start_matches("__ask_user__:").trim().to_string();
let resume_ctx = ResumeContext {
agent_id: self.id.clone(),
iteration: run_ctx.iteration,
session_snapshot: session.snapshot(),
};
return Ok(AgentRunOutcome::NeedsInput {
question,
resume_context: resume_ctx,
});
}
run_ctx.push_response(text);
let met = self.criteria.iter().any(|c| c.is_complete(&run_ctx));
let reason = if met {
self.criteria
.iter()
.find(|c| c.is_complete(&run_ctx))
.map(|c| c.completion_reason(&run_ctx))
.unwrap_or_else(|| "criteria met".into())
} else {
String::new()
};
if met || run_ctx.iteration + 1 >= max_iter {
if let Some(mem_lock) = &self.memory {
let mut mem = mem_lock.lock().await;
if let Some(user_msg) = last_user_msg.take() {
let assistant_msgs = chat_response.to_message_num(&self.provider)?;
if let Some(asst_msg) = assistant_msgs.into_iter().next() {
let turn = MemoryTurn {
user: user_msg,
assistant: asst_msg,
};
if let Some(pm) = mem
.as_any_mut()
.and_then(|a| a.downcast_mut::<PersistentMemory>())
{
pm.push_turn_async(turn).await?;
} else {
mem.push_turn(turn);
}
}
}
}
if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
let snapshot = SessionSnapshot::from(&*session);
store.save(app, uid, sid, &snapshot).await?;
}
return Ok(AgentRunOutcome::complete(AgentRunResult {
final_response: agent_response,
iterations: run_ctx.iteration,
completion_reason: if met {
reason
} else {
format!("max iterations ({}) reached", max_iter)
},
combined_text: None,
}));
}
let assistant_msgs = chat_response.to_message_num(&self.provider)?;
for msg in assistant_msgs {
prompt.request.push_message(msg);
}
run_ctx.increment();
}
Err(AgentError::MaxIterationsExceeded(max_iter))
}
async fn resume(
&self,
user_answer: &str,
ctx: ResumeContext,
session: &mut SessionState,
) -> Result<AgentRunOutcome, AgentError> {
session.merge(ctx.session_snapshot);
self.run(user_answer, session).await
}
}
impl Clone for Agent {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
client: self.client.clone(),
provider: self.provider.clone(),
system_instruction: self.system_instruction.clone(),
tools: self.tools.clone(),
max_iterations: self.max_iterations,
run_config: self.run_config.clone(),
model_override: self.model_override.clone(),
criteria: Vec::new(),
callbacks: Vec::new(),
memory: None,
app_name: None,
user_id: None,
session_id: None,
session_store: None,
user_state_store: None,
app_state_store: None,
}
}
}
impl PartialEq for Agent {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
&& self.provider == other.provider
&& self.system_instruction == other.system_instruction
&& self.max_iterations == other.max_iterations
&& self.client == other.client
}
}
impl Serialize for Agent {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("Agent", 3)?;
state.serialize_field("id", &self.id)?;
state.serialize_field("provider", &self.provider)?;
state.serialize_field("system_instruction", &self.system_instruction)?;
state.end()
}
}
impl<'de> Deserialize<'de> for Agent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Id,
Provider,
SystemInstruction,
}
struct AgentVisitor;
impl<'de> Visitor<'de> for AgentVisitor {
type Value = Agent;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct Agent")
}
fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
where
V: MapAccess<'de>,
{
let mut id = None;
let mut provider = None;
let mut system_instruction = None;
while let Some(key) = map.next_key()? {
match key {
Field::Id => {
id = Some(map.next_value()?);
}
Field::Provider => {
provider = Some(map.next_value()?);
}
Field::SystemInstruction => {
system_instruction = Some(map.next_value()?);
}
}
}
let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
let system_instruction = system_instruction
.ok_or_else(|| de::Error::missing_field("system_instruction"))?;
let client = GenAiClient::Undefined;
Ok(Agent {
id,
client: Arc::new(client),
system_instruction,
provider,
tools: Arc::new(RwLock::new(ToolRegistry::new())),
max_iterations: 10,
run_config: None,
model_override: None,
criteria: Vec::new(),
callbacks: Vec::new(),
memory: None,
app_name: None,
user_id: None,
session_id: None,
session_store: None,
user_state_store: None,
app_state_store: None,
})
}
}
const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
}
}
#[pyclass(from_py_object, name = "Agent")]
#[derive(Debug, Clone)]
pub struct PyAgent {
pub agent: Arc<Agent>,
}
#[pymethods]
impl PyAgent {
#[new]
#[pyo3(signature = (provider=None, system_instruction = None))]
pub fn new(
provider: Option<&Bound<'_, PyAny>>,
system_instruction: Option<&Bound<'_, PyAny>>,
) -> Result<Self, AgentError> {
let provider = Provider::resolve_from_py(provider)?;
let system_instructions = extract_system_instructions(system_instruction, &provider)?;
let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
Ok(Self {
agent: Arc::new(agent),
})
}
#[pyo3(signature = (task, output_type=None))]
pub fn execute_task(
&self,
task: &mut Task,
output_type: Option<Bound<'_, PyAny>>,
) -> Result<PyAgentResponse, AgentError> {
debug!("Executing task");
if task.prompt.provider != *self.agent.client_provider() {
return Err(AgentError::ProviderMismatch(
task.prompt.provider.to_string(),
self.agent.client_provider().as_str().to_string(),
));
}
debug!(
"Task prompt model identifier: {}",
task.prompt.model_identifier()
);
let chat_response = block_on(async { self.agent.execute_task(task).await })?;
debug!("Task executed successfully");
let output = output_type.as_ref().map(|obj| obj.clone().unbind());
let response = PyAgentResponse::new(chat_response, output);
Ok(response)
}
#[pyo3(signature = (prompt, output_type=None))]
pub fn execute_prompt(
&self,
prompt: &mut Prompt,
output_type: Option<Bound<'_, PyAny>>,
) -> Result<PyAgentResponse, AgentError> {
debug!("Executing task");
if prompt.provider != *self.agent.client_provider() {
return Err(AgentError::ProviderMismatch(
prompt.provider.to_string(),
self.agent.client_provider().as_str().to_string(),
));
}
let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
debug!("Task executed successfully");
let output = output_type.as_ref().map(|obj| obj.clone().unbind());
let response = PyAgentResponse::new(chat_response, output);
Ok(response)
}
#[getter]
pub fn system_instruction<'py>(
&self,
py: Python<'py>,
) -> Result<Bound<'py, PyList>, AgentError> {
let instructions = self
.agent
.system_instruction
.iter()
.map(|msg_num| msg_num.to_bound_py_object(py))
.collect::<Result<Vec<_>, _>>()
.map(|instructions| PyList::new(py, &instructions))?;
Ok(instructions?)
}
#[getter]
pub fn id(&self) -> &str {
self.agent.id.as_str()
}
}