Skip to main content

nemo_flow_adaptive/acg/
capability.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Capability registry tracking per-backend and per-model-family supported features.
5//!
6//! The registry provides feature discovery so the policy engine knows which
7//! optimization intents can be expressed on which backend/model combinations.
8//!
9//! # Two-Level Feature Lookup
10//!
11//! [`BackendCapabilities`] stores backend-level defaults plus per-model-family
12//! overrides via [`ModelFamilyCapabilities`]. Feature lookups check the model
13//! family first, falling back to backend-level if the family is not registered.
14//!
15//! # Built-in Defaults
16//!
17//! [`CapabilityRegistry::with_defaults()`] returns a registry pre-populated
18//! with known Anthropic and OpenAI capabilities.
19
20use std::collections::{HashMap, HashSet};
21
22use serde::{Deserialize, Serialize};
23
24// ===================================================================
25// ProviderFeature enum
26// ===================================================================
27
28/// Feature that a backend or model family may support.
29///
30/// Used by the capability registry and policy engine to determine
31/// which optimization intents can be expressed for a given target.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum ProviderFeature {
35    /// Backend supports explicit cache control breakpoints (e.g., Anthropic).
36    ExplicitCacheBreakpoints,
37    /// Backend uses automatic prefix caching (e.g., OpenAI).
38    AutomaticPrefixCaching,
39    /// Backend supports retention tier control.
40    RetentionTiers,
41    /// Backend supports priority-based scheduling.
42    PriorityScheduling,
43    /// Backend supports model routing/selection.
44    ModelRouting,
45    /// Backend supports deferred tool loading.
46    DeferredToolLoading,
47    /// Backend supports file/artifact references in prompts.
48    FileReferences,
49    /// Backend supports structured output schemas.
50    StructuredOutput,
51    /// Backend supports prefix-affinity routing hints.
52    PrefixAffinityHints,
53    /// Backend reports per-chunk token counts in streaming responses.
54    StreamingTokenCounts,
55}
56
57/// Provider/model-specific cache economics used by the internal planner.
58///
59/// These values are kept on the capability surface so the core planner can
60/// stay provider-agnostic while concrete plugins/model families supply the
61/// pricing model that makes a cache write profitable.
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
63pub struct CacheEconomics {
64    /// Input cost multiplier for creating a short-lived cache entry.
65    pub write_short_multiplier: f64,
66    /// Optional input cost multiplier for creating a longer-lived cache entry.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    #[serde(default)]
69    pub write_long_multiplier: Option<f64>,
70    /// Input cost multiplier for reading from cache.
71    pub read_multiplier: f64,
72}
73
74// ===================================================================
75// ModelFamilyCapabilities
76// ===================================================================
77
78/// Per-model-family capability overrides within a backend.
79///
80/// Some features vary by model within the same backend (e.g., Claude 3.5
81/// Sonnet supports 4 cache breakpoints while older models support fewer).
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub struct ModelFamilyCapabilities {
84    /// Model family identifier (e.g., "claude-3.5-sonnet", "gpt-4o").
85    pub model_family: String,
86    /// Features supported by this model family.
87    pub supported_features: HashSet<ProviderFeature>,
88    /// Maximum number of cache breakpoints (if applicable).
89    #[serde(skip_serializing_if = "Option::is_none")]
90    #[serde(default)]
91    pub max_cache_breakpoints: Option<u32>,
92    /// Minimum tokens required for a block to be cacheable.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    #[serde(default)]
95    pub min_cacheable_tokens: Option<u32>,
96    /// Provider/model-specific cache economics for explicit cache planning.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    #[serde(default)]
99    pub cache_economics: Option<CacheEconomics>,
100}
101
102impl ModelFamilyCapabilities {
103    /// Check if this model family supports a specific feature.
104    pub fn supports(&self, feature: ProviderFeature) -> bool {
105        self.supported_features.contains(&feature)
106    }
107}
108
109// ===================================================================
110// BackendCapabilities
111// ===================================================================
112
113/// Capabilities of a specific backend provider.
114///
115/// Two-level model: backend-level defaults plus per-model-family overrides.
116/// Feature lookup checks model-family first, falls back to backend-level.
117#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
118pub struct BackendCapabilities {
119    /// Backend identifier (e.g., "anthropic", "openai", "passthrough").
120    pub backend_id: String,
121    /// Backend-level supported features (default for all models).
122    pub supported_features: HashSet<ProviderFeature>,
123    /// Per-model-family capability overrides.
124    pub model_families: HashMap<String, ModelFamilyCapabilities>,
125}
126
127impl BackendCapabilities {
128    /// Create capabilities with no features (used by passthrough plugin).
129    pub fn none(backend_id: &str) -> Self {
130        Self {
131            backend_id: backend_id.to_string(),
132            supported_features: HashSet::new(),
133            model_families: HashMap::new(),
134        }
135    }
136
137    /// Check if the backend supports a feature at the backend level.
138    pub fn supports(&self, feature: ProviderFeature) -> bool {
139        self.supported_features.contains(&feature)
140    }
141
142    /// Check if a specific model family supports a feature.
143    ///
144    /// Falls back to backend-level if the model family is not registered.
145    pub fn model_supports(&self, model_family: &str, feature: ProviderFeature) -> bool {
146        if let Some(family_caps) = self.model_families.get(model_family) {
147            family_caps.supports(feature)
148        } else {
149            self.supports(feature)
150        }
151    }
152
153    /// Add a model family capability override.
154    pub fn add_model_family(&mut self, caps: ModelFamilyCapabilities) {
155        self.model_families.insert(caps.model_family.clone(), caps);
156    }
157}
158
159// ===================================================================
160// CapabilityRegistry
161// ===================================================================
162
163/// Registry holding capabilities for all known backends.
164///
165/// Provides feature discovery so the policy engine and validation
166/// framework know which intents can be expressed on which targets.
167#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168pub struct CapabilityRegistry {
169    backends: HashMap<String, BackendCapabilities>,
170}
171
172impl CapabilityRegistry {
173    /// Create a new empty capability registry.
174    pub fn new() -> Self {
175        Self {
176            backends: HashMap::new(),
177        }
178    }
179
180    /// Create a registry pre-populated with known Anthropic and OpenAI capabilities.
181    pub fn with_defaults() -> Self {
182        let mut registry = Self::new();
183
184        // -----------------------------------------------------------------
185        // Anthropic backend
186        // -----------------------------------------------------------------
187        let anthropic_features: HashSet<ProviderFeature> = [
188            ProviderFeature::ExplicitCacheBreakpoints,
189            ProviderFeature::RetentionTiers,
190            ProviderFeature::StreamingTokenCounts,
191        ]
192        .into_iter()
193        .collect();
194
195        let mut anthropic = BackendCapabilities {
196            backend_id: "anthropic".to_string(),
197            supported_features: anthropic_features.clone(),
198            model_families: HashMap::new(),
199        };
200
201        // Current model families (2026)
202        anthropic.add_model_family(ModelFamilyCapabilities {
203            model_family: "claude-opus-4.6".to_string(),
204            supported_features: anthropic_features.clone(),
205            max_cache_breakpoints: Some(4),
206            min_cacheable_tokens: Some(4096),
207            cache_economics: Some(CacheEconomics {
208                write_short_multiplier: 1.25,
209                write_long_multiplier: Some(2.0),
210                read_multiplier: 0.1,
211            }),
212        });
213
214        anthropic.add_model_family(ModelFamilyCapabilities {
215            model_family: "claude-opus-4.5".to_string(),
216            supported_features: anthropic_features.clone(),
217            max_cache_breakpoints: Some(4),
218            min_cacheable_tokens: Some(4096),
219            cache_economics: Some(CacheEconomics {
220                write_short_multiplier: 1.25,
221                write_long_multiplier: Some(2.0),
222                read_multiplier: 0.1,
223            }),
224        });
225
226        anthropic.add_model_family(ModelFamilyCapabilities {
227            model_family: "claude-opus-4.1".to_string(),
228            supported_features: anthropic_features.clone(),
229            max_cache_breakpoints: Some(4),
230            min_cacheable_tokens: Some(1024),
231            cache_economics: Some(CacheEconomics {
232                write_short_multiplier: 1.25,
233                write_long_multiplier: Some(2.0),
234                read_multiplier: 0.1,
235            }),
236        });
237
238        anthropic.add_model_family(ModelFamilyCapabilities {
239            model_family: "claude-opus-4".to_string(),
240            supported_features: anthropic_features.clone(),
241            max_cache_breakpoints: Some(4),
242            min_cacheable_tokens: Some(1024),
243            cache_economics: Some(CacheEconomics {
244                write_short_multiplier: 1.25,
245                write_long_multiplier: Some(2.0),
246                read_multiplier: 0.1,
247            }),
248        });
249
250        anthropic.add_model_family(ModelFamilyCapabilities {
251            model_family: "claude-sonnet-4.6".to_string(),
252            supported_features: anthropic_features.clone(),
253            max_cache_breakpoints: Some(4),
254            min_cacheable_tokens: Some(2048),
255            cache_economics: Some(CacheEconomics {
256                write_short_multiplier: 1.25,
257                write_long_multiplier: Some(2.0),
258                read_multiplier: 0.1,
259            }),
260        });
261
262        anthropic.add_model_family(ModelFamilyCapabilities {
263            model_family: "claude-sonnet-4.5".to_string(),
264            supported_features: anthropic_features.clone(),
265            max_cache_breakpoints: Some(4),
266            min_cacheable_tokens: Some(1024),
267            cache_economics: Some(CacheEconomics {
268                write_short_multiplier: 1.25,
269                write_long_multiplier: Some(2.0),
270                read_multiplier: 0.1,
271            }),
272        });
273
274        anthropic.add_model_family(ModelFamilyCapabilities {
275            model_family: "claude-sonnet-4".to_string(),
276            supported_features: anthropic_features.clone(),
277            max_cache_breakpoints: Some(4),
278            min_cacheable_tokens: Some(1024),
279            cache_economics: Some(CacheEconomics {
280                write_short_multiplier: 1.25,
281                write_long_multiplier: Some(2.0),
282                read_multiplier: 0.1,
283            }),
284        });
285
286        anthropic.add_model_family(ModelFamilyCapabilities {
287            model_family: "claude-haiku-4.5".to_string(),
288            supported_features: anthropic_features.clone(),
289            max_cache_breakpoints: Some(4),
290            min_cacheable_tokens: Some(4096),
291            cache_economics: Some(CacheEconomics {
292                write_short_multiplier: 1.25,
293                write_long_multiplier: Some(2.0),
294                read_multiplier: 0.1,
295            }),
296        });
297
298        anthropic.add_model_family(ModelFamilyCapabilities {
299            model_family: "claude-haiku-3.5".to_string(),
300            supported_features: anthropic_features.clone(),
301            max_cache_breakpoints: Some(4),
302            min_cacheable_tokens: Some(2048),
303            cache_economics: Some(CacheEconomics {
304                write_short_multiplier: 1.25,
305                write_long_multiplier: Some(2.0),
306                read_multiplier: 0.1,
307            }),
308        });
309
310        // Legacy model families (backward compatibility)
311        anthropic.add_model_family(ModelFamilyCapabilities {
312            model_family: "claude-3.5-sonnet".to_string(),
313            supported_features: anthropic_features.clone(),
314            max_cache_breakpoints: Some(4),
315            min_cacheable_tokens: Some(1024),
316            cache_economics: Some(CacheEconomics {
317                write_short_multiplier: 1.25,
318                write_long_multiplier: Some(2.0),
319                read_multiplier: 0.1,
320            }),
321        });
322
323        anthropic.add_model_family(ModelFamilyCapabilities {
324            model_family: "claude-3-opus".to_string(),
325            supported_features: anthropic_features.clone(),
326            max_cache_breakpoints: Some(4),
327            min_cacheable_tokens: Some(2048),
328            cache_economics: Some(CacheEconomics {
329                write_short_multiplier: 1.25,
330                write_long_multiplier: Some(2.0),
331                read_multiplier: 0.1,
332            }),
333        });
334
335        anthropic.add_model_family(ModelFamilyCapabilities {
336            model_family: "claude-3-haiku".to_string(),
337            supported_features: anthropic_features,
338            max_cache_breakpoints: Some(4),
339            min_cacheable_tokens: Some(1024),
340            cache_economics: Some(CacheEconomics {
341                write_short_multiplier: 1.25,
342                write_long_multiplier: Some(2.0),
343                read_multiplier: 0.1,
344            }),
345        });
346
347        registry.register_backend(anthropic);
348
349        // -----------------------------------------------------------------
350        // OpenAI backend
351        // -----------------------------------------------------------------
352        let openai_features: HashSet<ProviderFeature> = [
353            ProviderFeature::AutomaticPrefixCaching,
354            ProviderFeature::StreamingTokenCounts,
355            ProviderFeature::StructuredOutput,
356        ]
357        .into_iter()
358        .collect();
359
360        let mut openai = BackendCapabilities {
361            backend_id: "openai".to_string(),
362            supported_features: openai_features.clone(),
363            model_families: HashMap::new(),
364        };
365
366        openai.add_model_family(ModelFamilyCapabilities {
367            model_family: "gpt-4o".to_string(),
368            supported_features: openai_features.clone(),
369            max_cache_breakpoints: None,
370            min_cacheable_tokens: None,
371            cache_economics: None,
372        });
373
374        openai.add_model_family(ModelFamilyCapabilities {
375            model_family: "gpt-4o-mini".to_string(),
376            supported_features: openai_features,
377            max_cache_breakpoints: None,
378            min_cacheable_tokens: None,
379            cache_economics: None,
380        });
381
382        // o1 reasoning models: only streaming token counts (no prefix caching)
383        let o1_features: HashSet<ProviderFeature> = [ProviderFeature::StreamingTokenCounts]
384            .into_iter()
385            .collect();
386
387        openai.add_model_family(ModelFamilyCapabilities {
388            model_family: "o1".to_string(),
389            supported_features: o1_features,
390            max_cache_breakpoints: None,
391            min_cacheable_tokens: None,
392            cache_economics: None,
393        });
394
395        registry.register_backend(openai);
396
397        registry
398    }
399
400    /// Register a backend's capabilities in the registry.
401    pub fn register_backend(&mut self, caps: BackendCapabilities) {
402        self.backends.insert(caps.backend_id.clone(), caps);
403    }
404
405    /// Retrieve a backend's capabilities by ID.
406    pub fn get_backend(&self, backend_id: &str) -> Option<&BackendCapabilities> {
407        self.backends.get(backend_id)
408    }
409
410    /// Check if a backend supports a feature at the backend level.
411    pub fn supports_feature(&self, backend_id: &str, feature: ProviderFeature) -> bool {
412        self.backends
413            .get(backend_id)
414            .is_some_and(|b| b.supports(feature))
415    }
416
417    /// Check if a specific model family on a backend supports a feature.
418    ///
419    /// Falls back to backend-level if the model family is not registered.
420    pub fn model_supports_feature(
421        &self,
422        backend_id: &str,
423        model_family: &str,
424        feature: ProviderFeature,
425    ) -> bool {
426        self.backends
427            .get(backend_id)
428            .is_some_and(|b| b.model_supports(model_family, feature))
429    }
430
431    /// Return a sorted list of all registered backend IDs.
432    pub fn list_backend_ids(&self) -> Vec<String> {
433        let mut ids: Vec<String> = self.backends.keys().cloned().collect();
434        ids.sort();
435        ids
436    }
437}
438
439#[cfg(test)]
440#[path = "../../tests/unit/acg/capability_tests.rs"]
441mod tests;