1#[cfg(feature = "brave")]
30use crate::brave::BraveSearchProvider;
31use crate::provider_api::DataSovereignty;
32use crate::provider_api::LlmProvider;
33use converge_core::capability::{
34 CapabilityKind, CapabilityMetadata, Embedding, GraphRecall, Modality, Reranking, VectorRecall,
35};
36use std::collections::HashMap;
37use std::sync::Arc;
38
39#[derive(Debug, Clone)]
41pub struct CapabilityRequirements {
42 pub capability: CapabilityKind,
44 pub modalities: Vec<Modality>,
46 pub prefer_local: bool,
48 pub data_sovereignty: DataSovereignty,
50 pub max_latency_ms: u32,
52}
53
54impl CapabilityRequirements {
55 #[must_use]
57 pub fn completion() -> Self {
58 Self {
59 capability: CapabilityKind::Completion,
60 modalities: vec![Modality::Text],
61 prefer_local: false,
62 data_sovereignty: DataSovereignty::Any,
63 max_latency_ms: 30_000,
64 }
65 }
66
67 #[must_use]
69 pub fn embedding() -> Self {
70 Self {
71 capability: CapabilityKind::Embedding,
72 modalities: vec![Modality::Text],
73 prefer_local: false,
74 data_sovereignty: DataSovereignty::Any,
75 max_latency_ms: 5_000,
76 }
77 }
78
79 #[must_use]
81 pub fn reranking() -> Self {
82 Self {
83 capability: CapabilityKind::Reranking,
84 modalities: vec![Modality::Text],
85 prefer_local: false,
86 data_sovereignty: DataSovereignty::Any,
87 max_latency_ms: 5_000,
88 }
89 }
90
91 #[must_use]
93 pub fn vector_recall() -> Self {
94 Self {
95 capability: CapabilityKind::VectorRecall,
96 modalities: vec![],
97 prefer_local: true,
98 data_sovereignty: DataSovereignty::Any,
99 max_latency_ms: 100,
100 }
101 }
102
103 #[must_use]
105 pub fn graph_recall() -> Self {
106 Self {
107 capability: CapabilityKind::GraphRecall,
108 modalities: vec![],
109 prefer_local: true,
110 data_sovereignty: DataSovereignty::Any,
111 max_latency_ms: 100,
112 }
113 }
114
115 #[must_use]
117 pub fn with_modality(mut self, modality: Modality) -> Self {
118 if !self.modalities.contains(&modality) {
119 self.modalities.push(modality);
120 }
121 self
122 }
123
124 #[must_use]
126 pub fn prefer_local(mut self, prefer: bool) -> Self {
127 self.prefer_local = prefer;
128 self
129 }
130
131 #[must_use]
133 pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
134 self.data_sovereignty = sovereignty;
135 self
136 }
137
138 #[must_use]
140 pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
141 self.max_latency_ms = ms;
142 self
143 }
144}
145
146struct RegisteredProvider {
148 metadata: CapabilityMetadata,
150 llm: Option<Arc<dyn LlmProvider>>,
152 embedder: Option<Arc<dyn Embedding>>,
154 reranker: Option<Arc<dyn Reranking>>,
156}
157
158#[derive(Debug, Clone)]
160pub struct SearchProviderMeta {
161 pub name: String,
163 pub available: bool,
165 pub typical_latency_ms: u32,
167 pub supports_ai_summary: bool,
169 pub supports_news: bool,
171 pub supports_images: bool,
173 pub supports_local: bool,
175}
176
177#[derive(Debug, Clone)]
182pub struct WebSearchRequirements {
183 pub max_latency_ms: u32,
185 pub requires_ai_summary: bool,
187 pub requires_news: bool,
189 pub requires_images: bool,
191 pub requires_local: bool,
193 pub data_sovereignty: DataSovereignty,
195}
196
197impl Default for WebSearchRequirements {
198 fn default() -> Self {
199 Self {
200 max_latency_ms: 10_000,
201 requires_ai_summary: false,
202 requires_news: false,
203 requires_images: false,
204 requires_local: false,
205 data_sovereignty: DataSovereignty::Any,
206 }
207 }
208}
209
210impl WebSearchRequirements {
211 #[must_use]
213 pub fn web_search() -> Self {
214 Self::default()
215 }
216
217 #[must_use]
219 pub fn grounded() -> Self {
220 Self {
221 max_latency_ms: 15_000,
222 requires_ai_summary: true,
223 ..Self::default()
224 }
225 }
226
227 #[must_use]
229 pub fn news() -> Self {
230 Self {
231 requires_news: true,
232 ..Self::default()
233 }
234 }
235
236 #[must_use]
238 pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
239 self.max_latency_ms = ms;
240 self
241 }
242
243 #[must_use]
245 pub fn with_ai_summary(mut self, required: bool) -> Self {
246 self.requires_ai_summary = required;
247 self
248 }
249
250 #[must_use]
252 pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
253 self.data_sovereignty = sovereignty;
254 self
255 }
256}
257
258pub struct CapabilityRegistry {
262 providers: HashMap<String, RegisteredProvider>,
264 vector_stores: HashMap<String, Arc<dyn VectorRecall>>,
266 graph_stores: HashMap<String, Arc<dyn GraphRecall>>,
268 search_providers: HashMap<String, SearchProviderMeta>,
270 #[cfg(feature = "brave")]
272 brave_provider: Option<BraveSearchProvider>,
273}
274
275impl Default for CapabilityRegistry {
276 fn default() -> Self {
277 Self::new()
278 }
279}
280
281impl CapabilityRegistry {
282 #[must_use]
284 pub fn new() -> Self {
285 Self {
286 providers: HashMap::new(),
287 vector_stores: HashMap::new(),
288 graph_stores: HashMap::new(),
289 search_providers: HashMap::new(),
290 #[cfg(feature = "brave")]
291 brave_provider: None,
292 }
293 }
294
295 #[must_use]
303 pub fn with_local_defaults() -> Self {
304 let mut registry = Self::new();
305
306 registry.add_vector_store(
308 "default",
309 Arc::new(crate::vector::InMemoryVectorStore::new()),
310 );
311
312 registry.add_graph_store("default", Arc::new(crate::graph::InMemoryGraphStore::new()));
314
315 registry.try_add_brave_from_env();
317
318 registry
319 }
320
321 pub fn try_add_brave_from_env(&mut self) -> bool {
325 #[cfg(feature = "brave")]
326 if let Ok(provider) = BraveSearchProvider::from_env() {
327 self.brave_provider = Some(provider);
328 self.search_providers.insert(
329 "brave".to_string(),
330 SearchProviderMeta {
331 name: "brave".to_string(),
332 available: true,
333 typical_latency_ms: 500,
334 supports_ai_summary: false, supports_news: true,
336 supports_images: true,
337 supports_local: true,
338 },
339 );
340 return true;
341 }
342 false
343 }
344
345 #[cfg(feature = "brave")]
347 pub fn add_brave(&mut self, api_key: impl Into<String>) {
348 self.brave_provider = Some(BraveSearchProvider::new(api_key));
349 self.search_providers.insert(
350 "brave".to_string(),
351 SearchProviderMeta {
352 name: "brave".to_string(),
353 available: true,
354 typical_latency_ms: 500,
355 supports_ai_summary: false,
356 supports_news: true,
357 supports_images: true,
358 supports_local: true,
359 },
360 );
361 }
362
363 #[cfg(feature = "brave")]
365 #[must_use]
366 pub fn brave(&self) -> Option<&BraveSearchProvider> {
367 self.brave_provider.as_ref()
368 }
369
370 #[must_use]
372 pub fn has_web_search(&self) -> bool {
373 !self.search_providers.is_empty()
374 }
375
376 #[must_use]
378 pub fn search_providers(&self) -> Vec<&SearchProviderMeta> {
379 self.search_providers.values().collect()
380 }
381
382 #[must_use]
386 pub fn select_search_provider(
387 &self,
388 requirements: &WebSearchRequirements,
389 ) -> Option<&SearchProviderMeta> {
390 self.search_providers
391 .values()
392 .filter(|p| {
393 if !p.available || p.typical_latency_ms > requirements.max_latency_ms {
395 return false;
396 }
397 if requirements.requires_ai_summary && !p.supports_ai_summary {
399 return false;
400 }
401 if requirements.requires_news && !p.supports_news {
402 return false;
403 }
404 if requirements.requires_images && !p.supports_images {
405 return false;
406 }
407 if requirements.requires_local && !p.supports_local {
408 return false;
409 }
410 true
411 })
412 .max_by_key(|p| {
413 let mut score = 0i32;
415 if p.supports_ai_summary {
416 score += 100;
417 }
418 if p.supports_news {
419 score += 20;
420 }
421 if p.supports_images {
422 score += 20;
423 }
424 if p.supports_local {
425 score += 10;
426 }
427 score -= (p.typical_latency_ms / 100) as i32;
429 score
430 })
431 }
432
433 pub fn add_llm_provider(
435 &mut self,
436 name: &str,
437 provider: Arc<dyn LlmProvider>,
438 metadata: CapabilityMetadata,
439 ) {
440 let entry = self
441 .providers
442 .entry(name.to_string())
443 .or_insert_with(|| RegisteredProvider {
444 metadata: metadata.clone(),
445 llm: None,
446 embedder: None,
447 reranker: None,
448 });
449 entry.llm = Some(provider);
450 entry.metadata = metadata;
451 }
452
453 #[allow(clippy::needless_pass_by_value)]
455 pub fn add_embedder(
456 &mut self,
457 name: &str,
458 provider: Arc<dyn Embedding>,
459 metadata: CapabilityMetadata,
460 ) {
461 let entry = self
462 .providers
463 .entry(name.to_string())
464 .or_insert_with(|| RegisteredProvider {
465 metadata: metadata.clone(),
466 llm: None,
467 embedder: None,
468 reranker: None,
469 });
470 entry.embedder = Some(provider);
471 for cap in &metadata.capabilities {
473 if !entry.metadata.capabilities.contains(cap) {
474 entry.metadata.capabilities.push(*cap);
475 }
476 }
477 }
478
479 #[allow(clippy::needless_pass_by_value)]
481 pub fn add_reranker(
482 &mut self,
483 name: &str,
484 provider: Arc<dyn Reranking>,
485 metadata: CapabilityMetadata,
486 ) {
487 let entry = self
488 .providers
489 .entry(name.to_string())
490 .or_insert_with(|| RegisteredProvider {
491 metadata: metadata.clone(),
492 llm: None,
493 embedder: None,
494 reranker: None,
495 });
496 entry.reranker = Some(provider);
497 for cap in &metadata.capabilities {
499 if !entry.metadata.capabilities.contains(cap) {
500 entry.metadata.capabilities.push(*cap);
501 }
502 }
503 }
504
505 pub fn add_vector_store(&mut self, name: &str, store: Arc<dyn VectorRecall>) {
507 self.vector_stores.insert(name.to_string(), store);
508 }
509
510 pub fn add_graph_store(&mut self, name: &str, store: Arc<dyn GraphRecall>) {
512 self.graph_stores.insert(name.to_string(), store);
513 }
514
515 #[must_use]
517 pub fn select_llm(
518 &self,
519 requirements: &CapabilityRequirements,
520 ) -> Option<Arc<dyn LlmProvider>> {
521 self.providers
522 .values()
523 .filter(|p| p.llm.is_some() && self.matches_requirements(&p.metadata, requirements))
524 .max_by_key(|p| self.score_provider(&p.metadata, requirements))
525 .and_then(|p| p.llm.clone())
526 }
527
528 #[must_use]
530 pub fn select_embedder(
531 &self,
532 requirements: &CapabilityRequirements,
533 ) -> Option<Arc<dyn Embedding>> {
534 self.providers
535 .values()
536 .filter(|p| {
537 p.embedder.is_some() && self.matches_requirements(&p.metadata, requirements)
538 })
539 .max_by_key(|p| self.score_provider(&p.metadata, requirements))
540 .and_then(|p| p.embedder.clone())
541 }
542
543 #[must_use]
545 pub fn select_reranker(
546 &self,
547 requirements: &CapabilityRequirements,
548 ) -> Option<Arc<dyn Reranking>> {
549 self.providers
550 .values()
551 .filter(|p| {
552 p.reranker.is_some() && self.matches_requirements(&p.metadata, requirements)
553 })
554 .max_by_key(|p| self.score_provider(&p.metadata, requirements))
555 .and_then(|p| p.reranker.clone())
556 }
557
558 #[must_use]
560 pub fn get_vector_store(&self, name: &str) -> Option<Arc<dyn VectorRecall>> {
561 self.vector_stores.get(name).cloned()
562 }
563
564 #[must_use]
566 pub fn get_graph_store(&self, name: &str) -> Option<Arc<dyn GraphRecall>> {
567 self.graph_stores.get(name).cloned()
568 }
569
570 #[must_use]
572 pub fn default_vector_store(&self) -> Option<Arc<dyn VectorRecall>> {
573 self.get_vector_store("default")
574 }
575
576 #[must_use]
578 pub fn default_graph_store(&self) -> Option<Arc<dyn GraphRecall>> {
579 self.get_graph_store("default")
580 }
581
582 #[must_use]
584 pub fn provider_names(&self) -> Vec<&str> {
585 self.providers.keys().map(String::as_str).collect()
586 }
587
588 #[must_use]
590 pub fn vector_store_names(&self) -> Vec<&str> {
591 self.vector_stores.keys().map(String::as_str).collect()
592 }
593
594 #[must_use]
596 pub fn graph_store_names(&self) -> Vec<&str> {
597 self.graph_stores.keys().map(String::as_str).collect()
598 }
599
600 #[allow(clippy::unused_self)]
602 fn matches_requirements(
603 &self,
604 metadata: &CapabilityMetadata,
605 requirements: &CapabilityRequirements,
606 ) -> bool {
607 if !metadata.capabilities.contains(&requirements.capability) {
609 return false;
610 }
611
612 for modality in &requirements.modalities {
614 if !metadata.modalities.contains(modality) {
615 return false;
616 }
617 }
618
619 #[allow(clippy::match_same_arms)]
621 match (&requirements.data_sovereignty, metadata.is_local) {
622 (DataSovereignty::Any, _) | (_, true) => {} _ => {} }
625
626 if metadata.typical_latency_ms > requirements.max_latency_ms {
628 return false;
629 }
630
631 true
632 }
633
634 #[allow(clippy::unused_self, clippy::cast_possible_wrap)]
636 fn score_provider(
637 &self,
638 metadata: &CapabilityMetadata,
639 requirements: &CapabilityRequirements,
640 ) -> i32 {
641 let mut score = 0;
642
643 if requirements.prefer_local && metadata.is_local {
645 score += 100;
646 }
647
648 if metadata.typical_latency_ms < requirements.max_latency_ms / 2 {
650 score += 50;
651 }
652
653 score += (metadata.modalities.len() * 10) as i32;
655
656 score
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use crate::graph::InMemoryGraphStore;
664 use crate::vector::InMemoryVectorStore;
665
666 #[test]
667 fn registry_with_local_defaults() {
668 let registry = CapabilityRegistry::with_local_defaults();
669
670 assert!(registry.default_vector_store().is_some());
671 assert!(registry.default_graph_store().is_some());
672 }
673
674 #[test]
675 fn add_and_get_stores() {
676 let mut registry = CapabilityRegistry::new();
677
678 registry.add_vector_store("test", Arc::new(InMemoryVectorStore::new()));
679 registry.add_graph_store("test", Arc::new(InMemoryGraphStore::new()));
680
681 assert!(registry.get_vector_store("test").is_some());
682 assert!(registry.get_graph_store("test").is_some());
683 assert!(registry.get_vector_store("missing").is_none());
684 }
685
686 #[test]
687 fn list_registered_stores() {
688 let registry = CapabilityRegistry::with_local_defaults();
689
690 let vector_stores = registry.vector_store_names();
691 assert!(vector_stores.contains(&"default"));
692
693 let graph_stores = registry.graph_store_names();
694 assert!(graph_stores.contains(&"default"));
695 }
696
697 #[test]
698 fn capability_requirements_builder() {
699 let reqs = CapabilityRequirements::embedding()
700 .with_modality(Modality::Image)
701 .prefer_local(true)
702 .with_max_latency_ms(1000);
703
704 assert_eq!(reqs.capability, CapabilityKind::Embedding);
705 assert!(reqs.modalities.contains(&Modality::Image));
706 assert!(reqs.prefer_local);
707 assert_eq!(reqs.max_latency_ms, 1000);
708 }
709}