Skip to main content

potato_agent/agents/
agent.rs

1use crate::agents::{
2    error::AgentError,
3    task::Task,
4    types::{AgentResponse, PyAgentResponse},
5};
6use potato_provider::providers::anthropic::client::AnthropicClient;
7use potato_provider::providers::types::ServiceType;
8use potato_provider::GeminiClient;
9use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
10use potato_state::block_on;
11use potato_type::prompt::Prompt;
12use potato_type::prompt::{MessageNum, Role};
13use potato_type::Provider;
14use potato_type::{
15    prompt::extract_system_instructions,
16    tools::{Tool, ToolRegistry},
17};
18use potato_util::create_uuid7;
19use pyo3::prelude::*;
20use pyo3::types::PyList;
21use serde::{
22    de::{self, MapAccess, Visitor},
23    ser::SerializeStruct,
24    Deserializer, Serializer,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::sync::RwLock;
31use tracing::{debug, instrument, warn};
32
33#[derive(Debug, Clone)]
34pub struct Agent {
35    pub id: String,
36    client: Arc<GenAiClient>,
37    pub provider: Provider,
38    pub system_instruction: Vec<MessageNum>,
39    pub tools: Arc<RwLock<ToolRegistry>>, // Add tool registry
40    pub max_iterations: u32,
41}
42
43/// Rust method implementation of the Agent
44impl Agent {
45    /// Helper method to rebuild the client, useful for deserialization
46    #[instrument(skip_all)]
47    pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
48        let client = match self.provider {
49            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
50            Provider::Gemini => {
51                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
52            }
53            Provider::Vertex => {
54                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
55            }
56            Provider::Anthropic => {
57                GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
58            }
59            Provider::Google => {
60                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
61            }
62            _ => {
63                return Err(AgentError::MissingProviderError);
64            } // Add other providers here as needed
65        };
66
67        Ok(Self {
68            id: self.id.clone(),
69            client: Arc::new(client),
70            system_instruction: self.system_instruction.clone(),
71            provider: self.provider.clone(),
72            tools: self.tools.clone(),
73            max_iterations: self.max_iterations,
74        })
75    }
76    pub async fn new(
77        provider: Provider,
78        system_instruction: Option<Vec<MessageNum>>,
79    ) -> Result<Self, AgentError> {
80        let client = match provider {
81            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
82            Provider::Gemini => {
83                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
84            }
85            Provider::Vertex => {
86                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
87            }
88            Provider::Anthropic => {
89                GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
90            }
91            Provider::Google => {
92                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
93            }
94            _ => {
95                return Err(AgentError::MissingProviderError);
96            } // Add other providers here as needed
97        };
98
99        Ok(Self {
100            client: Arc::new(client),
101            id: create_uuid7(),
102            system_instruction: system_instruction.unwrap_or_default(),
103            provider,
104            tools: Arc::new(RwLock::new(ToolRegistry::new())),
105            max_iterations: 10,
106        })
107    }
108
109    pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
110        self.tools.write().unwrap().register_tool(tool);
111    }
112
113    //TODO: add back later
114    /// Execute task with agentic reasoning loop
115    //pub async fn execute_agentic_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
116    //    let mut prompt = task.prompt.clone();
117    //    self.prepend_system_instructions(&mut prompt);
118
119    //    // Add tool definitions to prompt if tools are registered
120    //    let tool_definitions = self.tools.read().unwrap().get_definitions();
121    //    if !tool_definitions.is_empty() {
122    //        // Convert tools to provider-specific format and add to prompt
123    //        prompt.add_tools(tool_definitions)?;
124    //    }
125
126    //    let mut iteration = 0;
127    //    let mut conversation_history = Vec::new();
128
129    //    loop {
130    //        if iteration >= self.max_iterations {
131    //            return Err(AgentError::Error("Max iterations reached".to_string()));
132    //        }
133
134    //        // Generate response
135    //        let response = self.client.generate_content(&prompt).await?;
136
137    //        // Check if response contains tool calls
138    //        if let Some(tool_calls) = response.extract_tool_calls() {
139    //            debug!("Agent requesting {} tool calls", tool_calls.len());
140
141    //            // Execute all requested tools
142    //            let mut tool_results = Vec::new();
143    //            for tool_call in tool_calls {
144    //                let result = self.tools.read().unwrap().execute(&tool_call)?;
145    //                tool_results.push((tool_call.tool_name.clone(), result));
146    //            }
147
148    //            // Add tool results back to conversation
149    //            conversation_history.push(response.clone());
150    //            prompt.add_tool_results(tool_results)?;
151
152    //            iteration += 1;
153    //            continue;
154    //        }
155
156    //        // No tool calls - agent has final answer
157    //        return Ok(AgentResponse::new(task.id.clone(), response));
158    //    }
159    //}
160
161    #[instrument(skip_all)]
162    fn append_task_with_message_dependency_context(
163        &self,
164        task: &mut Task,
165        context_messages: &HashMap<String, Vec<MessageNum>>,
166    ) {
167        //
168        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
169
170        if task.dependencies.is_empty() {
171            return;
172        }
173
174        let messages = task.prompt.request.messages_mut();
175        let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
176
177        match first_user_idx {
178            Some(insert_idx) => {
179                // Collect all dependency messages to insert
180                let mut dependency_messages = Vec::new();
181
182                for dep_id in &task.dependencies {
183                    if let Some(messages) = context_messages.get(dep_id) {
184                        debug!(
185                            "Adding {} messages from dependency {}",
186                            messages.len(),
187                            dep_id
188                        );
189                        dependency_messages.extend(messages.iter().cloned());
190                    }
191                }
192
193                // Always insert at same index to keep pushing user message forward
194                for message in dependency_messages.into_iter() {
195                    task.prompt
196                        .request
197                        .insert_message(message, Some(insert_idx))
198                }
199
200                debug!(
201                    "Inserted {} dependency messages before user message at index {}",
202                    task.dependencies.len(),
203                    insert_idx
204                );
205            }
206            None => {
207                warn!(
208                    "No user message found in task {}, appending dependency context to end",
209                    task.id
210                );
211
212                for dep_id in &task.dependencies {
213                    if let Some(messages) = context_messages.get(dep_id) {
214                        for message in messages {
215                            task.prompt.request.push_message(message.clone());
216                        }
217                    }
218                }
219            }
220        }
221    }
222
223    /// This function will bind dependency-specific context and global context if provided to the user prompt.
224    ///
225    /// # Arguments:
226    /// * `prompt` - The prompt to bind parameters to.
227    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
228    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
229    ///
230    /// # Returns:
231    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
232    #[instrument(skip_all)]
233    fn bind_context(
234        &self,
235        prompt: &mut Prompt,
236        parameter_context: &Value,
237        global_context: &Option<Arc<Value>>,
238    ) -> Result<(), AgentError> {
239        // print user messages
240        if !prompt.parameters.is_empty() {
241            for param in &prompt.parameters {
242                // Bind parameter context to the user message
243                if let Some(value) = parameter_context.get(param) {
244                    for message in prompt.request.messages_mut() {
245                        if message.role() == Role::User.as_str() {
246                            debug!("Binding parameter: {} with value: {}", param, value);
247                            message.bind_mut(param, &value.to_string())?;
248                        }
249                    }
250                }
251
252                // If global context is provided, bind it to the user message
253                if let Some(global_value) = global_context {
254                    if let Some(value) = global_value.get(param) {
255                        for message in prompt.request.messages_mut() {
256                            if message.role() == Role::User.as_str() {
257                                debug!("Binding global parameter: {} with value: {}", param, value);
258                                message.bind_mut(param, &value.to_string())?;
259                            }
260                        }
261                    }
262                }
263            }
264        }
265        Ok(())
266    }
267
268    /// If system instructions are set on the agent, prepend them to the prompt.
269    /// Agent system instructions take precedence over task system instructions.
270    /// If a user wishes to be more dynamic based on the task, they should set system instructions on the task/prompt
271    fn prepend_system_instructions(&self, prompt: &mut Prompt) {
272        if !self.system_instruction.is_empty() {
273            prompt
274                .request
275                .prepend_system_instructions(self.system_instruction.clone())
276                .unwrap();
277        }
278    }
279    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
280        // Extract the prompt from the task
281        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
282        let mut prompt = task.prompt.clone();
283        self.prepend_system_instructions(&mut prompt);
284
285        // Use the client to execute the task
286        let chat_response = self.client.generate_content(&prompt).await?;
287
288        Ok(AgentResponse::new(task.id.clone(), chat_response))
289    }
290
291    #[instrument(skip_all)]
292    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
293        // Extract the prompt from the task
294        debug!("Executing prompt");
295        let mut prompt = prompt.clone();
296        self.prepend_system_instructions(&mut prompt);
297
298        // Use the client to execute the task
299        let chat_response = self.client.generate_content(&prompt).await?;
300
301        Ok(AgentResponse::new(chat_response.id(), chat_response))
302    }
303
304    /// Execute task with context without mutating the original task
305    /// This method is used by the workflow executor to run individual tasks with context
306    #[instrument(skip_all)]
307    pub async fn execute_task_with_context(
308        &self,
309        task: &Arc<RwLock<Task>>,
310        context: &Value,
311    ) -> Result<AgentResponse, AgentError> {
312        // Clone prompt and task_id without holding lock across await
313        let (mut prompt, task_id) = {
314            let task = task.read().unwrap();
315            (task.prompt.clone(), task.id.clone())
316        };
317
318        self.bind_context(&mut prompt, context, &None)?;
319        self.prepend_system_instructions(&mut prompt);
320
321        let chat_response = self.client.generate_content(&prompt).await?;
322        Ok(AgentResponse::new(task_id, chat_response))
323    }
324
325    pub async fn execute_task_with_context_message(
326        &self,
327        task: &Arc<RwLock<Task>>,
328        context_messages: HashMap<String, Vec<MessageNum>>,
329        parameter_context: Value,
330        global_context: Option<Arc<Value>>,
331    ) -> Result<AgentResponse, AgentError> {
332        // Prepare prompt and context before await
333        let (prompt, task_id) = {
334            let mut task = task.write().unwrap();
335            // 1. Add dependency context (should come after system instructions, before user message)
336            self.append_task_with_message_dependency_context(&mut task, &context_messages);
337            // 2. Bind parameters
338            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
339            // 3. Prepend agent system instructions (add to front)
340            self.prepend_system_instructions(&mut task.prompt);
341            (task.prompt.clone(), task.id.clone())
342        };
343
344        // Now do the async work without holding the lock
345        let chat_response = self.client.generate_content(&prompt).await?;
346        Ok(AgentResponse::new(task_id, chat_response))
347    }
348
349    pub fn client_provider(&self) -> &Provider {
350        self.client.provider()
351    }
352}
353
354impl PartialEq for Agent {
355    fn eq(&self, other: &Self) -> bool {
356        self.id == other.id
357            && self.provider == other.provider
358            && self.system_instruction == other.system_instruction
359            && self.max_iterations == other.max_iterations
360            && self.client == other.client
361    }
362}
363
364impl Serialize for Agent {
365    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
366    where
367        S: Serializer,
368    {
369        let mut state = serializer.serialize_struct("Agent", 3)?;
370        state.serialize_field("id", &self.id)?;
371        state.serialize_field("provider", &self.provider)?;
372        state.serialize_field("system_instruction", &self.system_instruction)?;
373        state.end()
374    }
375}
376
377/// Allows for deserialization of the Agent, re-initializing the client.
378impl<'de> Deserialize<'de> for Agent {
379    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
380    where
381        D: Deserializer<'de>,
382    {
383        #[derive(Deserialize)]
384        #[serde(field_identifier, rename_all = "snake_case")]
385        enum Field {
386            Id,
387            Provider,
388            SystemInstruction,
389        }
390
391        struct AgentVisitor;
392
393        impl<'de> Visitor<'de> for AgentVisitor {
394            type Value = Agent;
395
396            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
397                formatter.write_str("struct Agent")
398            }
399
400            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
401            where
402                V: MapAccess<'de>,
403            {
404                let mut id = None;
405                let mut provider = None;
406                let mut system_instruction = None;
407
408                while let Some(key) = map.next_key()? {
409                    match key {
410                        Field::Id => {
411                            id = Some(map.next_value()?);
412                        }
413                        Field::Provider => {
414                            provider = Some(map.next_value()?);
415                        }
416                        Field::SystemInstruction => {
417                            system_instruction = Some(map.next_value()?);
418                        }
419                    }
420                }
421
422                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
423                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
424                let system_instruction = system_instruction
425                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
426
427                // Deserialize is a sync op, so we can't await here (gemini requires async to init)
428                // After deserialization, we re-initialize the client based on the provider
429                let client = GenAiClient::Undefined;
430                Ok(Agent {
431                    id,
432                    client: Arc::new(client),
433                    system_instruction,
434                    provider,
435                    tools: Arc::new(RwLock::new(ToolRegistry::new())),
436                    max_iterations: 10,
437                })
438            }
439        }
440
441        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
442        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
443    }
444}
445
446#[pyclass(name = "Agent")]
447#[derive(Debug, Clone)]
448pub struct PyAgent {
449    pub agent: Arc<Agent>,
450}
451
452#[pymethods]
453impl PyAgent {
454    #[new]
455    #[pyo3(signature = (provider, system_instruction = None))]
456    /// Creates a new Agent instance.
457    ///
458    /// # Arguments:
459    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
460    /// that can be mapped to a provider variant
461    ///
462    pub fn new(
463        provider: &Bound<'_, PyAny>,
464        system_instruction: Option<&Bound<'_, PyAny>>,
465    ) -> Result<Self, AgentError> {
466        let provider = Provider::extract_provider(provider)?;
467        let system_instructions = extract_system_instructions(system_instruction, &provider)?;
468        let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
469
470        Ok(Self {
471            agent: Arc::new(agent),
472        })
473    }
474
475    #[pyo3(signature = (task, output_type=None))]
476    pub fn execute_task(
477        &self,
478        task: &mut Task,
479        output_type: Option<Bound<'_, PyAny>>,
480    ) -> Result<PyAgentResponse, AgentError> {
481        // Extract the prompt from the task
482        debug!("Executing task");
483
484        // agent provider and task.prompt provider must match
485        if task.prompt.provider != *self.agent.client_provider() {
486            return Err(AgentError::ProviderMismatch(
487                task.prompt.provider.to_string(),
488                self.agent.client_provider().as_str().to_string(),
489            ));
490        }
491
492        debug!(
493            "Task prompt model identifier: {}",
494            task.prompt.model_identifier()
495        );
496
497        let chat_response = block_on(async { self.agent.execute_task(task).await })?;
498
499        debug!("Task executed successfully");
500        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
501        let response = PyAgentResponse::new(chat_response, output);
502
503        Ok(response)
504    }
505
506    /// Executes a prompt directly without a task.
507    /// # Arguments:
508    /// * `prompt` - The prompt to execute.
509    /// * `output_type` - An optional Python type to bind the response to. If
510    /// provide, it is expected that the output type_object matches the response schema defined in the prompt.
511    /// # Returns:
512    /// * `PyAgentResponse` - The response from the agent.
513    #[pyo3(signature = (prompt, output_type=None))]
514    pub fn execute_prompt(
515        &self,
516        prompt: &mut Prompt,
517        output_type: Option<Bound<'_, PyAny>>,
518    ) -> Result<PyAgentResponse, AgentError> {
519        // Extract the prompt from the task
520        debug!("Executing task");
521
522        // agent provider and task.prompt provider must match
523        if prompt.provider != *self.agent.client_provider() {
524            return Err(AgentError::ProviderMismatch(
525                prompt.provider.to_string(),
526                self.agent.client_provider().as_str().to_string(),
527            ));
528        }
529
530        let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
531
532        debug!("Task executed successfully");
533        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
534        let response = PyAgentResponse::new(chat_response, output);
535
536        Ok(response)
537    }
538
539    #[getter]
540    pub fn system_instruction<'py>(
541        &self,
542        py: Python<'py>,
543    ) -> Result<Bound<'py, PyList>, AgentError> {
544        let instructions = self
545            .agent
546            .system_instruction
547            .iter()
548            .map(|msg_num| msg_num.to_bound_py_object(py))
549            .collect::<Result<Vec<_>, _>>()
550            .map(|instructions| PyList::new(py, &instructions))?;
551
552        Ok(instructions?)
553    }
554
555    #[getter]
556    pub fn id(&self) -> &str {
557        self.agent.id.as_str()
558    }
559}