potato_agent/agents/
agent.rs1use crate::agents::provider::openai::OpenAIClient;
2use crate::agents::provider::types::Provider;
3use potato_prompt::prompt::types::Message;
4
5use crate::{
6 agents::client::GenAiClient, agents::error::AgentError, agents::task::Task,
7 agents::types::AgentResponse,
8};
9use potato_prompt::Prompt;
10use potato_util::create_uuid7;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Agent {
17 pub id: String,
18
19 client: GenAiClient,
20
21 pub system_message: Vec<Message>,
22}
23
24impl Agent {
26 pub fn new(
27 provider: Provider,
28 system_message: Option<Vec<Message>>,
29 ) -> Result<Self, AgentError> {
30 let client = match provider {
31 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
32 };
34
35 let system_message = system_message.unwrap_or_default();
36
37 Ok(Self {
38 client,
39 id: create_uuid7(),
40 system_message,
41 })
42 }
43
44 fn get_task_with_context(
45 &self,
46 task: &Task,
47 context_messages: &HashMap<String, Vec<Message>>,
48 ) -> Task {
49 let mut cloned_task = task.clone();
50
51 if !cloned_task.dependencies.is_empty() {
52 for dep in &cloned_task.dependencies {
53 if let Some(messages) = context_messages.get(dep) {
54 for message in messages {
55 cloned_task.prompt.user_message.insert(0, message.clone());
57 }
58 }
59 }
60 }
61
62 cloned_task
63 }
64
65 fn append_system_messages(&self, prompt: &mut Prompt) {
66 if !self.system_message.is_empty() {
67 let mut combined_messages = self.system_message.clone();
68 combined_messages.extend(prompt.system_message.clone());
69 prompt.system_message = combined_messages;
70 }
71 }
72 pub async fn execute_async_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
73 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
75 let mut prompt = task.prompt.clone();
76 self.append_system_messages(&mut prompt);
77
78 let chat_response = self.client.execute(&prompt).await?;
80
81 Ok(AgentResponse::new(task.id.clone(), chat_response))
82 }
83
84 pub async fn execute_async_task_with_context(
85 &self,
86 task: &Task,
87 context_messages: HashMap<String, Vec<Message>>,
88 ) -> Result<AgentResponse, AgentError> {
89 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
91 let mut prompt = self.get_task_with_context(task, &context_messages).prompt;
92 self.append_system_messages(&mut prompt);
93
94 let chat_response = self.client.execute(&prompt).await?;
96
97 Ok(AgentResponse::new(task.id.clone(), chat_response))
98 }
99
100 pub fn provider(&self) -> &Provider {
101 self.client.provider()
102 }
103}