potato_agent/agents/
agent.rs

1use crate::agents::provider::openai::OpenAIClient;
2use crate::agents::provider::types::Provider;
3use potato_prompt::prompt::types::Message;
4
5use crate::{
6    agents::client::GenAiClient, agents::error::AgentError, agents::task::Task,
7    agents::types::AgentResponse,
8};
9use potato_prompt::Prompt;
10use potato_util::create_uuid7;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Agent {
17    pub id: String,
18
19    client: GenAiClient,
20
21    pub system_message: Vec<Message>,
22}
23
24/// Rust method implementation of the Agent
25impl Agent {
26    pub fn new(
27        provider: Provider,
28        system_message: Option<Vec<Message>>,
29    ) -> Result<Self, AgentError> {
30        let client = match provider {
31            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
32            // Add other providers here as needed
33        };
34
35        let system_message = system_message.unwrap_or_default();
36
37        Ok(Self {
38            client,
39            id: create_uuid7(),
40            system_message,
41        })
42    }
43
44    fn get_task_with_context(
45        &self,
46        task: &Task,
47        context_messages: &HashMap<String, Vec<Message>>,
48    ) -> Task {
49        let mut cloned_task = task.clone();
50
51        if !cloned_task.dependencies.is_empty() {
52            for dep in &cloned_task.dependencies {
53                if let Some(messages) = context_messages.get(dep) {
54                    for message in messages {
55                        // prepend the messages from dependencies
56                        cloned_task.prompt.user_message.insert(0, message.clone());
57                    }
58                }
59            }
60        }
61
62        cloned_task
63    }
64
65    fn append_system_messages(&self, prompt: &mut Prompt) {
66        if !self.system_message.is_empty() {
67            let mut combined_messages = self.system_message.clone();
68            combined_messages.extend(prompt.system_message.clone());
69            prompt.system_message = combined_messages;
70        }
71    }
72    pub async fn execute_async_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
73        // Extract the prompt from the task
74        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
75        let mut prompt = task.prompt.clone();
76        self.append_system_messages(&mut prompt);
77
78        // Use the client to execute the task
79        let chat_response = self.client.execute(&prompt).await?;
80
81        Ok(AgentResponse::new(task.id.clone(), chat_response))
82    }
83
84    pub async fn execute_async_task_with_context(
85        &self,
86        task: &Task,
87        context_messages: HashMap<String, Vec<Message>>,
88    ) -> Result<AgentResponse, AgentError> {
89        // Extract the prompt from the task
90        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
91        let mut prompt = self.get_task_with_context(task, &context_messages).prompt;
92        self.append_system_messages(&mut prompt);
93
94        // Use the client to execute the task
95        let chat_response = self.client.execute(&prompt).await?;
96
97        Ok(AgentResponse::new(task.id.clone(), chat_response))
98    }
99
100    pub fn provider(&self) -> &Provider {
101        self.client.provider()
102    }
103}