potato_agent/agents/
agent.rs

1use crate::agents::provider::gemini::GeminiClient;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::{
4    agents::client::GenAiClient,
5    agents::error::AgentError,
6    agents::task::Task,
7    agents::types::{AgentResponse, PyAgentResponse},
8};
9use potato_prompt::prompt::settings::ModelSettings;
10use potato_prompt::{
11    parse_response_to_json, prompt::parse_prompt, prompt::types::Message, Prompt, Role,
12};
13use potato_type::Provider;
14use potato_util::create_uuid7;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use serde::{
17    de::{self, MapAccess, Visitor},
18    ser::SerializeStruct,
19    Deserializer, Serializer,
20};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::sync::RwLock;
26use tracing::{debug, error, instrument, warn};
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct Agent {
30    pub id: String,
31
32    client: GenAiClient,
33
34    pub system_instruction: Vec<Message>,
35}
36
37/// Rust method implementation of the Agent
38impl Agent {
39    pub fn new(
40        provider: Provider,
41        system_instruction: Option<Vec<Message>>,
42    ) -> Result<Self, AgentError> {
43        let client = match provider {
44            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
45            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
46            _ => {
47                let msg = "No provider specified in ModelSettings";
48                error!("{}", msg);
49                return Err(AgentError::UndefinedError(msg.to_string()));
50            } // Add other providers here as needed
51        };
52
53        let system_instruction = system_instruction.unwrap_or_default();
54
55        Ok(Self {
56            client,
57            id: create_uuid7(),
58            system_instruction,
59        })
60    }
61
62    #[instrument(skip_all)]
63    fn append_task_with_message_context(
64        &self,
65        task: &mut Task,
66        context_messages: &HashMap<String, Vec<Message>>,
67    ) {
68        //
69        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
70        if !task.dependencies.is_empty() {
71            for dep in &task.dependencies {
72                if let Some(messages) = context_messages.get(dep) {
73                    for message in messages {
74                        // prepend the messages from dependencies
75                        task.prompt.message.insert(0, message.clone());
76                    }
77                }
78            }
79        }
80    }
81
82    /// This function will bind dependency-specific context and global context if provided to the user prompt.
83    ///
84    /// # Arguments:
85    /// * `prompt` - The prompt to bind parameters to.
86    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
87    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
88    ///
89    /// # Returns:
90    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
91    #[instrument(skip_all)]
92    fn bind_context(
93        &self,
94        prompt: &mut Prompt,
95        parameter_context: &Value,
96        global_context: &Option<Value>,
97    ) -> Result<(), AgentError> {
98        // print user messages
99        if !prompt.parameters.is_empty() {
100            for param in &prompt.parameters {
101                // Bind parameter context to the user message
102                if let Some(value) = parameter_context.get(param) {
103                    for message in &mut prompt.message {
104                        if message.role == "user" {
105                            debug!("Binding parameter: {} with value: {}", param, value);
106                            message.bind_mut(param, &value.to_string())?;
107                        }
108                    }
109                }
110
111                // If global context is provided, bind it to the user message
112                if let Some(global_value) = global_context {
113                    if let Some(value) = global_value.get(param) {
114                        for message in &mut prompt.message {
115                            if message.role == "user" {
116                                debug!("Binding global parameter: {} with value: {}", param, value);
117                                message.bind_mut(param, &value.to_string())?;
118                            }
119                        }
120                    }
121                }
122            }
123        }
124        Ok(())
125    }
126
127    fn append_system_instructions(&self, prompt: &mut Prompt) {
128        if !self.system_instruction.is_empty() {
129            let mut combined_messages = self.system_instruction.clone();
130            combined_messages.extend(prompt.system_instruction.clone());
131            prompt.system_instruction = combined_messages;
132        }
133    }
134    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
135        // Extract the prompt from the task
136        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
137        let mut prompt = task.prompt.clone();
138        self.append_system_instructions(&mut prompt);
139
140        // Use the client to execute the task
141        let chat_response = self.client.execute(&prompt).await?;
142
143        Ok(AgentResponse::new(task.id.clone(), chat_response))
144    }
145
146    #[instrument(skip_all)]
147    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
148        // Extract the prompt from the task
149        debug!("Executing prompt");
150        let mut prompt = prompt.clone();
151        self.append_system_instructions(&mut prompt);
152
153        // Use the client to execute the task
154        let chat_response = self.client.execute(&prompt).await?;
155
156        Ok(AgentResponse::new(chat_response.id(), chat_response))
157    }
158
159    pub async fn execute_task_with_context(
160        &self,
161        task: &Arc<RwLock<Task>>,
162        context_messages: HashMap<String, Vec<Message>>,
163        parameter_context: Value,
164        global_context: Option<Value>,
165    ) -> Result<AgentResponse, AgentError> {
166        // Prepare prompt and context before await
167        let (prompt, task_id) = {
168            let mut task = task.write().unwrap();
169            self.append_task_with_message_context(&mut task, &context_messages);
170            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
171
172            self.append_system_instructions(&mut task.prompt);
173            (task.prompt.clone(), task.id.clone())
174        };
175
176        // Now do the async work without holding the lock
177        let chat_response = self.client.execute(&prompt).await?;
178
179        Ok(AgentResponse::new(task_id, chat_response))
180    }
181
182    pub fn provider(&self) -> &Provider {
183        self.client.provider()
184    }
185
186    pub fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
187        let provider = model_settings.provider();
188        let client = match provider {
189            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
190            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
191            Provider::Undefined => {
192                let msg = "No provider specified in ModelSettings";
193                error!("{}", msg);
194                return Err(AgentError::UndefinedError(msg.to_string()));
195            }
196        };
197
198        Ok(Self {
199            client,
200            id: create_uuid7(),
201            system_instruction: Vec::new(),
202        })
203    }
204}
205
206impl Serialize for Agent {
207    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208    where
209        S: Serializer,
210    {
211        let mut state = serializer.serialize_struct("Agent", 3)?;
212        state.serialize_field("id", &self.id)?;
213        state.serialize_field("provider", &self.client.provider())?;
214        state.serialize_field("system_instruction", &self.system_instruction)?;
215        state.end()
216    }
217}
218
219/// Allows for deserialization of the Agent, re-initializing the client.
220impl<'de> Deserialize<'de> for Agent {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        #[derive(Deserialize)]
226        #[serde(field_identifier, rename_all = "snake_case")]
227        enum Field {
228            Id,
229            Provider,
230            SystemInstruction,
231        }
232
233        struct AgentVisitor;
234
235        impl<'de> Visitor<'de> for AgentVisitor {
236            type Value = Agent;
237
238            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
239                formatter.write_str("struct Agent")
240            }
241
242            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
243            where
244                V: MapAccess<'de>,
245            {
246                let mut id = None;
247                let mut provider = None;
248                let mut system_instruction = None;
249
250                while let Some(key) = map.next_key()? {
251                    match key {
252                        Field::Id => {
253                            id = Some(map.next_value()?);
254                        }
255                        Field::Provider => {
256                            provider = Some(map.next_value()?);
257                        }
258                        Field::SystemInstruction => {
259                            system_instruction = Some(map.next_value()?);
260                        }
261                    }
262                }
263
264                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
265                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
266                let system_instruction = system_instruction
267                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
268
269                // Re-initialize the client based on the provider
270                let client = match provider {
271                    Provider::OpenAI => {
272                        GenAiClient::OpenAI(OpenAIClient::new(None, None, None).map_err(|e| {
273                            de::Error::custom(format!("Failed to initialize OpenAIClient: {e}"))
274                        })?)
275                    }
276                    Provider::Gemini => {
277                        GenAiClient::Gemini(GeminiClient::new(None, None, None).map_err(|e| {
278                            de::Error::custom(format!("Failed to initialize GeminiClient: {e}"))
279                        })?)
280                    }
281
282                    Provider::Undefined => {
283                        let msg = "No provider specified in ModelSettings";
284                        error!("{}", msg);
285                        return Err(de::Error::custom(msg));
286                    }
287                };
288
289                Ok(Agent {
290                    id,
291                    client,
292                    system_instruction,
293                })
294            }
295        }
296
297        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
298        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
299    }
300}
301
302#[pyclass(name = "Agent")]
303#[derive(Debug, Clone)]
304pub struct PyAgent {
305    pub agent: Arc<Agent>,
306    pub runtime: Arc<tokio::runtime::Runtime>,
307}
308
309#[pymethods]
310impl PyAgent {
311    #[new]
312    #[pyo3(signature = (provider, system_instruction = None))]
313    /// Creates a new Agent instance.
314    ///
315    /// # Arguments:
316    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
317    /// that can be mapped to a provider variant
318    ///
319    pub fn new(
320        provider: &Bound<'_, PyAny>,
321        system_instruction: Option<&Bound<'_, PyAny>>,
322    ) -> Result<Self, AgentError> {
323        let provider = Provider::extract_provider(provider)?;
324
325        let system_instruction = if let Some(system_instruction) = system_instruction {
326            Some(
327                parse_prompt(system_instruction)?
328                    .into_iter()
329                    .map(|mut msg| {
330                        msg.role = Role::Developer.to_string();
331                        msg
332                    })
333                    .collect::<Vec<Message>>(),
334            )
335        } else {
336            None
337        };
338
339        let agent = Agent::new(provider, system_instruction)?;
340
341        Ok(Self {
342            agent: Arc::new(agent),
343            runtime: Arc::new(
344                tokio::runtime::Runtime::new()
345                    .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
346            ),
347        })
348    }
349
350    #[pyo3(signature = (task, output_type=None, model=None))]
351    pub fn execute_task(
352        &self,
353        py: Python<'_>,
354        task: &mut Task,
355        output_type: Option<Bound<'_, PyAny>>,
356        model: Option<&str>,
357    ) -> Result<PyAgentResponse, AgentError> {
358        // Extract the prompt from the task
359        debug!("Executing task");
360
361        // if output_type is not None,  mutate task prompt
362        if let Some(output_type) = &output_type {
363            match parse_response_to_json(py, output_type) {
364                Ok((response_type, response_format)) => {
365                    task.prompt.response_type = response_type;
366                    task.prompt.response_json_schema = response_format;
367                }
368                Err(_) => {
369                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
370                }
371            }
372        }
373
374        // if model is not None, set task prompt model (this will override the task prompt model)
375        if let Some(model) = model {
376            task.prompt.model = model.to_string();
377        }
378
379        // agent provider and task.prompt provider must match
380        if task.prompt.provider != *self.agent.provider() {
381            return Err(AgentError::ProviderMismatch(
382                task.prompt.provider.to_string(),
383                self.agent.provider().as_str().to_string(),
384            ));
385        }
386
387        debug!(
388            "Task prompt model identifier: {}",
389            task.prompt.model_identifier()
390        );
391
392        let chat_response = self
393            .runtime
394            .block_on(async { self.agent.execute_task(task).await })?;
395
396        debug!("Task executed successfully");
397        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
398        let response = PyAgentResponse::new(chat_response, output);
399
400        Ok(response)
401    }
402
403    #[pyo3(signature = (prompt, output_type=None, model=None))]
404    pub fn execute_prompt(
405        &self,
406        py: Python<'_>,
407        prompt: &mut Prompt,
408        output_type: Option<Bound<'_, PyAny>>,
409        model: Option<&str>,
410    ) -> Result<PyAgentResponse, AgentError> {
411        // Extract the prompt from the task
412        debug!("Executing task");
413        // if output_type is not None,  mutate task prompt
414        if let Some(output_type) = &output_type {
415            match parse_response_to_json(py, output_type) {
416                Ok((response_type, response_format)) => {
417                    prompt.response_type = response_type;
418                    prompt.response_json_schema = response_format;
419                }
420                Err(_) => {
421                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
422                }
423            }
424        }
425
426        // if model is not None, set task prompt model (this will override the task prompt model)
427        if let Some(model) = model {
428            prompt.model = model.to_string();
429        }
430
431        // agent provider and task.prompt provider must match
432        if prompt.provider != *self.agent.provider() {
433            return Err(AgentError::ProviderMismatch(
434                prompt.provider.to_string(),
435                self.agent.provider().as_str().to_string(),
436            ));
437        }
438
439        let chat_response = self
440            .runtime
441            .block_on(async { self.agent.execute_prompt(prompt).await })?;
442
443        debug!("Task executed successfully");
444        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
445        let response = PyAgentResponse::new(chat_response, output);
446
447        Ok(response)
448    }
449
450    #[getter]
451    pub fn system_instruction<'py>(
452        &self,
453        py: Python<'py>,
454    ) -> Result<Bound<'py, PyAny>, AgentError> {
455        Ok(self
456            .agent
457            .system_instruction
458            .clone()
459            .into_bound_py_any(py)?)
460    }
461
462    #[getter]
463    pub fn id(&self) -> &str {
464        self.agent.id.as_str()
465    }
466}