Skip to main content

converge_provider/
capability_registry.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3// See LICENSE file in the project root for full license information.
4
5//! Unified capability registry for Converge providers.
6//!
7//! The capability registry provides a single point for discovering and
8//! selecting providers based on their capabilities. This supports the
9//! Converge principle that different models excel at different tasks.
10//!
11//! # Example
12//!
13//! ```ignore
14//! use converge_provider::{CapabilityRegistry, CapabilityRequirements};
15//! use converge_core::capability::{CapabilityKind, Modality};
16//!
17//! let registry = CapabilityRegistry::from_env();
18//!
19//! // Find an embedder that supports images
20//! let requirements = CapabilityRequirements::embedding()
21//!     .with_modality(Modality::Image)
22//!     .prefer_local(true);
23//!
24//! if let Some(embedder) = registry.select_embedder(&requirements) {
25//!     // Use the embedder
26//! }
27//! ```
28
29#[cfg(feature = "brave")]
30use crate::brave::BraveSearchProvider;
31use crate::provider_api::DataSovereignty;
32use crate::provider_api::LlmProvider;
33use converge_core::capability::{
34    CapabilityKind, CapabilityMetadata, Embedding, GraphRecall, Modality, Reranking, VectorRecall,
35};
36use std::collections::HashMap;
37use std::sync::Arc;
38
39/// Requirements for capability selection.
40#[derive(Debug, Clone)]
41pub struct CapabilityRequirements {
42    /// Required capability kind.
43    pub capability: CapabilityKind,
44    /// Required modalities (for embedding/reranking).
45    pub modalities: Vec<Modality>,
46    /// Prefer local providers (data sovereignty).
47    pub prefer_local: bool,
48    /// Required data sovereignty level.
49    pub data_sovereignty: DataSovereignty,
50    /// Maximum acceptable latency in milliseconds.
51    pub max_latency_ms: u32,
52}
53
54impl CapabilityRequirements {
55    /// Requirements for LLM completion.
56    #[must_use]
57    pub fn completion() -> Self {
58        Self {
59            capability: CapabilityKind::Completion,
60            modalities: vec![Modality::Text],
61            prefer_local: false,
62            data_sovereignty: DataSovereignty::Any,
63            max_latency_ms: 30_000,
64        }
65    }
66
67    /// Requirements for embedding.
68    #[must_use]
69    pub fn embedding() -> Self {
70        Self {
71            capability: CapabilityKind::Embedding,
72            modalities: vec![Modality::Text],
73            prefer_local: false,
74            data_sovereignty: DataSovereignty::Any,
75            max_latency_ms: 5_000,
76        }
77    }
78
79    /// Requirements for reranking.
80    #[must_use]
81    pub fn reranking() -> Self {
82        Self {
83            capability: CapabilityKind::Reranking,
84            modalities: vec![Modality::Text],
85            prefer_local: false,
86            data_sovereignty: DataSovereignty::Any,
87            max_latency_ms: 5_000,
88        }
89    }
90
91    /// Requirements for vector recall.
92    #[must_use]
93    pub fn vector_recall() -> Self {
94        Self {
95            capability: CapabilityKind::VectorRecall,
96            modalities: vec![],
97            prefer_local: true,
98            data_sovereignty: DataSovereignty::Any,
99            max_latency_ms: 100,
100        }
101    }
102
103    /// Requirements for graph recall.
104    #[must_use]
105    pub fn graph_recall() -> Self {
106        Self {
107            capability: CapabilityKind::GraphRecall,
108            modalities: vec![],
109            prefer_local: true,
110            data_sovereignty: DataSovereignty::Any,
111            max_latency_ms: 100,
112        }
113    }
114
115    /// Add required modality.
116    #[must_use]
117    pub fn with_modality(mut self, modality: Modality) -> Self {
118        if !self.modalities.contains(&modality) {
119            self.modalities.push(modality);
120        }
121        self
122    }
123
124    /// Set local preference.
125    #[must_use]
126    pub fn prefer_local(mut self, prefer: bool) -> Self {
127        self.prefer_local = prefer;
128        self
129    }
130
131    /// Set data sovereignty requirement.
132    #[must_use]
133    pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
134        self.data_sovereignty = sovereignty;
135        self
136    }
137
138    /// Set maximum latency.
139    #[must_use]
140    pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
141        self.max_latency_ms = ms;
142        self
143    }
144}
145
146/// Registered provider with its capabilities.
147struct RegisteredProvider {
148    /// Provider metadata.
149    metadata: CapabilityMetadata,
150    /// LLM provider instance (if applicable).
151    llm: Option<Arc<dyn LlmProvider>>,
152    /// Embedding provider instance (if applicable).
153    embedder: Option<Arc<dyn Embedding>>,
154    /// Reranker provider instance (if applicable).
155    reranker: Option<Arc<dyn Reranking>>,
156}
157
158/// Web search provider metadata for agent selection.
159#[derive(Debug, Clone)]
160pub struct SearchProviderMeta {
161    /// Provider name (e.g., "brave", "perplexity").
162    pub name: String,
163    /// Whether this provider is available (API key set).
164    pub available: bool,
165    /// Typical latency in milliseconds.
166    pub typical_latency_ms: u32,
167    /// Whether this provider supports AI-powered summaries.
168    pub supports_ai_summary: bool,
169    /// Whether this provider supports news search.
170    pub supports_news: bool,
171    /// Whether this provider supports image search.
172    pub supports_images: bool,
173    /// Whether this provider supports local/POI search.
174    pub supports_local: bool,
175}
176
177/// Requirements for selecting a web search provider.
178///
179/// Unlike LLM requirements, web search requirements focus on
180/// search-specific capabilities like news, images, and AI summaries.
181#[derive(Debug, Clone)]
182pub struct WebSearchRequirements {
183    /// Maximum latency in milliseconds.
184    pub max_latency_ms: u32,
185    /// Whether AI-powered summaries are required.
186    pub requires_ai_summary: bool,
187    /// Whether news search is required.
188    pub requires_news: bool,
189    /// Whether image search is required.
190    pub requires_images: bool,
191    /// Whether local/POI search is required.
192    pub requires_local: bool,
193    /// Data sovereignty requirement.
194    pub data_sovereignty: DataSovereignty,
195}
196
197impl Default for WebSearchRequirements {
198    fn default() -> Self {
199        Self {
200            max_latency_ms: 10_000,
201            requires_ai_summary: false,
202            requires_news: false,
203            requires_images: false,
204            requires_local: false,
205            data_sovereignty: DataSovereignty::Any,
206        }
207    }
208}
209
210impl WebSearchRequirements {
211    /// Creates default requirements for general web search.
212    #[must_use]
213    pub fn web_search() -> Self {
214        Self::default()
215    }
216
217    /// Creates requirements for AI-grounded search (RAG).
218    #[must_use]
219    pub fn grounded() -> Self {
220        Self {
221            max_latency_ms: 15_000,
222            requires_ai_summary: true,
223            ..Self::default()
224        }
225    }
226
227    /// Creates requirements for news search.
228    #[must_use]
229    pub fn news() -> Self {
230        Self {
231            requires_news: true,
232            ..Self::default()
233        }
234    }
235
236    /// Sets the maximum latency.
237    #[must_use]
238    pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
239        self.max_latency_ms = ms;
240        self
241    }
242
243    /// Requires AI-powered summaries.
244    #[must_use]
245    pub fn with_ai_summary(mut self, required: bool) -> Self {
246        self.requires_ai_summary = required;
247        self
248    }
249
250    /// Sets data sovereignty requirement.
251    #[must_use]
252    pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
253        self.data_sovereignty = sovereignty;
254        self
255    }
256}
257
258/// Unified capability registry.
259///
260/// Discovers and manages all available capability providers.
261pub struct CapabilityRegistry {
262    /// Registered providers by name.
263    providers: HashMap<String, RegisteredProvider>,
264    /// Vector stores by name.
265    vector_stores: HashMap<String, Arc<dyn VectorRecall>>,
266    /// Graph stores by name.
267    graph_stores: HashMap<String, Arc<dyn GraphRecall>>,
268    /// Web search providers by name.
269    search_providers: HashMap<String, SearchProviderMeta>,
270    /// Brave search provider instance (if available).
271    #[cfg(feature = "brave")]
272    brave_provider: Option<BraveSearchProvider>,
273}
274
275impl Default for CapabilityRegistry {
276    fn default() -> Self {
277        Self::new()
278    }
279}
280
281impl CapabilityRegistry {
282    /// Creates an empty capability registry.
283    #[must_use]
284    pub fn new() -> Self {
285        Self {
286            providers: HashMap::new(),
287            vector_stores: HashMap::new(),
288            graph_stores: HashMap::new(),
289            search_providers: HashMap::new(),
290            #[cfg(feature = "brave")]
291            brave_provider: None,
292        }
293    }
294
295    /// Creates a registry with auto-detected providers from environment.
296    ///
297    /// This checks for:
298    /// - Ollama (local LLM and embedding)
299    /// - In-memory vector store (always available)
300    /// - In-memory graph store (always available)
301    /// - Brave Search (if `BRAVE_API_KEY` is set)
302    #[must_use]
303    pub fn with_local_defaults() -> Self {
304        let mut registry = Self::new();
305
306        // Add in-memory vector store
307        registry.add_vector_store(
308            "default",
309            Arc::new(crate::vector::InMemoryVectorStore::new()),
310        );
311
312        // Add in-memory graph store
313        registry.add_graph_store("default", Arc::new(crate::graph::InMemoryGraphStore::new()));
314
315        // Try to add Brave Search if available
316        registry.try_add_brave_from_env();
317
318        registry
319    }
320
321    /// Attempts to add Brave Search provider from environment.
322    ///
323    /// Returns `true` if Brave Search was added successfully.
324    pub fn try_add_brave_from_env(&mut self) -> bool {
325        #[cfg(feature = "brave")]
326        if let Ok(provider) = BraveSearchProvider::from_env() {
327            self.brave_provider = Some(provider);
328            self.search_providers.insert(
329                "brave".to_string(),
330                SearchProviderMeta {
331                    name: "brave".to_string(),
332                    available: true,
333                    typical_latency_ms: 500,
334                    supports_ai_summary: false, // Requires Pro plan
335                    supports_news: true,
336                    supports_images: true,
337                    supports_local: true,
338                },
339            );
340            return true;
341        }
342        false
343    }
344
345    /// Adds Brave Search provider with a specific API key.
346    #[cfg(feature = "brave")]
347    pub fn add_brave(&mut self, api_key: impl Into<String>) {
348        self.brave_provider = Some(BraveSearchProvider::new(api_key));
349        self.search_providers.insert(
350            "brave".to_string(),
351            SearchProviderMeta {
352                name: "brave".to_string(),
353                available: true,
354                typical_latency_ms: 500,
355                supports_ai_summary: false,
356                supports_news: true,
357                supports_images: true,
358                supports_local: true,
359            },
360        );
361    }
362
363    /// Gets the Brave Search provider if available.
364    #[cfg(feature = "brave")]
365    #[must_use]
366    pub fn brave(&self) -> Option<&BraveSearchProvider> {
367        self.brave_provider.as_ref()
368    }
369
370    /// Checks if web search capability is available.
371    #[must_use]
372    pub fn has_web_search(&self) -> bool {
373        !self.search_providers.is_empty()
374    }
375
376    /// Gets metadata for all available search providers.
377    #[must_use]
378    pub fn search_providers(&self) -> Vec<&SearchProviderMeta> {
379        self.search_providers.values().collect()
380    }
381
382    /// Selects the best search provider based on requirements.
383    ///
384    /// Currently returns Brave if available, as it's the primary search provider.
385    #[must_use]
386    pub fn select_search_provider(
387        &self,
388        requirements: &WebSearchRequirements,
389    ) -> Option<&SearchProviderMeta> {
390        self.search_providers
391            .values()
392            .filter(|p| {
393                // Basic availability and latency check
394                if !p.available || p.typical_latency_ms > requirements.max_latency_ms {
395                    return false;
396                }
397                // Check required capabilities
398                if requirements.requires_ai_summary && !p.supports_ai_summary {
399                    return false;
400                }
401                if requirements.requires_news && !p.supports_news {
402                    return false;
403                }
404                if requirements.requires_images && !p.supports_images {
405                    return false;
406                }
407                if requirements.requires_local && !p.supports_local {
408                    return false;
409                }
410                true
411            })
412            .max_by_key(|p| {
413                // Score providers by their capabilities
414                let mut score = 0i32;
415                if p.supports_ai_summary {
416                    score += 100;
417                }
418                if p.supports_news {
419                    score += 20;
420                }
421                if p.supports_images {
422                    score += 20;
423                }
424                if p.supports_local {
425                    score += 10;
426                }
427                // Prefer lower latency
428                score -= (p.typical_latency_ms / 100) as i32;
429                score
430            })
431    }
432
433    /// Registers an LLM provider.
434    pub fn add_llm_provider(
435        &mut self,
436        name: &str,
437        provider: Arc<dyn LlmProvider>,
438        metadata: CapabilityMetadata,
439    ) {
440        let entry = self
441            .providers
442            .entry(name.to_string())
443            .or_insert_with(|| RegisteredProvider {
444                metadata: metadata.clone(),
445                llm: None,
446                embedder: None,
447                reranker: None,
448            });
449        entry.llm = Some(provider);
450        entry.metadata = metadata;
451    }
452
453    /// Registers an embedding provider.
454    #[allow(clippy::needless_pass_by_value)]
455    pub fn add_embedder(
456        &mut self,
457        name: &str,
458        provider: Arc<dyn Embedding>,
459        metadata: CapabilityMetadata,
460    ) {
461        let entry = self
462            .providers
463            .entry(name.to_string())
464            .or_insert_with(|| RegisteredProvider {
465                metadata: metadata.clone(),
466                llm: None,
467                embedder: None,
468                reranker: None,
469            });
470        entry.embedder = Some(provider);
471        // Merge capabilities
472        for cap in &metadata.capabilities {
473            if !entry.metadata.capabilities.contains(cap) {
474                entry.metadata.capabilities.push(*cap);
475            }
476        }
477    }
478
479    /// Registers a reranker provider.
480    #[allow(clippy::needless_pass_by_value)]
481    pub fn add_reranker(
482        &mut self,
483        name: &str,
484        provider: Arc<dyn Reranking>,
485        metadata: CapabilityMetadata,
486    ) {
487        let entry = self
488            .providers
489            .entry(name.to_string())
490            .or_insert_with(|| RegisteredProvider {
491                metadata: metadata.clone(),
492                llm: None,
493                embedder: None,
494                reranker: None,
495            });
496        entry.reranker = Some(provider);
497        // Merge capabilities
498        for cap in &metadata.capabilities {
499            if !entry.metadata.capabilities.contains(cap) {
500                entry.metadata.capabilities.push(*cap);
501            }
502        }
503    }
504
505    /// Registers a vector store.
506    pub fn add_vector_store(&mut self, name: &str, store: Arc<dyn VectorRecall>) {
507        self.vector_stores.insert(name.to_string(), store);
508    }
509
510    /// Registers a graph store.
511    pub fn add_graph_store(&mut self, name: &str, store: Arc<dyn GraphRecall>) {
512        self.graph_stores.insert(name.to_string(), store);
513    }
514
515    /// Selects an LLM provider matching requirements.
516    #[must_use]
517    pub fn select_llm(
518        &self,
519        requirements: &CapabilityRequirements,
520    ) -> Option<Arc<dyn LlmProvider>> {
521        self.providers
522            .values()
523            .filter(|p| p.llm.is_some() && self.matches_requirements(&p.metadata, requirements))
524            .max_by_key(|p| self.score_provider(&p.metadata, requirements))
525            .and_then(|p| p.llm.clone())
526    }
527
528    /// Selects an embedding provider matching requirements.
529    #[must_use]
530    pub fn select_embedder(
531        &self,
532        requirements: &CapabilityRequirements,
533    ) -> Option<Arc<dyn Embedding>> {
534        self.providers
535            .values()
536            .filter(|p| {
537                p.embedder.is_some() && self.matches_requirements(&p.metadata, requirements)
538            })
539            .max_by_key(|p| self.score_provider(&p.metadata, requirements))
540            .and_then(|p| p.embedder.clone())
541    }
542
543    /// Selects a reranker provider matching requirements.
544    #[must_use]
545    pub fn select_reranker(
546        &self,
547        requirements: &CapabilityRequirements,
548    ) -> Option<Arc<dyn Reranking>> {
549        self.providers
550            .values()
551            .filter(|p| {
552                p.reranker.is_some() && self.matches_requirements(&p.metadata, requirements)
553            })
554            .max_by_key(|p| self.score_provider(&p.metadata, requirements))
555            .and_then(|p| p.reranker.clone())
556    }
557
558    /// Gets the default vector store.
559    #[must_use]
560    pub fn get_vector_store(&self, name: &str) -> Option<Arc<dyn VectorRecall>> {
561        self.vector_stores.get(name).cloned()
562    }
563
564    /// Gets the default graph store.
565    #[must_use]
566    pub fn get_graph_store(&self, name: &str) -> Option<Arc<dyn GraphRecall>> {
567        self.graph_stores.get(name).cloned()
568    }
569
570    /// Gets the default vector store (named "default").
571    #[must_use]
572    pub fn default_vector_store(&self) -> Option<Arc<dyn VectorRecall>> {
573        self.get_vector_store("default")
574    }
575
576    /// Gets the default graph store (named "default").
577    #[must_use]
578    pub fn default_graph_store(&self) -> Option<Arc<dyn GraphRecall>> {
579        self.get_graph_store("default")
580    }
581
582    /// Lists all registered provider names.
583    #[must_use]
584    pub fn provider_names(&self) -> Vec<&str> {
585        self.providers.keys().map(String::as_str).collect()
586    }
587
588    /// Lists all registered vector store names.
589    #[must_use]
590    pub fn vector_store_names(&self) -> Vec<&str> {
591        self.vector_stores.keys().map(String::as_str).collect()
592    }
593
594    /// Lists all registered graph store names.
595    #[must_use]
596    pub fn graph_store_names(&self) -> Vec<&str> {
597        self.graph_stores.keys().map(String::as_str).collect()
598    }
599
600    /// Checks if a provider matches the requirements.
601    #[allow(clippy::unused_self)]
602    fn matches_requirements(
603        &self,
604        metadata: &CapabilityMetadata,
605        requirements: &CapabilityRequirements,
606    ) -> bool {
607        // Check capability
608        if !metadata.capabilities.contains(&requirements.capability) {
609            return false;
610        }
611
612        // Check modalities
613        for modality in &requirements.modalities {
614            if !metadata.modalities.contains(modality) {
615                return false;
616            }
617        }
618
619        // Check data sovereignty - local providers satisfy all requirements
620        #[allow(clippy::match_same_arms)]
621        match (&requirements.data_sovereignty, metadata.is_local) {
622            (DataSovereignty::Any, _) | (_, true) => {} // Always OK or local
623            _ => {} // Remote providers must match specific sovereignty
624        }
625
626        // Check latency
627        if metadata.typical_latency_ms > requirements.max_latency_ms {
628            return false;
629        }
630
631        true
632    }
633
634    /// Scores a provider for selection (higher = better).
635    #[allow(clippy::unused_self, clippy::cast_possible_wrap)]
636    fn score_provider(
637        &self,
638        metadata: &CapabilityMetadata,
639        requirements: &CapabilityRequirements,
640    ) -> i32 {
641        let mut score = 0;
642
643        // Prefer local if requested
644        if requirements.prefer_local && metadata.is_local {
645            score += 100;
646        }
647
648        // Lower latency is better
649        if metadata.typical_latency_ms < requirements.max_latency_ms / 2 {
650            score += 50;
651        }
652
653        // More modalities is better
654        score += (metadata.modalities.len() * 10) as i32;
655
656        score
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use crate::graph::InMemoryGraphStore;
664    use crate::vector::InMemoryVectorStore;
665
666    #[test]
667    fn registry_with_local_defaults() {
668        let registry = CapabilityRegistry::with_local_defaults();
669
670        assert!(registry.default_vector_store().is_some());
671        assert!(registry.default_graph_store().is_some());
672    }
673
674    #[test]
675    fn add_and_get_stores() {
676        let mut registry = CapabilityRegistry::new();
677
678        registry.add_vector_store("test", Arc::new(InMemoryVectorStore::new()));
679        registry.add_graph_store("test", Arc::new(InMemoryGraphStore::new()));
680
681        assert!(registry.get_vector_store("test").is_some());
682        assert!(registry.get_graph_store("test").is_some());
683        assert!(registry.get_vector_store("missing").is_none());
684    }
685
686    #[test]
687    fn list_registered_stores() {
688        let registry = CapabilityRegistry::with_local_defaults();
689
690        let vector_stores = registry.vector_store_names();
691        assert!(vector_stores.contains(&"default"));
692
693        let graph_stores = registry.graph_store_names();
694        assert!(graph_stores.contains(&"default"));
695    }
696
697    #[test]
698    fn capability_requirements_builder() {
699        let reqs = CapabilityRequirements::embedding()
700            .with_modality(Modality::Image)
701            .prefer_local(true)
702            .with_max_latency_ms(1000);
703
704        assert_eq!(reqs.capability, CapabilityKind::Embedding);
705        assert!(reqs.modalities.contains(&Modality::Image));
706        assert!(reqs.prefer_local);
707        assert_eq!(reqs.max_latency_ms, 1000);
708    }
709}