1use crate::agents::{
2 callbacks::{AgentCallback, CallbackAction},
3 criteria::CompletionCriteria,
4 error::AgentError,
5 memory::{Memory, MemoryTurn},
6 run_context::{AgentRunConfig, AgentRunContext, ResumeContext},
7 runner::{AgentRunOutcome, AgentRunResult, AgentRunner},
8 session::{SessionSnapshot, SessionState},
9 store::{
10 app_state_store::AppStateStore, persistent_memory::PersistentMemory,
11 session_store::SessionStore, user_state_store::UserStateStore,
12 },
13 task::Task,
14 tool_ext::AgentTool,
15 types::{AgentResponse, PyAgentResponse},
16};
17use async_trait::async_trait;
18use potato_provider::providers::anthropic::client::AnthropicClient;
19use potato_provider::providers::types::ServiceType;
20use potato_provider::GeminiClient;
21use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
22use potato_state::block_on;
23use potato_type::prompt::Prompt;
24use potato_type::prompt::{MessageNum, Role};
25use potato_type::Provider;
26use potato_type::{
27 prompt::extract_system_instructions,
28 tools::{Tool, ToolRegistry},
29};
30use potato_util::create_uuid7;
31use pyo3::prelude::*;
32use pyo3::types::PyList;
33use serde::{
34 de::{self, MapAccess, Visitor},
35 ser::SerializeStruct,
36 Deserializer, Serializer,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::Arc;
42use std::sync::RwLock;
43use tracing::{debug, instrument, warn};
44
45#[derive(Debug)]
46pub struct Agent {
47 pub id: String,
48 client: Arc<GenAiClient>,
49 pub provider: Provider,
50 pub system_instruction: Vec<MessageNum>,
51 pub tools: Arc<RwLock<ToolRegistry>>,
52 pub max_iterations: u32,
53 pub run_config: Option<AgentRunConfig>,
55 pub model_override: Option<String>,
57 pub criteria: Vec<Box<dyn CompletionCriteria>>,
58 pub callbacks: Vec<Arc<dyn AgentCallback>>,
59 pub memory: Option<Arc<tokio::sync::Mutex<Box<dyn Memory>>>>,
60 pub app_name: Option<String>,
62 pub user_id: Option<String>,
64 pub session_id: Option<String>,
66 pub session_store: Option<Arc<dyn SessionStore>>,
68 pub user_state_store: Option<Arc<dyn UserStateStore>>,
70 pub app_state_store: Option<Arc<dyn AppStateStore>>,
72}
73
74impl Agent {
76 #[instrument(skip_all)]
78 pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
79 let client = match self.provider {
80 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
81 Provider::Gemini => {
82 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
83 }
84 Provider::Vertex => {
85 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
86 }
87 Provider::Anthropic => {
88 GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
89 }
90 Provider::Google => {
91 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
92 }
93 _ => {
94 return Err(AgentError::MissingProviderError);
95 } };
97
98 Ok(Self {
99 id: self.id.clone(),
100 client: Arc::new(client),
101 system_instruction: self.system_instruction.clone(),
102 provider: self.provider.clone(),
103 tools: self.tools.clone(),
104 max_iterations: self.max_iterations,
105 run_config: None,
106 model_override: None,
107 criteria: Vec::new(),
108 callbacks: Vec::new(),
109 memory: None,
110 app_name: None,
111 user_id: None,
112 session_id: None,
113 session_store: None,
114 user_state_store: None,
115 app_state_store: None,
116 })
117 }
118 pub async fn new(
119 provider: Provider,
120 system_instruction: Option<Vec<MessageNum>>,
121 ) -> Result<Self, AgentError> {
122 let client = match provider {
123 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
124 Provider::Gemini => {
125 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
126 }
127 Provider::Vertex => {
128 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
129 }
130 Provider::Anthropic => {
131 GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
132 }
133 Provider::Google => {
134 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
135 }
136 _ => {
137 return Err(AgentError::MissingProviderError);
138 } };
140
141 Ok(Self {
142 client: Arc::new(client),
143 id: create_uuid7(),
144 system_instruction: system_instruction.unwrap_or_default(),
145 provider,
146 tools: Arc::new(RwLock::new(ToolRegistry::new())),
147 max_iterations: 10,
148 run_config: None,
149 model_override: None,
150 criteria: Vec::new(),
151 callbacks: Vec::new(),
152 memory: None,
153 app_name: None,
154 user_id: None,
155 session_id: None,
156 session_store: None,
157 user_state_store: None,
158 app_state_store: None,
159 })
160 }
161
162 pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
163 self.tools
164 .write()
165 .unwrap_or_else(|e| e.into_inner())
166 .register_tool(tool);
167 }
168
169 #[instrument(skip_all)]
218 fn append_task_with_message_dependency_context(
219 &self,
220 task: &mut Task,
221 context_messages: &HashMap<String, Vec<MessageNum>>,
222 ) {
223 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
225
226 if task.dependencies.is_empty() {
227 return;
228 }
229
230 let messages = task.prompt.request.messages_mut();
231 let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
232
233 match first_user_idx {
234 Some(insert_idx) => {
235 let mut dependency_messages = Vec::new();
237
238 for dep_id in &task.dependencies {
239 if let Some(messages) = context_messages.get(dep_id) {
240 debug!(
241 "Adding {} messages from dependency {}",
242 messages.len(),
243 dep_id
244 );
245 dependency_messages.extend(messages.iter().cloned());
246 }
247 }
248
249 for message in dependency_messages.into_iter() {
251 task.prompt
252 .request
253 .insert_message(message, Some(insert_idx))
254 }
255
256 debug!(
257 "Inserted {} dependency messages before user message at index {}",
258 task.dependencies.len(),
259 insert_idx
260 );
261 }
262 None => {
263 warn!(
264 "No user message found in task {}, appending dependency context to end",
265 task.id
266 );
267
268 for dep_id in &task.dependencies {
269 if let Some(messages) = context_messages.get(dep_id) {
270 for message in messages {
271 task.prompt.request.push_message(message.clone());
272 }
273 }
274 }
275 }
276 }
277 }
278
279 #[instrument(skip_all)]
289 fn bind_context(
290 &self,
291 prompt: &mut Prompt,
292 parameter_context: &Value,
293 global_context: &Option<Arc<Value>>,
294 ) -> Result<(), AgentError> {
295 if !prompt.parameters.is_empty() {
297 for param in &prompt.parameters {
298 if let Some(value) = parameter_context.get(param) {
300 for message in prompt.request.messages_mut() {
301 if message.role() == Role::User.as_str() {
302 debug!("Binding parameter: {} with value: {}", param, value);
303 message.bind_mut(param, &value.to_string())?;
304 }
305 }
306 }
307
308 if let Some(global_value) = global_context {
310 if let Some(value) = global_value.get(param) {
311 for message in prompt.request.messages_mut() {
312 if message.role() == Role::User.as_str() {
313 debug!("Binding global parameter: {} with value: {}", param, value);
314 message.bind_mut(param, &value.to_string())?;
315 }
316 }
317 }
318 }
319 }
320 }
321 Ok(())
322 }
323
324 fn prepend_system_instructions(&self, prompt: &mut Prompt) -> Result<(), AgentError> {
328 if !self.system_instruction.is_empty() {
329 prompt
330 .request
331 .prepend_system_instructions(self.system_instruction.clone())
332 .map_err(|e| AgentError::Error(e.to_string()))?;
333 }
334 Ok(())
335 }
336 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
337 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
339 let mut prompt = task.prompt.clone();
340 self.prepend_system_instructions(&mut prompt)?;
341
342 let chat_response = self.client.generate_content(&prompt).await?;
344
345 Ok(AgentResponse::new(task.id.clone(), chat_response))
346 }
347
348 #[instrument(skip_all)]
349 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
350 debug!("Executing prompt");
352 let mut prompt = prompt.clone();
353 self.prepend_system_instructions(&mut prompt)?;
354
355 let chat_response = self.client.generate_content(&prompt).await?;
357
358 Ok(AgentResponse::new(chat_response.id(), chat_response))
359 }
360
361 #[instrument(skip_all)]
364 pub async fn execute_task_with_context(
365 &self,
366 task: &Arc<RwLock<Task>>,
367 context: &Value,
368 ) -> Result<AgentResponse, AgentError> {
369 let (mut prompt, task_id) = {
371 let task = task.read().unwrap();
372 (task.prompt.clone(), task.id.clone())
373 };
374
375 self.bind_context(&mut prompt, context, &None)?;
376 self.prepend_system_instructions(&mut prompt)?;
377
378 let chat_response = self.client.generate_content(&prompt).await?;
379 Ok(AgentResponse::new(task_id, chat_response))
380 }
381
382 pub async fn execute_task_with_context_message(
383 &self,
384 task: &Arc<RwLock<Task>>,
385 context_messages: HashMap<String, Vec<MessageNum>>,
386 parameter_context: Value,
387 global_context: Option<Arc<Value>>,
388 ) -> Result<AgentResponse, AgentError> {
389 let (prompt, task_id) = {
391 let mut task = task.write().unwrap();
392 self.append_task_with_message_dependency_context(&mut task, &context_messages);
394 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
396 self.prepend_system_instructions(&mut task.prompt)?;
398 (task.prompt.clone(), task.id.clone())
399 };
400
401 let chat_response = self.client.generate_content(&prompt).await?;
403 Ok(AgentResponse::new(task_id, chat_response))
404 }
405
406 pub fn client_provider(&self) -> &Provider {
407 self.client.provider()
408 }
409
410 fn build_input_prompt(&self, input: &str) -> Result<Prompt, AgentError> {
414 use potato_type::prompt::builder::to_provider_request;
415 use potato_type::prompt::settings::ModelSettings;
416 use potato_type::prompt::types::ResponseType;
417
418 let msg = {
419 use potato_type::traits::MessageFactory;
420 match self.provider {
421 Provider::OpenAI => {
422 use potato_type::openai::v1::chat::request::ChatMessage;
423 ChatMessage::from_text(input.to_string(), "user")
424 .map(MessageNum::OpenAIMessageV1)?
425 }
426 Provider::Anthropic => {
427 use potato_type::anthropic::v1::request::MessageParam;
428 MessageParam::from_text(input.to_string(), "user")
429 .map(MessageNum::AnthropicMessageV1)?
430 }
431 Provider::Gemini | Provider::Google | Provider::Vertex => {
432 use potato_type::google::v1::generate::request::GeminiContent;
433 GeminiContent::from_text(input.to_string(), "user")
434 .map(MessageNum::GeminiContentV1)?
435 }
436 _ => {
437 return Err(AgentError::MissingProviderError);
438 }
439 }
440 };
441
442 let model = self.model_override.clone().ok_or_else(|| {
443 AgentError::Error("model must be set explicitly via AgentBuilder::model()".into())
444 })?;
445
446 let settings = ModelSettings::provider_default_settings(&self.provider);
447
448 let request = to_provider_request(
449 vec![msg],
450 self.system_instruction.clone(),
451 model.clone(),
452 settings,
453 None,
454 )?;
455
456 Ok(Prompt {
457 request,
458 model,
459 provider: self.provider.clone(),
460 version: env!("CARGO_PKG_VERSION").to_string(),
461 parameters: Vec::new(),
462 response_type: ResponseType::Null,
463 })
464 }
465
466 fn fire_before_model(&self, ctx: &AgentRunContext, prompt: &Prompt) -> Result<(), AgentError> {
468 for cb in &self.callbacks {
469 if let CallbackAction::Abort(msg) = cb.before_model_call(ctx, prompt) {
470 return Err(AgentError::CallbackAbort(msg));
471 }
472 }
473 Ok(())
474 }
475
476 fn fire_after_model(
478 &self,
479 ctx: &AgentRunContext,
480 response: &AgentResponse,
481 ) -> Result<Option<String>, AgentError> {
482 for cb in &self.callbacks {
483 match cb.after_model_call(ctx, response) {
484 CallbackAction::Abort(msg) => return Err(AgentError::CallbackAbort(msg)),
485 CallbackAction::OverrideResponse(text) => return Ok(Some(text)),
486 CallbackAction::Continue => {}
487 }
488 }
489 Ok(None)
490 }
491
492 fn fire_before_tool(
494 &self,
495 ctx: &AgentRunContext,
496 call: &potato_type::tools::ToolCall,
497 ) -> Result<(), AgentError> {
498 for cb in &self.callbacks {
499 if let CallbackAction::Abort(msg) = cb.before_tool_call(ctx, call) {
500 return Err(AgentError::CallbackAbort(msg));
501 }
502 }
503 Ok(())
504 }
505
506 fn fire_after_tool(
508 &self,
509 ctx: &AgentRunContext,
510 call: &potato_type::tools::ToolCall,
511 result: &serde_json::Value,
512 ) -> Result<(), AgentError> {
513 for cb in &self.callbacks {
514 if let CallbackAction::Abort(msg) = cb.after_tool_call(ctx, call, result) {
515 return Err(AgentError::CallbackAbort(msg));
516 }
517 }
518 Ok(())
519 }
520}
521
522#[async_trait]
523impl AgentRunner for Agent {
524 fn id(&self) -> &str {
525 &self.id
526 }
527
528 async fn run(
529 &self,
530 input: &str,
531 session: &mut SessionState,
532 ) -> Result<AgentRunOutcome, AgentError> {
533 let max_iter = self
534 .run_config
535 .as_ref()
536 .map(|c| c.max_iterations)
537 .unwrap_or(self.max_iterations);
538
539 let mut run_ctx = AgentRunContext::new(self.id.clone(), max_iter);
540
541 let app = self.app_name.as_deref().unwrap_or("default");
542 let uid = self.user_id.as_deref().unwrap_or("default");
543
544 if let Some(store) = &self.app_state_store {
546 if let Some(snapshot) = store.load(app).await? {
547 session.merge(snapshot.0);
548 }
549 }
550
551 if let Some(store) = &self.user_state_store {
553 if let Some(snapshot) = store.load(app, uid).await? {
554 session.merge(snapshot.0);
555 }
556 }
557
558 if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
560 if let Some(snapshot) = store.load(app, uid, sid).await? {
561 session.merge(snapshot.0);
562 }
563 }
564
565 let mut prompt = self.build_input_prompt(input)?;
567
568 if let Some(mem_lock) = &self.memory {
570 let mut mem = mem_lock.lock().await;
571 if let Some(pm) = mem
572 .as_any_mut()
573 .and_then(|a| a.downcast_mut::<PersistentMemory>())
574 {
575 pm.hydrate().await?;
576 }
577 }
578
579 if let Some(mem_lock) = &self.memory {
581 let mem = mem_lock.lock().await;
582 let history = mem.messages();
583 if !history.is_empty() {
584 let insert_at = prompt
586 .request
587 .messages()
588 .iter()
589 .position(|m| !m.is_system_message())
590 .unwrap_or(0);
591 for (i, msg) in history.into_iter().enumerate() {
592 prompt.request.insert_message(msg, Some(insert_at + i));
593 }
594 }
595 }
596
597 {
599 let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
600 let defs = registry.get_all_definitions();
601 if !defs.is_empty() {
602 prompt.request.add_tools(defs)?;
603 }
604 }
605
606 let mut last_user_msg: Option<MessageNum> = None;
607 if let Some(msg) = prompt.request.messages().last().cloned() {
609 last_user_msg = Some(msg);
610 }
611
612 loop {
613 if run_ctx.iteration >= max_iter {
615 break;
616 }
617
618 self.fire_before_model(&run_ctx, &prompt)?;
620
621 let chat_response = self.client.generate_content(&prompt).await?;
623 let agent_response = AgentResponse::new(chat_response.id(), chat_response.clone());
624
625 if let Some(override_text) = self.fire_after_model(&run_ctx, &agent_response)? {
627 run_ctx.push_response(override_text.clone());
628 return Ok(AgentRunOutcome::complete(AgentRunResult {
629 final_response: agent_response,
630 iterations: run_ctx.iteration,
631 completion_reason: format!("callback override: {}", override_text),
632 combined_text: None,
633 }));
634 }
635
636 if let Some(tool_calls) = chat_response.extract_tool_calls() {
638 let assistant_msgs = chat_response.to_message_num(&self.provider)?;
640 for msg in assistant_msgs {
641 prompt.request.push_message(msg);
642 }
643
644 for call in &tool_calls {
645 self.fire_before_tool(&run_ctx, call)?;
646
647 let result = {
649 let async_tool = {
650 let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
651 registry.get_async_tool(&call.tool_name)
652 };
653 if let Some(tool) = async_tool {
654 if let Some(agent_tool) =
655 tool.as_any().and_then(|a| a.downcast_ref::<AgentTool>())
656 {
657 agent_tool
659 .dispatch(call.arguments.clone(), session)
660 .await
661 .map_err(|e| {
662 AgentError::Error(format!(
663 "Tool '{}' failed: {}",
664 call.tool_name, e
665 ))
666 })?
667 } else {
668 tool.execute(call.arguments.clone()).await.map_err(|e| {
669 AgentError::Error(format!(
670 "Tool '{}' failed: {}",
671 call.tool_name, e
672 ))
673 })?
674 }
675 } else {
676 let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
677 registry.execute(call).map_err(|e| {
678 AgentError::Error(format!(
679 "Tool '{}' failed: {}",
680 call.tool_name, e
681 ))
682 })?
683 }
684 };
685
686 self.fire_after_tool(&run_ctx, call, &result)?;
687 prompt.request.add_tool_result(call, &result)?;
688 }
689
690 run_ctx.increment();
691 continue;
692 }
693
694 let text = chat_response.response_text();
696
697 if text.trim().starts_with("__ask_user__:") {
699 let question = text.trim_start_matches("__ask_user__:").trim().to_string();
700 let resume_ctx = ResumeContext {
701 agent_id: self.id.clone(),
702 iteration: run_ctx.iteration,
703 session_snapshot: session.snapshot(),
704 };
705 return Ok(AgentRunOutcome::NeedsInput {
706 question,
707 resume_context: resume_ctx,
708 });
709 }
710
711 run_ctx.push_response(text);
712
713 let met = self.criteria.iter().any(|c| c.is_complete(&run_ctx));
715 let reason = if met {
716 self.criteria
717 .iter()
718 .find(|c| c.is_complete(&run_ctx))
719 .map(|c| c.completion_reason(&run_ctx))
720 .unwrap_or_else(|| "criteria met".into())
721 } else {
722 String::new()
723 };
724
725 if met || run_ctx.iteration + 1 >= max_iter {
726 if let Some(mem_lock) = &self.memory {
728 let mut mem = mem_lock.lock().await;
729 if let Some(user_msg) = last_user_msg.take() {
730 let assistant_msgs = chat_response.to_message_num(&self.provider)?;
731 if let Some(asst_msg) = assistant_msgs.into_iter().next() {
732 let turn = MemoryTurn {
733 user: user_msg,
734 assistant: asst_msg,
735 };
736 if let Some(pm) = mem
738 .as_any_mut()
739 .and_then(|a| a.downcast_mut::<PersistentMemory>())
740 {
741 pm.push_turn_async(turn).await?;
742 } else {
743 mem.push_turn(turn);
744 }
745 }
746 }
747 }
748
749 if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
751 let snapshot = SessionSnapshot::from(&*session);
752 store.save(app, uid, sid, &snapshot).await?;
753 }
754
755 return Ok(AgentRunOutcome::complete(AgentRunResult {
756 final_response: agent_response,
757 iterations: run_ctx.iteration,
758 completion_reason: if met {
759 reason
760 } else {
761 format!("max iterations ({}) reached", max_iter)
762 },
763 combined_text: None,
764 }));
765 }
766
767 let assistant_msgs = chat_response.to_message_num(&self.provider)?;
769 for msg in assistant_msgs {
770 prompt.request.push_message(msg);
771 }
772
773 run_ctx.increment();
774 }
775
776 Err(AgentError::MaxIterationsExceeded(max_iter))
778 }
779
780 async fn resume(
781 &self,
782 user_answer: &str,
783 ctx: ResumeContext,
784 session: &mut SessionState,
785 ) -> Result<AgentRunOutcome, AgentError> {
786 session.merge(ctx.session_snapshot);
788 self.run(user_answer, session).await
790 }
791}
792
793impl Clone for Agent {
796 fn clone(&self) -> Self {
797 Self {
798 id: self.id.clone(),
799 client: self.client.clone(),
800 provider: self.provider.clone(),
801 system_instruction: self.system_instruction.clone(),
802 tools: self.tools.clone(),
803 max_iterations: self.max_iterations,
804 run_config: self.run_config.clone(),
805 model_override: self.model_override.clone(),
806 criteria: Vec::new(),
808 callbacks: Vec::new(),
809 memory: None,
810 app_name: None,
811 user_id: None,
812 session_id: None,
813 session_store: None,
814 user_state_store: None,
815 app_state_store: None,
816 }
817 }
818}
819
820impl PartialEq for Agent {
821 fn eq(&self, other: &Self) -> bool {
822 self.id == other.id
823 && self.provider == other.provider
824 && self.system_instruction == other.system_instruction
825 && self.max_iterations == other.max_iterations
826 && self.client == other.client
827 }
829}
830
831impl Serialize for Agent {
832 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
833 where
834 S: Serializer,
835 {
836 let mut state = serializer.serialize_struct("Agent", 3)?;
837 state.serialize_field("id", &self.id)?;
838 state.serialize_field("provider", &self.provider)?;
839 state.serialize_field("system_instruction", &self.system_instruction)?;
840 state.end()
841 }
842}
843
844impl<'de> Deserialize<'de> for Agent {
846 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
847 where
848 D: Deserializer<'de>,
849 {
850 #[derive(Deserialize)]
851 #[serde(field_identifier, rename_all = "snake_case")]
852 enum Field {
853 Id,
854 Provider,
855 SystemInstruction,
856 }
857
858 struct AgentVisitor;
859
860 impl<'de> Visitor<'de> for AgentVisitor {
861 type Value = Agent;
862
863 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
864 formatter.write_str("struct Agent")
865 }
866
867 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
868 where
869 V: MapAccess<'de>,
870 {
871 let mut id = None;
872 let mut provider = None;
873 let mut system_instruction = None;
874
875 while let Some(key) = map.next_key()? {
876 match key {
877 Field::Id => {
878 id = Some(map.next_value()?);
879 }
880 Field::Provider => {
881 provider = Some(map.next_value()?);
882 }
883 Field::SystemInstruction => {
884 system_instruction = Some(map.next_value()?);
885 }
886 }
887 }
888
889 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
890 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
891 let system_instruction = system_instruction
892 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
893
894 let client = GenAiClient::Undefined;
897 Ok(Agent {
898 id,
899 client: Arc::new(client),
900 system_instruction,
901 provider,
902 tools: Arc::new(RwLock::new(ToolRegistry::new())),
903 max_iterations: 10,
904 run_config: None,
905 model_override: None,
906 criteria: Vec::new(),
907 callbacks: Vec::new(),
908 memory: None,
909 app_name: None,
910 user_id: None,
911 session_id: None,
912 session_store: None,
913 user_state_store: None,
914 app_state_store: None,
915 })
916 }
917 }
918
919 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
920 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
921 }
922}
923
924#[pyclass(name = "Agent")]
925#[derive(Debug, Clone)]
926pub struct PyAgent {
927 pub agent: Arc<Agent>,
928}
929
930#[pymethods]
931impl PyAgent {
932 #[new]
933 #[pyo3(signature = (provider, system_instruction = None))]
934 pub fn new(
941 provider: &Bound<'_, PyAny>,
942 system_instruction: Option<&Bound<'_, PyAny>>,
943 ) -> Result<Self, AgentError> {
944 let provider = Provider::extract_provider(provider)?;
945 let system_instructions = extract_system_instructions(system_instruction, &provider)?;
946 let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
947
948 Ok(Self {
949 agent: Arc::new(agent),
950 })
951 }
952
953 #[pyo3(signature = (task, output_type=None))]
954 pub fn execute_task(
955 &self,
956 task: &mut Task,
957 output_type: Option<Bound<'_, PyAny>>,
958 ) -> Result<PyAgentResponse, AgentError> {
959 debug!("Executing task");
961
962 if task.prompt.provider != *self.agent.client_provider() {
964 return Err(AgentError::ProviderMismatch(
965 task.prompt.provider.to_string(),
966 self.agent.client_provider().as_str().to_string(),
967 ));
968 }
969
970 debug!(
971 "Task prompt model identifier: {}",
972 task.prompt.model_identifier()
973 );
974
975 let chat_response = block_on(async { self.agent.execute_task(task).await })?;
976
977 debug!("Task executed successfully");
978 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
979 let response = PyAgentResponse::new(chat_response, output);
980
981 Ok(response)
982 }
983
984 #[pyo3(signature = (prompt, output_type=None))]
992 pub fn execute_prompt(
993 &self,
994 prompt: &mut Prompt,
995 output_type: Option<Bound<'_, PyAny>>,
996 ) -> Result<PyAgentResponse, AgentError> {
997 debug!("Executing task");
999
1000 if prompt.provider != *self.agent.client_provider() {
1002 return Err(AgentError::ProviderMismatch(
1003 prompt.provider.to_string(),
1004 self.agent.client_provider().as_str().to_string(),
1005 ));
1006 }
1007
1008 let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
1009
1010 debug!("Task executed successfully");
1011 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
1012 let response = PyAgentResponse::new(chat_response, output);
1013
1014 Ok(response)
1015 }
1016
1017 #[getter]
1018 pub fn system_instruction<'py>(
1019 &self,
1020 py: Python<'py>,
1021 ) -> Result<Bound<'py, PyList>, AgentError> {
1022 let instructions = self
1023 .agent
1024 .system_instruction
1025 .iter()
1026 .map(|msg_num| msg_num.to_bound_py_object(py))
1027 .collect::<Result<Vec<_>, _>>()
1028 .map(|instructions| PyList::new(py, &instructions))?;
1029
1030 Ok(instructions?)
1031 }
1032
1033 #[getter]
1034 pub fn id(&self) -> &str {
1035 self.agent.id.as_str()
1036 }
1037}