potato_agent/agents/
agent.rs

1use crate::{
2    agents::error::AgentError,
3    agents::task::Task,
4    agents::types::{AgentResponse, PyAgentResponse},
5};
6use potato_prompt::prompt::settings::ModelSettings;
7use potato_prompt::{
8    parse_response_to_json, prompt::parse_prompt, prompt::types::Message, Prompt, Role,
9};
10
11use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
12
13use potato_provider::providers::types::ServiceType;
14use potato_provider::GeminiClient;
15use potato_type::Provider;
16use potato_util::create_uuid7;
17use pyo3::{prelude::*, IntoPyObjectExt};
18use serde::{
19    de::{self, MapAccess, Visitor},
20    ser::SerializeStruct,
21    Deserializer, Serializer,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::sync::RwLock;
28use tracing::{debug, instrument, warn};
29
30#[derive(Debug, PartialEq, Clone)]
31pub struct Agent {
32    pub id: String,
33    client: Arc<GenAiClient>,
34    pub provider: Provider,
35    pub system_instruction: Vec<Message>,
36}
37
38/// Rust method implementation of the Agent
39impl Agent {
40    /// Helper method to rebuild the client, useful for deserialization
41    #[instrument(skip_all)]
42    pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
43        let client = match self.provider {
44            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
45            Provider::Gemini => {
46                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
47            }
48            Provider::Vertex => {
49                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
50            }
51            _ => {
52                return Err(AgentError::MissingProviderError);
53            } // Add other providers here as needed
54        };
55
56        Ok(Self {
57            id: self.id.clone(),
58            client: Arc::new(client),
59            system_instruction: self.system_instruction.clone(),
60            provider: self.provider.clone(),
61        })
62    }
63    pub async fn new(
64        provider: Provider,
65        system_instruction: Option<Vec<Message>>,
66    ) -> Result<Self, AgentError> {
67        let client = match provider {
68            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
69            Provider::Gemini => {
70                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
71            }
72            Provider::Vertex => {
73                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
74            }
75            _ => {
76                return Err(AgentError::MissingProviderError);
77            } // Add other providers here as needed
78        };
79
80        let system_instruction = system_instruction.unwrap_or_default();
81
82        Ok(Self {
83            client: Arc::new(client),
84            id: create_uuid7(),
85            system_instruction,
86            provider,
87        })
88    }
89
90    #[instrument(skip_all)]
91    fn append_task_with_message_context(
92        &self,
93        task: &mut Task,
94        context_messages: &HashMap<String, Vec<Message>>,
95    ) {
96        //
97        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
98        if !task.dependencies.is_empty() {
99            for dep in &task.dependencies {
100                if let Some(messages) = context_messages.get(dep) {
101                    for message in messages {
102                        // prepend the messages from dependencies
103                        task.prompt.message.insert(0, message.clone());
104                    }
105                }
106            }
107        }
108    }
109
110    /// This function will bind dependency-specific context and global context if provided to the user prompt.
111    ///
112    /// # Arguments:
113    /// * `prompt` - The prompt to bind parameters to.
114    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
115    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
116    ///
117    /// # Returns:
118    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
119    #[instrument(skip_all)]
120    fn bind_context(
121        &self,
122        prompt: &mut Prompt,
123        parameter_context: &Value,
124        global_context: &Option<Value>,
125    ) -> Result<(), AgentError> {
126        // print user messages
127        if !prompt.parameters.is_empty() {
128            for param in &prompt.parameters {
129                // Bind parameter context to the user message
130                if let Some(value) = parameter_context.get(param) {
131                    for message in &mut prompt.message {
132                        if message.role == "user" {
133                            debug!("Binding parameter: {} with value: {}", param, value);
134                            message.bind_mut(param, &value.to_string())?;
135                        }
136                    }
137                }
138
139                // If global context is provided, bind it to the user message
140                if let Some(global_value) = global_context {
141                    if let Some(value) = global_value.get(param) {
142                        for message in &mut prompt.message {
143                            if message.role == "user" {
144                                debug!("Binding global parameter: {} with value: {}", param, value);
145                                message.bind_mut(param, &value.to_string())?;
146                            }
147                        }
148                    }
149                }
150            }
151        }
152        Ok(())
153    }
154
155    fn append_system_instructions(&self, prompt: &mut Prompt) {
156        if !self.system_instruction.is_empty() {
157            let mut combined_messages = self.system_instruction.clone();
158            combined_messages.extend(prompt.system_instruction.clone());
159            prompt.system_instruction = combined_messages;
160        }
161    }
162    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
163        // Extract the prompt from the task
164        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
165        let mut prompt = task.prompt.clone();
166        self.append_system_instructions(&mut prompt);
167
168        // Use the client to execute the task
169        let chat_response = self.client.generate_content(&prompt).await?;
170
171        Ok(AgentResponse::new(task.id.clone(), chat_response))
172    }
173
174    #[instrument(skip_all)]
175    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
176        // Extract the prompt from the task
177        debug!("Executing prompt");
178        let mut prompt = prompt.clone();
179        self.append_system_instructions(&mut prompt);
180
181        // Use the client to execute the task
182        let chat_response = self.client.generate_content(&prompt).await?;
183
184        Ok(AgentResponse::new(chat_response.id(), chat_response))
185    }
186
187    pub async fn execute_task_with_context(
188        &self,
189        task: &Arc<RwLock<Task>>,
190        context_messages: HashMap<String, Vec<Message>>,
191        parameter_context: Value,
192        global_context: Option<Value>,
193    ) -> Result<AgentResponse, AgentError> {
194        // Prepare prompt and context before await
195        let (prompt, task_id) = {
196            let mut task = task.write().unwrap();
197            self.append_task_with_message_context(&mut task, &context_messages);
198            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
199
200            self.append_system_instructions(&mut task.prompt);
201            (task.prompt.clone(), task.id.clone())
202        };
203
204        // Now do the async work without holding the lock
205        let chat_response = self.client.generate_content(&prompt).await?;
206
207        Ok(AgentResponse::new(task_id, chat_response))
208    }
209
210    pub fn client_provider(&self) -> &Provider {
211        self.client.provider()
212    }
213
214    pub async fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
215        let provider = model_settings.provider();
216        let client = match provider {
217            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
218            Provider::Gemini => {
219                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
220            }
221            Provider::Vertex => {
222                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
223            }
224            Provider::Undefined => {
225                return Err(AgentError::MissingProviderError);
226            }
227        };
228
229        Ok(Self {
230            client: Arc::new(client),
231            id: create_uuid7(),
232            system_instruction: Vec::new(),
233            provider,
234        })
235    }
236}
237
238impl Serialize for Agent {
239    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240    where
241        S: Serializer,
242    {
243        let mut state = serializer.serialize_struct("Agent", 3)?;
244        state.serialize_field("id", &self.id)?;
245        state.serialize_field("provider", &self.provider)?;
246        state.serialize_field("system_instruction", &self.system_instruction)?;
247        state.end()
248    }
249}
250
251/// Allows for deserialization of the Agent, re-initializing the client.
252impl<'de> Deserialize<'de> for Agent {
253    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
254    where
255        D: Deserializer<'de>,
256    {
257        #[derive(Deserialize)]
258        #[serde(field_identifier, rename_all = "snake_case")]
259        enum Field {
260            Id,
261            Provider,
262            SystemInstruction,
263        }
264
265        struct AgentVisitor;
266
267        impl<'de> Visitor<'de> for AgentVisitor {
268            type Value = Agent;
269
270            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
271                formatter.write_str("struct Agent")
272            }
273
274            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
275            where
276                V: MapAccess<'de>,
277            {
278                let mut id = None;
279                let mut provider = None;
280                let mut system_instruction = None;
281
282                while let Some(key) = map.next_key()? {
283                    match key {
284                        Field::Id => {
285                            id = Some(map.next_value()?);
286                        }
287                        Field::Provider => {
288                            provider = Some(map.next_value()?);
289                        }
290                        Field::SystemInstruction => {
291                            system_instruction = Some(map.next_value()?);
292                        }
293                    }
294                }
295
296                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
297                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
298                let system_instruction = system_instruction
299                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
300
301                // Deserialize is a sync op, so we can't await here (gemini requires async to init)
302                // After deserialization, we re-initialize the client based on the provider
303                let client = GenAiClient::Undefined;
304                Ok(Agent {
305                    id,
306                    client: Arc::new(client),
307                    system_instruction,
308                    provider,
309                })
310            }
311        }
312
313        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
314        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
315    }
316}
317
318#[pyclass(name = "Agent")]
319#[derive(Debug, Clone)]
320pub struct PyAgent {
321    pub agent: Arc<Agent>,
322    pub runtime: Arc<tokio::runtime::Runtime>,
323}
324
325#[pymethods]
326impl PyAgent {
327    #[new]
328    #[pyo3(signature = (provider, system_instruction = None))]
329    /// Creates a new Agent instance.
330    ///
331    /// # Arguments:
332    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
333    /// that can be mapped to a provider variant
334    ///
335    pub fn new(
336        provider: &Bound<'_, PyAny>,
337        system_instruction: Option<&Bound<'_, PyAny>>,
338    ) -> Result<Self, AgentError> {
339        let provider = Provider::extract_provider(provider)?;
340
341        let system_instruction = if let Some(system_instruction) = system_instruction {
342            Some(
343                parse_prompt(system_instruction)?
344                    .into_iter()
345                    .map(|mut msg| {
346                        msg.role = Role::Developer.to_string();
347                        msg
348                    })
349                    .collect::<Vec<Message>>(),
350            )
351        } else {
352            None
353        };
354
355        let runtime = tokio::runtime::Runtime::new()?;
356
357        let agent = runtime.block_on(async { Agent::new(provider, system_instruction).await })?;
358
359        Ok(Self {
360            agent: Arc::new(agent),
361            runtime: Arc::new(runtime),
362        })
363    }
364
365    #[pyo3(signature = (task, output_type=None, model=None))]
366    pub fn execute_task(
367        &self,
368        py: Python<'_>,
369        task: &mut Task,
370        output_type: Option<Bound<'_, PyAny>>,
371        model: Option<&str>,
372    ) -> Result<PyAgentResponse, AgentError> {
373        // Extract the prompt from the task
374        debug!("Executing task");
375
376        // if output_type is not None,  mutate task prompt
377        if let Some(output_type) = &output_type {
378            match parse_response_to_json(py, output_type) {
379                Ok((response_type, response_format)) => {
380                    task.prompt.response_type = response_type;
381                    task.prompt.response_json_schema = response_format;
382                }
383                Err(_) => {
384                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
385                }
386            }
387        }
388
389        // if model is not None, set task prompt model (this will override the task prompt model)
390        if let Some(model) = model {
391            task.prompt.model = model.to_string();
392        }
393
394        // agent provider and task.prompt provider must match
395        if task.prompt.provider != *self.agent.client_provider() {
396            return Err(AgentError::ProviderMismatch(
397                task.prompt.provider.to_string(),
398                self.agent.client_provider().as_str().to_string(),
399            ));
400        }
401
402        debug!(
403            "Task prompt model identifier: {}",
404            task.prompt.model_identifier()
405        );
406
407        let chat_response = self
408            .runtime
409            .block_on(async { self.agent.execute_task(task).await })?;
410
411        debug!("Task executed successfully");
412        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
413        let response = PyAgentResponse::new(chat_response, output);
414
415        Ok(response)
416    }
417
418    #[pyo3(signature = (prompt, output_type=None, model=None))]
419    pub fn execute_prompt(
420        &self,
421        py: Python<'_>,
422        prompt: &mut Prompt,
423        output_type: Option<Bound<'_, PyAny>>,
424        model: Option<&str>,
425    ) -> Result<PyAgentResponse, AgentError> {
426        // Extract the prompt from the task
427        debug!("Executing task");
428        // if output_type is not None,  mutate task prompt
429        if let Some(output_type) = &output_type {
430            match parse_response_to_json(py, output_type) {
431                Ok((response_type, response_format)) => {
432                    prompt.response_type = response_type;
433                    prompt.response_json_schema = response_format;
434                }
435                Err(_) => {
436                    return Err(AgentError::InvalidOutputType(output_type.to_string()));
437                }
438            }
439        }
440
441        // if model is not None, set task prompt model (this will override the task prompt model)
442        if let Some(model) = model {
443            prompt.model = model.to_string();
444        }
445
446        // agent provider and task.prompt provider must match
447        if prompt.provider != *self.agent.client_provider() {
448            return Err(AgentError::ProviderMismatch(
449                prompt.provider.to_string(),
450                self.agent.client_provider().as_str().to_string(),
451            ));
452        }
453
454        let chat_response = self
455            .runtime
456            .block_on(async { self.agent.execute_prompt(prompt).await })?;
457
458        debug!("Task executed successfully");
459        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
460        let response = PyAgentResponse::new(chat_response, output);
461
462        Ok(response)
463    }
464
465    #[getter]
466    pub fn system_instruction<'py>(
467        &self,
468        py: Python<'py>,
469    ) -> Result<Bound<'py, PyAny>, AgentError> {
470        Ok(self
471            .agent
472            .system_instruction
473            .clone()
474            .into_bound_py_any(py)?)
475    }
476
477    #[getter]
478    pub fn id(&self) -> &str {
479        self.agent.id.as_str()
480    }
481}