Skip to main content

potato_agent/agents/
agent.rs

1use crate::agents::{
2    callbacks::{AgentCallback, CallbackAction},
3    criteria::CompletionCriteria,
4    error::AgentError,
5    memory::{Memory, MemoryTurn},
6    run_context::{AgentRunConfig, AgentRunContext, ResumeContext},
7    runner::{AgentRunOutcome, AgentRunResult, AgentRunner},
8    session::{SessionSnapshot, SessionState},
9    store::{
10        app_state_store::AppStateStore, persistent_memory::PersistentMemory,
11        session_store::SessionStore, user_state_store::UserStateStore,
12    },
13    task::Task,
14    tool_ext::AgentTool,
15    types::{AgentResponse, PyAgentResponse},
16};
17use async_trait::async_trait;
18use potato_provider::providers::anthropic::client::AnthropicClient;
19use potato_provider::providers::types::ServiceType;
20use potato_provider::GeminiClient;
21use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
22use potato_state::block_on;
23use potato_type::prompt::Prompt;
24use potato_type::prompt::{MessageNum, Role};
25use potato_type::Provider;
26use potato_type::{
27    prompt::extract_system_instructions,
28    tools::{Tool, ToolRegistry},
29};
30use potato_util::create_uuid7;
31use pyo3::prelude::*;
32use pyo3::types::PyList;
33use serde::{
34    de::{self, MapAccess, Visitor},
35    ser::SerializeStruct,
36    Deserializer, Serializer,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::Arc;
42use std::sync::RwLock;
43use tracing::{debug, instrument, warn};
44
45#[derive(Debug)]
46pub struct Agent {
47    pub id: String,
48    client: Arc<GenAiClient>,
49    pub provider: Provider,
50    pub system_instruction: Vec<MessageNum>,
51    pub tools: Arc<RwLock<ToolRegistry>>,
52    pub max_iterations: u32,
53    // --- new agentic-loop fields ---
54    pub run_config: Option<AgentRunConfig>,
55    /// If set, overrides the model in any Prompt built by AgentBuilder::run().
56    pub model_override: Option<String>,
57    pub criteria: Vec<Box<dyn CompletionCriteria>>,
58    pub callbacks: Vec<Arc<dyn AgentCallback>>,
59    pub memory: Option<Arc<tokio::sync::Mutex<Box<dyn Memory>>>>,
60    /// Application name used for store scoping.
61    pub app_name: Option<String>,
62    /// User identifier used for store scoping.
63    pub user_id: Option<String>,
64    /// Session identifier used to key store lookups.
65    pub session_id: Option<String>,
66    /// Optional durable session state store.
67    pub session_store: Option<Arc<dyn SessionStore>>,
68    /// Optional per-user state store.
69    pub user_state_store: Option<Arc<dyn UserStateStore>>,
70    /// Optional app-level state store.
71    pub app_state_store: Option<Arc<dyn AppStateStore>>,
72}
73
74/// Rust method implementation of the Agent
75impl Agent {
76    /// Helper method to rebuild the client, useful for deserialization
77    #[instrument(skip_all)]
78    pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
79        let client = match self.provider {
80            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
81            Provider::Gemini => {
82                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
83            }
84            Provider::Vertex => {
85                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
86            }
87            Provider::Anthropic => {
88                GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
89            }
90            Provider::Google => {
91                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
92            }
93            _ => {
94                return Err(AgentError::MissingProviderError);
95            } // Add other providers here as needed
96        };
97
98        Ok(Self {
99            id: self.id.clone(),
100            client: Arc::new(client),
101            system_instruction: self.system_instruction.clone(),
102            provider: self.provider.clone(),
103            tools: self.tools.clone(),
104            max_iterations: self.max_iterations,
105            run_config: None,
106            model_override: None,
107            criteria: Vec::new(),
108            callbacks: Vec::new(),
109            memory: None,
110            app_name: None,
111            user_id: None,
112            session_id: None,
113            session_store: None,
114            user_state_store: None,
115            app_state_store: None,
116        })
117    }
118    pub async fn new(
119        provider: Provider,
120        system_instruction: Option<Vec<MessageNum>>,
121    ) -> Result<Self, AgentError> {
122        let client = match provider {
123            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
124            Provider::Gemini => {
125                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
126            }
127            Provider::Vertex => {
128                GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
129            }
130            Provider::Anthropic => {
131                GenAiClient::Anthropic(AnthropicClient::new(ServiceType::Generate)?)
132            }
133            Provider::Google => {
134                GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
135            }
136            _ => {
137                return Err(AgentError::MissingProviderError);
138            } // Add other providers here as needed
139        };
140
141        Ok(Self {
142            client: Arc::new(client),
143            id: create_uuid7(),
144            system_instruction: system_instruction.unwrap_or_default(),
145            provider,
146            tools: Arc::new(RwLock::new(ToolRegistry::new())),
147            max_iterations: 10,
148            run_config: None,
149            model_override: None,
150            criteria: Vec::new(),
151            callbacks: Vec::new(),
152            memory: None,
153            app_name: None,
154            user_id: None,
155            session_id: None,
156            session_store: None,
157            user_state_store: None,
158            app_state_store: None,
159        })
160    }
161
162    pub fn register_tool(&self, tool: Box<dyn Tool + Send + Sync>) {
163        self.tools
164            .write()
165            .unwrap_or_else(|e| e.into_inner())
166            .register_tool(tool);
167    }
168
169    //TODO: add back later
170    /// Execute task with agentic reasoning loop
171    //pub async fn execute_agentic_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
172    //    let mut prompt = task.prompt.clone();
173    //    self.prepend_system_instructions(&mut prompt);
174
175    //    // Add tool definitions to prompt if tools are registered
176    //    let tool_definitions = self.tools.read().unwrap().get_definitions();
177    //    if !tool_definitions.is_empty() {
178    //        // Convert tools to provider-specific format and add to prompt
179    //        prompt.add_tools(tool_definitions)?;
180    //    }
181
182    //    let mut iteration = 0;
183    //    let mut conversation_history = Vec::new();
184
185    //    loop {
186    //        if iteration >= self.max_iterations {
187    //            return Err(AgentError::Error("Max iterations reached".to_string()));
188    //        }
189
190    //        // Generate response
191    //        let response = self.client.generate_content(&prompt).await?;
192
193    //        // Check if response contains tool calls
194    //        if let Some(tool_calls) = response.extract_tool_calls() {
195    //            debug!("Agent requesting {} tool calls", tool_calls.len());
196
197    //            // Execute all requested tools
198    //            let mut tool_results = Vec::new();
199    //            for tool_call in tool_calls {
200    //                let result = self.tools.read().unwrap().execute(&tool_call)?;
201    //                tool_results.push((tool_call.tool_name.clone(), result));
202    //            }
203
204    //            // Add tool results back to conversation
205    //            conversation_history.push(response.clone());
206    //            prompt.add_tool_results(tool_results)?;
207
208    //            iteration += 1;
209    //            continue;
210    //        }
211
212    //        // No tool calls - agent has final answer
213    //        return Ok(AgentResponse::new(task.id.clone(), response));
214    //    }
215    //}
216
217    #[instrument(skip_all)]
218    fn append_task_with_message_dependency_context(
219        &self,
220        task: &mut Task,
221        context_messages: &HashMap<String, Vec<MessageNum>>,
222    ) {
223        //
224        debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
225
226        if task.dependencies.is_empty() {
227            return;
228        }
229
230        let messages = task.prompt.request.messages_mut();
231        let first_user_idx = messages.iter().position(|msg| !msg.is_system_message());
232
233        match first_user_idx {
234            Some(insert_idx) => {
235                // Collect all dependency messages to insert
236                let mut dependency_messages = Vec::new();
237
238                for dep_id in &task.dependencies {
239                    if let Some(messages) = context_messages.get(dep_id) {
240                        debug!(
241                            "Adding {} messages from dependency {}",
242                            messages.len(),
243                            dep_id
244                        );
245                        dependency_messages.extend(messages.iter().cloned());
246                    }
247                }
248
249                // Always insert at same index to keep pushing user message forward
250                for message in dependency_messages.into_iter() {
251                    task.prompt
252                        .request
253                        .insert_message(message, Some(insert_idx))
254                }
255
256                debug!(
257                    "Inserted {} dependency messages before user message at index {}",
258                    task.dependencies.len(),
259                    insert_idx
260                );
261            }
262            None => {
263                warn!(
264                    "No user message found in task {}, appending dependency context to end",
265                    task.id
266                );
267
268                for dep_id in &task.dependencies {
269                    if let Some(messages) = context_messages.get(dep_id) {
270                        for message in messages {
271                            task.prompt.request.push_message(message.clone());
272                        }
273                    }
274                }
275            }
276        }
277    }
278
279    /// This function will bind dependency-specific context and global context if provided to the user prompt.
280    ///
281    /// # Arguments:
282    /// * `prompt` - The prompt to bind parameters to.
283    /// * `parameter_context` - A serde_json::Value containing the parameters to bind.
284    /// * `global_context` - An optional serde_json::Value containing global parameters to bind.
285    ///
286    /// # Returns:
287    /// * `Result<(), AgentError>` - Returns Ok(()) if successful, or an `AgentError` if there was an issue binding the parameters.
288    #[instrument(skip_all)]
289    fn bind_context(
290        &self,
291        prompt: &mut Prompt,
292        parameter_context: &Value,
293        global_context: &Option<Arc<Value>>,
294    ) -> Result<(), AgentError> {
295        // print user messages
296        if !prompt.parameters.is_empty() {
297            for param in &prompt.parameters {
298                // Bind parameter context to the user message
299                if let Some(value) = parameter_context.get(param) {
300                    for message in prompt.request.messages_mut() {
301                        if message.role() == Role::User.as_str() {
302                            debug!("Binding parameter: {} with value: {}", param, value);
303                            message.bind_mut(param, &value.to_string())?;
304                        }
305                    }
306                }
307
308                // If global context is provided, bind it to the user message
309                if let Some(global_value) = global_context {
310                    if let Some(value) = global_value.get(param) {
311                        for message in prompt.request.messages_mut() {
312                            if message.role() == Role::User.as_str() {
313                                debug!("Binding global parameter: {} with value: {}", param, value);
314                                message.bind_mut(param, &value.to_string())?;
315                            }
316                        }
317                    }
318                }
319            }
320        }
321        Ok(())
322    }
323
324    /// If system instructions are set on the agent, prepend them to the prompt.
325    /// Agent system instructions take precedence over task system instructions.
326    /// If a user wishes to be more dynamic based on the task, they should set system instructions on the task/prompt
327    fn prepend_system_instructions(&self, prompt: &mut Prompt) -> Result<(), AgentError> {
328        if !self.system_instruction.is_empty() {
329            prompt
330                .request
331                .prepend_system_instructions(self.system_instruction.clone())
332                .map_err(|e| AgentError::Error(e.to_string()))?;
333        }
334        Ok(())
335    }
336    pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
337        // Extract the prompt from the task
338        debug!("Executing task: {}, count: {}", task.id, task.retry_count);
339        let mut prompt = task.prompt.clone();
340        self.prepend_system_instructions(&mut prompt)?;
341
342        // Use the client to execute the task
343        let chat_response = self.client.generate_content(&prompt).await?;
344
345        Ok(AgentResponse::new(task.id.clone(), chat_response))
346    }
347
348    #[instrument(skip_all)]
349    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
350        // Extract the prompt from the task
351        debug!("Executing prompt");
352        let mut prompt = prompt.clone();
353        self.prepend_system_instructions(&mut prompt)?;
354
355        // Use the client to execute the task
356        let chat_response = self.client.generate_content(&prompt).await?;
357
358        Ok(AgentResponse::new(chat_response.id(), chat_response))
359    }
360
361    /// Execute task with context without mutating the original task
362    /// This method is used by the workflow executor to run individual tasks with context
363    #[instrument(skip_all)]
364    pub async fn execute_task_with_context(
365        &self,
366        task: &Arc<RwLock<Task>>,
367        context: &Value,
368    ) -> Result<AgentResponse, AgentError> {
369        // Clone prompt and task_id without holding lock across await
370        let (mut prompt, task_id) = {
371            let task = task.read().unwrap();
372            (task.prompt.clone(), task.id.clone())
373        };
374
375        self.bind_context(&mut prompt, context, &None)?;
376        self.prepend_system_instructions(&mut prompt)?;
377
378        let chat_response = self.client.generate_content(&prompt).await?;
379        Ok(AgentResponse::new(task_id, chat_response))
380    }
381
382    pub async fn execute_task_with_context_message(
383        &self,
384        task: &Arc<RwLock<Task>>,
385        context_messages: HashMap<String, Vec<MessageNum>>,
386        parameter_context: Value,
387        global_context: Option<Arc<Value>>,
388    ) -> Result<AgentResponse, AgentError> {
389        // Prepare prompt and context before await
390        let (prompt, task_id) = {
391            let mut task = task.write().unwrap();
392            // 1. Add dependency context (should come after system instructions, before user message)
393            self.append_task_with_message_dependency_context(&mut task, &context_messages);
394            // 2. Bind parameters
395            self.bind_context(&mut task.prompt, &parameter_context, &global_context)?;
396            // 3. Prepend agent system instructions (add to front)
397            self.prepend_system_instructions(&mut task.prompt)?;
398            (task.prompt.clone(), task.id.clone())
399        };
400
401        // Now do the async work without holding the lock
402        let chat_response = self.client.generate_content(&prompt).await?;
403        Ok(AgentResponse::new(task_id, chat_response))
404    }
405
406    pub fn client_provider(&self) -> &Provider {
407        self.client.provider()
408    }
409
410    // ── Agentic loop helper ─────────────────────────────────────────────────
411
412    /// Build a minimal one-turn Prompt from a plain input string.
413    fn build_input_prompt(&self, input: &str) -> Result<Prompt, AgentError> {
414        use potato_type::prompt::builder::to_provider_request;
415        use potato_type::prompt::settings::ModelSettings;
416        use potato_type::prompt::types::ResponseType;
417
418        let msg = {
419            use potato_type::traits::MessageFactory;
420            match self.provider {
421                Provider::OpenAI => {
422                    use potato_type::openai::v1::chat::request::ChatMessage;
423                    ChatMessage::from_text(input.to_string(), "user")
424                        .map(MessageNum::OpenAIMessageV1)?
425                }
426                Provider::Anthropic => {
427                    use potato_type::anthropic::v1::request::MessageParam;
428                    MessageParam::from_text(input.to_string(), "user")
429                        .map(MessageNum::AnthropicMessageV1)?
430                }
431                Provider::Gemini | Provider::Google | Provider::Vertex => {
432                    use potato_type::google::v1::generate::request::GeminiContent;
433                    GeminiContent::from_text(input.to_string(), "user")
434                        .map(MessageNum::GeminiContentV1)?
435                }
436                _ => {
437                    return Err(AgentError::MissingProviderError);
438                }
439            }
440        };
441
442        let model = self.model_override.clone().ok_or_else(|| {
443            AgentError::Error("model must be set explicitly via AgentBuilder::model()".into())
444        })?;
445
446        let settings = ModelSettings::provider_default_settings(&self.provider);
447
448        let request = to_provider_request(
449            vec![msg],
450            self.system_instruction.clone(),
451            model.clone(),
452            settings,
453            None,
454        )?;
455
456        Ok(Prompt {
457            request,
458            model,
459            provider: self.provider.clone(),
460            version: env!("CARGO_PKG_VERSION").to_string(),
461            parameters: Vec::new(),
462            response_type: ResponseType::Null,
463        })
464    }
465
466    /// Fire `before_model_call` callbacks. Returns an error if any callback aborts.
467    fn fire_before_model(&self, ctx: &AgentRunContext, prompt: &Prompt) -> Result<(), AgentError> {
468        for cb in &self.callbacks {
469            if let CallbackAction::Abort(msg) = cb.before_model_call(ctx, prompt) {
470                return Err(AgentError::CallbackAbort(msg));
471            }
472        }
473        Ok(())
474    }
475
476    /// Fire `after_model_call` callbacks. Returns `Some(override_text)` or None.
477    fn fire_after_model(
478        &self,
479        ctx: &AgentRunContext,
480        response: &AgentResponse,
481    ) -> Result<Option<String>, AgentError> {
482        for cb in &self.callbacks {
483            match cb.after_model_call(ctx, response) {
484                CallbackAction::Abort(msg) => return Err(AgentError::CallbackAbort(msg)),
485                CallbackAction::OverrideResponse(text) => return Ok(Some(text)),
486                CallbackAction::Continue => {}
487            }
488        }
489        Ok(None)
490    }
491
492    /// Fire `before_tool_call` callbacks.
493    fn fire_before_tool(
494        &self,
495        ctx: &AgentRunContext,
496        call: &potato_type::tools::ToolCall,
497    ) -> Result<(), AgentError> {
498        for cb in &self.callbacks {
499            if let CallbackAction::Abort(msg) = cb.before_tool_call(ctx, call) {
500                return Err(AgentError::CallbackAbort(msg));
501            }
502        }
503        Ok(())
504    }
505
506    /// Fire `after_tool_call` callbacks.
507    fn fire_after_tool(
508        &self,
509        ctx: &AgentRunContext,
510        call: &potato_type::tools::ToolCall,
511        result: &serde_json::Value,
512    ) -> Result<(), AgentError> {
513        for cb in &self.callbacks {
514            if let CallbackAction::Abort(msg) = cb.after_tool_call(ctx, call, result) {
515                return Err(AgentError::CallbackAbort(msg));
516            }
517        }
518        Ok(())
519    }
520}
521
522#[async_trait]
523impl AgentRunner for Agent {
524    fn id(&self) -> &str {
525        &self.id
526    }
527
528    async fn run(
529        &self,
530        input: &str,
531        session: &mut SessionState,
532    ) -> Result<AgentRunOutcome, AgentError> {
533        let max_iter = self
534            .run_config
535            .as_ref()
536            .map(|c| c.max_iterations)
537            .unwrap_or(self.max_iterations);
538
539        let mut run_ctx = AgentRunContext::new(self.id.clone(), max_iter);
540
541        let app = self.app_name.as_deref().unwrap_or("default");
542        let uid = self.user_id.as_deref().unwrap_or("default");
543
544        // Load app-level state (lowest precedence — overwritten by later loads).
545        if let Some(store) = &self.app_state_store {
546            if let Some(snapshot) = store.load(app).await? {
547                session.merge(snapshot.0);
548            }
549        }
550
551        // Load user-level state (medium precedence).
552        if let Some(store) = &self.user_state_store {
553            if let Some(snapshot) = store.load(app, uid).await? {
554                session.merge(snapshot.0);
555            }
556        }
557
558        // Load session state (highest precedence — wins over user and app).
559        if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
560            if let Some(snapshot) = store.load(app, uid, sid).await? {
561                session.merge(snapshot.0);
562            }
563        }
564
565        // Build the prompt from the input string
566        let mut prompt = self.build_input_prompt(input)?;
567
568        // Hydrate PersistentMemory from the backing store (lazy, idempotent).
569        if let Some(mem_lock) = &self.memory {
570            let mut mem = mem_lock.lock().await;
571            if let Some(pm) = mem
572                .as_any_mut()
573                .and_then(|a| a.downcast_mut::<PersistentMemory>())
574            {
575                pm.hydrate().await?;
576            }
577        }
578
579        // Inject memory history in chronological order, after any system messages
580        if let Some(mem_lock) = &self.memory {
581            let mem = mem_lock.lock().await;
582            let history = mem.messages();
583            if !history.is_empty() {
584                // Find the first non-system message position; insert history before it
585                let insert_at = prompt
586                    .request
587                    .messages()
588                    .iter()
589                    .position(|m| !m.is_system_message())
590                    .unwrap_or(0);
591                for (i, msg) in history.into_iter().enumerate() {
592                    prompt.request.insert_message(msg, Some(insert_at + i));
593                }
594            }
595        }
596
597        // Attach tool definitions
598        {
599            let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
600            let defs = registry.get_all_definitions();
601            if !defs.is_empty() {
602                prompt.request.add_tools(defs)?;
603            }
604        }
605
606        let mut last_user_msg: Option<MessageNum> = None;
607        // Capture the user message for memory storage later
608        if let Some(msg) = prompt.request.messages().last().cloned() {
609            last_user_msg = Some(msg);
610        }
611
612        loop {
613            // Check max iterations
614            if run_ctx.iteration >= max_iter {
615                break;
616            }
617
618            // Before-model callbacks
619            self.fire_before_model(&run_ctx, &prompt)?;
620
621            // Call the LLM
622            let chat_response = self.client.generate_content(&prompt).await?;
623            let agent_response = AgentResponse::new(chat_response.id(), chat_response.clone());
624
625            // After-model callbacks
626            if let Some(override_text) = self.fire_after_model(&run_ctx, &agent_response)? {
627                run_ctx.push_response(override_text.clone());
628                return Ok(AgentRunOutcome::complete(AgentRunResult {
629                    final_response: agent_response,
630                    iterations: run_ctx.iteration,
631                    completion_reason: format!("callback override: {}", override_text),
632                    combined_text: None,
633                }));
634            }
635
636            // Check for tool calls
637            if let Some(tool_calls) = chat_response.extract_tool_calls() {
638                // Append assistant message with tool calls to the prompt
639                let assistant_msgs = chat_response.to_message_num(&self.provider)?;
640                for msg in assistant_msgs {
641                    prompt.request.push_message(msg);
642                }
643
644                for call in &tool_calls {
645                    self.fire_before_tool(&run_ctx, call)?;
646
647                    // Try async tool first, then sync
648                    let result = {
649                        let async_tool = {
650                            let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
651                            registry.get_async_tool(&call.tool_name)
652                        };
653                        if let Some(tool) = async_tool {
654                            if let Some(agent_tool) =
655                                tool.as_any().and_then(|a| a.downcast_ref::<AgentTool>())
656                            {
657                                // Route AgentTool through dispatch() to propagate ancestor tracking.
658                                agent_tool
659                                    .dispatch(call.arguments.clone(), session)
660                                    .await
661                                    .map_err(|e| {
662                                        AgentError::Error(format!(
663                                            "Tool '{}' failed: {}",
664                                            call.tool_name, e
665                                        ))
666                                    })?
667                            } else {
668                                tool.execute(call.arguments.clone()).await.map_err(|e| {
669                                    AgentError::Error(format!(
670                                        "Tool '{}' failed: {}",
671                                        call.tool_name, e
672                                    ))
673                                })?
674                            }
675                        } else {
676                            let registry = self.tools.read().unwrap_or_else(|e| e.into_inner());
677                            registry.execute(call).map_err(|e| {
678                                AgentError::Error(format!(
679                                    "Tool '{}' failed: {}",
680                                    call.tool_name, e
681                                ))
682                            })?
683                        }
684                    };
685
686                    self.fire_after_tool(&run_ctx, call, &result)?;
687                    prompt.request.add_tool_result(call, &result)?;
688                }
689
690                run_ctx.increment();
691                continue;
692            }
693
694            // No tool calls — this is a candidate final response
695            let text = chat_response.response_text();
696
697            // Check for ask_user tool pattern (special built-in)
698            if text.trim().starts_with("__ask_user__:") {
699                let question = text.trim_start_matches("__ask_user__:").trim().to_string();
700                let resume_ctx = ResumeContext {
701                    agent_id: self.id.clone(),
702                    iteration: run_ctx.iteration,
703                    session_snapshot: session.snapshot(),
704                };
705                return Ok(AgentRunOutcome::NeedsInput {
706                    question,
707                    resume_context: resume_ctx,
708                });
709            }
710
711            run_ctx.push_response(text);
712
713            // Check completion criteria (any = stop)
714            let met = self.criteria.iter().any(|c| c.is_complete(&run_ctx));
715            let reason = if met {
716                self.criteria
717                    .iter()
718                    .find(|c| c.is_complete(&run_ctx))
719                    .map(|c| c.completion_reason(&run_ctx))
720                    .unwrap_or_else(|| "criteria met".into())
721            } else {
722                String::new()
723            };
724
725            if met || run_ctx.iteration + 1 >= max_iter {
726                // Store memory turn
727                if let Some(mem_lock) = &self.memory {
728                    let mut mem = mem_lock.lock().await;
729                    if let Some(user_msg) = last_user_msg.take() {
730                        let assistant_msgs = chat_response.to_message_num(&self.provider)?;
731                        if let Some(asst_msg) = assistant_msgs.into_iter().next() {
732                            let turn = MemoryTurn {
733                                user: user_msg,
734                                assistant: asst_msg,
735                            };
736                            // Use write-through async path for PersistentMemory.
737                            if let Some(pm) = mem
738                                .as_any_mut()
739                                .and_then(|a| a.downcast_mut::<PersistentMemory>())
740                            {
741                                pm.push_turn_async(turn).await?;
742                            } else {
743                                mem.push_turn(turn);
744                            }
745                        }
746                    }
747                }
748
749                // Persist session state to backing store.
750                if let (Some(sid), Some(store)) = (&self.session_id, &self.session_store) {
751                    let snapshot = SessionSnapshot::from(&*session);
752                    store.save(app, uid, sid, &snapshot).await?;
753                }
754
755                return Ok(AgentRunOutcome::complete(AgentRunResult {
756                    final_response: agent_response,
757                    iterations: run_ctx.iteration,
758                    completion_reason: if met {
759                        reason
760                    } else {
761                        format!("max iterations ({}) reached", max_iter)
762                    },
763                    combined_text: None,
764                }));
765            }
766
767            // Not complete yet — append assistant message and continue
768            let assistant_msgs = chat_response.to_message_num(&self.provider)?;
769            for msg in assistant_msgs {
770                prompt.request.push_message(msg);
771            }
772
773            run_ctx.increment();
774        }
775
776        // Fell out of the loop without a final response — max iterations were all spent on tool calls
777        Err(AgentError::MaxIterationsExceeded(max_iter))
778    }
779
780    async fn resume(
781        &self,
782        user_answer: &str,
783        ctx: ResumeContext,
784        session: &mut SessionState,
785    ) -> Result<AgentRunOutcome, AgentError> {
786        // Restore session from the snapshot in ResumeContext
787        session.merge(ctx.session_snapshot);
788        // Re-run with the user's answer as new input
789        self.run(user_answer, session).await
790    }
791}
792
793/// Manual Clone: clones the provider-level fields; criteria/callbacks/memory are NOT cloned.
794/// This preserves backward compatibility with the workflow layer which stores `Arc<Agent>`.
795impl Clone for Agent {
796    fn clone(&self) -> Self {
797        Self {
798            id: self.id.clone(),
799            client: self.client.clone(),
800            provider: self.provider.clone(),
801            system_instruction: self.system_instruction.clone(),
802            tools: self.tools.clone(),
803            max_iterations: self.max_iterations,
804            run_config: self.run_config.clone(),
805            model_override: self.model_override.clone(),
806            // Non-clonable fields — intentionally reset on clone
807            criteria: Vec::new(),
808            callbacks: Vec::new(),
809            memory: None,
810            app_name: None,
811            user_id: None,
812            session_id: None,
813            session_store: None,
814            user_state_store: None,
815            app_state_store: None,
816        }
817    }
818}
819
820impl PartialEq for Agent {
821    fn eq(&self, other: &Self) -> bool {
822        self.id == other.id
823            && self.provider == other.provider
824            && self.system_instruction == other.system_instruction
825            && self.max_iterations == other.max_iterations
826            && self.client == other.client
827        // criteria / callbacks / memory intentionally excluded
828    }
829}
830
831impl Serialize for Agent {
832    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
833    where
834        S: Serializer,
835    {
836        let mut state = serializer.serialize_struct("Agent", 3)?;
837        state.serialize_field("id", &self.id)?;
838        state.serialize_field("provider", &self.provider)?;
839        state.serialize_field("system_instruction", &self.system_instruction)?;
840        state.end()
841    }
842}
843
844/// Allows for deserialization of the Agent, re-initializing the client.
845impl<'de> Deserialize<'de> for Agent {
846    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
847    where
848        D: Deserializer<'de>,
849    {
850        #[derive(Deserialize)]
851        #[serde(field_identifier, rename_all = "snake_case")]
852        enum Field {
853            Id,
854            Provider,
855            SystemInstruction,
856        }
857
858        struct AgentVisitor;
859
860        impl<'de> Visitor<'de> for AgentVisitor {
861            type Value = Agent;
862
863            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
864                formatter.write_str("struct Agent")
865            }
866
867            fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
868            where
869                V: MapAccess<'de>,
870            {
871                let mut id = None;
872                let mut provider = None;
873                let mut system_instruction = None;
874
875                while let Some(key) = map.next_key()? {
876                    match key {
877                        Field::Id => {
878                            id = Some(map.next_value()?);
879                        }
880                        Field::Provider => {
881                            provider = Some(map.next_value()?);
882                        }
883                        Field::SystemInstruction => {
884                            system_instruction = Some(map.next_value()?);
885                        }
886                    }
887                }
888
889                let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
890                let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
891                let system_instruction = system_instruction
892                    .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
893
894                // Deserialize is a sync op, so we can't await here (gemini requires async to init)
895                // After deserialization, we re-initialize the client based on the provider
896                let client = GenAiClient::Undefined;
897                Ok(Agent {
898                    id,
899                    client: Arc::new(client),
900                    system_instruction,
901                    provider,
902                    tools: Arc::new(RwLock::new(ToolRegistry::new())),
903                    max_iterations: 10,
904                    run_config: None,
905                    model_override: None,
906                    criteria: Vec::new(),
907                    callbacks: Vec::new(),
908                    memory: None,
909                    app_name: None,
910                    user_id: None,
911                    session_id: None,
912                    session_store: None,
913                    user_state_store: None,
914                    app_state_store: None,
915                })
916            }
917        }
918
919        const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
920        deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
921    }
922}
923
924#[pyclass(name = "Agent")]
925#[derive(Debug, Clone)]
926pub struct PyAgent {
927    pub agent: Arc<Agent>,
928}
929
930#[pymethods]
931impl PyAgent {
932    #[new]
933    #[pyo3(signature = (provider, system_instruction = None))]
934    /// Creates a new Agent instance.
935    ///
936    /// # Arguments:
937    /// * `provider` - A Python object representing the provider, expected to be an a variant of Provider or a string
938    /// that can be mapped to a provider variant
939    ///
940    pub fn new(
941        provider: &Bound<'_, PyAny>,
942        system_instruction: Option<&Bound<'_, PyAny>>,
943    ) -> Result<Self, AgentError> {
944        let provider = Provider::extract_provider(provider)?;
945        let system_instructions = extract_system_instructions(system_instruction, &provider)?;
946        let agent = block_on(async { Agent::new(provider, system_instructions).await })?;
947
948        Ok(Self {
949            agent: Arc::new(agent),
950        })
951    }
952
953    #[pyo3(signature = (task, output_type=None))]
954    pub fn execute_task(
955        &self,
956        task: &mut Task,
957        output_type: Option<Bound<'_, PyAny>>,
958    ) -> Result<PyAgentResponse, AgentError> {
959        // Extract the prompt from the task
960        debug!("Executing task");
961
962        // agent provider and task.prompt provider must match
963        if task.prompt.provider != *self.agent.client_provider() {
964            return Err(AgentError::ProviderMismatch(
965                task.prompt.provider.to_string(),
966                self.agent.client_provider().as_str().to_string(),
967            ));
968        }
969
970        debug!(
971            "Task prompt model identifier: {}",
972            task.prompt.model_identifier()
973        );
974
975        let chat_response = block_on(async { self.agent.execute_task(task).await })?;
976
977        debug!("Task executed successfully");
978        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
979        let response = PyAgentResponse::new(chat_response, output);
980
981        Ok(response)
982    }
983
984    /// Executes a prompt directly without a task.
985    /// # Arguments:
986    /// * `prompt` - The prompt to execute.
987    /// * `output_type` - An optional Python type to bind the response to. If
988    /// provide, it is expected that the output type_object matches the response schema defined in the prompt.
989    /// # Returns:
990    /// * `PyAgentResponse` - The response from the agent.
991    #[pyo3(signature = (prompt, output_type=None))]
992    pub fn execute_prompt(
993        &self,
994        prompt: &mut Prompt,
995        output_type: Option<Bound<'_, PyAny>>,
996    ) -> Result<PyAgentResponse, AgentError> {
997        // Extract the prompt from the task
998        debug!("Executing task");
999
1000        // agent provider and task.prompt provider must match
1001        if prompt.provider != *self.agent.client_provider() {
1002            return Err(AgentError::ProviderMismatch(
1003                prompt.provider.to_string(),
1004                self.agent.client_provider().as_str().to_string(),
1005            ));
1006        }
1007
1008        let chat_response = block_on(async { self.agent.execute_prompt(prompt).await })?;
1009
1010        debug!("Task executed successfully");
1011        let output = output_type.as_ref().map(|obj| obj.clone().unbind());
1012        let response = PyAgentResponse::new(chat_response, output);
1013
1014        Ok(response)
1015    }
1016
1017    #[getter]
1018    pub fn system_instruction<'py>(
1019        &self,
1020        py: Python<'py>,
1021    ) -> Result<Bound<'py, PyList>, AgentError> {
1022        let instructions = self
1023            .agent
1024            .system_instruction
1025            .iter()
1026            .map(|msg_num| msg_num.to_bound_py_object(py))
1027            .collect::<Result<Vec<_>, _>>()
1028            .map(|instructions| PyList::new(py, &instructions))?;
1029
1030        Ok(instructions?)
1031    }
1032
1033    #[getter]
1034    pub fn id(&self) -> &str {
1035        self.agent.id.as_str()
1036    }
1037}