Skip to main content

potato_agent/agents/
agent.rs

1use crate::agents::{
2    error::AgentError,
3    task::Task,
4    types::{AgentResponse, PyAgentResponse},
5};
6use potato_provider::providers::anthropic::client::AnthropicClient;
7use potato_provider::providers::types::ServiceType;
8use potato_provider::GeminiClient;
9use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
10use potato_state::block_on;
11use potato_type::prompt::Prompt;
12use potato_type::prompt::{MessageNum, Role};
13use potato_type::Provider;
14use potato_type::{
15    prompt::extract_system_instructions,
16    tools::{Tool, ToolRegistry},
17};
18use potato_util::create_uuid7;
19use pyo3::prelude::*;
20use pyo3::types::PyList;
21use serde::{
22    de::{self, MapAccess, Visitor},
23    ser::SerializeStruct,
24    Deserializer, Serializer,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::sync::RwLock;
31use tracing::{debug, instrument, warn};
32
33#[derive(Debug, Clone)]
34pub struct Agent {
35    pub id: String,
36    client: Arc<GenAiClient>,
37    pub provider: Provider,
38    pub system_instruction: Vec<MessageNum>,
39    pub tools: Arc<RwLock<ToolRegistry>>, // Add tool registry
40    pub max_iterations: u32,
41}
42
43/// Rust method implementation of the Agent
44impl Agent {
45    /// Helper method to rebuild the client, useful for deserialization
46    #[instrument(skip_all)]
47    pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
48        let client = match self.provider {
49            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
50            Provider::Gemini => {
51                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
52            }
53            Provider::Vertex => {
54                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
55            }
56            _ => {
57                return Err(AgentError::MissingProviderError);
58            } // Add other providers here as needed
59        };
60
61        Ok(Self {
62            id: self.id.clone(),
63            client: Arc::new(client),
64            system_instruction: self.system_instruction.clone(),
65            provider: self.provider.clone(),
66            tools: self.tools.clone(),
67            max_iterations: self.max_iterations,
68        })
69    }
70    pub async fn new(
71        provider: Provider,
72        system_instruction: Option<Vec<MessageNum>>,
73    ) -> Result<Self, AgentError> {
74        let client = match provider {
75            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
76            Provider::Gemini => {
77                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
78            }
79            Provider::Vertex => {
80                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
81            }
82            Provider::Anthropic => {
83                GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
84            }
85            _ => {
86                return Err(AgentError::MissingProviderError);
87            } // Add other providers here as needed
88        };
89
90        Ok(Self {
91            client: Arc::new(client),
92            id: create_uuid7(),
93            system_instruction: system_instruction.unwrap_or_default(),
94            provider,
95            tools: Arc::new(RwLock::new(ToolRegistry::new())),
96            max_iterations: 10,
97        })
98    }
99
100    pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
101        self.tools.write().unwrap().register_tool(tool);
102    }
103
104    //TODO: add back later
105    /// Execute task with agentic reasoning loop
106    //pub async fn execute_agentic_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
107    //    let mut prompt = task.prompt.clone();
108    //    self.prepend_system_instructions(&mut prompt);
109
110    //    // Add tool definitions to prompt if tools are registered
111    //    let tool_definitions = self.tools.read().unwrap().get_definitions();
112    //    if !tool_definitions.is_empty() {
113    //        // Convert tools to provider-specific format and add to prompt
114    //        prompt.add_tools(tool_definitions)?;
115    //    }
116
117    //    let mut iteration = 0;
118    //    let mut conversation_history = Vec::new();
119
120    //    loop {
121    //        if iteration >= self.max_iterations {
122    //            return Err(AgentError::Error("Max iterations reached".to_string()));
123    //        }
124
125    //        // Generate response
126    //        let response = self.client.generate_content(&prompt).await?;
127
128    //        // Check if response contains tool calls
129    //        if let Some(tool_calls) = response.extract_tool_calls() {
130    //            debug!("Agent requesting {} tool calls", tool_calls.len());
131
132    //            // Execute all requested tools
133    //            let mut tool_results = Vec::new();
134    //            for tool_call in tool_calls {
135    //                let result = self.tools.read().unwrap().execute(&tool_call)?;
136    //                tool_results.push((tool_call.tool_name.clone(), result));
137    //            }
138
139    //            // Add tool results back to conversation
140    //            conversation_history.push(response.clone());
141    //            prompt.add_tool_results(tool_results)?;
142
143    //            iteration += 1;
144    //            continue;
145    //        }
146
147    //        // No tool calls - agent has final answer
148    //        return Ok(AgentResponse::new(task.id.clone(), response));
149    //    }
150    //}
151
152    #[instrument(skip_all)]
153    fn append_task_with_message_dependency_context(
154        &self,
155        task: &mut Task,
156        context_messages: &HashMap<String, Vec<MessageNum>>,
157    ) {
158        //
159        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
160
161        if task.dependencies.is_empty() {
162            return;
163        }
164
165        let messages = task.prompt.request.messages_mut();
166        let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
167
168        match first_user_idx {
169            Some(insert_idx) => {
170                // Collect all dependency messages to insert
171                let mut dependency_messages = Vec::new();
172
173                for dep_id in &task.dependencies {
174                    if let Some(messages) = context_messages.get(dep_id) {
175                        debug!(
176                            "Adding {} messages from dependency {}",
177                            messages.len(),
178                            dep_id
179                        );
180                        dependency_messages.extend(messages.iter().cloned());
181                    }
182                }
183
184                // Always insert at same index to keep pushing user message forward
185                for message in dependency_messages.into_iter() {
186                    task.prompt
187                        .request
188                        .insert_message(message, Some(insert_idx))
189                }
190
191                debug!(
192                    "Inserted {} dependency messages before user message at index {}",
193                    task.dependencies.len(),
194                    insert_idx
195                );
196            }
197            None => {
198                warn!(
199                    "No user message found in task {}, appending dependency context to end",
200                    task.id
201                );
202
203                for dep_id in &task.dependencies {
204                    if let Some(messages) = context_messages.get(dep_id) {
205                        for message in messages {
206                            task.prompt.request.push_message(message.clone());
207                        }
208                    }
209                }
210            }
211        }
212    }
213
214    /// This function will bind dependency-specific context and global context if provided to the user prompt.
215    ///
216    /// # Arguments:
217    /// * `prompt` - The prompt to bind parameters to.
218    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
219    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
220    ///
221    /// # Returns:
222    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
223    #[instrument(skip_all)]
224    fn bind_context(
225        &self,
226        prompt: &mut Prompt,
227        parameter_context: &Value,
228        global_context: &Option<Arc<Value>>,
229    ) -> Result<(), AgentError> {
230        // print user messages
231        if !prompt.parameters.is_empty() {
232            for param in &prompt.parameters {
233                // Bind parameter context to the user message
234                if let Some(value) = parameter_context.get(param) {
235                    for message in prompt.request.messages_mut() {
236                        if message.role() == Role::User.as_str() {
237                            debug!("Binding parameter: {} with value: {}", param, value);
238                            message.bind_mut(param, &value.to_string())?;
239                        }
240                    }
241                }
242
243                // If global context is provided, bind it to the user message
244                if let Some(global_value) = global_context {
245                    if let Some(value) = global_value.get(param) {
246                        for message in prompt.request.messages_mut() {
247                            if message.role() == Role::User.as_str() {
248                                debug!("Binding global parameter: {} with value: {}", param, value);
249                                message.bind_mut(param, &value.to_string())?;
250                            }
251                        }
252                    }
253                }
254            }
255        }
256        Ok(())
257    }
258
259    /// If system instructions are set on the agent, prepend them to the prompt.
260    /// Agent system instructions take precedence over task system instructions.
261    /// If a user wishes to be more dynamic based on the task, they should set system instructions on the task/prompt
262    fn prepend_system_instructions(&self, prompt: &mut Prompt) {
263        if !self.system_instruction.is_empty() {
264            prompt
265                .request
266                .prepend_system_instructions(self.system_instruction.clone())
267                .unwrap();
268        }
269    }
270    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
271        // Extract the prompt from the task
272        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
273        let mut prompt = task.prompt.clone();
274        self.prepend_system_instructions(&mut prompt);
275
276        // Use the client to execute the task
277        let chat_response = self.client.generate_content(&prompt).await?;
278
279        Ok(AgentResponse::new(task.id.clone(), chat_response))
280    }
281
282    #[instrument(skip_all)]
283    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
284        // Extract the prompt from the task
285        debug!("Executing prompt");
286        let mut prompt = prompt.clone();
287        self.prepend_system_instructions(&mut prompt);
288
289        // Use the client to execute the task
290        let chat_response = self.client.generate_content(&prompt).await?;
291
292        Ok(AgentResponse::new(chat_response.id(), chat_response))
293    }
294
295    /// Execute task with context without mutating the original task
296    /// This method is used by the workflow executor to run individual tasks with context
297    #[instrument(skip_all)]
298    pub async fn execute_task_with_context(
299        &self,
300        task: &Arc<RwLock<Task>>,
301        context: &Value,
302    ) -> Result<AgentResponse, AgentError> {
303        // Clone prompt and task_id without holding lock across await
304        let (mut prompt, task_id) = {
305            let task = task.read().unwrap();
306            (task.prompt.clone(), task.id.clone())
307        };
308
309        self.bind_context(&mut prompt, context, &None)?;
310        self.prepend_system_instructions(&mut prompt);
311
312        let chat_response = self.client.generate_content(&prompt).await?;
313        Ok(AgentResponse::new(task_id, chat_response))
314    }
315
316    pub async fn execute_task_with_context_message(
317        &self,
318        task: &Arc<RwLock<Task>>,
319        context_messages: HashMap<String, Vec<MessageNum>>,
320        parameter_context: Value,
321        global_context: Option<Arc<Value>>,
322    ) -> Result<AgentResponse, AgentError> {
323        // Prepare prompt and context before await
324        let (prompt, task_id) = {
325            let mut task = task.write().unwrap();
326            // 1. Add dependency context (should come after system instructions, before user message)
327            self.append_task_with_message_dependency_context(&mut task, &context_messages);
328            // 2. Bind parameters
329            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
330            // 3. Prepend agent system instructions (add to front)
331            self.prepend_system_instructions(&mut task.prompt);
332            (task.prompt.clone(), task.id.clone())
333        };
334
335        // Now do the async work without holding the lock
336        let chat_response = self.client.generate_content(&prompt).await?;
337        Ok(AgentResponse::new(task_id, chat_response))
338    }
339
340    pub fn client_provider(&self) -> &Provider {
341        self.client.provider()
342    }
343}
344
345impl PartialEq for Agent {
346    fn eq(&self, other: &Self) -> bool {
347        self.id == other.id
348            && self.provider == other.provider
349            && self.system_instruction == other.system_instruction
350            && self.max_iterations == other.max_iterations
351            && self.client == other.client
352    }
353}
354
355impl Serialize for Agent {
356    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
357    where
358        S: Serializer,
359    {
360        let mut state = serializer.serialize_struct("Agent", 3)?;
361        state.serialize_field("id", &self.id)?;
362        state.serialize_field("provider", &self.provider)?;
363        state.serialize_field("system_instruction", &self.system_instruction)?;
364        state.end()
365    }
366}
367
368/// Allows for deserialization of the Agent, re-initializing the client.
369impl<'de> Deserialize<'de> for Agent {
370    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
371    where
372        D: Deserializer<'de>,
373    {
374        #[derive(Deserialize)]
375        #[serde(field_identifier, rename_all = "snake_case")]
376        enum Field {
377            Id,
378            Provider,
379            SystemInstruction,
380        }
381
382        struct AgentVisitor;
383
384        impl<'de> Visitor<'de> for AgentVisitor {
385            type Value = Agent;
386
387            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
388                formatter.write_str("struct Agent")
389            }
390
391            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
392            where
393                V: MapAccess<'de>,
394            {
395                let mut id = None;
396                let mut provider = None;
397                let mut system_instruction = None;
398
399                while let Some(key) = map.next_key()? {
400                    match key {
401                        Field::Id => {
402                            id = Some(map.next_value()?);
403                        }
404                        Field::Provider => {
405                            provider = Some(map.next_value()?);
406                        }
407                        Field::SystemInstruction => {
408                            system_instruction = Some(map.next_value()?);
409                        }
410                    }
411                }
412
413                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
414                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
415                let system_instruction = system_instruction
416                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
417
418                // Deserialize is a sync op, so we can't await here (gemini requires async to init)
419                // After deserialization, we re-initialize the client based on the provider
420                let client = GenAiClient::Undefined;
421                Ok(Agent {
422                    id,
423                    client: Arc::new(client),
424                    system_instruction,
425                    provider,
426                    tools: Arc::new(RwLock::new(ToolRegistry::new())),
427                    max_iterations: 10,
428                })
429            }
430        }
431
432        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
433        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
434    }
435}
436
437#[pyclass(name = "Agent")]
438#[derive(Debug, Clone)]
439pub struct PyAgent {
440    pub agent: Arc<Agent>,
441}
442
443#[pymethods]
444impl PyAgent {
445    #[new]
446    #[pyo3(signature = (provider, system_instruction = None))]
447    /// Creates a new Agent instance.
448    ///
449    /// # Arguments:
450    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
451    /// that can be mapped to a provider variant
452    ///
453    pub fn new(
454        provider: &Bound<'_, PyAny>,
455        system_instruction: Option<&Bound<'_, PyAny>>,
456    ) -> Result<Self, AgentError> {
457        let provider = Provider::extract_provider(provider)?;
458        let system_instructions = extract_system_instructions(system_instruction, &provider)?;
459        let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
460
461        Ok(Self {
462            agent: Arc::new(agent),
463        })
464    }
465
466    #[pyo3(signature = (task, output_type=None))]
467    pub fn execute_task(
468        &self,
469        task: &mut Task,
470        output_type: Option<Bound<'_, PyAny>>,
471    ) -> Result<PyAgentResponse, AgentError> {
472        // Extract the prompt from the task
473        debug!("Executing task");
474
475        // agent provider and task.prompt provider must match
476        if task.prompt.provider != *self.agent.client_provider() {
477            return Err(AgentError::ProviderMismatch(
478                task.prompt.provider.to_string(),
479                self.agent.client_provider().as_str().to_string(),
480            ));
481        }
482
483        debug!(
484            "Task prompt model identifier: {}",
485            task.prompt.model_identifier()
486        );
487
488        let chat_response = block_on(async { self.agent.execute_task(task).await })?;
489
490        debug!("Task executed successfully");
491        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
492        let response = PyAgentResponse::new(chat_response, output);
493
494        Ok(response)
495    }
496
497    /// Executes a prompt directly without a task.
498    /// # Arguments:
499    /// * `prompt` - The prompt to execute.
500    /// * `output_type` - An optional Python type to bind the response to. If
501    /// provide, it is expected that the output type_object matches the response schema defined in the prompt.
502    /// # Returns:
503    /// * `PyAgentResponse` - The response from the agent.
504    #[pyo3(signature = (prompt, output_type=None))]
505    pub fn execute_prompt(
506        &self,
507        prompt: &mut Prompt,
508        output_type: Option<Bound<'_, PyAny>>,
509    ) -> Result<PyAgentResponse, AgentError> {
510        // Extract the prompt from the task
511        debug!("Executing task");
512
513        // agent provider and task.prompt provider must match
514        if prompt.provider != *self.agent.client_provider() {
515            return Err(AgentError::ProviderMismatch(
516                prompt.provider.to_string(),
517                self.agent.client_provider().as_str().to_string(),
518            ));
519        }
520
521        let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
522
523        debug!("Task executed successfully");
524        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
525        let response = PyAgentResponse::new(chat_response, output);
526
527        Ok(response)
528    }
529
530    #[getter]
531    pub fn system_instruction<'py>(
532        &self,
533        py: Python<'py>,
534    ) -> Result<Bound<'py, PyList>, AgentError> {
535        let instructions = self
536            .agent
537            .system_instruction
538            .iter()
539            .map(|msg_num| msg_num.to_bound_py_object(py))
540            .collect::<Result<Vec<_>, _>>()
541            .map(|instructions| PyList::new(py, &instructions))?;
542
543        Ok(instructions?)
544    }
545
546    #[getter]
547    pub fn id(&self) -> &str {
548        self.agent.id.as_str()
549    }
550}