Skip to main content

zag_agent/
builder.rs

1//! High-level builder API for driving agents programmatically.
2//!
3//! Instead of shelling out to the `agent` CLI binary, Rust programs can
4//! use `AgentBuilder` to configure and execute agent sessions directly.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use zag_agent::builder::AgentBuilder;
10//!
11//! # async fn example() -> anyhow::Result<()> {
12//! // Non-interactive exec — returns structured output
13//! let output = AgentBuilder::new()
14//!     .provider("claude")
15//!     .model("sonnet")
16//!     .auto_approve(true)
17//!     .exec("write a hello world program")
18//!     .await?;
19//!
20//! println!("{}", output.result.unwrap_or_default());
21//!
22//! // Interactive session
23//! AgentBuilder::new()
24//!     .provider("claude")
25//!     .run(Some("initial prompt"))
26//!     .await?;
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44
45/// Builder for configuring and running agent sessions.
46///
47/// Use the builder pattern to set options, then call a terminal method
48/// (`exec`, `run`, `resume`, `continue_last`) to execute.
49pub struct AgentBuilder {
50    provider: Option<String>,
51    model: Option<String>,
52    system_prompt: Option<String>,
53    root: Option<String>,
54    auto_approve: bool,
55    add_dirs: Vec<String>,
56    worktree: Option<Option<String>>,
57    sandbox: Option<Option<String>>,
58    size: Option<String>,
59    json_mode: bool,
60    json_schema: Option<serde_json::Value>,
61    json_stream: bool,
62    session_id: Option<String>,
63    output_format: Option<String>,
64    input_format: Option<String>,
65    replay_user_messages: bool,
66    include_partial_messages: bool,
67    verbose: bool,
68    quiet: bool,
69    show_usage: bool,
70    max_turns: Option<u32>,
71    progress: Box<dyn ProgressHandler>,
72}
73
74impl Default for AgentBuilder {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl AgentBuilder {
81    /// Create a new builder with default settings.
82    pub fn new() -> Self {
83        Self {
84            provider: None,
85            model: None,
86            system_prompt: None,
87            root: None,
88            auto_approve: false,
89            add_dirs: Vec::new(),
90            worktree: None,
91            sandbox: None,
92            size: None,
93            json_mode: false,
94            json_schema: None,
95            json_stream: false,
96            session_id: None,
97            output_format: None,
98            input_format: None,
99            replay_user_messages: false,
100            include_partial_messages: false,
101            verbose: false,
102            quiet: false,
103            show_usage: false,
104            max_turns: None,
105            progress: Box::new(SilentProgress),
106        }
107    }
108
109    /// Set the provider (e.g., "claude", "codex", "gemini", "copilot", "ollama").
110    pub fn provider(mut self, provider: &str) -> Self {
111        self.provider = Some(provider.to_string());
112        self
113    }
114
115    /// Set the model (e.g., "sonnet", "opus", "small", "large").
116    pub fn model(mut self, model: &str) -> Self {
117        self.model = Some(model.to_string());
118        self
119    }
120
121    /// Set a system prompt to configure agent behavior.
122    pub fn system_prompt(mut self, prompt: &str) -> Self {
123        self.system_prompt = Some(prompt.to_string());
124        self
125    }
126
127    /// Set the root directory for the agent to operate in.
128    pub fn root(mut self, root: &str) -> Self {
129        self.root = Some(root.to_string());
130        self
131    }
132
133    /// Enable auto-approve mode (skip permission prompts).
134    pub fn auto_approve(mut self, approve: bool) -> Self {
135        self.auto_approve = approve;
136        self
137    }
138
139    /// Add an additional directory for the agent to include.
140    pub fn add_dir(mut self, dir: &str) -> Self {
141        self.add_dirs.push(dir.to_string());
142        self
143    }
144
145    /// Enable worktree mode with an optional name.
146    pub fn worktree(mut self, name: Option<&str>) -> Self {
147        self.worktree = Some(name.map(String::from));
148        self
149    }
150
151    /// Enable sandbox mode with an optional name.
152    pub fn sandbox(mut self, name: Option<&str>) -> Self {
153        self.sandbox = Some(name.map(String::from));
154        self
155    }
156
157    /// Set the Ollama parameter size (e.g., "2b", "9b", "35b").
158    pub fn size(mut self, size: &str) -> Self {
159        self.size = Some(size.to_string());
160        self
161    }
162
163    /// Request JSON output from the agent.
164    pub fn json(mut self) -> Self {
165        self.json_mode = true;
166        self
167    }
168
169    /// Set a JSON schema for structured output validation.
170    /// Implies `json()`.
171    pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
172        self.json_schema = Some(schema);
173        self.json_mode = true;
174        self
175    }
176
177    /// Enable streaming JSON output (NDJSON format).
178    pub fn json_stream(mut self) -> Self {
179        self.json_stream = true;
180        self
181    }
182
183    /// Set a specific session ID (UUID).
184    pub fn session_id(mut self, id: &str) -> Self {
185        self.session_id = Some(id.to_string());
186        self
187    }
188
189    /// Set the output format (e.g., "text", "json", "json-pretty", "stream-json").
190    pub fn output_format(mut self, format: &str) -> Self {
191        self.output_format = Some(format.to_string());
192        self
193    }
194
195    /// Set the input format (Claude only, e.g., "text", "stream-json").
196    pub fn input_format(mut self, format: &str) -> Self {
197        self.input_format = Some(format.to_string());
198        self
199    }
200
201    /// Re-emit user messages from stdin on stdout (Claude only).
202    ///
203    /// Only works with `--input-format stream-json` and `--output-format stream-json`.
204    pub fn replay_user_messages(mut self, replay: bool) -> Self {
205        self.replay_user_messages = replay;
206        self
207    }
208
209    /// Include partial message chunks in streaming output (Claude only).
210    ///
211    /// Only works with `--output-format stream-json`.
212    pub fn include_partial_messages(mut self, include: bool) -> Self {
213        self.include_partial_messages = include;
214        self
215    }
216
217    /// Enable verbose output.
218    pub fn verbose(mut self, v: bool) -> Self {
219        self.verbose = v;
220        self
221    }
222
223    /// Enable quiet mode (suppress all non-essential output).
224    pub fn quiet(mut self, q: bool) -> Self {
225        self.quiet = q;
226        self
227    }
228
229    /// Show token usage statistics.
230    pub fn show_usage(mut self, show: bool) -> Self {
231        self.show_usage = show;
232        self
233    }
234
235    /// Set the maximum number of agentic turns.
236    pub fn max_turns(mut self, turns: u32) -> Self {
237        self.max_turns = Some(turns);
238        self
239    }
240
241    /// Set a custom progress handler for status reporting.
242    pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
243        self.progress = handler;
244        self
245    }
246
247    /// Resolve the effective provider name.
248    fn resolve_provider(&self) -> Result<String> {
249        if let Some(ref p) = self.provider {
250            let p = p.to_lowercase();
251            if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
252                bail!(
253                    "Invalid provider '{}'. Available: {}",
254                    p,
255                    Config::VALID_PROVIDERS.join(", ")
256                );
257            }
258            return Ok(p);
259        }
260        let config = Config::load(self.root.as_deref()).unwrap_or_default();
261        if let Some(p) = config.provider() {
262            return Ok(p.to_string());
263        }
264        Ok("claude".to_string())
265    }
266
267    /// Create and configure the agent.
268    fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
269        // Apply system_prompt config fallback
270        let base_system_prompt = self.system_prompt.clone().or_else(|| {
271            Config::load(self.root.as_deref())
272                .unwrap_or_default()
273                .system_prompt()
274                .map(String::from)
275        });
276
277        // Augment system prompt with JSON instructions for non-Claude agents
278        let system_prompt = if self.json_mode && provider != "claude" {
279            let mut prompt = base_system_prompt.unwrap_or_default();
280            if let Some(ref schema) = self.json_schema {
281                let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
282                prompt.push_str(&format!(
283                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
284                     Your response must conform to this JSON schema:\n{}",
285                    schema_str
286                ));
287            } else {
288                prompt.push_str(
289                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
290                );
291            }
292            Some(prompt)
293        } else {
294            base_system_prompt
295        };
296
297        self.progress
298            .on_spinner_start(&format!("Initializing {} agent", provider));
299
300        let mut agent = AgentFactory::create(
301            provider,
302            system_prompt,
303            self.model.clone(),
304            self.root.clone(),
305            self.auto_approve,
306            self.add_dirs.clone(),
307        )?;
308
309        // Apply max_turns: explicit > config > none
310        let effective_max_turns = self.max_turns.or_else(|| {
311            Config::load(self.root.as_deref())
312                .unwrap_or_default()
313                .max_turns()
314        });
315        if let Some(turns) = effective_max_turns {
316            agent.set_max_turns(turns);
317        }
318
319        // Set output format
320        let mut output_format = self.output_format.clone();
321        if self.json_mode && output_format.is_none() {
322            output_format = Some("json".to_string());
323            if provider != "claude" {
324                agent.set_capture_output(true);
325            }
326        }
327        if self.json_stream && output_format.is_none() {
328            output_format = Some("stream-json".to_string());
329        }
330        agent.set_output_format(output_format);
331
332        // Configure Claude-specific options
333        if provider == "claude"
334            && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
335        {
336            claude_agent.set_verbose(self.verbose);
337            if let Some(ref session_id) = self.session_id {
338                claude_agent.set_session_id(session_id.clone());
339            }
340            if let Some(ref input_fmt) = self.input_format {
341                claude_agent.set_input_format(Some(input_fmt.clone()));
342            }
343            if self.replay_user_messages {
344                claude_agent.set_replay_user_messages(true);
345            }
346            if self.include_partial_messages {
347                claude_agent.set_include_partial_messages(true);
348            }
349            if self.json_mode
350                && let Some(ref schema) = self.json_schema
351            {
352                let schema_str = serde_json::to_string(schema).unwrap_or_default();
353                claude_agent.set_json_schema(Some(schema_str));
354            }
355        }
356
357        // Configure Ollama-specific options
358        if provider == "ollama"
359            && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
360        {
361            let config = Config::load(self.root.as_deref()).unwrap_or_default();
362            if let Some(ref size) = self.size {
363                let resolved = config.ollama_size_for(size);
364                ollama_agent.set_size(resolved.to_string());
365            }
366        }
367
368        // Configure sandbox
369        if let Some(ref sandbox_opt) = self.sandbox {
370            let sandbox_name = sandbox_opt
371                .as_deref()
372                .map(String::from)
373                .unwrap_or_else(crate::sandbox::generate_name);
374            let template = crate::sandbox::template_for_provider(provider);
375            let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
376            agent.set_sandbox(SandboxConfig {
377                name: sandbox_name,
378                template: template.to_string(),
379                workspace,
380            });
381        }
382
383        self.progress.on_spinner_finish();
384        self.progress.on_success(&format!(
385            "{} initialized with model {}",
386            provider,
387            agent.get_model()
388        ));
389
390        Ok(agent)
391    }
392
393    /// Run the agent non-interactively and return structured output.
394    ///
395    /// This is the primary entry point for programmatic use.
396    pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
397        let provider = self.resolve_provider()?;
398        debug!("exec: provider={}", provider);
399
400        // Set up worktree if requested
401        let effective_root = if let Some(ref wt_opt) = self.worktree {
402            let wt_name = wt_opt
403                .as_deref()
404                .map(String::from)
405                .unwrap_or_else(worktree::generate_name);
406            let repo_root = worktree::git_repo_root(self.root.as_deref())?;
407            let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
408            self.progress
409                .on_success(&format!("Worktree created at {}", wt_path.display()));
410            Some(wt_path.to_string_lossy().to_string())
411        } else {
412            self.root.clone()
413        };
414
415        let mut builder = self;
416        if effective_root.is_some() {
417            builder.root = effective_root;
418        }
419
420        let agent = builder.create_agent(&provider)?;
421
422        // Handle JSON mode with prompt wrapping for non-Claude agents
423        let effective_prompt = if builder.json_mode && provider != "claude" {
424            let wrapped = format!(
425                "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
426                prompt
427            );
428            wrapped
429        } else {
430            prompt.to_string()
431        };
432
433        let result = agent.run(Some(&effective_prompt)).await?;
434
435        // Clean up
436        agent.cleanup().await?;
437
438        if let Some(output) = result {
439            // Validate JSON output if schema is provided
440            if let Some(ref schema) = builder.json_schema {
441                if !builder.json_mode {
442                    warn!(
443                        "json_schema is set but json_mode is false — \
444                         schema will not be sent to the agent, only used for output validation"
445                    );
446                }
447                if let Some(ref result_text) = output.result {
448                    debug!(
449                        "exec: validating result ({} bytes): {:.300}",
450                        result_text.len(),
451                        result_text
452                    );
453                    if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
454                    {
455                        let preview = if result_text.len() > 500 {
456                            &result_text[..500]
457                        } else {
458                            result_text.as_str()
459                        };
460                        bail!(
461                            "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
462                            errors.join("; "),
463                            result_text.len(),
464                            preview
465                        );
466                    }
467                }
468            }
469            Ok(output)
470        } else {
471            // Agent returned no structured output — create a minimal one
472            Ok(AgentOutput::from_text(&provider, ""))
473        }
474    }
475
476    /// Run the agent with streaming input and output (Claude only).
477    ///
478    /// Returns a `StreamingSession` that allows sending NDJSON messages to
479    /// the agent's stdin and reading events from stdout. Automatically
480    /// configures `--input-format stream-json` and `--replay-user-messages`.
481    ///
482    /// # Examples
483    ///
484    /// ```no_run
485    /// use zag_agent::builder::AgentBuilder;
486    ///
487    /// # async fn example() -> anyhow::Result<()> {
488    /// let mut session = AgentBuilder::new()
489    ///     .provider("claude")
490    ///     .exec_streaming("initial prompt")
491    ///     .await?;
492    ///
493    /// session.send_user_message("do something").await?;
494    ///
495    /// while let Some(event) = session.next_event().await? {
496    ///     println!("{:?}", event);
497    /// }
498    ///
499    /// session.wait().await?;
500    /// # Ok(())
501    /// # }
502    /// ```
503    pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
504        let provider = self.resolve_provider()?;
505        debug!("exec_streaming: provider={}", provider);
506
507        if provider != "claude" {
508            bail!("Streaming input is only supported by the Claude provider");
509        }
510
511        let agent = self.create_agent(&provider)?;
512
513        // Downcast to Claude to call execute_streaming
514        let claude_agent = agent
515            .as_any_ref()
516            .downcast_ref::<Claude>()
517            .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
518
519        claude_agent.execute_streaming(Some(prompt))
520    }
521
522    /// Start an interactive agent session.
523    ///
524    /// This takes over stdin/stdout for the duration of the session.
525    pub async fn run(self, prompt: Option<&str>) -> Result<()> {
526        let provider = self.resolve_provider()?;
527        debug!("run: provider={}", provider);
528
529        let agent = self.create_agent(&provider)?;
530        agent.run_interactive(prompt).await?;
531        agent.cleanup().await?;
532        Ok(())
533    }
534
535    /// Resume a previous session by ID.
536    pub async fn resume(self, session_id: &str) -> Result<()> {
537        let provider = self.resolve_provider()?;
538        debug!("resume: provider={}, session={}", provider, session_id);
539
540        let agent = self.create_agent(&provider)?;
541        agent.run_resume(Some(session_id), false).await?;
542        agent.cleanup().await?;
543        Ok(())
544    }
545
546    /// Resume the most recent session.
547    pub async fn continue_last(self) -> Result<()> {
548        let provider = self.resolve_provider()?;
549        debug!("continue_last: provider={}", provider);
550
551        let agent = self.create_agent(&provider)?;
552        agent.run_resume(None, true).await?;
553        agent.cleanup().await?;
554        Ok(())
555    }
556}
557
558#[cfg(test)]
559#[path = "builder_tests.rs"]
560mod tests;