1use crate::agents::{
2 error::AgentError,
3 task::Task,
4 types::{AgentResponse, PyAgentResponse},
5};
6use potato_provider::providers::anthropic::client::AnthropicClient;
7use potato_provider::providers::types::ServiceType;
8use potato_provider::GeminiClient;
9use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
10use potato_state::block_on;
11use potato_type::prompt::Prompt;
12use potato_type::prompt::{MessageNum, Role};
13use potato_type::Provider;
14use potato_type::{
15 prompt::extract_system_instructions,
16 tools::{Tool, ToolRegistry},
17};
18use potato_util::create_uuid7;
19use pyo3::prelude::*;
20use pyo3::types::PyList;
21use serde::{
22 de::{self, MapAccess, Visitor},
23 ser::SerializeStruct,
24 Deserializer, Serializer,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::sync::RwLock;
31use tracing::{debug, instrument, warn};
32
33#[derive(Debug, Clone)]
34pub struct Agent {
35 pub id: String,
36 client: Arc<GenAiClient>,
37 pub provider: Provider,
38 pub system_instruction: Vec<MessageNum>,
39 pub tools: Arc<RwLock<ToolRegistry>>, pub max_iterations: u32,
41}
42
43impl Agent {
45 #[instrument(skip_all)]
47 pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
48 let client = match self.provider {
49 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
50 Provider::Gemini => {
51 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
52 }
53 Provider::Vertex => {
54 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
55 }
56 Provider::Anthropic => {
57 GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
58 }
59 Provider::Google => {
60 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
61 }
62 _ => {
63 return Err(AgentError::MissingProviderError);
64 } };
66
67 Ok(Self {
68 id: self.id.clone(),
69 client: Arc::new(client),
70 system_instruction: self.system_instruction.clone(),
71 provider: self.provider.clone(),
72 tools: self.tools.clone(),
73 max_iterations: self.max_iterations,
74 })
75 }
76 pub async fn new(
77 provider: Provider,
78 system_instruction: Option<Vec<MessageNum>>,
79 ) -> Result<Self, AgentError> {
80 let client = match provider {
81 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
82 Provider::Gemini => {
83 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
84 }
85 Provider::Vertex => {
86 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
87 }
88 Provider::Anthropic => {
89 GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
90 }
91 Provider::Google => {
92 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
93 }
94 _ => {
95 return Err(AgentError::MissingProviderError);
96 } };
98
99 Ok(Self {
100 client: Arc::new(client),
101 id: create_uuid7(),
102 system_instruction: system_instruction.unwrap_or_default(),
103 provider,
104 tools: Arc::new(RwLock::new(ToolRegistry::new())),
105 max_iterations: 10,
106 })
107 }
108
109 pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
110 self.tools.write().unwrap().register_tool(tool);
111 }
112
113 #[instrument(skip_all)]
162 fn append_task_with_message_dependency_context(
163 &self,
164 task: &mut Task,
165 context_messages: &HashMap<String, Vec<MessageNum>>,
166 ) {
167 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
169
170 if task.dependencies.is_empty() {
171 return;
172 }
173
174 let messages = task.prompt.request.messages_mut();
175 let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
176
177 match first_user_idx {
178 Some(insert_idx) => {
179 let mut dependency_messages = Vec::new();
181
182 for dep_id in &task.dependencies {
183 if let Some(messages) = context_messages.get(dep_id) {
184 debug!(
185 "Adding {} messages from dependency {}",
186 messages.len(),
187 dep_id
188 );
189 dependency_messages.extend(messages.iter().cloned());
190 }
191 }
192
193 for message in dependency_messages.into_iter() {
195 task.prompt
196 .request
197 .insert_message(message, Some(insert_idx))
198 }
199
200 debug!(
201 "Inserted {} dependency messages before user message at index {}",
202 task.dependencies.len(),
203 insert_idx
204 );
205 }
206 None => {
207 warn!(
208 "No user message found in task {}, appending dependency context to end",
209 task.id
210 );
211
212 for dep_id in &task.dependencies {
213 if let Some(messages) = context_messages.get(dep_id) {
214 for message in messages {
215 task.prompt.request.push_message(message.clone());
216 }
217 }
218 }
219 }
220 }
221 }
222
223 #[instrument(skip_all)]
233 fn bind_context(
234 &self,
235 prompt: &mut Prompt,
236 parameter_context: &Value,
237 global_context: &Option<Arc<Value>>,
238 ) -> Result<(), AgentError> {
239 if !prompt.parameters.is_empty() {
241 for param in &prompt.parameters {
242 if let Some(value) = parameter_context.get(param) {
244 for message in prompt.request.messages_mut() {
245 if message.role() == Role::User.as_str() {
246 debug!("Binding parameter: {} with value: {}", param, value);
247 message.bind_mut(param, &value.to_string())?;
248 }
249 }
250 }
251
252 if let Some(global_value) = global_context {
254 if let Some(value) = global_value.get(param) {
255 for message in prompt.request.messages_mut() {
256 if message.role() == Role::User.as_str() {
257 debug!("Binding global parameter: {} with value: {}", param, value);
258 message.bind_mut(param, &value.to_string())?;
259 }
260 }
261 }
262 }
263 }
264 }
265 Ok(())
266 }
267
268 fn prepend_system_instructions(&self, prompt: &mut Prompt) {
272 if !self.system_instruction.is_empty() {
273 prompt
274 .request
275 .prepend_system_instructions(self.system_instruction.clone())
276 .unwrap();
277 }
278 }
279 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
280 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
282 let mut prompt = task.prompt.clone();
283 self.prepend_system_instructions(&mut prompt);
284
285 let chat_response = self.client.generate_content(&prompt).await?;
287
288 Ok(AgentResponse::new(task.id.clone(), chat_response))
289 }
290
291 #[instrument(skip_all)]
292 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
293 debug!("Executing prompt");
295 let mut prompt = prompt.clone();
296 self.prepend_system_instructions(&mut prompt);
297
298 let chat_response = self.client.generate_content(&prompt).await?;
300
301 Ok(AgentResponse::new(chat_response.id(), chat_response))
302 }
303
304 #[instrument(skip_all)]
307 pub async fn execute_task_with_context(
308 &self,
309 task: &Arc<RwLock<Task>>,
310 context: &Value,
311 ) -> Result<AgentResponse, AgentError> {
312 let (mut prompt, task_id) = {
314 let task = task.read().unwrap();
315 (task.prompt.clone(), task.id.clone())
316 };
317
318 self.bind_context(&mut prompt, context, &None)?;
319 self.prepend_system_instructions(&mut prompt);
320
321 let chat_response = self.client.generate_content(&prompt).await?;
322 Ok(AgentResponse::new(task_id, chat_response))
323 }
324
325 pub async fn execute_task_with_context_message(
326 &self,
327 task: &Arc<RwLock<Task>>,
328 context_messages: HashMap<String, Vec<MessageNum>>,
329 parameter_context: Value,
330 global_context: Option<Arc<Value>>,
331 ) -> Result<AgentResponse, AgentError> {
332 let (prompt, task_id) = {
334 let mut task = task.write().unwrap();
335 self.append_task_with_message_dependency_context(&mut task, &context_messages);
337 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
339 self.prepend_system_instructions(&mut task.prompt);
341 (task.prompt.clone(), task.id.clone())
342 };
343
344 let chat_response = self.client.generate_content(&prompt).await?;
346 Ok(AgentResponse::new(task_id, chat_response))
347 }
348
349 pub fn client_provider(&self) -> &Provider {
350 self.client.provider()
351 }
352}
353
354impl PartialEq for Agent {
355 fn eq(&self, other: &Self) -> bool {
356 self.id == other.id
357 && self.provider == other.provider
358 && self.system_instruction == other.system_instruction
359 && self.max_iterations == other.max_iterations
360 && self.client == other.client
361 }
362}
363
364impl Serialize for Agent {
365 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
366 where
367 S: Serializer,
368 {
369 let mut state = serializer.serialize_struct("Agent", 3)?;
370 state.serialize_field("id", &self.id)?;
371 state.serialize_field("provider", &self.provider)?;
372 state.serialize_field("system_instruction", &self.system_instruction)?;
373 state.end()
374 }
375}
376
377impl<'de> Deserialize<'de> for Agent {
379 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
380 where
381 D: Deserializer<'de>,
382 {
383 #[derive(Deserialize)]
384 #[serde(field_identifier, rename_all = "snake_case")]
385 enum Field {
386 Id,
387 Provider,
388 SystemInstruction,
389 }
390
391 struct AgentVisitor;
392
393 impl<'de> Visitor<'de> for AgentVisitor {
394 type Value = Agent;
395
396 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
397 formatter.write_str("struct Agent")
398 }
399
400 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
401 where
402 V: MapAccess<'de>,
403 {
404 let mut id = None;
405 let mut provider = None;
406 let mut system_instruction = None;
407
408 while let Some(key) = map.next_key()? {
409 match key {
410 Field::Id => {
411 id = Some(map.next_value()?);
412 }
413 Field::Provider => {
414 provider = Some(map.next_value()?);
415 }
416 Field::SystemInstruction => {
417 system_instruction = Some(map.next_value()?);
418 }
419 }
420 }
421
422 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
423 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
424 let system_instruction = system_instruction
425 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
426
427 let client = GenAiClient::Undefined;
430 Ok(Agent {
431 id,
432 client: Arc::new(client),
433 system_instruction,
434 provider,
435 tools: Arc::new(RwLock::new(ToolRegistry::new())),
436 max_iterations: 10,
437 })
438 }
439 }
440
441 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
442 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
443 }
444}
445
446#[pyclass(name = "Agent")]
447#[derive(Debug, Clone)]
448pub struct PyAgent {
449 pub agent: Arc<Agent>,
450}
451
452#[pymethods]
453impl PyAgent {
454 #[new]
455 #[pyo3(signature = (provider, system_instruction = None))]
456 pub fn new(
463 provider: &Bound<'_, PyAny>,
464 system_instruction: Option<&Bound<'_, PyAny>>,
465 ) -> Result<Self, AgentError> {
466 let provider = Provider::extract_provider(provider)?;
467 let system_instructions = extract_system_instructions(system_instruction, &provider)?;
468 let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
469
470 Ok(Self {
471 agent: Arc::new(agent),
472 })
473 }
474
475 #[pyo3(signature = (task, output_type=None))]
476 pub fn execute_task(
477 &self,
478 task: &mut Task,
479 output_type: Option<Bound<'_, PyAny>>,
480 ) -> Result<PyAgentResponse, AgentError> {
481 debug!("Executing task");
483
484 if task.prompt.provider != *self.agent.client_provider() {
486 return Err(AgentError::ProviderMismatch(
487 task.prompt.provider.to_string(),
488 self.agent.client_provider().as_str().to_string(),
489 ));
490 }
491
492 debug!(
493 "Task prompt model identifier: {}",
494 task.prompt.model_identifier()
495 );
496
497 let chat_response = block_on(async { self.agent.execute_task(task).await })?;
498
499 debug!("Task executed successfully");
500 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
501 let response = PyAgentResponse::new(chat_response, output);
502
503 Ok(response)
504 }
505
506 #[pyo3(signature = (prompt, output_type=None))]
514 pub fn execute_prompt(
515 &self,
516 prompt: &mut Prompt,
517 output_type: Option<Bound<'_, PyAny>>,
518 ) -> Result<PyAgentResponse, AgentError> {
519 debug!("Executing task");
521
522 if prompt.provider != *self.agent.client_provider() {
524 return Err(AgentError::ProviderMismatch(
525 prompt.provider.to_string(),
526 self.agent.client_provider().as_str().to_string(),
527 ));
528 }
529
530 let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
531
532 debug!("Task executed successfully");
533 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
534 let response = PyAgentResponse::new(chat_response, output);
535
536 Ok(response)
537 }
538
539 #[getter]
540 pub fn system_instruction<'py>(
541 &self,
542 py: Python<'py>,
543 ) -> Result<Bound<'py, PyList>, AgentError> {
544 let instructions = self
545 .agent
546 .system_instruction
547 .iter()
548 .map(|msg_num| msg_num.to_bound_py_object(py))
549 .collect::<Result<Vec<_>, _>>()
550 .map(|instructions| PyList::new(py, &instructions))?;
551
552 Ok(instructions?)
553 }
554
555 #[getter]
556 pub fn id(&self) -> &str {
557 self.agent.id.as_str()
558 }
559}