potato_agent/agents/
agent.rs

1use crate::agents::provider::openai::OpenAIClient;
2use crate::agents::provider::types::Provider;
3use potato_prompt::prompt::types::Message;
4
5use crate::{
6    agents::client::GenAiClient, agents::error::AgentError, agents::task::Task,
7    agents::types::AgentResponse,
8};
9use potato_prompt::Prompt;
10use potato_util::create_uuid7;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16
17pub struct Agent {
18    pub id: String,
19
20    client: GenAiClient,
21
22    pub system_message: Vec<Message>,
23}
24
25/// Rust method implementation of the Agent
26impl Agent {
27    pub fn new(
28        provider: Provider,
29        system_message: Option<Vec<Message>>,
30    ) -> Result<Self, AgentError> {
31        let client = match provider {
32            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
33            // Add other providers here as needed
34        };
35
36        let system_message = system_message.unwrap_or_default();
37
38        Ok(Self {
39            client,
40            id: create_uuid7(),
41            system_message,
42        })
43    }
44
45    fn get_task_with_context(
46        &self,
47        task: &Task,
48        context_messages: &HashMap<String, Vec<Message>>,
49    ) -> Task {
50        let mut cloned_task = task.clone();
51
52        if !cloned_task.dependencies.is_empty() {
53            for dep in &cloned_task.dependencies {
54                if let Some(messages) = context_messages.get(dep) {
55                    for message in messages {
56                        // prepend the messages from dependencies
57                        cloned_task.prompt.user_message.insert(0, message.clone());
58                    }
59                }
60            }
61        }
62
63        cloned_task
64    }
65
66    fn append_system_messages(&self, prompt: &mut Prompt) {
67        if !self.system_message.is_empty() {
68            let mut combined_messages = self.system_message.clone();
69            combined_messages.extend(prompt.system_message.clone());
70            prompt.system_message = combined_messages;
71        }
72    }
73    pub async fn execute_async_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
74        // Extract the prompt from the task
75        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
76        let mut prompt = task.prompt.clone();
77        self.append_system_messages(&mut prompt);
78
79        // Use the client to execute the task
80        let chat_response = self.client.execute(&prompt).await?;
81
82        Ok(AgentResponse::new(task.id.clone(), chat_response))
83    }
84
85    pub async fn execute_async_task_with_context(
86        &self,
87        task: &Task,
88        context_messages: HashMap<String, Vec<Message>>,
89    ) -> Result<AgentResponse, AgentError> {
90        // Extract the prompt from the task
91        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
92        let mut prompt = self.get_task_with_context(task, &context_messages).prompt;
93        self.append_system_messages(&mut prompt);
94
95        // Use the client to execute the task
96        let chat_response = self.client.execute(&prompt).await?;
97
98        Ok(AgentResponse::new(task.id.clone(), chat_response))
99    }
100
101    pub fn provider(&self) -> &Provider {
102        self.client.provider()
103    }
104}