Skip to main content

serdes_ai_agent/
builder.rs

1//! Agent builder pattern.
2//!
3//! The builder provides a fluent interface for configuring agents.
4//!
5//! # Examples
6//!
7//! ## Using a model spec string (simplest)
8//!
9//! ```ignore
10//! use serdes_ai_agent::AgentBuilder;
11//!
12//! // Uses environment variables for API keys
13//! let agent = AgentBuilder::from_model("openai:gpt-4o")?
14//!     .system_prompt("You are helpful.")
15//!     .build();
16//! ```
17//!
18//! ## With explicit API key
19//!
20//! ```ignore
21//! use serdes_ai_agent::{AgentBuilder, ModelConfig};
22//!
23//! let config = ModelConfig::new("openai:gpt-4o")
24//!     .with_api_key("sk-your-api-key");
25//!
26//! let agent = AgentBuilder::from_config(config)?
27//!     .system_prompt("You are helpful.")
28//!     .build();
29//! ```
30//!
31//! ## With concrete model type (most control)
32//!
33//! ```ignore
34//! use serdes_ai_agent::AgentBuilder;
35//! use serdes_ai_models::openai::OpenAIChatModel;
36//!
37//! let model = OpenAIChatModel::new("gpt-4o", "sk-your-api-key")
38//!     .with_base_url("https://custom-endpoint.com/v1");
39//!
40//! let agent = AgentBuilder::new(model)
41//!     .system_prompt("You are helpful.")
42//!     .build();
43//! ```
44
45use crate::agent::{Agent, EndStrategy, InstrumentationSettings, RegisteredTool, ToolExecutor};
46use crate::context::{RunContext, UsageLimits};
47use crate::errors::OutputValidationError;
48use crate::history::HistoryProcessor;
49use crate::instructions::{
50    AsyncInstructionFn, AsyncSystemPromptFn, InstructionFn, SyncInstructionFn, SyncSystemPromptFn,
51    SystemPromptFn,
52};
53use crate::output::{
54    DefaultOutputSchema, JsonOutputSchema, OutputSchema, OutputValidator, SyncValidator,
55    ToolOutputSchema,
56};
57use serde::de::DeserializeOwned;
58use serde_json::Value as JsonValue;
59use serdes_ai_core::ModelSettings;
60use serdes_ai_models::{Model, ModelError};
61use serdes_ai_tools::{ToolDefinition, ToolError, ToolReturn};
62use std::future::Future;
63use std::marker::PhantomData;
64use std::sync::Arc;
65use std::time::Duration;
66
67// ============================================================================
68// Model Configuration
69// ============================================================================
70
71/// Configuration for creating a model from a string spec.
72///
73/// This allows specifying a model using the standard `provider:model` format
74/// while also providing custom API keys, base URLs, and other options.
75///
76/// # Examples
77///
78/// ```ignore
79/// use serdes_ai_agent::ModelConfig;
80///
81/// // Simple: just a model spec (uses env vars for keys)
82/// let config = ModelConfig::new("openai:gpt-4o");
83///
84/// // With explicit API key
85/// let config = ModelConfig::new("anthropic:claude-3-5-sonnet-20241022")
86///     .with_api_key("sk-ant-your-key");
87///
88/// // With custom base URL (for proxies or compatible APIs)
89/// let config = ModelConfig::new("openai:gpt-4o")
90///     .with_api_key("your-key")
91///     .with_base_url("https://your-proxy.com/v1");
92/// ```
93#[derive(Debug, Clone)]
94pub struct ModelConfig {
95    /// Model spec in `provider:model` format (e.g., "openai:gpt-4o")
96    pub spec: String,
97    /// Optional API key (overrides environment variable)
98    pub api_key: Option<String>,
99    /// Optional base URL (for custom endpoints)
100    pub base_url: Option<String>,
101    /// Optional request timeout
102    pub timeout: Option<Duration>,
103}
104
105impl ModelConfig {
106    /// Create a new model config from a spec string.
107    ///
108    /// The spec should be in `provider:model` format, e.g.:
109    /// - `"openai:gpt-4o"`
110    /// - `"anthropic:claude-3-5-sonnet-20241022"`
111    /// - `"groq:llama-3.1-70b-versatile"`
112    /// - `"ollama:llama3.1"`
113    ///
114    /// If no provider prefix is given, OpenAI is assumed.
115    #[must_use]
116    pub fn new(spec: impl Into<String>) -> Self {
117        Self {
118            spec: spec.into(),
119            api_key: None,
120            base_url: None,
121            timeout: None,
122        }
123    }
124
125    /// Set an explicit API key (overrides environment variable).
126    #[must_use]
127    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
128        self.api_key = Some(api_key.into());
129        self
130    }
131
132    /// Set a custom base URL (for proxies or compatible APIs).
133    #[must_use]
134    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
135        self.base_url = Some(base_url.into());
136        self
137    }
138
139    /// Set a request timeout.
140    #[must_use]
141    pub fn with_timeout(mut self, timeout: Duration) -> Self {
142        self.timeout = Some(timeout);
143        self
144    }
145
146    /// Parse the provider and model name from the spec.
147    fn parse_spec(&self) -> (&str, &str) {
148        if self.spec.contains(':') {
149            let parts: Vec<&str> = self.spec.splitn(2, ':').collect();
150            (parts[0], parts[1])
151        } else {
152            ("openai", self.spec.as_str())
153        }
154    }
155
156    /// Build a model from this configuration.
157    ///
158    /// This creates the appropriate model type based on the provider,
159    /// applying any custom API key, base URL, or timeout settings.
160    ///
161    /// # Note
162    ///
163    /// This method delegates to `serdes_ai_models::infer_model_with_config` when
164    /// using default settings (no custom API key/base URL), or creates the model
165    /// directly when custom configuration is provided.
166    ///
167    /// The available providers depend on the features enabled in `serdes-ai-models`:
168    /// - `openai` (default) - OpenAI models (gpt-4o, gpt-4, etc.)
169    /// - `anthropic` - Anthropic models (claude-3-5-sonnet, etc.)
170    /// - `groq` - Groq models
171    /// - `mistral` - Mistral models
172    /// - `ollama` - Local Ollama models
173    /// - `google` - Google/Gemini models
174    pub fn build_model(&self) -> Result<Arc<dyn Model>, ModelError> {
175        // If no custom config, use infer_model which handles feature flags
176        if self.api_key.is_none() && self.base_url.is_none() && self.timeout.is_none() {
177            return serdes_ai_models::infer_model(&self.spec);
178        }
179
180        // Custom config requires building the model directly
181        let (provider, model_name) = self.parse_spec();
182
183        // We need to build the model with custom settings
184        // This requires the concrete model types which are behind feature flags
185        // in serdes-ai-models. We use a helper function pattern.
186        self.build_model_with_config(provider, model_name)
187    }
188
189    fn build_model_with_config(
190        &self,
191        provider: &str,
192        model_name: &str,
193    ) -> Result<Arc<dyn Model>, ModelError> {
194        // Use serdes_ai_models to build models - it has the feature flags
195        serdes_ai_models::build_model_with_config(
196            provider,
197            model_name,
198            self.api_key.as_deref(),
199            self.base_url.as_deref(),
200            self.timeout,
201        )
202    }
203}
204
205/// Builder for creating agents.
206pub struct AgentBuilder<Deps = (), Output = String> {
207    model: Arc<dyn Model>,
208    name: Option<String>,
209    model_settings: ModelSettings,
210    instructions: Vec<String>,
211    instruction_fns: Vec<Box<dyn InstructionFn<Deps>>>,
212    system_prompts: Vec<String>,
213    system_prompt_fns: Vec<Box<dyn SystemPromptFn<Deps>>>,
214    tools: Vec<RegisteredTool<Deps>>,
215    output_schema: Option<Box<dyn OutputSchema<Output>>>,
216    output_validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
217    end_strategy: EndStrategy,
218    max_output_retries: u32,
219    max_tool_retries: u32,
220    usage_limits: Option<UsageLimits>,
221    history_processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
222    instrument: Option<InstrumentationSettings>,
223    parallel_tool_calls: bool,
224    max_concurrent_tools: Option<usize>,
225    _phantom: PhantomData<(Deps, Output)>,
226}
227
228impl<Deps, Output> AgentBuilder<Deps, Output>
229where
230    Deps: Send + Sync + 'static,
231    Output: Send + Sync + 'static,
232{
233    /// Create a new agent builder with the given model.
234    ///
235    /// This is the most flexible constructor, accepting any type that implements
236    /// the `Model` trait. Use this when you need full control over model configuration.
237    ///
238    /// # Example
239    ///
240    /// ```ignore
241    /// use serdes_ai_agent::AgentBuilder;
242    /// use serdes_ai_models::openai::OpenAIChatModel;
243    ///
244    /// let model = OpenAIChatModel::new("gpt-4o", "sk-your-api-key");
245    /// let agent = AgentBuilder::new(model)
246    ///     .system_prompt("You are helpful.")
247    ///     .build();
248    /// ```
249    pub fn new<M: Model + 'static>(model: M) -> Self {
250        Self::from_arc(Arc::new(model))
251    }
252
253    /// Create a new agent builder from an `Arc<dyn Model>`.
254    ///
255    /// This is useful when you already have a model wrapped in an Arc,
256    /// such as from `infer_model()`.
257    ///
258    /// # Example
259    ///
260    /// ```ignore
261    /// use serdes_ai_agent::AgentBuilder;
262    /// use serdes_ai_models::infer_model;
263    ///
264    /// let model = infer_model("openai:gpt-4o")?;
265    /// let agent = AgentBuilder::from_arc(model)
266    ///     .system_prompt("You are helpful.")
267    ///     .build();
268    /// ```
269    pub fn from_arc(model: Arc<dyn Model>) -> Self {
270        Self {
271            model,
272            name: None,
273            model_settings: ModelSettings::default(),
274            instructions: Vec::new(),
275            instruction_fns: Vec::new(),
276            system_prompts: Vec::new(),
277            system_prompt_fns: Vec::new(),
278            tools: Vec::new(),
279            output_schema: None,
280            output_validators: Vec::new(),
281            end_strategy: EndStrategy::Early,
282            max_output_retries: 3,
283            max_tool_retries: 3,
284            usage_limits: None,
285            history_processors: Vec::new(),
286            instrument: None,
287            parallel_tool_calls: true,
288            max_concurrent_tools: None,
289            _phantom: PhantomData,
290        }
291    }
292
293    /// Create a new agent builder from a model spec string.
294    ///
295    /// This is the simplest way to create an agent when you just need to specify
296    /// the model. API keys are read from environment variables.
297    ///
298    /// # Model Spec Format
299    ///
300    /// The spec should be in `provider:model` format:
301    /// - `"openai:gpt-4o"` - OpenAI GPT-4o
302    /// - `"anthropic:claude-3-5-sonnet-20241022"` - Anthropic Claude
303    /// - `"groq:llama-3.1-70b-versatile"` - Groq
304    /// - `"ollama:llama3.1"` - Local Ollama
305    ///
306    /// If no provider prefix is given, OpenAI is assumed.
307    ///
308    /// # Example
309    ///
310    /// ```ignore
311    /// use serdes_ai_agent::AgentBuilder;
312    ///
313    /// let agent = AgentBuilder::from_model("openai:gpt-4o")?
314    ///     .system_prompt("You are helpful.")
315    ///     .build();
316    /// ```
317    ///
318    /// # Errors
319    ///
320    /// Returns an error if the model cannot be created (e.g., missing API key,
321    /// unsupported provider, or disabled feature).
322    pub fn from_model(spec: impl Into<String>) -> Result<Self, ModelError> {
323        let config = ModelConfig::new(spec);
324        Self::from_config(config)
325    }
326
327    /// Create a new agent builder from a model configuration.
328    ///
329    /// This allows specifying custom API keys, base URLs, and other options
330    /// while still using the convenient string-based model spec.
331    ///
332    /// # Example
333    ///
334    /// ```ignore
335    /// use serdes_ai_agent::{AgentBuilder, ModelConfig};
336    ///
337    /// let config = ModelConfig::new("openai:gpt-4o")
338    ///     .with_api_key("sk-your-api-key")
339    ///     .with_base_url("https://your-proxy.com/v1");
340    ///
341    /// let agent = AgentBuilder::from_config(config)?
342    ///     .system_prompt("You are helpful.")
343    ///     .build();
344    /// ```
345    ///
346    /// # Errors
347    ///
348    /// Returns an error if the model cannot be created.
349    pub fn from_config(config: ModelConfig) -> Result<Self, ModelError> {
350        let model = config.build_model()?;
351        Ok(Self::from_arc(model))
352    }
353
354    /// Set agent name.
355    #[must_use]
356    pub fn name(mut self, name: impl Into<String>) -> Self {
357        self.name = Some(name.into());
358        self
359    }
360
361    /// Set model settings.
362    #[must_use]
363    pub fn model_settings(mut self, settings: ModelSettings) -> Self {
364        self.model_settings = settings;
365        self
366    }
367
368    /// Set temperature.
369    #[must_use]
370    pub fn temperature(mut self, temp: f64) -> Self {
371        self.model_settings = self.model_settings.temperature(temp);
372        self
373    }
374
375    /// Set max tokens.
376    #[must_use]
377    pub fn max_tokens(mut self, tokens: u64) -> Self {
378        self.model_settings = self.model_settings.max_tokens(tokens);
379        self
380    }
381
382    /// Set top-p.
383    #[must_use]
384    pub fn top_p(mut self, p: f64) -> Self {
385        self.model_settings = self.model_settings.top_p(p);
386        self
387    }
388
389    /// Add static instructions.
390    #[must_use]
391    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
392        self.instructions.push(instructions.into());
393        self
394    }
395
396    /// Add dynamic instructions function (async).
397    #[must_use]
398    pub fn instructions_fn<F, Fut>(mut self, f: F) -> Self
399    where
400        F: Fn(&RunContext<Deps>) -> Fut + Send + Sync + 'static,
401        Fut: Future<Output = Option<String>> + Send + 'static,
402    {
403        self.instruction_fns
404            .push(Box::new(AsyncInstructionFn::new(f)));
405        self
406    }
407
408    /// Add dynamic instructions function (sync).
409    #[must_use]
410    pub fn instructions_fn_sync<F>(mut self, f: F) -> Self
411    where
412        F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync + 'static,
413    {
414        self.instruction_fns
415            .push(Box::new(SyncInstructionFn::new(f)));
416        self
417    }
418
419    /// Add system prompt.
420    #[must_use]
421    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
422        self.system_prompts.push(prompt.into());
423        self
424    }
425
426    /// Add dynamic system prompt function (async).
427    #[must_use]
428    pub fn system_prompt_fn<F, Fut>(mut self, f: F) -> Self
429    where
430        F: Fn(&RunContext<Deps>) -> Fut + Send + Sync + 'static,
431        Fut: Future<Output = Option<String>> + Send + 'static,
432    {
433        self.system_prompt_fns
434            .push(Box::new(AsyncSystemPromptFn::new(f)));
435        self
436    }
437
438    /// Add dynamic system prompt function (sync).
439    #[must_use]
440    pub fn system_prompt_fn_sync<F>(mut self, f: F) -> Self
441    where
442        F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync + 'static,
443    {
444        self.system_prompt_fns
445            .push(Box::new(SyncSystemPromptFn::new(f)));
446        self
447    }
448
449    /// Add a tool with a custom executor.
450    #[must_use]
451    pub fn tool_with_executor<E>(mut self, definition: ToolDefinition, executor: E) -> Self
452    where
453        E: ToolExecutor<Deps> + 'static,
454    {
455        self.tools.push(RegisteredTool {
456            definition,
457            executor: Arc::new(executor),
458            max_retries: self.max_tool_retries,
459        });
460        self
461    }
462
463    /// Add a tool from a sync function.
464    #[must_use]
465    pub fn tool_fn<F, Args>(
466        mut self,
467        name: impl Into<String>,
468        description: impl Into<String>,
469        f: F,
470    ) -> Self
471    where
472        F: Fn(&RunContext<Deps>, Args) -> Result<ToolReturn, ToolError> + Send + Sync + 'static,
473        Args: DeserializeOwned + Send + 'static,
474    {
475        let tool_name = name.into();
476        let definition = ToolDefinition::new(tool_name.clone(), description.into());
477
478        let executor = SyncFnExecutor {
479            func: Arc::new(move |ctx, args: JsonValue| {
480                let parsed: Args = serde_json::from_value(args)
481                    .map_err(|e| ToolError::invalid_arguments(tool_name.clone(), e.to_string()))?;
482                f(ctx, parsed)
483            }),
484            _phantom: PhantomData,
485        };
486
487        self.tools.push(RegisteredTool {
488            definition,
489            executor: Arc::new(executor),
490            max_retries: self.max_tool_retries,
491        });
492        self
493    }
494
495    /// Add a tool from an async function.
496    #[must_use]
497    pub fn tool_fn_async<F, Fut, Args>(
498        mut self,
499        name: impl Into<String>,
500        description: impl Into<String>,
501        f: F,
502    ) -> Self
503    where
504        F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync + 'static,
505        Fut: Future<Output = Result<ToolReturn, ToolError>> + Send + Sync + 'static,
506        Args: DeserializeOwned + Send + Sync + 'static,
507    {
508        let tool_name = name.into();
509        let definition = ToolDefinition::new(tool_name.clone(), description.into());
510
511        let executor = AsyncFnExecutor {
512            func: Arc::new(f),
513            tool_name,
514            _phantom: PhantomData,
515        };
516
517        self.tools.push(RegisteredTool {
518            definition,
519            executor: Arc::new(executor),
520            max_retries: self.max_tool_retries,
521        });
522        self
523    }
524
525    /// Set custom output schema.
526    #[must_use]
527    pub fn output_schema<S: OutputSchema<Output> + 'static>(mut self, schema: S) -> Self {
528        self.output_schema = Some(Box::new(schema));
529        self
530    }
531
532    /// Add output validator.
533    #[must_use]
534    pub fn output_validator<V: OutputValidator<Output, Deps> + 'static>(
535        mut self,
536        validator: V,
537    ) -> Self {
538        self.output_validators.push(Box::new(validator));
539        self
540    }
541
542    /// Add output validator from sync function.
543    #[must_use]
544    pub fn output_validator_fn<F>(mut self, f: F) -> Self
545    where
546        F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError>
547            + Send
548            + Sync
549            + 'static,
550    {
551        self.output_validators.push(Box::new(SyncValidator::new(f)));
552        self
553    }
554
555    /// Set end strategy.
556    #[must_use]
557    pub fn end_strategy(mut self, strategy: EndStrategy) -> Self {
558        self.end_strategy = strategy;
559        self
560    }
561
562    /// Set max output retries.
563    #[must_use]
564    pub fn max_output_retries(mut self, retries: u32) -> Self {
565        self.max_output_retries = retries;
566        self
567    }
568
569    /// Set max tool retries.
570    #[must_use]
571    pub fn max_tool_retries(mut self, retries: u32) -> Self {
572        self.max_tool_retries = retries;
573        self
574    }
575
576    /// Set usage limits.
577    #[must_use]
578    pub fn usage_limits(mut self, limits: UsageLimits) -> Self {
579        self.usage_limits = Some(limits);
580        self
581    }
582
583    /// Add history processor.
584    #[must_use]
585    pub fn history_processor<P: HistoryProcessor<Deps> + 'static>(mut self, processor: P) -> Self {
586        self.history_processors.push(Box::new(processor));
587        self
588    }
589
590    /// Enable instrumentation.
591    #[must_use]
592    pub fn instrument(mut self, settings: InstrumentationSettings) -> Self {
593        self.instrument = Some(settings);
594        self
595    }
596
597    /// Enable or disable parallel tool execution.
598    ///
599    /// When enabled (default), multiple tool calls from the model will be
600    /// executed concurrently using `futures::future::join_all`.
601    ///
602    /// When disabled, tools are executed sequentially in order.
603    #[must_use]
604    pub fn parallel_tool_calls(mut self, enabled: bool) -> Self {
605        self.parallel_tool_calls = enabled;
606        self
607    }
608
609    /// Set the maximum number of concurrent tool calls.
610    ///
611    /// When set, limits the number of tools that can execute simultaneously
612    /// using a semaphore. This is useful for rate-limiting or resource control.
613    ///
614    /// Only applies when `parallel_tool_calls` is enabled.
615    #[must_use]
616    pub fn max_concurrent_tools(mut self, max: usize) -> Self {
617        self.max_concurrent_tools = Some(max);
618        self
619    }
620
621    /// Build the agent.
622    pub fn build(self) -> Agent<Deps, Output>
623    where
624        Output: serde::de::DeserializeOwned,
625    {
626        let output_schema = self
627            .output_schema
628            .unwrap_or_else(|| Box::new(DefaultOutputSchema::<Output>::new()));
629
630        // Pre-join static system prompts and instructions at build time.
631        // This avoids cloning these strings on every run.
632        let static_system_prompt = {
633            let mut parts = Vec::new();
634
635            // Static system prompts first
636            for prompt in &self.system_prompts {
637                if !prompt.is_empty() {
638                    parts.push(prompt.as_str());
639                }
640            }
641
642            // Then static instructions
643            for instruction in &self.instructions {
644                if !instruction.is_empty() {
645                    parts.push(instruction.as_str());
646                }
647            }
648
649            Arc::from(parts.join("\n\n"))
650        };
651
652        // Pre-compute tool definitions at build time.
653        // This avoids cloning tool definitions on every agent step.
654        let cached_tool_defs = Arc::new(
655            self.tools
656                .iter()
657                .map(|t| t.definition.clone())
658                .collect::<Vec<_>>(),
659        );
660
661        Agent {
662            model: self.model,
663            name: self.name,
664            model_settings: self.model_settings,
665            static_system_prompt,
666            instruction_fns: self.instruction_fns,
667            system_prompt_fns: self.system_prompt_fns,
668            tools: self.tools,
669            cached_tool_defs,
670            output_schema,
671            output_validators: self.output_validators,
672            end_strategy: self.end_strategy,
673            max_output_retries: self.max_output_retries,
674            max_tool_retries: self.max_tool_retries,
675            usage_limits: self.usage_limits,
676            history_processors: self.history_processors,
677            instrument: self.instrument,
678            parallel_tool_calls: self.parallel_tool_calls,
679            max_concurrent_tools: self.max_concurrent_tools,
680            _phantom: PhantomData,
681        }
682    }
683}
684
685// Specialized builders for output types
686
687impl<Deps: Send + Sync + 'static> AgentBuilder<Deps, String> {
688    /// Change output type to a JSON-parsed type.
689    #[must_use]
690    pub fn output_type<T: DeserializeOwned + Send + Sync + 'static>(self) -> AgentBuilder<Deps, T> {
691        AgentBuilder {
692            model: self.model,
693            name: self.name,
694            model_settings: self.model_settings,
695            instructions: self.instructions,
696            instruction_fns: self.instruction_fns,
697            system_prompts: self.system_prompts,
698            system_prompt_fns: self.system_prompt_fns,
699            tools: self.tools,
700            output_schema: Some(Box::new(JsonOutputSchema::<T>::new())),
701            output_validators: Vec::new(),
702            end_strategy: self.end_strategy,
703            max_output_retries: self.max_output_retries,
704            max_tool_retries: self.max_tool_retries,
705            usage_limits: self.usage_limits,
706            history_processors: self.history_processors,
707            instrument: self.instrument,
708            parallel_tool_calls: self.parallel_tool_calls,
709            max_concurrent_tools: self.max_concurrent_tools,
710            _phantom: PhantomData,
711        }
712    }
713
714    /// Change output type with JSON schema.
715    #[must_use]
716    pub fn output_type_with_schema<T: DeserializeOwned + Send + Sync + 'static>(
717        self,
718        schema: JsonValue,
719    ) -> AgentBuilder<Deps, T> {
720        AgentBuilder {
721            model: self.model,
722            name: self.name,
723            model_settings: self.model_settings,
724            instructions: self.instructions,
725            instruction_fns: self.instruction_fns,
726            system_prompts: self.system_prompts,
727            system_prompt_fns: self.system_prompt_fns,
728            tools: self.tools,
729            output_schema: Some(Box::new(JsonOutputSchema::<T>::new().with_schema(schema))),
730            output_validators: Vec::new(),
731            end_strategy: self.end_strategy,
732            max_output_retries: self.max_output_retries,
733            max_tool_retries: self.max_tool_retries,
734            usage_limits: self.usage_limits,
735            history_processors: self.history_processors,
736            instrument: self.instrument,
737            parallel_tool_calls: self.parallel_tool_calls,
738            max_concurrent_tools: self.max_concurrent_tools,
739            _phantom: PhantomData,
740        }
741    }
742
743    /// Use tool-based output.
744    #[must_use]
745    pub fn output_tool<T: DeserializeOwned + Send + Sync + 'static>(
746        self,
747        tool_name: impl Into<String>,
748        schema: JsonValue,
749    ) -> AgentBuilder<Deps, T> {
750        AgentBuilder {
751            model: self.model,
752            name: self.name,
753            model_settings: self.model_settings,
754            instructions: self.instructions,
755            instruction_fns: self.instruction_fns,
756            system_prompts: self.system_prompts,
757            system_prompt_fns: self.system_prompt_fns,
758            tools: self.tools,
759            output_schema: Some(Box::new(
760                ToolOutputSchema::<T>::new(tool_name).with_schema(schema),
761            )),
762            output_validators: Vec::new(),
763            end_strategy: self.end_strategy,
764            max_output_retries: self.max_output_retries,
765            max_tool_retries: self.max_tool_retries,
766            usage_limits: self.usage_limits,
767            history_processors: self.history_processors,
768            instrument: self.instrument,
769            parallel_tool_calls: self.parallel_tool_calls,
770            max_concurrent_tools: self.max_concurrent_tools,
771            _phantom: PhantomData,
772        }
773    }
774}
775
776// ============================================================================
777// Tool Executors
778// ============================================================================
779
780/// Sync function executor.
781#[allow(clippy::type_complexity)]
782struct SyncFnExecutor<Deps> {
783    func: Arc<dyn Fn(&RunContext<Deps>, JsonValue) -> Result<ToolReturn, ToolError> + Send + Sync>,
784    _phantom: PhantomData<Deps>,
785}
786
787#[async_trait::async_trait]
788impl<Deps: Send + Sync> ToolExecutor<Deps> for SyncFnExecutor<Deps> {
789    async fn execute(
790        &self,
791        args: JsonValue,
792        ctx: &RunContext<Deps>,
793    ) -> Result<ToolReturn, ToolError> {
794        (self.func)(ctx, args)
795    }
796}
797
798/// Async function executor.
799struct AsyncFnExecutor<F, Deps, Args, Fut>
800where
801    F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync,
802    Fut: Future<Output = Result<ToolReturn, ToolError>> + Send,
803    Args: DeserializeOwned + Send,
804{
805    func: Arc<F>,
806    tool_name: String,
807    _phantom: PhantomData<(Deps, Args, Fut)>,
808}
809
810#[async_trait::async_trait]
811impl<F, Deps, Args, Fut> ToolExecutor<Deps> for AsyncFnExecutor<F, Deps, Args, Fut>
812where
813    F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync,
814    Fut: Future<Output = Result<ToolReturn, ToolError>> + Send + Sync,
815    Args: DeserializeOwned + Send + Sync,
816    Deps: Send + Sync,
817{
818    async fn execute(
819        &self,
820        args: JsonValue,
821        ctx: &RunContext<Deps>,
822    ) -> Result<ToolReturn, ToolError> {
823        let parsed: Args = serde_json::from_value(args)
824            .map_err(|e| ToolError::invalid_arguments(self.tool_name.clone(), e.to_string()))?;
825        (self.func)(ctx, parsed).await
826    }
827}
828
829/// Convenience function to create a builder.
830pub fn agent<M: Model + 'static>(model: M) -> AgentBuilder<(), String> {
831    AgentBuilder::new(model)
832}
833
834/// Convenience function to create a builder with dependencies.
835pub fn agent_with_deps<Deps: Send + Sync + 'static, M: Model + 'static>(
836    model: M,
837) -> AgentBuilder<Deps, String> {
838    AgentBuilder::new(model)
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844    use serdes_ai_models::MockModel;
845
846    fn create_mock_model() -> MockModel {
847        MockModel::new("test-model")
848    }
849
850    #[test]
851    fn test_builder_basic() {
852        let model = create_mock_model();
853        let agent = AgentBuilder::<(), String>::new(model)
854            .name("test-agent")
855            .temperature(0.7)
856            .build();
857
858        assert_eq!(agent.name(), Some("test-agent"));
859        assert_eq!(agent.model_settings().temperature, Some(0.7));
860    }
861
862    #[test]
863    fn test_builder_with_instructions() {
864        let model = create_mock_model();
865        let agent = AgentBuilder::<(), String>::new(model)
866            .system_prompt("You are helpful.")
867            .instructions("Be concise.")
868            .build();
869
870        // Static prompts are now pre-joined at build time
871        assert!(agent.static_system_prompt.contains("You are helpful."));
872        assert!(agent.static_system_prompt.contains("Be concise."));
873    }
874
875    #[test]
876    fn test_builder_with_tool() {
877        let model = create_mock_model();
878        let agent = AgentBuilder::<(), String>::new(model)
879            .tool_fn(
880                "greet",
881                "Greet someone",
882                |_ctx: &RunContext<()>, args: serde_json::Value| {
883                    let name = args["name"].as_str().unwrap_or("World");
884                    Ok(ToolReturn::text(format!("Hello, {}!", name)))
885                },
886            )
887            .build();
888
889        assert_eq!(agent.tools.len(), 1);
890        assert_eq!(agent.tools[0].definition.name, "greet");
891    }
892
893    #[test]
894    fn test_builder_usage_limits() {
895        let model = create_mock_model();
896        let agent = AgentBuilder::<(), String>::new(model)
897            .usage_limits(UsageLimits::new().total_tokens(1000).requests(10))
898            .build();
899
900        let limits = agent.usage_limits().unwrap();
901        assert_eq!(limits.max_total_tokens, Some(1000));
902        assert_eq!(limits.max_requests, Some(10));
903    }
904
905    #[test]
906    fn test_builder_end_strategy() {
907        let model = create_mock_model();
908        let agent = AgentBuilder::<(), String>::new(model)
909            .end_strategy(EndStrategy::Exhaustive)
910            .build();
911
912        assert_eq!(agent.end_strategy, EndStrategy::Exhaustive);
913    }
914
915    #[test]
916    fn test_agent_convenience() {
917        let model = create_mock_model();
918        let agent = agent(model).name("quick-agent").build();
919
920        assert_eq!(agent.name(), Some("quick-agent"));
921    }
922
923    #[test]
924    fn test_builder_parallel_tool_calls_default() {
925        let model = create_mock_model();
926        let agent = AgentBuilder::<(), String>::new(model).build();
927
928        // Default should be true (parallel enabled)
929        assert!(agent.parallel_tool_calls());
930        assert!(agent.max_concurrent_tools().is_none());
931    }
932
933    #[test]
934    fn test_builder_parallel_tool_calls_disabled() {
935        let model = create_mock_model();
936        let agent = AgentBuilder::<(), String>::new(model)
937            .parallel_tool_calls(false)
938            .build();
939
940        assert!(!agent.parallel_tool_calls());
941    }
942
943    #[test]
944    fn test_builder_max_concurrent_tools() {
945        let model = create_mock_model();
946        let agent = AgentBuilder::<(), String>::new(model)
947            .max_concurrent_tools(4)
948            .build();
949
950        assert!(agent.parallel_tool_calls());
951        assert_eq!(agent.max_concurrent_tools(), Some(4));
952    }
953
954    #[test]
955    fn test_builder_parallel_config_preserved_on_output_type() {
956        let model = create_mock_model();
957        let agent: Agent<(), serde_json::Value> = AgentBuilder::<(), String>::new(model)
958            .parallel_tool_calls(false)
959            .max_concurrent_tools(2)
960            .output_type()
961            .build();
962
963        // Config should be preserved when changing output type
964        assert!(!agent.parallel_tool_calls());
965        assert_eq!(agent.max_concurrent_tools(), Some(2));
966    }
967
968    #[test]
969    fn test_builder_from_arc() {
970        let model = create_mock_model();
971        let arc_model: Arc<dyn Model> = Arc::new(model);
972        let agent = AgentBuilder::<(), String>::from_arc(arc_model)
973            .name("arc-agent")
974            .build();
975
976        assert_eq!(agent.name(), Some("arc-agent"));
977    }
978
979    #[test]
980    fn test_model_config_basic() {
981        let config = ModelConfig::new("openai:gpt-4o");
982        assert_eq!(config.spec, "openai:gpt-4o");
983        assert!(config.api_key.is_none());
984        assert!(config.base_url.is_none());
985        assert!(config.timeout.is_none());
986    }
987
988    #[test]
989    fn test_model_config_with_options() {
990        let config = ModelConfig::new("anthropic:claude-3-5-sonnet-20241022")
991            .with_api_key("sk-test-key")
992            .with_base_url("https://custom.api.com")
993            .with_timeout(Duration::from_secs(60));
994
995        assert_eq!(config.spec, "anthropic:claude-3-5-sonnet-20241022");
996        assert_eq!(config.api_key, Some("sk-test-key".to_string()));
997        assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
998        assert_eq!(config.timeout, Some(Duration::from_secs(60)));
999    }
1000
1001    #[test]
1002    fn test_model_config_parse_spec_with_provider() {
1003        let config = ModelConfig::new("openai:gpt-4o");
1004        let (provider, model) = config.parse_spec();
1005        assert_eq!(provider, "openai");
1006        assert_eq!(model, "gpt-4o");
1007    }
1008
1009    #[test]
1010    fn test_model_config_parse_spec_without_provider() {
1011        let config = ModelConfig::new("gpt-4o");
1012        let (provider, model) = config.parse_spec();
1013        assert_eq!(provider, "openai");
1014        assert_eq!(model, "gpt-4o");
1015    }
1016
1017    #[test]
1018    fn test_model_config_parse_spec_anthropic() {
1019        let config = ModelConfig::new("anthropic:claude-3-5-sonnet-20241022");
1020        let (provider, model) = config.parse_spec();
1021        assert_eq!(provider, "anthropic");
1022        assert_eq!(model, "claude-3-5-sonnet-20241022");
1023    }
1024
1025    #[test]
1026    fn test_model_config_unknown_provider() {
1027        let config = ModelConfig::new("unknown:some-model");
1028        let result = config.build_model();
1029        assert!(result.is_err());
1030        // Can't use unwrap_err because Arc<dyn Model> doesn't impl Debug
1031        match result {
1032            Err(e) => {
1033                let msg = e.to_string();
1034                assert!(
1035                    msg.contains("Unknown") || msg.contains("unsupported"),
1036                    "Expected error about unknown provider, got: {}",
1037                    msg
1038                );
1039            }
1040            Ok(_) => panic!("Expected error for unknown provider"),
1041        }
1042    }
1043}