ai_lib_rust/registry/
mod.rs1use std::collections::{HashMap, HashSet};
8
9use crate::protocol::v2::capabilities::Capability;
10
11#[derive(Debug, Clone)]
14pub struct CapabilityRegistry {
15 required: HashSet<Capability>,
17 optional: HashSet<Capability>,
19 available: HashSet<Capability>,
21}
22
23impl CapabilityRegistry {
24 pub fn from_capabilities(caps: &crate::protocol::v2::capabilities::CapabilitiesV2) -> Self {
26 let required: HashSet<_> = caps.required_capabilities().into_iter().collect();
27 let all: HashSet<_> = caps.all_capabilities().into_iter().collect();
28 let optional: HashSet<_> = all.difference(&required).cloned().collect();
29
30 let available = Self::detect_available_capabilities();
31
32 Self {
33 required,
34 optional,
35 available,
36 }
37 }
38
39 fn detect_available_capabilities() -> HashSet<Capability> {
41 let mut caps = HashSet::new();
42
43 caps.insert(Capability::Text);
45 caps.insert(Capability::Streaming);
46 caps.insert(Capability::Tools);
47 caps.insert(Capability::ParallelTools);
48
49 #[cfg(feature = "embeddings")]
51 caps.insert(Capability::Embeddings);
52
53 #[cfg(feature = "batch")]
54 caps.insert(Capability::Batch);
55
56 #[cfg(feature = "mcp")]
57 caps.insert(Capability::McpClient);
58
59 #[cfg(feature = "mcp")]
60 caps.insert(Capability::McpServer);
61
62 #[cfg(feature = "computer_use")]
63 caps.insert(Capability::ComputerUse);
64
65 #[cfg(feature = "multimodal")]
66 {
67 caps.insert(Capability::Audio);
68 caps.insert(Capability::Video);
69 caps.insert(Capability::Vision);
70 }
71
72 #[cfg(feature = "reasoning")]
73 caps.insert(Capability::Reasoning);
74
75 #[cfg(not(feature = "multimodal"))]
77 caps.insert(Capability::Vision);
78
79 caps.insert(Capability::Agentic);
80 caps.insert(Capability::StructuredOutput);
81
82 caps
83 }
84
85 pub fn validate_requirements(&self) -> Result<(), Vec<CapabilityGap>> {
87 let mut gaps = Vec::new();
88
89 for cap in &self.required {
90 if !self.available.contains(cap) {
91 gaps.push(CapabilityGap {
92 capability: *cap,
93 required: true,
94 feature_flag: cap.feature_flag().map(String::from),
95 });
96 }
97 }
98
99 if gaps.is_empty() {
100 Ok(())
101 } else {
102 Err(gaps)
103 }
104 }
105
106 pub fn active_capabilities(&self) -> HashSet<Capability> {
108 let declared: HashSet<_> = self.required.union(&self.optional).cloned().collect();
109 declared.intersection(&self.available).cloned().collect()
110 }
111
112 pub fn is_active(&self, cap: Capability) -> bool {
114 (self.required.contains(&cap) || self.optional.contains(&cap))
115 && self.available.contains(&cap)
116 }
117
118 pub fn status_report(&self) -> HashMap<Capability, CapabilityStatus> {
120 let mut report = HashMap::new();
121 let all_declared: HashSet<_> = self.required.union(&self.optional).cloned().collect();
122
123 for cap in &all_declared {
124 let status = if self.available.contains(cap) {
125 if self.required.contains(cap) {
126 CapabilityStatus::ActiveRequired
127 } else {
128 CapabilityStatus::ActiveOptional
129 }
130 } else {
131 CapabilityStatus::Unavailable {
132 feature_flag: cap.feature_flag().map(String::from),
133 }
134 };
135 report.insert(*cap, status);
136 }
137
138 report
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct CapabilityGap {
145 pub capability: Capability,
146 pub required: bool,
147 pub feature_flag: Option<String>,
148}
149
150impl std::fmt::Display for CapabilityGap {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 if let Some(flag) = &self.feature_flag {
153 write!(
154 f,
155 "Capability {:?} is required but not available. Enable with: cargo feature '{}'",
156 self.capability, flag
157 )
158 } else {
159 write!(
160 f,
161 "Capability {:?} is required but not available",
162 self.capability
163 )
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
170pub enum CapabilityStatus {
171 ActiveRequired,
172 ActiveOptional,
173 Unavailable { feature_flag: Option<String> },
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::protocol::v2::capabilities::{CapabilitiesV2, FeatureFlags};
180
181 #[test]
182 fn test_registry_from_capabilities() {
183 let caps = CapabilitiesV2::Structured {
184 required: vec![Capability::Text, Capability::Streaming],
185 optional: vec![Capability::Vision, Capability::Tools],
186 feature_flags: FeatureFlags::default(),
187 };
188 let registry = CapabilityRegistry::from_capabilities(&caps);
189
190 assert!(registry.is_active(Capability::Text));
191 assert!(registry.is_active(Capability::Streaming));
192 assert!(registry.is_active(Capability::Vision));
193 assert!(registry.is_active(Capability::Tools));
194 }
195
196 #[test]
197 fn test_validate_requirements_pass() {
198 let caps = CapabilitiesV2::Structured {
199 required: vec![Capability::Text, Capability::Streaming],
200 optional: vec![],
201 feature_flags: FeatureFlags::default(),
202 };
203 let registry = CapabilityRegistry::from_capabilities(&caps);
204 assert!(registry.validate_requirements().is_ok());
205 }
206
207 #[test]
208 fn test_active_capabilities() {
209 let caps = CapabilitiesV2::Structured {
210 required: vec![Capability::Text],
211 optional: vec![Capability::Vision, Capability::McpClient],
212 feature_flags: FeatureFlags::default(),
213 };
214 let registry = CapabilityRegistry::from_capabilities(&caps);
215 let active = registry.active_capabilities();
216 assert!(active.contains(&Capability::Text));
217 assert!(active.contains(&Capability::Vision));
218 }
219
220 #[test]
221 fn test_status_report() {
222 let caps = CapabilitiesV2::Structured {
223 required: vec![Capability::Text],
224 optional: vec![Capability::Vision],
225 feature_flags: FeatureFlags::default(),
226 };
227 let registry = CapabilityRegistry::from_capabilities(&caps);
228 let report = registry.status_report();
229 assert!(matches!(
230 report.get(&Capability::Text),
231 Some(CapabilityStatus::ActiveRequired)
232 ));
233 }
234}