1use std::collections::{HashMap, HashSet};
21
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum ProviderFeature {
35 ExplicitCacheBreakpoints,
37 AutomaticPrefixCaching,
39 RetentionTiers,
41 PriorityScheduling,
43 ModelRouting,
45 DeferredToolLoading,
47 FileReferences,
49 StructuredOutput,
51 PrefixAffinityHints,
53 StreamingTokenCounts,
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
63pub struct CacheEconomics {
64 pub write_short_multiplier: f64,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 #[serde(default)]
69 pub write_long_multiplier: Option<f64>,
70 pub read_multiplier: f64,
72}
73
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub struct ModelFamilyCapabilities {
84 pub model_family: String,
86 pub supported_features: HashSet<ProviderFeature>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 #[serde(default)]
91 pub max_cache_breakpoints: Option<u32>,
92 #[serde(skip_serializing_if = "Option::is_none")]
94 #[serde(default)]
95 pub min_cacheable_tokens: Option<u32>,
96 #[serde(skip_serializing_if = "Option::is_none")]
98 #[serde(default)]
99 pub cache_economics: Option<CacheEconomics>,
100}
101
102impl ModelFamilyCapabilities {
103 pub fn supports(&self, feature: ProviderFeature) -> bool {
105 self.supported_features.contains(&feature)
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
118pub struct BackendCapabilities {
119 pub backend_id: String,
121 pub supported_features: HashSet<ProviderFeature>,
123 pub model_families: HashMap<String, ModelFamilyCapabilities>,
125}
126
127impl BackendCapabilities {
128 pub fn none(backend_id: &str) -> Self {
130 Self {
131 backend_id: backend_id.to_string(),
132 supported_features: HashSet::new(),
133 model_families: HashMap::new(),
134 }
135 }
136
137 pub fn supports(&self, feature: ProviderFeature) -> bool {
139 self.supported_features.contains(&feature)
140 }
141
142 pub fn model_supports(&self, model_family: &str, feature: ProviderFeature) -> bool {
146 if let Some(family_caps) = self.model_families.get(model_family) {
147 family_caps.supports(feature)
148 } else {
149 self.supports(feature)
150 }
151 }
152
153 pub fn add_model_family(&mut self, caps: ModelFamilyCapabilities) {
155 self.model_families.insert(caps.model_family.clone(), caps);
156 }
157}
158
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168pub struct CapabilityRegistry {
169 backends: HashMap<String, BackendCapabilities>,
170}
171
172impl CapabilityRegistry {
173 pub fn new() -> Self {
175 Self {
176 backends: HashMap::new(),
177 }
178 }
179
180 pub fn with_defaults() -> Self {
182 let mut registry = Self::new();
183
184 let anthropic_features: HashSet<ProviderFeature> = [
188 ProviderFeature::ExplicitCacheBreakpoints,
189 ProviderFeature::RetentionTiers,
190 ProviderFeature::StreamingTokenCounts,
191 ]
192 .into_iter()
193 .collect();
194
195 let mut anthropic = BackendCapabilities {
196 backend_id: "anthropic".to_string(),
197 supported_features: anthropic_features.clone(),
198 model_families: HashMap::new(),
199 };
200
201 anthropic.add_model_family(ModelFamilyCapabilities {
203 model_family: "claude-opus-4.6".to_string(),
204 supported_features: anthropic_features.clone(),
205 max_cache_breakpoints: Some(4),
206 min_cacheable_tokens: Some(4096),
207 cache_economics: Some(CacheEconomics {
208 write_short_multiplier: 1.25,
209 write_long_multiplier: Some(2.0),
210 read_multiplier: 0.1,
211 }),
212 });
213
214 anthropic.add_model_family(ModelFamilyCapabilities {
215 model_family: "claude-opus-4.5".to_string(),
216 supported_features: anthropic_features.clone(),
217 max_cache_breakpoints: Some(4),
218 min_cacheable_tokens: Some(4096),
219 cache_economics: Some(CacheEconomics {
220 write_short_multiplier: 1.25,
221 write_long_multiplier: Some(2.0),
222 read_multiplier: 0.1,
223 }),
224 });
225
226 anthropic.add_model_family(ModelFamilyCapabilities {
227 model_family: "claude-opus-4.1".to_string(),
228 supported_features: anthropic_features.clone(),
229 max_cache_breakpoints: Some(4),
230 min_cacheable_tokens: Some(1024),
231 cache_economics: Some(CacheEconomics {
232 write_short_multiplier: 1.25,
233 write_long_multiplier: Some(2.0),
234 read_multiplier: 0.1,
235 }),
236 });
237
238 anthropic.add_model_family(ModelFamilyCapabilities {
239 model_family: "claude-opus-4".to_string(),
240 supported_features: anthropic_features.clone(),
241 max_cache_breakpoints: Some(4),
242 min_cacheable_tokens: Some(1024),
243 cache_economics: Some(CacheEconomics {
244 write_short_multiplier: 1.25,
245 write_long_multiplier: Some(2.0),
246 read_multiplier: 0.1,
247 }),
248 });
249
250 anthropic.add_model_family(ModelFamilyCapabilities {
251 model_family: "claude-sonnet-4.6".to_string(),
252 supported_features: anthropic_features.clone(),
253 max_cache_breakpoints: Some(4),
254 min_cacheable_tokens: Some(2048),
255 cache_economics: Some(CacheEconomics {
256 write_short_multiplier: 1.25,
257 write_long_multiplier: Some(2.0),
258 read_multiplier: 0.1,
259 }),
260 });
261
262 anthropic.add_model_family(ModelFamilyCapabilities {
263 model_family: "claude-sonnet-4.5".to_string(),
264 supported_features: anthropic_features.clone(),
265 max_cache_breakpoints: Some(4),
266 min_cacheable_tokens: Some(1024),
267 cache_economics: Some(CacheEconomics {
268 write_short_multiplier: 1.25,
269 write_long_multiplier: Some(2.0),
270 read_multiplier: 0.1,
271 }),
272 });
273
274 anthropic.add_model_family(ModelFamilyCapabilities {
275 model_family: "claude-sonnet-4".to_string(),
276 supported_features: anthropic_features.clone(),
277 max_cache_breakpoints: Some(4),
278 min_cacheable_tokens: Some(1024),
279 cache_economics: Some(CacheEconomics {
280 write_short_multiplier: 1.25,
281 write_long_multiplier: Some(2.0),
282 read_multiplier: 0.1,
283 }),
284 });
285
286 anthropic.add_model_family(ModelFamilyCapabilities {
287 model_family: "claude-haiku-4.5".to_string(),
288 supported_features: anthropic_features.clone(),
289 max_cache_breakpoints: Some(4),
290 min_cacheable_tokens: Some(4096),
291 cache_economics: Some(CacheEconomics {
292 write_short_multiplier: 1.25,
293 write_long_multiplier: Some(2.0),
294 read_multiplier: 0.1,
295 }),
296 });
297
298 anthropic.add_model_family(ModelFamilyCapabilities {
299 model_family: "claude-haiku-3.5".to_string(),
300 supported_features: anthropic_features.clone(),
301 max_cache_breakpoints: Some(4),
302 min_cacheable_tokens: Some(2048),
303 cache_economics: Some(CacheEconomics {
304 write_short_multiplier: 1.25,
305 write_long_multiplier: Some(2.0),
306 read_multiplier: 0.1,
307 }),
308 });
309
310 anthropic.add_model_family(ModelFamilyCapabilities {
312 model_family: "claude-3.5-sonnet".to_string(),
313 supported_features: anthropic_features.clone(),
314 max_cache_breakpoints: Some(4),
315 min_cacheable_tokens: Some(1024),
316 cache_economics: Some(CacheEconomics {
317 write_short_multiplier: 1.25,
318 write_long_multiplier: Some(2.0),
319 read_multiplier: 0.1,
320 }),
321 });
322
323 anthropic.add_model_family(ModelFamilyCapabilities {
324 model_family: "claude-3-opus".to_string(),
325 supported_features: anthropic_features.clone(),
326 max_cache_breakpoints: Some(4),
327 min_cacheable_tokens: Some(2048),
328 cache_economics: Some(CacheEconomics {
329 write_short_multiplier: 1.25,
330 write_long_multiplier: Some(2.0),
331 read_multiplier: 0.1,
332 }),
333 });
334
335 anthropic.add_model_family(ModelFamilyCapabilities {
336 model_family: "claude-3-haiku".to_string(),
337 supported_features: anthropic_features,
338 max_cache_breakpoints: Some(4),
339 min_cacheable_tokens: Some(1024),
340 cache_economics: Some(CacheEconomics {
341 write_short_multiplier: 1.25,
342 write_long_multiplier: Some(2.0),
343 read_multiplier: 0.1,
344 }),
345 });
346
347 registry.register_backend(anthropic);
348
349 let openai_features: HashSet<ProviderFeature> = [
353 ProviderFeature::AutomaticPrefixCaching,
354 ProviderFeature::StreamingTokenCounts,
355 ProviderFeature::StructuredOutput,
356 ]
357 .into_iter()
358 .collect();
359
360 let mut openai = BackendCapabilities {
361 backend_id: "openai".to_string(),
362 supported_features: openai_features.clone(),
363 model_families: HashMap::new(),
364 };
365
366 openai.add_model_family(ModelFamilyCapabilities {
367 model_family: "gpt-4o".to_string(),
368 supported_features: openai_features.clone(),
369 max_cache_breakpoints: None,
370 min_cacheable_tokens: None,
371 cache_economics: None,
372 });
373
374 openai.add_model_family(ModelFamilyCapabilities {
375 model_family: "gpt-4o-mini".to_string(),
376 supported_features: openai_features,
377 max_cache_breakpoints: None,
378 min_cacheable_tokens: None,
379 cache_economics: None,
380 });
381
382 let o1_features: HashSet<ProviderFeature> = [ProviderFeature::StreamingTokenCounts]
384 .into_iter()
385 .collect();
386
387 openai.add_model_family(ModelFamilyCapabilities {
388 model_family: "o1".to_string(),
389 supported_features: o1_features,
390 max_cache_breakpoints: None,
391 min_cacheable_tokens: None,
392 cache_economics: None,
393 });
394
395 registry.register_backend(openai);
396
397 registry
398 }
399
400 pub fn register_backend(&mut self, caps: BackendCapabilities) {
402 self.backends.insert(caps.backend_id.clone(), caps);
403 }
404
405 pub fn get_backend(&self, backend_id: &str) -> Option<&BackendCapabilities> {
407 self.backends.get(backend_id)
408 }
409
410 pub fn supports_feature(&self, backend_id: &str, feature: ProviderFeature) -> bool {
412 self.backends
413 .get(backend_id)
414 .is_some_and(|b| b.supports(feature))
415 }
416
417 pub fn model_supports_feature(
421 &self,
422 backend_id: &str,
423 model_family: &str,
424 feature: ProviderFeature,
425 ) -> bool {
426 self.backends
427 .get(backend_id)
428 .is_some_and(|b| b.model_supports(model_family, feature))
429 }
430
431 pub fn list_backend_ids(&self) -> Vec<String> {
433 let mut ids: Vec<String> = self.backends.keys().cloned().collect();
434 ids.sort();
435 ids
436 }
437}
438
439#[cfg(test)]
440#[path = "../../tests/unit/acg/capability_tests.rs"]
441mod tests;