Skip to main content

zag_agent/
capability.rs

1use anyhow::{Result, bail};
2use serde::{Deserialize, Serialize};
3
4/// A feature that can be either natively supported by the provider or implemented by the wrapper.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct FeatureSupport {
7    pub supported: bool,
8    pub native: bool,
9}
10
11/// Session log support with completeness level.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SessionLogSupport {
14    pub supported: bool,
15    pub native: bool,
16    /// Completeness level: "full", "partial", or absent when unsupported.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub completeness: Option<String>,
19}
20
21/// Streaming input support with mid-turn injection semantics.
22///
23/// Describes what happens when `StreamingSession::send_user_message` is called
24/// while the agent is already producing a response on the current turn.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct StreamingInputSupport {
27    pub supported: bool,
28    pub native: bool,
29    /// Mid-turn semantics when `send_user_message` is called while the agent
30    /// is already producing a response. One of:
31    /// - `"queue"` — message is buffered and delivered at the next turn boundary
32    ///   (the current turn runs to completion before the new message is processed).
33    /// - `"interrupt"` — message cancels the current turn and starts a new one
34    ///   with the new input.
35    /// - `"between-turns-only"` — calling mid-turn is an error or no-op; callers
36    ///   must wait for the current turn to finish before sending.
37    ///
38    /// Absent when `supported == false`.
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub semantics: Option<String>,
41}
42
43/// Size alias mappings for a provider.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SizeMappings {
46    pub small: String,
47    pub medium: String,
48    pub large: String,
49}
50
51/// All feature flags for a provider.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Features {
54    pub interactive: FeatureSupport,
55    pub non_interactive: FeatureSupport,
56    pub resume: FeatureSupport,
57    pub resume_with_prompt: FeatureSupport,
58    pub session_logs: SessionLogSupport,
59    pub json_output: FeatureSupport,
60    pub stream_json: FeatureSupport,
61    pub json_schema: FeatureSupport,
62    pub input_format: FeatureSupport,
63    pub streaming_input: StreamingInputSupport,
64    pub worktree: FeatureSupport,
65    pub sandbox: FeatureSupport,
66    pub system_prompt: FeatureSupport,
67    pub auto_approve: FeatureSupport,
68    pub review: FeatureSupport,
69    pub add_dirs: FeatureSupport,
70    pub max_turns: FeatureSupport,
71}
72
73/// Full capability declaration for a provider.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ProviderCapability {
76    pub provider: String,
77    pub default_model: String,
78    pub available_models: Vec<String>,
79    pub size_mappings: SizeMappings,
80    pub features: Features,
81}
82
83impl FeatureSupport {
84    pub fn native() -> Self {
85        Self {
86            supported: true,
87            native: true,
88        }
89    }
90
91    pub fn wrapper() -> Self {
92        Self {
93            supported: true,
94            native: false,
95        }
96    }
97
98    pub fn unsupported() -> Self {
99        Self {
100            supported: false,
101            native: false,
102        }
103    }
104}
105
106impl SessionLogSupport {
107    pub fn full() -> Self {
108        Self {
109            supported: true,
110            native: true,
111            completeness: Some("full".to_string()),
112        }
113    }
114
115    pub fn partial() -> Self {
116        Self {
117            supported: true,
118            native: true,
119            completeness: Some("partial".to_string()),
120        }
121    }
122
123    pub fn unsupported() -> Self {
124        Self {
125            supported: false,
126            native: false,
127            completeness: None,
128        }
129    }
130}
131
132impl StreamingInputSupport {
133    /// Mid-turn messages are queued and delivered at the next turn boundary.
134    /// The currently running turn is not interrupted.
135    pub fn queue() -> Self {
136        Self {
137            supported: true,
138            native: true,
139            semantics: Some("queue".to_string()),
140        }
141    }
142
143    /// Mid-turn messages cancel the current turn and start a new one.
144    pub fn interrupt() -> Self {
145        Self {
146            supported: true,
147            native: true,
148            semantics: Some("interrupt".to_string()),
149        }
150    }
151
152    /// Messages may only be sent between turns; mid-turn sends are an error.
153    pub fn between_turns_only() -> Self {
154        Self {
155            supported: true,
156            native: true,
157            semantics: Some("between-turns-only".to_string()),
158        }
159    }
160
161    /// The provider does not support streaming input at all.
162    pub fn unsupported() -> Self {
163        Self {
164            supported: false,
165            native: false,
166            semantics: None,
167        }
168    }
169}
170
171/// Get capability declarations for a provider.
172pub fn get_capability(provider: &str) -> Result<ProviderCapability> {
173    use crate::agent::{Agent, ModelSize};
174
175    match provider {
176        "claude" => {
177            use crate::providers::claude::{self, Claude};
178            Ok(ProviderCapability {
179                provider: "claude".to_string(),
180                default_model: claude::DEFAULT_MODEL.to_string(),
181                available_models: models_to_vec(claude::AVAILABLE_MODELS),
182                size_mappings: SizeMappings {
183                    small: Claude::model_for_size(ModelSize::Small).to_string(),
184                    medium: Claude::model_for_size(ModelSize::Medium).to_string(),
185                    large: Claude::model_for_size(ModelSize::Large).to_string(),
186                },
187                features: Features {
188                    interactive: FeatureSupport::native(),
189                    non_interactive: FeatureSupport::native(),
190                    resume: FeatureSupport::native(),
191                    resume_with_prompt: FeatureSupport::native(),
192                    session_logs: SessionLogSupport::full(),
193                    json_output: FeatureSupport::native(),
194                    stream_json: FeatureSupport::native(),
195                    json_schema: FeatureSupport::native(),
196                    input_format: FeatureSupport::native(),
197                    streaming_input: StreamingInputSupport::queue(),
198                    worktree: FeatureSupport::wrapper(),
199                    sandbox: FeatureSupport::wrapper(),
200                    system_prompt: FeatureSupport::native(),
201                    auto_approve: FeatureSupport::native(),
202                    review: FeatureSupport::unsupported(),
203                    add_dirs: FeatureSupport::native(),
204                    max_turns: FeatureSupport::native(),
205                },
206            })
207        }
208        "codex" => {
209            use crate::providers::codex::{self, Codex};
210            Ok(ProviderCapability {
211                provider: "codex".to_string(),
212                default_model: codex::DEFAULT_MODEL.to_string(),
213                available_models: models_to_vec(codex::AVAILABLE_MODELS),
214                size_mappings: SizeMappings {
215                    small: Codex::model_for_size(ModelSize::Small).to_string(),
216                    medium: Codex::model_for_size(ModelSize::Medium).to_string(),
217                    large: Codex::model_for_size(ModelSize::Large).to_string(),
218                },
219                features: Features {
220                    interactive: FeatureSupport::native(),
221                    non_interactive: FeatureSupport::native(),
222                    resume: FeatureSupport::native(),
223                    resume_with_prompt: FeatureSupport::native(),
224                    session_logs: SessionLogSupport::partial(),
225                    json_output: FeatureSupport::native(),
226                    stream_json: FeatureSupport::unsupported(),
227                    json_schema: FeatureSupport::wrapper(),
228                    input_format: FeatureSupport::unsupported(),
229                    streaming_input: StreamingInputSupport::unsupported(),
230                    worktree: FeatureSupport::wrapper(),
231                    sandbox: FeatureSupport::wrapper(),
232                    system_prompt: FeatureSupport::wrapper(),
233                    auto_approve: FeatureSupport::native(),
234                    review: FeatureSupport::native(),
235                    add_dirs: FeatureSupport::native(),
236                    max_turns: FeatureSupport::native(),
237                },
238            })
239        }
240        "gemini" => {
241            use crate::providers::gemini::{self, Gemini};
242            Ok(ProviderCapability {
243                provider: "gemini".to_string(),
244                default_model: gemini::DEFAULT_MODEL.to_string(),
245                available_models: models_to_vec(gemini::AVAILABLE_MODELS),
246                size_mappings: SizeMappings {
247                    small: Gemini::model_for_size(ModelSize::Small).to_string(),
248                    medium: Gemini::model_for_size(ModelSize::Medium).to_string(),
249                    large: Gemini::model_for_size(ModelSize::Large).to_string(),
250                },
251                features: Features {
252                    interactive: FeatureSupport::native(),
253                    non_interactive: FeatureSupport::native(),
254                    resume: FeatureSupport::native(),
255                    resume_with_prompt: FeatureSupport::unsupported(),
256                    session_logs: SessionLogSupport::full(),
257                    json_output: FeatureSupport::wrapper(),
258                    stream_json: FeatureSupport::unsupported(),
259                    json_schema: FeatureSupport::wrapper(),
260                    input_format: FeatureSupport::unsupported(),
261                    streaming_input: StreamingInputSupport::unsupported(),
262                    worktree: FeatureSupport::wrapper(),
263                    sandbox: FeatureSupport::wrapper(),
264                    system_prompt: FeatureSupport::wrapper(),
265                    auto_approve: FeatureSupport::native(),
266                    review: FeatureSupport::unsupported(),
267                    add_dirs: FeatureSupport::native(),
268                    max_turns: FeatureSupport::native(),
269                },
270            })
271        }
272        "copilot" => {
273            use crate::providers::copilot::{self, Copilot};
274            Ok(ProviderCapability {
275                provider: "copilot".to_string(),
276                default_model: copilot::DEFAULT_MODEL.to_string(),
277                available_models: models_to_vec(copilot::AVAILABLE_MODELS),
278                size_mappings: SizeMappings {
279                    small: Copilot::model_for_size(ModelSize::Small).to_string(),
280                    medium: Copilot::model_for_size(ModelSize::Medium).to_string(),
281                    large: Copilot::model_for_size(ModelSize::Large).to_string(),
282                },
283                features: Features {
284                    interactive: FeatureSupport::native(),
285                    non_interactive: FeatureSupport::native(),
286                    resume: FeatureSupport::native(),
287                    resume_with_prompt: FeatureSupport::unsupported(),
288                    session_logs: SessionLogSupport::full(),
289                    json_output: FeatureSupport::unsupported(),
290                    stream_json: FeatureSupport::unsupported(),
291                    json_schema: FeatureSupport::unsupported(),
292                    input_format: FeatureSupport::unsupported(),
293                    streaming_input: StreamingInputSupport::unsupported(),
294                    worktree: FeatureSupport::wrapper(),
295                    sandbox: FeatureSupport::wrapper(),
296                    system_prompt: FeatureSupport::wrapper(),
297                    auto_approve: FeatureSupport::native(),
298                    review: FeatureSupport::unsupported(),
299                    add_dirs: FeatureSupport::native(),
300                    max_turns: FeatureSupport::native(),
301                },
302            })
303        }
304        "ollama" => {
305            use crate::providers::ollama;
306            Ok(ProviderCapability {
307                provider: "ollama".to_string(),
308                default_model: ollama::DEFAULT_MODEL.to_string(),
309                available_models: models_to_vec(ollama::AVAILABLE_SIZES),
310                size_mappings: SizeMappings {
311                    small: "2b".to_string(),
312                    medium: "9b".to_string(),
313                    large: "35b".to_string(),
314                },
315                features: Features {
316                    interactive: FeatureSupport::native(),
317                    non_interactive: FeatureSupport::native(),
318                    resume: FeatureSupport::unsupported(),
319                    resume_with_prompt: FeatureSupport::unsupported(),
320                    session_logs: SessionLogSupport::unsupported(),
321                    json_output: FeatureSupport::wrapper(),
322                    stream_json: FeatureSupport::unsupported(),
323                    json_schema: FeatureSupport::wrapper(),
324                    input_format: FeatureSupport::unsupported(),
325                    streaming_input: StreamingInputSupport::unsupported(),
326                    worktree: FeatureSupport::wrapper(),
327                    sandbox: FeatureSupport::wrapper(),
328                    system_prompt: FeatureSupport::wrapper(),
329                    auto_approve: FeatureSupport::native(),
330                    review: FeatureSupport::unsupported(),
331                    add_dirs: FeatureSupport::unsupported(),
332                    max_turns: FeatureSupport::unsupported(),
333                },
334            })
335        }
336        _ => bail!(
337            "No capabilities defined for provider '{}'. Available: claude, codex, gemini, copilot, ollama",
338            provider
339        ),
340    }
341}
342
343/// Format a capability struct into the requested output format.
344pub fn format_capability(cap: &ProviderCapability, format: &str, pretty: bool) -> Result<String> {
345    match format {
346        "json" => {
347            if pretty {
348                Ok(serde_json::to_string_pretty(cap)?)
349            } else {
350                Ok(serde_json::to_string(cap)?)
351            }
352        }
353        "yaml" => Ok(serde_yaml::to_string(cap)?),
354        "toml" => Ok(toml::to_string_pretty(cap)?),
355        _ => bail!(
356            "Unsupported format '{}'. Available: json, yaml, toml",
357            format
358        ),
359    }
360}
361
362/// Canonical list of provider names (excludes "auto" and "mock").
363pub const PROVIDERS: &[&str] = &["claude", "codex", "gemini", "copilot", "ollama"];
364
365/// List all available provider names.
366pub fn list_providers() -> Vec<String> {
367    PROVIDERS.iter().map(|s| s.to_string()).collect()
368}
369
370/// Get capabilities for all providers.
371pub fn get_all_capabilities() -> Vec<ProviderCapability> {
372    PROVIDERS
373        .iter()
374        .filter_map(|p| get_capability(p).ok())
375        .collect()
376}
377
378/// Result of resolving a model alias.
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ResolvedModel {
381    pub input: String,
382    pub resolved: String,
383    pub is_alias: bool,
384    pub provider: String,
385}
386
387/// Resolve a model name or alias for a given provider.
388///
389/// Size aliases (`small`/`s`, `medium`/`m`/`default`, `large`/`l`/`max`) are
390/// resolved to the provider-specific model. Non-alias names pass through unchanged.
391pub fn resolve_model(provider: &str, model_input: &str) -> Result<ResolvedModel> {
392    use crate::agent::Agent;
393    use crate::providers::{
394        claude::Claude, codex::Codex, copilot::Copilot, gemini::Gemini, ollama::Ollama,
395    };
396
397    let resolved = match provider {
398        "claude" => Claude::resolve_model(model_input),
399        "codex" => Codex::resolve_model(model_input),
400        "gemini" => Gemini::resolve_model(model_input),
401        "copilot" => Copilot::resolve_model(model_input),
402        "ollama" => Ollama::resolve_model(model_input),
403        _ => bail!(
404            "Unknown provider '{}'. Available: {}",
405            provider,
406            PROVIDERS.join(", ")
407        ),
408    };
409
410    Ok(ResolvedModel {
411        input: model_input.to_string(),
412        is_alias: resolved != model_input,
413        resolved,
414        provider: provider.to_string(),
415    })
416}
417
418/// Format a resolved model into the requested output format.
419pub fn format_resolved_model(rm: &ResolvedModel, format: &str, pretty: bool) -> Result<String> {
420    match format {
421        "json" => {
422            if pretty {
423                Ok(serde_json::to_string_pretty(rm)?)
424            } else {
425                Ok(serde_json::to_string(rm)?)
426            }
427        }
428        "yaml" => Ok(serde_yaml::to_string(rm)?),
429        "toml" => Ok(toml::to_string_pretty(rm)?),
430        _ => bail!(
431            "Unsupported format '{}'. Available: json, yaml, toml",
432            format
433        ),
434    }
435}
436
437/// Format a list of capabilities into the requested output format.
438pub fn format_capabilities(
439    caps: &[ProviderCapability],
440    format: &str,
441    pretty: bool,
442) -> Result<String> {
443    match format {
444        "json" => {
445            if pretty {
446                Ok(serde_json::to_string_pretty(caps)?)
447            } else {
448                Ok(serde_json::to_string(caps)?)
449            }
450        }
451        "yaml" => Ok(serde_yaml::to_string(caps)?),
452        "toml" => {
453            #[derive(Serialize)]
454            struct Wrapper<'a> {
455                providers: &'a [ProviderCapability],
456            }
457            Ok(toml::to_string_pretty(&Wrapper { providers: caps })?)
458        }
459        _ => bail!(
460            "Unsupported format '{}'. Available: json, yaml, toml",
461            format
462        ),
463    }
464}
465
466/// Format a models listing into the requested output format.
467pub fn format_models(caps: &[ProviderCapability], format: &str, pretty: bool) -> Result<String> {
468    #[derive(Serialize)]
469    struct ModelEntry {
470        provider: String,
471        default_model: String,
472        models: Vec<String>,
473    }
474
475    let entries: Vec<ModelEntry> = caps
476        .iter()
477        .map(|c| ModelEntry {
478            provider: c.provider.clone(),
479            default_model: c.default_model.clone(),
480            models: c.available_models.clone(),
481        })
482        .collect();
483
484    match format {
485        "json" => {
486            if pretty {
487                Ok(serde_json::to_string_pretty(&entries)?)
488            } else {
489                Ok(serde_json::to_string(&entries)?)
490            }
491        }
492        "yaml" => Ok(serde_yaml::to_string(&entries)?),
493        "toml" => bail!("TOML does not support top-level arrays. Use json or yaml"),
494        _ => bail!(
495            "Unsupported format '{}'. Available: json, yaml, toml",
496            format
497        ),
498    }
499}
500
501/// Convert a slice of string references into a Vec of owned Strings.
502pub fn models_to_vec(models: &[&str]) -> Vec<String> {
503    models.iter().map(|s| s.to_string()).collect()
504}
505
506#[cfg(test)]
507#[path = "capability_tests.rs"]
508mod tests;