potato_agent/agents/
agent.rs1use crate::agents::provider::openai::OpenAIClient;
2use crate::agents::provider::types::Provider;
3use potato_prompt::prompt::types::Message;
4
5use crate::{
6 agents::client::GenAiClient, agents::error::AgentError, agents::task::Task,
7 agents::types::AgentResponse,
8};
9use potato_prompt::Prompt;
10use potato_util::create_uuid7;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16
17pub struct Agent {
18 pub id: String,
19
20 client: GenAiClient,
21
22 pub system_message: Vec<Message>,
23}
24
25impl Agent {
27 pub fn new(
28 provider: Provider,
29 system_message: Option<Vec<Message>>,
30 ) -> Result<Self, AgentError> {
31 let client = match provider {
32 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
33 };
35
36 let system_message = system_message.unwrap_or_default();
37
38 Ok(Self {
39 client,
40 id: create_uuid7(),
41 system_message,
42 })
43 }
44
45 fn get_task_with_context(
46 &self,
47 task: &Task,
48 context_messages: &HashMap<String, Vec<Message>>,
49 ) -> Task {
50 let mut cloned_task = task.clone();
51
52 if !cloned_task.dependencies.is_empty() {
53 for dep in &cloned_task.dependencies {
54 if let Some(messages) = context_messages.get(dep) {
55 for message in messages {
56 cloned_task.prompt.user_message.insert(0, message.clone());
58 }
59 }
60 }
61 }
62
63 cloned_task
64 }
65
66 fn append_system_messages(&self, prompt: &mut Prompt) {
67 if !self.system_message.is_empty() {
68 let mut combined_messages = self.system_message.clone();
69 combined_messages.extend(prompt.system_message.clone());
70 prompt.system_message = combined_messages;
71 }
72 }
73 pub async fn execute_async_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
74 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
76 let mut prompt = task.prompt.clone();
77 self.append_system_messages(&mut prompt);
78
79 let chat_response = self.client.execute(&prompt).await?;
81
82 Ok(AgentResponse::new(task.id.clone(), chat_response))
83 }
84
85 pub async fn execute_async_task_with_context(
86 &self,
87 task: &Task,
88 context_messages: HashMap<String, Vec<Message>>,
89 ) -> Result<AgentResponse, AgentError> {
90 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
92 let mut prompt = self.get_task_with_context(task, &context_messages).prompt;
93 self.append_system_messages(&mut prompt);
94
95 let chat_response = self.client.execute(&prompt).await?;
97
98 Ok(AgentResponse::new(task.id.clone(), chat_response))
99 }
100
101 pub fn provider(&self) -> &Provider {
102 self.client.provider()
103 }
104}