#[cfg(feature = "brave")]
use crate::brave::BraveSearchProvider;
#[cfg(feature = "feed")]
use crate::feed::{FeedFetchBackend, HttpFeedProvider};
#[cfg(feature = "_http")]
use crate::fetch::HttpFetchProvider;
use crate::search::{WebFetchBackend, WebSearchBackend};
#[cfg(feature = "tavily")]
use crate::tavily::TavilySearchProvider;
use converge_core::capability::{
CapabilityKind, CapabilityMetadata, Embedding, GraphRecall, Modality, Reranking, VectorRecall,
};
use converge_provider_api::selection::DataSovereignty;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct CapabilityRequirements {
pub capability: CapabilityKind,
pub modalities: Vec<Modality>,
pub prefer_local: bool,
pub data_sovereignty: DataSovereignty,
pub max_latency_ms: u32,
}
impl CapabilityRequirements {
#[must_use]
pub fn completion() -> Self {
Self {
capability: CapabilityKind::Completion,
modalities: vec![Modality::Text],
prefer_local: false,
data_sovereignty: DataSovereignty::Any,
max_latency_ms: 30_000,
}
}
#[must_use]
pub fn embedding() -> Self {
Self {
capability: CapabilityKind::Embedding,
modalities: vec![Modality::Text],
prefer_local: false,
data_sovereignty: DataSovereignty::Any,
max_latency_ms: 5_000,
}
}
#[must_use]
pub fn reranking() -> Self {
Self {
capability: CapabilityKind::Reranking,
modalities: vec![Modality::Text],
prefer_local: false,
data_sovereignty: DataSovereignty::Any,
max_latency_ms: 5_000,
}
}
#[must_use]
pub fn vector_recall() -> Self {
Self {
capability: CapabilityKind::VectorRecall,
modalities: vec![],
prefer_local: true,
data_sovereignty: DataSovereignty::Any,
max_latency_ms: 100,
}
}
#[must_use]
pub fn graph_recall() -> Self {
Self {
capability: CapabilityKind::GraphRecall,
modalities: vec![],
prefer_local: true,
data_sovereignty: DataSovereignty::Any,
max_latency_ms: 100,
}
}
#[must_use]
pub fn with_modality(mut self, modality: Modality) -> Self {
if !self.modalities.contains(&modality) {
self.modalities.push(modality);
}
self
}
#[must_use]
pub fn prefer_local(mut self, prefer: bool) -> Self {
self.prefer_local = prefer;
self
}
#[must_use]
pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
self.data_sovereignty = sovereignty;
self
}
#[must_use]
pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
self.max_latency_ms = ms;
self
}
}
struct RegisteredProvider {
metadata: CapabilityMetadata,
embedder: Option<Arc<dyn Embedding>>,
reranker: Option<Arc<dyn Reranking>>,
}
#[derive(Debug, Clone)]
pub struct SearchProviderMeta {
pub name: String,
pub available: bool,
pub typical_latency_ms: u32,
pub supports_ai_summary: bool,
pub supports_news: bool,
pub supports_images: bool,
pub supports_local: bool,
}
#[derive(Debug, Clone)]
pub struct WebSearchRequirements {
pub max_latency_ms: u32,
pub requires_ai_summary: bool,
pub requires_news: bool,
pub requires_images: bool,
pub requires_local: bool,
pub data_sovereignty: DataSovereignty,
}
impl Default for WebSearchRequirements {
fn default() -> Self {
Self {
max_latency_ms: 10_000,
requires_ai_summary: false,
requires_news: false,
requires_images: false,
requires_local: false,
data_sovereignty: DataSovereignty::Any,
}
}
}
impl WebSearchRequirements {
#[must_use]
pub fn web_search() -> Self {
Self::default()
}
#[must_use]
pub fn grounded() -> Self {
Self {
max_latency_ms: 15_000,
requires_ai_summary: true,
..Self::default()
}
}
#[must_use]
pub fn news() -> Self {
Self {
requires_news: true,
..Self::default()
}
}
#[must_use]
pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
self.max_latency_ms = ms;
self
}
#[must_use]
pub fn with_ai_summary(mut self, required: bool) -> Self {
self.requires_ai_summary = required;
self
}
#[must_use]
pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
self.data_sovereignty = sovereignty;
self
}
#[must_use]
pub fn with_news(mut self, required: bool) -> Self {
self.requires_news = required;
self
}
#[must_use]
pub fn with_images(mut self, required: bool) -> Self {
self.requires_images = required;
self
}
#[must_use]
pub fn with_local(mut self, required: bool) -> Self {
self.requires_local = required;
self
}
}
pub struct CapabilityRegistry {
providers: HashMap<String, RegisteredProvider>,
vector_stores: HashMap<String, Arc<dyn VectorRecall>>,
graph_stores: HashMap<String, Arc<dyn GraphRecall>>,
search_providers: HashMap<String, SearchProviderMeta>,
search_backends: HashMap<String, Arc<dyn WebSearchBackend>>,
#[cfg(feature = "brave")]
brave_provider: Option<Arc<BraveSearchProvider>>,
#[cfg(feature = "tavily")]
tavily_provider: Option<Arc<TavilySearchProvider>>,
fetch_backend: Option<Arc<dyn WebFetchBackend>>,
#[cfg(feature = "feed")]
feed_backend: Option<Arc<dyn FeedFetchBackend>>,
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
impl CapabilityRegistry {
#[must_use]
pub fn new() -> Self {
Self {
providers: HashMap::new(),
vector_stores: HashMap::new(),
graph_stores: HashMap::new(),
search_providers: HashMap::new(),
search_backends: HashMap::new(),
#[cfg(feature = "brave")]
brave_provider: None,
#[cfg(feature = "tavily")]
tavily_provider: None,
fetch_backend: None,
#[cfg(feature = "feed")]
feed_backend: None,
}
}
#[must_use]
pub fn with_local_defaults() -> Self {
let mut registry = Self::new();
registry.add_vector_store(
"default",
Arc::new(crate::vector::InMemoryVectorStore::new()),
);
registry.try_add_brave_from_env();
registry.try_add_tavily_from_env();
#[cfg(feature = "_http")]
{
registry.fetch_backend = Some(Arc::new(HttpFetchProvider::new()));
}
#[cfg(feature = "feed")]
{
registry.feed_backend = Some(Arc::new(HttpFeedProvider::new()));
}
registry
}
pub fn try_add_brave_from_env(&mut self) -> bool {
#[cfg(feature = "brave")]
if let Ok(provider) = BraveSearchProvider::from_env() {
let provider = Arc::new(provider);
self.brave_provider = Some(provider.clone());
self.search_backends
.insert("brave".to_string(), provider.clone());
self.search_providers.insert(
"brave".to_string(),
SearchProviderMeta {
name: "brave".to_string(),
available: true,
typical_latency_ms: 500,
supports_ai_summary: false, supports_news: true,
supports_images: true,
supports_local: true,
},
);
return true;
}
false
}
pub fn try_add_tavily_from_env(&mut self) -> bool {
#[cfg(feature = "tavily")]
if let Ok(provider) = TavilySearchProvider::from_env() {
let provider = Arc::new(provider);
self.tavily_provider = Some(provider.clone());
self.search_backends
.insert("tavily".to_string(), provider.clone());
self.search_providers.insert(
"tavily".to_string(),
SearchProviderMeta {
name: "tavily".to_string(),
available: true,
typical_latency_ms: 1200,
supports_ai_summary: true,
supports_news: true,
supports_images: true,
supports_local: false,
},
);
return true;
}
false
}
#[cfg(feature = "brave")]
pub fn add_brave(&mut self, api_key: impl Into<String>) {
let provider = Arc::new(BraveSearchProvider::new(api_key));
self.brave_provider = Some(provider.clone());
self.search_backends
.insert("brave".to_string(), provider.clone());
self.search_providers.insert(
"brave".to_string(),
SearchProviderMeta {
name: "brave".to_string(),
available: true,
typical_latency_ms: 500,
supports_ai_summary: false,
supports_news: true,
supports_images: true,
supports_local: true,
},
);
}
#[cfg(feature = "tavily")]
pub fn add_tavily(&mut self, api_key: impl Into<String>) {
let provider = Arc::new(TavilySearchProvider::new(api_key));
self.tavily_provider = Some(provider.clone());
self.search_backends
.insert("tavily".to_string(), provider.clone());
self.search_providers.insert(
"tavily".to_string(),
SearchProviderMeta {
name: "tavily".to_string(),
available: true,
typical_latency_ms: 1200,
supports_ai_summary: true,
supports_news: true,
supports_images: true,
supports_local: false,
},
);
}
#[cfg(feature = "brave")]
#[must_use]
pub fn brave(&self) -> Option<&BraveSearchProvider> {
self.brave_provider.as_deref()
}
#[cfg(feature = "tavily")]
#[must_use]
pub fn tavily(&self) -> Option<&TavilySearchProvider> {
self.tavily_provider.as_deref()
}
#[must_use]
pub fn fetch_backend(&self) -> Option<Arc<dyn WebFetchBackend>> {
self.fetch_backend.clone()
}
pub fn set_fetch_backend(&mut self, backend: Arc<dyn WebFetchBackend>) {
self.fetch_backend = Some(backend);
}
#[must_use]
pub fn has_web_fetch(&self) -> bool {
self.fetch_backend.is_some()
}
#[cfg(feature = "feed")]
#[must_use]
pub fn feed_backend(&self) -> Option<Arc<dyn FeedFetchBackend>> {
self.feed_backend.clone()
}
#[cfg(feature = "feed")]
pub fn set_feed_backend(&mut self, backend: Arc<dyn FeedFetchBackend>) {
self.feed_backend = Some(backend);
}
#[cfg(feature = "feed")]
#[must_use]
pub fn has_feed_fetch(&self) -> bool {
self.feed_backend.is_some()
}
#[must_use]
pub fn has_web_search(&self) -> bool {
!self.search_providers.is_empty()
}
#[must_use]
pub fn search_providers(&self) -> Vec<&SearchProviderMeta> {
self.search_providers.values().collect()
}
#[must_use]
pub fn select_search_provider(
&self,
requirements: &WebSearchRequirements,
) -> Option<&SearchProviderMeta> {
self.search_providers
.values()
.filter(|p| {
if !p.available || p.typical_latency_ms > requirements.max_latency_ms {
return false;
}
if requirements.requires_ai_summary && !p.supports_ai_summary {
return false;
}
if requirements.requires_news && !p.supports_news {
return false;
}
if requirements.requires_images && !p.supports_images {
return false;
}
if requirements.requires_local && !p.supports_local {
return false;
}
true
})
.max_by_key(|p| {
let mut score = 0i32;
if requirements.requires_ai_summary && p.supports_ai_summary {
score += 100;
}
if requirements.requires_news && p.supports_news {
score += 30;
}
if requirements.requires_images && p.supports_images {
score += 30;
}
if requirements.requires_local && p.supports_local {
score += 20;
}
score -= (p.typical_latency_ms / 100) as i32;
score
})
}
#[must_use]
pub fn select_search_backend(
&self,
requirements: &WebSearchRequirements,
) -> Option<Arc<dyn WebSearchBackend>> {
self.select_search_provider(requirements)
.and_then(|meta| self.search_backends.get(&meta.name).cloned())
}
#[allow(clippy::needless_pass_by_value)]
pub fn add_embedder(
&mut self,
name: &str,
provider: Arc<dyn Embedding>,
metadata: CapabilityMetadata,
) {
let entry = self
.providers
.entry(name.to_string())
.or_insert_with(|| RegisteredProvider {
metadata: metadata.clone(),
embedder: None,
reranker: None,
});
entry.embedder = Some(provider);
for cap in &metadata.capabilities {
if !entry.metadata.capabilities.contains(cap) {
entry.metadata.capabilities.push(*cap);
}
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn add_reranker(
&mut self,
name: &str,
provider: Arc<dyn Reranking>,
metadata: CapabilityMetadata,
) {
let entry = self
.providers
.entry(name.to_string())
.or_insert_with(|| RegisteredProvider {
metadata: metadata.clone(),
embedder: None,
reranker: None,
});
entry.reranker = Some(provider);
for cap in &metadata.capabilities {
if !entry.metadata.capabilities.contains(cap) {
entry.metadata.capabilities.push(*cap);
}
}
}
pub fn add_vector_store(&mut self, name: &str, store: Arc<dyn VectorRecall>) {
self.vector_stores.insert(name.to_string(), store);
}
pub fn add_graph_store(&mut self, name: &str, store: Arc<dyn GraphRecall>) {
self.graph_stores.insert(name.to_string(), store);
}
#[must_use]
pub fn select_embedder(
&self,
requirements: &CapabilityRequirements,
) -> Option<Arc<dyn Embedding>> {
self.providers
.values()
.filter(|p| {
p.embedder.is_some() && self.matches_requirements(&p.metadata, requirements)
})
.max_by_key(|p| self.score_provider(&p.metadata, requirements))
.and_then(|p| p.embedder.clone())
}
#[must_use]
pub fn select_reranker(
&self,
requirements: &CapabilityRequirements,
) -> Option<Arc<dyn Reranking>> {
self.providers
.values()
.filter(|p| {
p.reranker.is_some() && self.matches_requirements(&p.metadata, requirements)
})
.max_by_key(|p| self.score_provider(&p.metadata, requirements))
.and_then(|p| p.reranker.clone())
}
#[must_use]
pub fn get_vector_store(&self, name: &str) -> Option<Arc<dyn VectorRecall>> {
self.vector_stores.get(name).cloned()
}
#[must_use]
pub fn get_graph_store(&self, name: &str) -> Option<Arc<dyn GraphRecall>> {
self.graph_stores.get(name).cloned()
}
#[must_use]
pub fn default_vector_store(&self) -> Option<Arc<dyn VectorRecall>> {
self.get_vector_store("default")
}
#[must_use]
pub fn default_graph_store(&self) -> Option<Arc<dyn GraphRecall>> {
self.get_graph_store("default")
}
#[must_use]
pub fn provider_names(&self) -> Vec<&str> {
self.providers.keys().map(String::as_str).collect()
}
#[must_use]
pub fn vector_store_names(&self) -> Vec<&str> {
self.vector_stores.keys().map(String::as_str).collect()
}
#[must_use]
pub fn graph_store_names(&self) -> Vec<&str> {
self.graph_stores.keys().map(String::as_str).collect()
}
#[allow(clippy::unused_self)]
fn matches_requirements(
&self,
metadata: &CapabilityMetadata,
requirements: &CapabilityRequirements,
) -> bool {
if !metadata.capabilities.contains(&requirements.capability) {
return false;
}
for modality in &requirements.modalities {
if !metadata.modalities.contains(modality) {
return false;
}
}
#[allow(clippy::match_same_arms)]
match (&requirements.data_sovereignty, metadata.is_local) {
(DataSovereignty::Any, _) | (_, true) => {} _ => {} }
if metadata.typical_latency_ms > requirements.max_latency_ms {
return false;
}
true
}
#[allow(clippy::unused_self, clippy::cast_possible_wrap)]
fn score_provider(
&self,
metadata: &CapabilityMetadata,
requirements: &CapabilityRequirements,
) -> i32 {
let mut score = 0;
if requirements.prefer_local && metadata.is_local {
score += 100;
}
if metadata.typical_latency_ms < requirements.max_latency_ms / 2 {
score += 50;
}
score += (metadata.modalities.len() * 10) as i32;
score
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::InMemoryVectorStore;
use converge_core::capability::{
CapabilityError, GraphEdge, GraphNode, GraphQuery, GraphRecall, GraphResult,
};
struct TestGraphStore;
impl TestGraphStore {
fn new() -> Self {
Self
}
}
impl GraphRecall for TestGraphStore {
fn name(&self) -> &str {
"test-graph"
}
fn add_node(&self, _node: &GraphNode) -> Result<(), CapabilityError> {
Ok(())
}
fn add_edge(&self, _edge: &GraphEdge) -> Result<(), CapabilityError> {
Ok(())
}
fn traverse(&self, _query: &GraphQuery) -> Result<GraphResult, CapabilityError> {
Ok(GraphResult {
nodes: Vec::new(),
edges: Vec::new(),
})
}
fn find_nodes(
&self,
_label: &str,
_properties: Option<&serde_json::Value>,
) -> Result<Vec<GraphNode>, CapabilityError> {
Ok(Vec::new())
}
fn get_node(&self, _id: &str) -> Result<Option<GraphNode>, CapabilityError> {
Ok(None)
}
fn delete_node(&self, _id: &str) -> Result<(), CapabilityError> {
Ok(())
}
fn clear(&self) -> Result<(), CapabilityError> {
Ok(())
}
}
#[test]
fn registry_with_local_defaults() {
let registry = CapabilityRegistry::with_local_defaults();
assert!(registry.default_vector_store().is_some());
assert!(registry.default_graph_store().is_none());
}
#[test]
fn add_and_get_stores() {
let mut registry = CapabilityRegistry::new();
registry.add_vector_store("test", Arc::new(InMemoryVectorStore::new()));
registry.add_graph_store("test", Arc::new(TestGraphStore::new()));
assert!(registry.get_vector_store("test").is_some());
assert!(registry.get_graph_store("test").is_some());
assert!(registry.get_vector_store("missing").is_none());
}
#[test]
fn list_registered_stores() {
let registry = CapabilityRegistry::with_local_defaults();
let vector_stores = registry.vector_store_names();
assert!(vector_stores.contains(&"default"));
let graph_stores = registry.graph_store_names();
assert!(graph_stores.is_empty());
}
#[test]
fn capability_requirements_builder() {
let reqs = CapabilityRequirements::embedding()
.with_modality(Modality::Image)
.prefer_local(true)
.with_max_latency_ms(1000);
assert_eq!(reqs.capability, CapabilityKind::Embedding);
assert!(reqs.modalities.contains(&Modality::Image));
assert!(reqs.prefer_local);
assert_eq!(reqs.max_latency_ms, 1000);
}
#[cfg(all(feature = "brave", feature = "tavily"))]
#[test]
fn search_provider_selection_prefers_tavily_for_ai_summary() {
let mut registry = CapabilityRegistry::new();
registry.add_brave("brave-key");
registry.add_tavily("tavily-key");
let selected = registry
.select_search_provider(&WebSearchRequirements::grounded())
.unwrap();
assert_eq!(selected.name, "tavily");
let backend = registry
.select_search_backend(&WebSearchRequirements::grounded())
.unwrap();
assert_eq!(backend.provider_name(), "tavily");
}
#[cfg(all(feature = "brave", feature = "tavily"))]
#[test]
fn search_provider_selection_prefers_brave_for_local_search() {
let mut registry = CapabilityRegistry::new();
registry.add_brave("brave-key");
registry.add_tavily("tavily-key");
let selected = registry
.select_search_provider(&WebSearchRequirements::web_search().with_local(true))
.unwrap();
assert_eq!(selected.name, "brave");
}
}