Skip to main content

defect_agent/llm/
registry.rs

1//! `ProviderRegistry`: catalog of configured providers + their model candidates.
2//!
3//! Exposes the `(provider, model)` candidate list to the ACP layer, and resolves which
4//! real provider should handle the current turn based on the `(vendor, model)` pair. The
5//! registry itself does **not** implement [`LlmProvider`] — it is a read-only directory
6//! assembled at configuration time. The session calls `set_model` / `run_turn` to look up
7//! the corresponding real provider using this pair.
8//!
9//! Design notes:
10//! - Each [`ProviderEntry`] carries an explicit `Vec<ModelInfo>`: during CLI assembly,
11//!   `providers.<p>.default_model` and `providers.<p>.models` are flattened into a model
12//!   table, so that ACP `list_models` does not require a network call to the adapter's
13//!   own `list_models`.
14//! - The selection key is the `(vendor, model id)` pair: the same model id may be
15//!   declared by multiple providers with different vendors (multi-gateway, same model).
16//!   ACP `set_model` switches on this pair.
17//! - Each entry also carries [`SessionCapabilitiesConfig`] — when switching models across
18//!   providers, the session must re-resolve hosted capabilities.
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use super::model::{ModelInfo, ProviderInfo};
24use super::provider::LlmProvider;
25use crate::session::SessionCapabilitiesConfig;
26
27/// A provider, the model IDs it exposes, and its session capability configuration.
28#[derive(Clone)]
29pub struct ProviderEntry {
30    provider: Arc<dyn LlmProvider>,
31    models: Vec<ModelInfo>,
32    capabilities: SessionCapabilitiesConfig,
33}
34
35impl std::fmt::Debug for ProviderEntry {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("ProviderEntry")
38            .field("provider", &self.provider.info())
39            .field("models", &self.models)
40            .field("capabilities", &self.capabilities)
41            .finish()
42    }
43}
44
45impl ProviderEntry {
46    #[must_use]
47    pub fn new(
48        provider: Arc<dyn LlmProvider>,
49        models: Vec<ModelInfo>,
50        capabilities: SessionCapabilitiesConfig,
51    ) -> Self {
52        Self {
53            provider,
54            models,
55            capabilities,
56        }
57    }
58
59    #[must_use]
60    pub fn provider(&self) -> &Arc<dyn LlmProvider> {
61        &self.provider
62    }
63
64    #[must_use]
65    pub fn models(&self) -> &[ModelInfo] {
66        &self.models
67    }
68
69    #[must_use]
70    pub fn capabilities(&self) -> SessionCapabilitiesConfig {
71        self.capabilities
72    }
73}
74
75#[derive(Debug, thiserror::Error)]
76pub enum ProviderRegistryError {
77    #[error("provider registry requires at least one entry")]
78    Empty,
79    #[error(
80        "duplicate model id `{model}` declared twice by provider `{provider}`; \
81         the same (provider, model) pair must be unique within a build"
82    )]
83    DuplicateSelection { provider: String, model: String },
84    #[error(
85        "default model `{model}` is not declared by provider `{provider}`; \
86         add it under that provider, or point `default.provider` at the one that has it"
87    )]
88    UnknownDefaultModel { provider: String, model: String },
89}
90
91/// A "provider directory" that is materialized at assembly time. The session holds an
92/// `Arc<ProviderRegistry>`.
93#[derive(Debug)]
94pub struct ProviderRegistry {
95    entries: Vec<ProviderEntry>,
96    /// (vendor, model id) → entries index. Multiple providers (with different vendors)
97    /// may declare the same model id — the lookup key is the pair (vendor, model), not
98    /// the bare model id.
99    model_index: HashMap<(String, String), usize>,
100    /// Index into `entries` for the default (provider, model), plus index into that
101    /// entry's `models`.
102    default: (usize, usize),
103}
104
105impl ProviderRegistry {
106    /// A convenience constructor for a single provider with a single model.
107    /// Used by tests, `EchoProvider`, and the `provider()` builder entry point.
108    /// This is the minimal form that satisfies the invariants checked by
109    /// `ProviderRegistry::new` (non-empty + `default_model` must belong to an entry).
110    #[must_use]
111    pub fn single(provider: Arc<dyn LlmProvider>, default_model: ModelInfo) -> Arc<Self> {
112        let vendor = provider.info().vendor;
113        let model_id = default_model.id.clone();
114        let entries = vec![ProviderEntry::new(
115            provider,
116            vec![default_model],
117            SessionCapabilitiesConfig::default(),
118        )];
119        Arc::new(
120            Self::new(entries, &vendor, &model_id)
121                .expect("single-entry registry with matching default model is always valid"),
122        )
123    }
124
125    /// Constructs a registry from a list of entries and a default `(provider vendor,
126    /// model id)` pair. The pair must appear in the `(vendor, models)` list of some
127    /// entry.
128    ///
129    /// The same model id may be declared by multiple entries with different vendors
130    /// (multiple gateways sharing a model) — the selection key is `(vendor, model)`. Only
131    /// a duplicate `(vendor, model)` pair is a configuration error.
132    ///
133    /// # Errors
134    ///
135    /// - [`ProviderRegistryError::Empty`]: entries is empty
136    /// - [`ProviderRegistryError::DuplicateSelection`]: the same `(vendor, model)` pair
137    ///   appears twice
138    /// - [`ProviderRegistryError::UnknownDefaultModel`]: the default `(vendor, model)`
139    ///   pair is not present in any entry
140    pub fn new(
141        entries: Vec<ProviderEntry>,
142        default_provider: &str,
143        default_model: &str,
144    ) -> Result<Self, ProviderRegistryError> {
145        if entries.is_empty() {
146            return Err(ProviderRegistryError::Empty);
147        }
148
149        let mut model_index = HashMap::new();
150        let mut default_pos = None;
151        for (entry_idx, entry) in entries.iter().enumerate() {
152            let provider_vendor = entry.provider.info().vendor;
153            let mut seen_in_entry = HashSet::new();
154            for (model_idx, model) in entry.models.iter().enumerate() {
155                if !seen_in_entry.insert(model.id.clone()) {
156                    continue;
157                }
158                let key = (provider_vendor.clone(), model.id.clone());
159                if model_index.insert(key, entry_idx).is_some() {
160                    return Err(ProviderRegistryError::DuplicateSelection {
161                        provider: provider_vendor,
162                        model: model.id.clone(),
163                    });
164                }
165                if provider_vendor == default_provider
166                    && model.id == default_model
167                    && default_pos.is_none()
168                {
169                    default_pos = Some((entry_idx, model_idx));
170                }
171            }
172        }
173
174        let default = default_pos.ok_or_else(|| ProviderRegistryError::UnknownDefaultModel {
175            provider: default_provider.to_string(),
176            model: default_model.to_string(),
177        })?;
178
179        Ok(Self {
180            entries,
181            model_index,
182            default,
183        })
184    }
185
186    /// The default entry used to initialize the current provider/model when a session
187    /// starts.
188    #[must_use]
189    pub fn default_entry(&self) -> &ProviderEntry {
190        let (entry_idx, _) = self.default;
191        self.entries
192            .get(entry_idx)
193            .expect("default index validated in `new`")
194    }
195
196    /// The default model ID.
197    #[must_use]
198    pub fn default_model(&self) -> &str {
199        let (entry_idx, model_idx) = self.default;
200        let entry = self
201            .entries
202            .get(entry_idx)
203            .expect("default index validated in `new`");
204        entry
205            .models
206            .get(model_idx)
207            .map(|m| m.id.as_str())
208            .expect("default model index validated in `new`")
209    }
210
211    /// Look up the entry for a given `(vendor, model id)` pair. Returns `None` if the
212    /// registry does not declare this pair.
213    #[must_use]
214    pub fn entry_for(&self, vendor: &str, model_id: &str) -> Option<&ProviderEntry> {
215        self.model_index
216            .get(&(vendor.to_string(), model_id.to_string()))
217            .and_then(|idx| self.entries.get(*idx))
218    }
219
220    /// Look up the first entry that declares the given bare model ID (in assembly order).
221    /// Used by legacy paths that lack a vendor dimension, such as the `model` field in
222    /// prompt hooks — when there are multiple matches, the first one is returned.
223    #[must_use]
224    pub fn first_entry_for_model(&self, model_id: &str) -> Option<&ProviderEntry> {
225        self.entries
226            .iter()
227            .find(|entry| entry.models.iter().any(|m| m.id == model_id))
228    }
229
230    /// Returns all entries in assembly order.
231    #[must_use]
232    pub fn entries(&self) -> &[ProviderEntry] {
233        &self.entries
234    }
235
236    /// Flatten all (provider_info, model) pairs. ACP `list_models` uses this to build
237    /// `SessionModelState::available_models`.
238    #[must_use]
239    pub fn list_candidates(&self) -> Vec<ModelCandidate> {
240        let mut out = Vec::new();
241        for entry in &self.entries {
242            let info = entry.provider.info();
243            for model in &entry.models {
244                out.push(ModelCandidate {
245                    provider: info.clone(),
246                    model: model.clone(),
247                });
248            }
249        }
250        out
251    }
252
253    /// Look up a candidate by model ID; used by the ACP layer to render the description.
254    #[must_use]
255    pub fn candidate_for(&self, vendor: &str, model_id: &str) -> Option<ModelCandidate> {
256        let entry = self.entry_for(vendor, model_id)?;
257        let model = entry.models.iter().find(|m| m.id == model_id)?.clone();
258        Some(ModelCandidate {
259            provider: entry.provider.info(),
260            model,
261        })
262    }
263}
264
265/// A flattened `(provider, model)` pair — the smallest projection unit of ACP
266/// `list_models`.
267#[derive(Debug, Clone)]
268pub struct ModelCandidate {
269    pub provider: ProviderInfo,
270    pub model: ModelInfo,
271}
272
273#[cfg(test)]
274mod tests;