defect_agent/llm/
registry.rs1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use defect_core::llm::{LlmProvider, ModelInfo, ProviderInfo};
24
25use crate::session::SessionCapabilitiesConfig;
26
27#[derive(Clone)]
29pub struct ProviderEntry {
30 provider: Arc<dyn LlmProvider>,
31 models: Vec<ModelInfo>,
32 capabilities: SessionCapabilitiesConfig,
33}
34
35impl std::fmt::Debug for ProviderEntry {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("ProviderEntry")
38 .field("provider", &self.provider.info())
39 .field("models", &self.models)
40 .field("capabilities", &self.capabilities)
41 .finish()
42 }
43}
44
45impl ProviderEntry {
46 #[must_use]
47 pub fn new(
48 provider: Arc<dyn LlmProvider>,
49 models: Vec<ModelInfo>,
50 capabilities: SessionCapabilitiesConfig,
51 ) -> Self {
52 Self {
53 provider,
54 models,
55 capabilities,
56 }
57 }
58
59 #[must_use]
60 pub fn provider(&self) -> &Arc<dyn LlmProvider> {
61 &self.provider
62 }
63
64 #[must_use]
65 pub fn models(&self) -> &[ModelInfo] {
66 &self.models
67 }
68
69 #[must_use]
70 pub fn capabilities(&self) -> SessionCapabilitiesConfig {
71 self.capabilities
72 }
73}
74
75#[derive(Debug, thiserror::Error)]
76pub enum ProviderRegistryError {
77 #[error("provider registry requires at least one entry")]
78 Empty,
79 #[error(
80 "duplicate model id `{model}` declared twice by provider `{provider}`; \
81 the same (provider, model) pair must be unique within a build"
82 )]
83 DuplicateSelection { provider: String, model: String },
84 #[error(
85 "default model `{model}` is not declared by provider `{provider}`; \
86 add it under that provider, or point `default.provider` at the one that has it"
87 )]
88 UnknownDefaultModel { provider: String, model: String },
89}
90
91#[derive(Debug)]
94pub struct ProviderRegistry {
95 entries: Vec<ProviderEntry>,
96 model_index: HashMap<(String, String), usize>,
100 default: (usize, usize),
103}
104
105impl ProviderRegistry {
106 #[must_use]
111 pub fn single(provider: Arc<dyn LlmProvider>, default_model: ModelInfo) -> Arc<Self> {
112 let vendor = provider.info().vendor;
113 let model_id = default_model.id.clone();
114 let entries = vec![ProviderEntry::new(
115 provider,
116 vec![default_model],
117 SessionCapabilitiesConfig::default(),
118 )];
119 Arc::new(
120 Self::new(entries, &vendor, &model_id)
121 .expect("single-entry registry with matching default model is always valid"),
122 )
123 }
124
125 pub fn new(
141 entries: Vec<ProviderEntry>,
142 default_provider: &str,
143 default_model: &str,
144 ) -> Result<Self, ProviderRegistryError> {
145 if entries.is_empty() {
146 return Err(ProviderRegistryError::Empty);
147 }
148
149 let mut model_index = HashMap::new();
150 let mut default_pos = None;
151 for (entry_idx, entry) in entries.iter().enumerate() {
152 let provider_vendor = entry.provider.info().vendor;
153 let mut seen_in_entry = HashSet::new();
154 for (model_idx, model) in entry.models.iter().enumerate() {
155 if !seen_in_entry.insert(model.id.clone()) {
156 continue;
157 }
158 let key = (provider_vendor.clone(), model.id.clone());
159 if model_index.insert(key, entry_idx).is_some() {
160 return Err(ProviderRegistryError::DuplicateSelection {
161 provider: provider_vendor,
162 model: model.id.clone(),
163 });
164 }
165 if provider_vendor == default_provider
166 && model.id == default_model
167 && default_pos.is_none()
168 {
169 default_pos = Some((entry_idx, model_idx));
170 }
171 }
172 }
173
174 let default = default_pos.ok_or_else(|| ProviderRegistryError::UnknownDefaultModel {
175 provider: default_provider.to_string(),
176 model: default_model.to_string(),
177 })?;
178
179 Ok(Self {
180 entries,
181 model_index,
182 default,
183 })
184 }
185
186 #[must_use]
189 pub fn default_entry(&self) -> &ProviderEntry {
190 let (entry_idx, _) = self.default;
191 self.entries
192 .get(entry_idx)
193 .expect("default index validated in `new`")
194 }
195
196 #[must_use]
198 pub fn default_model(&self) -> &str {
199 let (entry_idx, model_idx) = self.default;
200 let entry = self
201 .entries
202 .get(entry_idx)
203 .expect("default index validated in `new`");
204 entry
205 .models
206 .get(model_idx)
207 .map(|m| m.id.as_str())
208 .expect("default model index validated in `new`")
209 }
210
211 #[must_use]
214 pub fn entry_for(&self, vendor: &str, model_id: &str) -> Option<&ProviderEntry> {
215 self.model_index
216 .get(&(vendor.to_string(), model_id.to_string()))
217 .and_then(|idx| self.entries.get(*idx))
218 }
219
220 #[must_use]
224 pub fn first_entry_for_model(&self, model_id: &str) -> Option<&ProviderEntry> {
225 self.entries
226 .iter()
227 .find(|entry| entry.models.iter().any(|m| m.id == model_id))
228 }
229
230 #[must_use]
232 pub fn entries(&self) -> &[ProviderEntry] {
233 &self.entries
234 }
235
236 #[must_use]
239 pub fn list_candidates(&self) -> Vec<ModelCandidate> {
240 let mut out = Vec::new();
241 for entry in &self.entries {
242 let info = entry.provider.info();
243 for model in &entry.models {
244 out.push(ModelCandidate {
245 provider: info.clone(),
246 model: model.clone(),
247 });
248 }
249 }
250 out
251 }
252
253 #[must_use]
255 pub fn candidate_for(&self, vendor: &str, model_id: &str) -> Option<ModelCandidate> {
256 let entry = self.entry_for(vendor, model_id)?;
257 let model = entry.models.iter().find(|m| m.id == model_id)?.clone();
258 Some(ModelCandidate {
259 provider: entry.provider.info(),
260 model,
261 })
262 }
263}
264
265#[derive(Debug, Clone)]
268pub struct ModelCandidate {
269 pub provider: ProviderInfo,
270 pub model: ModelInfo,
271}
272
273#[cfg(test)]
274mod tests;