1use crate::agents::{
2 error::AgentError,
3 task::Task,
4 types::{AgentResponse, PyAgentResponse},
5};
6use potato_provider::providers::anthropic::client::AnthropicClient;
7use potato_provider::providers::types::ServiceType;
8use potato_provider::GeminiClient;
9use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
10use potato_state::block_on;
11use potato_type::prompt::Prompt;
12use potato_type::prompt::{MessageNum, Role};
13use potato_type::Provider;
14use potato_type::{
15 prompt::extract_system_instructions,
16 tools::{Tool, ToolRegistry},
17};
18use potato_util::create_uuid7;
19use pyo3::prelude::*;
20use pyo3::types::PyList;
21use serde::{
22 de::{self, MapAccess, Visitor},
23 ser::SerializeStruct,
24 Deserializer, Serializer,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::sync::RwLock;
31use tracing::{debug, instrument, warn};
32
33#[derive(Debug, Clone)]
34pub struct Agent {
35 pub id: String,
36 client: Arc<GenAiClient>,
37 pub provider: Provider,
38 pub system_instruction: Vec<MessageNum>,
39 pub tools: Arc<RwLock<ToolRegistry>>, pub max_iterations: u32,
41}
42
43impl Agent {
45 #[instrument(skip_all)]
47 pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
48 let client = match self.provider {
49 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
50 Provider::Gemini => {
51 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
52 }
53 Provider::Vertex => {
54 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
55 }
56 _ => {
57 return Err(AgentError::MissingProviderError);
58 } };
60
61 Ok(Self {
62 id: self.id.clone(),
63 client: Arc::new(client),
64 system_instruction: self.system_instruction.clone(),
65 provider: self.provider.clone(),
66 tools: self.tools.clone(),
67 max_iterations: self.max_iterations,
68 })
69 }
70 pub async fn new(
71 provider: Provider,
72 system_instruction: Option<Vec<MessageNum>>,
73 ) -> Result<Self, AgentError> {
74 let client = match provider {
75 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
76 Provider::Gemini => {
77 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
78 }
79 Provider::Vertex => {
80 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
81 }
82 Provider::Anthropic => {
83 GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
84 }
85 _ => {
86 return Err(AgentError::MissingProviderError);
87 } };
89
90 Ok(Self {
91 client: Arc::new(client),
92 id: create_uuid7(),
93 system_instruction: system_instruction.unwrap_or_default(),
94 provider,
95 tools: Arc::new(RwLock::new(ToolRegistry::new())),
96 max_iterations: 10,
97 })
98 }
99
100 pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
101 self.tools.write().unwrap().register_tool(tool);
102 }
103
104 #[instrument(skip_all)]
153 fn append_task_with_message_dependency_context(
154 &self,
155 task: &mut Task,
156 context_messages: &HashMap<String, Vec<MessageNum>>,
157 ) {
158 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
160
161 if task.dependencies.is_empty() {
162 return;
163 }
164
165 let messages = task.prompt.request.messages_mut();
166 let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
167
168 match first_user_idx {
169 Some(insert_idx) => {
170 let mut dependency_messages = Vec::new();
172
173 for dep_id in &task.dependencies {
174 if let Some(messages) = context_messages.get(dep_id) {
175 debug!(
176 "Adding {} messages from dependency {}",
177 messages.len(),
178 dep_id
179 );
180 dependency_messages.extend(messages.iter().cloned());
181 }
182 }
183
184 for message in dependency_messages.into_iter() {
186 task.prompt
187 .request
188 .insert_message(message, Some(insert_idx))
189 }
190
191 debug!(
192 "Inserted {} dependency messages before user message at index {}",
193 task.dependencies.len(),
194 insert_idx
195 );
196 }
197 None => {
198 warn!(
199 "No user message found in task {}, appending dependency context to end",
200 task.id
201 );
202
203 for dep_id in &task.dependencies {
204 if let Some(messages) = context_messages.get(dep_id) {
205 for message in messages {
206 task.prompt.request.push_message(message.clone());
207 }
208 }
209 }
210 }
211 }
212 }
213
214 #[instrument(skip_all)]
224 fn bind_context(
225 &self,
226 prompt: &mut Prompt,
227 parameter_context: &Value,
228 global_context: &Option<Arc<Value>>,
229 ) -> Result<(), AgentError> {
230 if !prompt.parameters.is_empty() {
232 for param in &prompt.parameters {
233 if let Some(value) = parameter_context.get(param) {
235 for message in prompt.request.messages_mut() {
236 if message.role() == Role::User.as_str() {
237 debug!("Binding parameter: {} with value: {}", param, value);
238 message.bind_mut(param, &value.to_string())?;
239 }
240 }
241 }
242
243 if let Some(global_value) = global_context {
245 if let Some(value) = global_value.get(param) {
246 for message in prompt.request.messages_mut() {
247 if message.role() == Role::User.as_str() {
248 debug!("Binding global parameter: {} with value: {}", param, value);
249 message.bind_mut(param, &value.to_string())?;
250 }
251 }
252 }
253 }
254 }
255 }
256 Ok(())
257 }
258
259 fn prepend_system_instructions(&self, prompt: &mut Prompt) {
263 if !self.system_instruction.is_empty() {
264 prompt
265 .request
266 .prepend_system_instructions(self.system_instruction.clone())
267 .unwrap();
268 }
269 }
270 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
271 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
273 let mut prompt = task.prompt.clone();
274 self.prepend_system_instructions(&mut prompt);
275
276 let chat_response = self.client.generate_content(&prompt).await?;
278
279 Ok(AgentResponse::new(task.id.clone(), chat_response))
280 }
281
282 #[instrument(skip_all)]
283 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
284 debug!("Executing prompt");
286 let mut prompt = prompt.clone();
287 self.prepend_system_instructions(&mut prompt);
288
289 let chat_response = self.client.generate_content(&prompt).await?;
291
292 Ok(AgentResponse::new(chat_response.id(), chat_response))
293 }
294
295 #[instrument(skip_all)]
298 pub async fn execute_task_with_context(
299 &self,
300 task: &Arc<RwLock<Task>>,
301 context: &Value,
302 ) -> Result<AgentResponse, AgentError> {
303 let (mut prompt, task_id) = {
305 let task = task.read().unwrap();
306 (task.prompt.clone(), task.id.clone())
307 };
308
309 self.bind_context(&mut prompt, context, &None)?;
310 self.prepend_system_instructions(&mut prompt);
311
312 let chat_response = self.client.generate_content(&prompt).await?;
313 Ok(AgentResponse::new(task_id, chat_response))
314 }
315
316 pub async fn execute_task_with_context_message(
317 &self,
318 task: &Arc<RwLock<Task>>,
319 context_messages: HashMap<String, Vec<MessageNum>>,
320 parameter_context: Value,
321 global_context: Option<Arc<Value>>,
322 ) -> Result<AgentResponse, AgentError> {
323 let (prompt, task_id) = {
325 let mut task = task.write().unwrap();
326 self.append_task_with_message_dependency_context(&mut task, &context_messages);
328 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
330 self.prepend_system_instructions(&mut task.prompt);
332 (task.prompt.clone(), task.id.clone())
333 };
334
335 let chat_response = self.client.generate_content(&prompt).await?;
337 Ok(AgentResponse::new(task_id, chat_response))
338 }
339
340 pub fn client_provider(&self) -> &Provider {
341 self.client.provider()
342 }
343}
344
345impl PartialEq for Agent {
346 fn eq(&self, other: &Self) -> bool {
347 self.id == other.id
348 && self.provider == other.provider
349 && self.system_instruction == other.system_instruction
350 && self.max_iterations == other.max_iterations
351 && self.client == other.client
352 }
353}
354
355impl Serialize for Agent {
356 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
357 where
358 S: Serializer,
359 {
360 let mut state = serializer.serialize_struct("Agent", 3)?;
361 state.serialize_field("id", &self.id)?;
362 state.serialize_field("provider", &self.provider)?;
363 state.serialize_field("system_instruction", &self.system_instruction)?;
364 state.end()
365 }
366}
367
368impl<'de> Deserialize<'de> for Agent {
370 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
371 where
372 D: Deserializer<'de>,
373 {
374 #[derive(Deserialize)]
375 #[serde(field_identifier, rename_all = "snake_case")]
376 enum Field {
377 Id,
378 Provider,
379 SystemInstruction,
380 }
381
382 struct AgentVisitor;
383
384 impl<'de> Visitor<'de> for AgentVisitor {
385 type Value = Agent;
386
387 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
388 formatter.write_str("struct Agent")
389 }
390
391 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
392 where
393 V: MapAccess<'de>,
394 {
395 let mut id = None;
396 let mut provider = None;
397 let mut system_instruction = None;
398
399 while let Some(key) = map.next_key()? {
400 match key {
401 Field::Id => {
402 id = Some(map.next_value()?);
403 }
404 Field::Provider => {
405 provider = Some(map.next_value()?);
406 }
407 Field::SystemInstruction => {
408 system_instruction = Some(map.next_value()?);
409 }
410 }
411 }
412
413 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
414 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
415 let system_instruction = system_instruction
416 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
417
418 let client = GenAiClient::Undefined;
421 Ok(Agent {
422 id,
423 client: Arc::new(client),
424 system_instruction,
425 provider,
426 tools: Arc::new(RwLock::new(ToolRegistry::new())),
427 max_iterations: 10,
428 })
429 }
430 }
431
432 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
433 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
434 }
435}
436
437#[pyclass(name = "Agent")]
438#[derive(Debug, Clone)]
439pub struct PyAgent {
440 pub agent: Arc<Agent>,
441}
442
443#[pymethods]
444impl PyAgent {
445 #[new]
446 #[pyo3(signature = (provider, system_instruction = None))]
447 pub fn new(
454 provider: &Bound<'_, PyAny>,
455 system_instruction: Option<&Bound<'_, PyAny>>,
456 ) -> Result<Self, AgentError> {
457 let provider = Provider::extract_provider(provider)?;
458 let system_instructions = extract_system_instructions(system_instruction, &provider)?;
459 let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
460
461 Ok(Self {
462 agent: Arc::new(agent),
463 })
464 }
465
466 #[pyo3(signature = (task, output_type=None))]
467 pub fn execute_task(
468 &self,
469 task: &mut Task,
470 output_type: Option<Bound<'_, PyAny>>,
471 ) -> Result<PyAgentResponse, AgentError> {
472 debug!("Executing task");
474
475 if task.prompt.provider != *self.agent.client_provider() {
477 return Err(AgentError::ProviderMismatch(
478 task.prompt.provider.to_string(),
479 self.agent.client_provider().as_str().to_string(),
480 ));
481 }
482
483 debug!(
484 "Task prompt model identifier: {}",
485 task.prompt.model_identifier()
486 );
487
488 let chat_response = block_on(async { self.agent.execute_task(task).await })?;
489
490 debug!("Task executed successfully");
491 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
492 let response = PyAgentResponse::new(chat_response, output);
493
494 Ok(response)
495 }
496
497 #[pyo3(signature = (prompt, output_type=None))]
505 pub fn execute_prompt(
506 &self,
507 prompt: &mut Prompt,
508 output_type: Option<Bound<'_, PyAny>>,
509 ) -> Result<PyAgentResponse, AgentError> {
510 debug!("Executing task");
512
513 if prompt.provider != *self.agent.client_provider() {
515 return Err(AgentError::ProviderMismatch(
516 prompt.provider.to_string(),
517 self.agent.client_provider().as_str().to_string(),
518 ));
519 }
520
521 let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
522
523 debug!("Task executed successfully");
524 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
525 let response = PyAgentResponse::new(chat_response, output);
526
527 Ok(response)
528 }
529
530 #[getter]
531 pub fn system_instruction<'py>(
532 &self,
533 py: Python<'py>,
534 ) -> Result<Bound<'py, PyList>, AgentError> {
535 let instructions = self
536 .agent
537 .system_instruction
538 .iter()
539 .map(|msg_num| msg_num.to_bound_py_object(py))
540 .collect::<Result<Vec<_>, _>>()
541 .map(|instructions| PyList::new(py, &instructions))?;
542
543 Ok(instructions?)
544 }
545
546 #[getter]
547 pub fn id(&self) -> &str {
548 self.agent.id.as_str()
549 }
550}