potato_agent/agents/
agent.rs

1use crate::agents::provider::gemini::GeminiClient;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::{
4    agents::client::GenAiClient,
5    agents::error::AgentError,
6    agents::task::Task,
7    agents::types::{AgentResponse, PyAgentResponse},
8};
9use potato_prompt::{
10    parse_response_to_json, prompt::parse_prompt, prompt::types::Message, ModelSettings, Prompt,
11    Role,
12};
13use potato_type::Model;
14use potato_type::Provider;
15use potato_util::create_uuid7;
16use pyo3::{prelude::*, IntoPyObjectExt};
17use serde::{
18    de::{self, MapAccess, Visitor},
19    ser::SerializeStruct,
20    Deserializer, Serializer,
21};
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24use std::collections::HashMap;
25use std::sync::Arc;
26use std::sync::RwLock;
27use tracing::{debug, error, instrument, warn};
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct Agent {
31    pub id: String,
32
33    client: GenAiClient,
34
35    pub system_instruction: Vec<Message>,
36}
37
38/// Rust method implementation of the Agent
39impl Agent {
40    pub fn new(
41        provider: Provider,
42        system_instruction: Option<Vec<Message>>,
43    ) -> Result<Self, AgentError> {
44        let client = match provider {
45            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
46            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
47            _ => {
48                let msg = "No provider specified in ModelSettings";
49                error!("{}", msg);
50                return Err(AgentError::UndefinedError(msg.to_string()));
51            } // Add other providers here as needed
52        };
53
54        let system_instruction = system_instruction.unwrap_or_default();
55
56        Ok(Self {
57            client,
58            id: create_uuid7(),
59            system_instruction,
60        })
61    }
62
63    #[instrument(skip_all)]
64    fn append_task_with_message_context(
65        &self,
66        task: &mut Task,
67        context_messages: &HashMap<String, Vec<Message>>,
68    ) {
69        //
70        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
71        if !task.dependencies.is_empty() {
72            for dep in &task.dependencies {
73                if let Some(messages) = context_messages.get(dep) {
74                    for message in messages {
75                        // prepend the messages from dependencies
76                        task.prompt.message.insert(0, message.clone());
77                    }
78                }
79            }
80        }
81    }
82
83    /// This function will bind dependency-specific context and global context if provided to the user prompt.
84    ///
85    /// # Arguments:
86    /// * `prompt` - The prompt to bind parameters to.
87    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
88    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
89    ///
90    /// # Returns:
91    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
92    #[instrument(skip_all)]
93    fn bind_context(
94        &self,
95        prompt: &mut Prompt,
96        parameter_context: &Value,
97        global_context: &Option<Value>,
98    ) -> Result<(), AgentError> {
99        // print user messages
100        if !prompt.parameters.is_empty() {
101            for param in &prompt.parameters {
102                // Bind parameter context to the user message
103                if let Some(value) = parameter_context.get(param) {
104                    for message in &mut prompt.message {
105                        if message.role == "user" {
106                            debug!("Binding parameter: {} with value: {}", param, value);
107                            message.bind_mut(param, &value.to_string())?;
108                        }
109                    }
110                }
111
112                // If global context is provided, bind it to the user message
113                if let Some(global_value) = global_context {
114                    if let Some(value) = global_value.get(param) {
115                        for message in &mut prompt.message {
116                            if message.role == "user" {
117                                debug!("Binding global parameter: {} with value: {}", param, value);
118                                message.bind_mut(param, &value.to_string())?;
119                            }
120                        }
121                    }
122                }
123            }
124        }
125        Ok(())
126    }
127
128    fn append_system_instructions(&self, prompt: &mut Prompt) {
129        if !self.system_instruction.is_empty() {
130            let mut combined_messages = self.system_instruction.clone();
131            combined_messages.extend(prompt.system_instruction.clone());
132            prompt.system_instruction = combined_messages;
133        }
134    }
135    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
136        // Extract the prompt from the task
137        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
138        let mut prompt = task.prompt.clone();
139        self.append_system_instructions(&mut prompt);
140
141        // Use the client to execute the task
142        let chat_response = self.client.execute(&prompt).await?;
143
144        Ok(AgentResponse::new(task.id.clone(), chat_response))
145    }
146
147    #[instrument(skip_all)]
148    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
149        // Extract the prompt from the task
150        debug!("Executing prompt");
151        let mut prompt = prompt.clone();
152        self.append_system_instructions(&mut prompt);
153
154        // Use the client to execute the task
155        let chat_response = self.client.execute(&prompt).await?;
156
157        Ok(AgentResponse::new(chat_response.id(), chat_response))
158    }
159
160    pub async fn execute_task_with_context(
161        &self,
162        task: &Arc<RwLock<Task>>,
163        context_messages: HashMap<String, Vec<Message>>,
164        parameter_context: Value,
165        global_context: Option<Value>,
166    ) -> Result<AgentResponse, AgentError> {
167        // Prepare prompt and context before await
168        let (prompt, task_id) = {
169            let mut task = task.write().unwrap();
170            self.append_task_with_message_context(&mut task, &context_messages);
171            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
172
173            self.append_system_instructions(&mut task.prompt);
174            (task.prompt.clone(), task.id.clone())
175        };
176
177        // Now do the async work without holding the lock
178        let chat_response = self.client.execute(&prompt).await?;
179
180        Ok(AgentResponse::new(task_id, chat_response))
181    }
182
183    pub fn provider(&self) -> &Provider {
184        self.client.provider()
185    }
186
187    pub fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
188        let provider = Provider::from_string(&model_settings.provider)?;
189        let client = match provider {
190            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
191            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
192            Provider::Undefined => {
193                let msg = "No provider specified in ModelSettings";
194                error!("{}", msg);
195                return Err(AgentError::UndefinedError(msg.to_string()));
196            }
197        };
198
199        Ok(Self {
200            client,
201            id: create_uuid7(),
202            system_instruction: Vec::new(),
203        })
204    }
205}
206
207impl Serialize for Agent {
208    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
209    where
210        S: Serializer,
211    {
212        let mut state = serializer.serialize_struct("Agent", 3)?;
213        state.serialize_field("id", &self.id)?;
214        state.serialize_field("provider", &self.client.provider())?;
215        state.serialize_field("system_instruction", &self.system_instruction)?;
216        state.end()
217    }
218}
219
220/// Allows for deserialization of the Agent, re-initializing the client.
221impl<'de> Deserialize<'de> for Agent {
222    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223    where
224        D: Deserializer<'de>,
225    {
226        #[derive(Deserialize)]
227        #[serde(field_identifier, rename_all = "snake_case")]
228        enum Field {
229            Id,
230            Provider,
231            SystemInstruction,
232        }
233
234        struct AgentVisitor;
235
236        impl<'de> Visitor<'de> for AgentVisitor {
237            type Value = Agent;
238
239            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
240                formatter.write_str("struct Agent")
241            }
242
243            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
244            where
245                V: MapAccess<'de>,
246            {
247                let mut id = None;
248                let mut provider = None;
249                let mut system_instruction = None;
250
251                while let Some(key) = map.next_key()? {
252                    match key {
253                        Field::Id => {
254                            id = Some(map.next_value()?);
255                        }
256                        Field::Provider => {
257                            provider = Some(map.next_value()?);
258                        }
259                        Field::SystemInstruction => {
260                            system_instruction = Some(map.next_value()?);
261                        }
262                    }
263                }
264
265                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
266                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
267                let system_instruction = system_instruction
268                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
269
270                // Re-initialize the client based on the provider
271                let client = match provider {
272                    Provider::OpenAI => {
273                        GenAiClient::OpenAI(OpenAIClient::new(None, None, None).map_err(|e| {
274                            de::Error::custom(format!("Failed to initialize OpenAIClient: {e}"))
275                        })?)
276                    }
277                    Provider::Gemini => {
278                        GenAiClient::Gemini(GeminiClient::new(None, None, None).map_err(|e| {
279                            de::Error::custom(format!("Failed to initialize GeminiClient: {e}"))
280                        })?)
281                    }
282
283                    Provider::Undefined => {
284                        let msg = "No provider specified in ModelSettings";
285                        error!("{}", msg);
286                        return Err(de::Error::custom(msg));
287                    }
288                };
289
290                Ok(Agent {
291                    id,
292                    client,
293                    system_instruction,
294                })
295            }
296        }
297
298        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
299        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
300    }
301}
302
303#[pyclass(name = "Agent")]
304#[derive(Debug, Clone)]
305pub struct PyAgent {
306    pub agent: Arc<Agent>,
307    pub runtime: Arc<tokio::runtime::Runtime>,
308}
309
310#[pymethods]
311impl PyAgent {
312    #[new]
313    #[pyo3(signature = (provider, system_instruction = None))]
314    /// Creates a new Agent instance.
315    ///
316    /// # Arguments:
317    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
318    /// that can be mapped to a provider variant
319    ///
320    pub fn new(
321        provider: &Bound<'_, PyAny>,
322        system_instruction: Option<&Bound<'_, PyAny>>,
323    ) -> Result<Self, AgentError> {
324        let provider = Provider::extract_provider(provider)?;
325
326        let system_instruction = if let Some(system_instruction) = system_instruction {
327            Some(
328                parse_prompt(system_instruction)?
329                    .into_iter()
330                    .map(|mut msg| {
331                        msg.role = Role::Developer.to_string();
332                        msg
333                    })
334                    .collect::<Vec<Message>>(),
335            )
336        } else {
337            None
338        };
339
340        let agent = Agent::new(provider, system_instruction)?;
341
342        Ok(Self {
343            agent: Arc::new(agent),
344            runtime: Arc::new(
345                tokio::runtime::Runtime::new()
346                    .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
347            ),
348        })
349    }
350
351    #[pyo3(signature = (task, output_type=None, model=None))]
352    pub fn execute_task(
353        &self,
354        py: Python<'_>,
355        task: &mut Task,
356        output_type: Option<Bound<'_, PyAny>>,
357        model: Option<&str>,
358    ) -> Result<PyAgentResponse, AgentError> {
359        // Extract the prompt from the task
360        debug!("Executing task");
361
362        // if output_type is not None,  mutate task prompt
363        if let Some(output_type) = &output_type {
364            match parse_response_to_json(py, output_type) {
365                Ok((response_type, response_format)) => {
366                    task.prompt.response_type = response_type;
367                    task.prompt.response_json_schema = response_format;
368                }
369                Err(_) => {
370                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
371                }
372            }
373        }
374
375        // if model is none and task.prompt.model is None, fail
376        if model.is_none() && task.prompt.model() == Model::Undefined.as_str() {
377            return Err(AgentError::UndefinedError(
378                "Model must be specified either as an argument or in the Task prompt".to_string(),
379            ));
380        }
381
382        // if model is not None, set task prompt model (this will override the task prompt model)
383        if let Some(model) = model {
384            task.prompt.set_model(model);
385        }
386
387        // check if prompt.provider is None, if so, set it to the agent provider
388        if task.prompt.provider() == Model::Undefined.as_str() {
389            task.prompt.set_provider(self.agent.provider().as_str());
390        }
391
392        println!("Task prompt model: {}", task.prompt.model());
393        println!("Task prompt provider: {}", task.prompt.provider());
394
395        let chat_response = self
396            .runtime
397            .block_on(async { self.agent.execute_task(task).await })?;
398
399        debug!("Task executed successfully");
400        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
401        let response = PyAgentResponse::new(chat_response, output);
402
403        Ok(response)
404    }
405
406    #[pyo3(signature = (prompt, output_type=None, model=None))]
407    pub fn execute_prompt(
408        &self,
409        py: Python<'_>,
410        prompt: &mut Prompt,
411        output_type: Option<Bound<'_, PyAny>>,
412        model: Option<&str>,
413    ) -> Result<PyAgentResponse, AgentError> {
414        // Extract the prompt from the task
415        debug!("Executing task");
416        // if output_type is not None,  mutate task prompt
417        if let Some(output_type) = &output_type {
418            match parse_response_to_json(py, output_type) {
419                Ok((response_type, response_format)) => {
420                    prompt.response_type = response_type;
421                    prompt.response_json_schema = response_format;
422                }
423                Err(_) => {
424                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
425                }
426            }
427        }
428
429        // if model is none and task.prompt.model is None, fail
430        if model.is_none() && prompt.model() == Model::Undefined.as_str() {
431            return Err(AgentError::UndefinedError(
432                "Model must be specified either as an argument or in the Prompt".to_string(),
433            ));
434        }
435
436        // if model is not None, set task prompt model (this will override the task prompt model)
437        if let Some(model) = model {
438            prompt.set_model(model);
439        }
440
441        if prompt.provider() == Model::Undefined.as_str() {
442            prompt.set_provider(self.agent.provider().as_str());
443        }
444
445        let chat_response = self
446            .runtime
447            .block_on(async { self.agent.execute_prompt(prompt).await })?;
448
449        debug!("Task executed successfully");
450        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
451        let response = PyAgentResponse::new(chat_response, output);
452
453        Ok(response)
454    }
455
456    #[getter]
457    pub fn system_instruction<'py>(
458        &self,
459        py: Python<'py>,
460    ) -> Result<Bound<'py, PyAny>, AgentError> {
461        Ok(self
462            .agent
463            .system_instruction
464            .clone()
465            .into_bound_py_any(py)?)
466    }
467
468    #[getter]
469    pub fn id(&self) -> &str {
470        self.agent.id.as_str()
471    }
472}