Skip to main content

hirn_engine/
provider_registry.rs

1//! Config-driven provider registry for AI traits.
2//!
3//! The [`ProviderRegistry`] holds named instances of the four core AI
4//! traits — [`Embedder`], [`Tokenizer`], [`Reranker`], and [`LlmProvider`]
5//! — and exposes a default + by-name lookup pattern.
6//!
7//! Providers can be configured via:
8//! - **Programmatic API**: `register_embedder()`, `register_llm()`, etc.
9//! - **Environment variables**: `from_env()` auto-discovers from env vars
10//! - **TOML configuration**: `from_config()` / `from_toml()` for config-driven setup
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use parking_lot::RwLock;
16
17use hirn_core::embed::{Embedder, LlmProvider, Reranker};
18use hirn_core::tokenizer::{EstimatingTokenizer, Tokenizer};
19use hirn_core::{HirnError, HirnResult};
20
21/// Names of the default providers in each category.
22#[derive(Debug, Clone, Default)]
23pub struct ProviderDefaults {
24    pub embedder: Option<String>,
25    pub tokenizer: Option<String>,
26    pub reranker: Option<String>,
27    pub llm: Option<String>,
28}
29
30// ── TOML configuration types ─────────────────────────────────────────────
31
32/// How an API key is specified in the TOML config.
33///
34/// Supports either a literal string or an environment variable reference:
35///
36/// ```toml
37/// api_key = "sk-literal-key"          # literal
38/// api_key = { env = "OPENAI_API_KEY" } # env var reference
39/// ```
40#[derive(Debug, Clone, serde::Deserialize, PartialEq)]
41#[serde(untagged)]
42pub enum ApiKeySource {
43    /// Reference to an environment variable.
44    Env {
45        /// Name of the environment variable.
46        env: String,
47    },
48    /// A literal API key string.
49    Literal(String),
50}
51
52impl ApiKeySource {
53    /// Resolve the API key to a string value.
54    ///
55    /// For `Env` variants, the environment variable is read.
56    /// Returns an error if the variable is not set.
57    pub fn resolve(&self) -> HirnResult<String> {
58        match self {
59            Self::Literal(key) => Ok(key.clone()),
60            Self::Env { env } => std::env::var(env).map_err(|_| {
61                HirnError::config(format!(
62                    "environment variable '{env}' not set (required by provider config)"
63                ))
64            }),
65        }
66    }
67}
68
69/// Configuration for a single embedder provider.
70#[derive(Debug, Clone, serde::Deserialize)]
71pub struct EmbedderConfig {
72    /// Provider type: `"openai"`, `"ollama"`, `"pseudo"`.
73    pub r#type: String,
74    /// Model name (e.g. `"text-embedding-3-small"`).
75    pub model: Option<String>,
76    /// Embedding dimensions.
77    pub dimensions: Option<usize>,
78    /// API key (for remote providers).
79    pub api_key: Option<ApiKeySource>,
80    /// Base URL override.
81    pub base_url: Option<String>,
82}
83
84/// Configuration for a single LLM provider.
85#[derive(Debug, Clone, serde::Deserialize)]
86pub struct LlmConfig {
87    /// Provider type: `"openai"`, `"ollama"`, `"anthropic"`, `"mock"`.
88    pub r#type: String,
89    /// Model name.
90    pub model: Option<String>,
91    /// API key (for remote providers).
92    pub api_key: Option<ApiKeySource>,
93    /// Base URL override.
94    pub base_url: Option<String>,
95}
96
97/// Configuration for a single reranker provider.
98#[derive(Debug, Clone, serde::Deserialize)]
99pub struct RerankerConfig {
100    /// Provider type: `"cohere"`, `"cross-encoder"`, `"noop"`.
101    pub r#type: String,
102    /// Model name.
103    pub model: Option<String>,
104    /// API key (for remote providers).
105    pub api_key: Option<ApiKeySource>,
106    /// Base URL override.
107    pub base_url: Option<String>,
108}
109
110/// Configuration for a single tokenizer provider.
111#[derive(Debug, Clone, serde::Deserialize)]
112pub struct TokenizerConfig {
113    /// Provider type: `"tiktoken"`, `"huggingface"`, `"estimating"`.
114    pub r#type: String,
115    /// Model name or identifier.
116    pub model: Option<String>,
117    /// Maximum token length.
118    pub max_length: Option<usize>,
119}
120
121/// Which provider name to use as default for each category.
122#[derive(Debug, Clone, Default, serde::Deserialize)]
123pub struct DefaultsConfig {
124    pub embedder: Option<String>,
125    pub tokenizer: Option<String>,
126    pub reranker: Option<String>,
127    pub llm: Option<String>,
128}
129
130/// Top-level provider configuration, TOML-deserializable.
131///
132/// # Example
133///
134/// ```toml
135/// [providers.embedder.openai]
136/// type = "openai"
137/// model = "text-embedding-3-small"
138/// api_key = { env = "OPENAI_API_KEY" }
139/// dimensions = 1536
140///
141/// [providers.llm.claude]
142/// type = "anthropic"
143/// model = "claude-sonnet-4-20250514"
144/// api_key = { env = "ANTHROPIC_API_KEY" }
145///
146/// [providers.reranker.cohere]
147/// type = "cohere"
148/// model = "rerank-v3.5"
149/// api_key = { env = "COHERE_API_KEY" }
150///
151/// [providers.tokenizer.default]
152/// type = "estimating"
153///
154/// [defaults]
155/// embedder = "openai"
156/// llm = "claude"
157/// reranker = "cohere"
158/// tokenizer = "default"
159/// ```
160#[derive(Debug, Clone, Default, serde::Deserialize)]
161pub struct ProviderConfig {
162    /// Provider definitions grouped by category.
163    #[serde(default)]
164    pub providers: ProvidersSection,
165    /// Which provider name to use as the default for each category.
166    #[serde(default)]
167    pub defaults: DefaultsConfig,
168}
169
170/// The `[providers]` section: maps category → name → config.
171#[derive(Debug, Clone, Default, serde::Deserialize)]
172pub struct ProvidersSection {
173    #[serde(default)]
174    pub embedder: HashMap<String, EmbedderConfig>,
175    #[serde(default)]
176    pub llm: HashMap<String, LlmConfig>,
177    #[serde(default)]
178    pub reranker: HashMap<String, RerankerConfig>,
179    #[serde(default)]
180    pub tokenizer: HashMap<String, TokenizerConfig>,
181}
182
183/// Central registry for AI providers, supporting runtime hot-swap.
184///
185/// Thread-safe: all state is behind `RwLock`/`Arc` so the registry can be
186/// shared across `tokio` tasks.
187///
188/// # Example
189///
190/// ```rust
191/// use hirn_engine::ProviderRegistry;
192/// use hirn_provider::PseudoEmbedder;
193/// use std::sync::Arc;
194///
195/// let mut reg = ProviderRegistry::new();
196/// reg.register_embedder("pseudo", Arc::new(PseudoEmbedder::new(128)));
197/// reg.set_default_embedder("pseudo").unwrap();
198/// assert!(reg.embedder().is_some());
199/// ```
200pub struct ProviderRegistry {
201    embedders: RwLock<HashMap<String, Arc<dyn Embedder>>>,
202    tokenizers: RwLock<HashMap<String, Arc<dyn Tokenizer>>>,
203    rerankers: RwLock<HashMap<String, Arc<dyn Reranker>>>,
204    llms: RwLock<HashMap<String, Arc<dyn LlmProvider>>>,
205    defaults: RwLock<ProviderDefaults>,
206}
207
208impl std::fmt::Debug for ProviderRegistry {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        let defaults = self.defaults.read();
211        f.debug_struct("ProviderRegistry")
212            .field(
213                "embedders",
214                &self.embedders.read().keys().collect::<Vec<_>>(),
215            )
216            .field(
217                "tokenizers",
218                &self.tokenizers.read().keys().collect::<Vec<_>>(),
219            )
220            .field(
221                "rerankers",
222                &self.rerankers.read().keys().collect::<Vec<_>>(),
223            )
224            .field("llms", &self.llms.read().keys().collect::<Vec<_>>())
225            .field("defaults", &*defaults)
226            .finish()
227    }
228}
229
230impl ProviderRegistry {
231    /// Create an empty registry.
232    pub fn new() -> Self {
233        Self {
234            embedders: RwLock::new(HashMap::new()),
235            tokenizers: RwLock::new(HashMap::new()),
236            rerankers: RwLock::new(HashMap::new()),
237            llms: RwLock::new(HashMap::new()),
238            defaults: RwLock::new(ProviderDefaults::default()),
239        }
240    }
241
242    fn with_fallbacks() -> Self {
243        let reg = Self::new();
244
245        reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(384)));
246        reg.register_tokenizer("estimating", Arc::new(EstimatingTokenizer));
247        reg.register_reranker("noop", Arc::new(hirn_core::embed::NoopReranker));
248        reg.register_llm(
249            "mock",
250            Arc::new(hirn_provider::MockLlmProvider::new("mock")),
251        );
252
253        let _ = reg.set_default_embedder("pseudo");
254        let _ = reg.set_default_tokenizer("estimating");
255        let _ = reg.set_default_reranker("noop");
256        let _ = reg.set_default_llm("mock");
257
258        #[cfg(feature = "tiktoken")]
259        if let Ok(tokenizer) = hirn_provider::build_tokenizer("tiktoken", Some("cl100k_base"), None)
260        {
261            reg.register_tokenizer("tiktoken", tokenizer);
262            let _ = reg.set_default_tokenizer("tiktoken");
263        }
264
265        reg
266    }
267
268    #[allow(dead_code)]
269    fn default_embedder_is_unset_or_fallback(&self) -> bool {
270        self.defaults
271            .read()
272            .embedder
273            .as_deref()
274            .is_none_or(|name| name == "pseudo")
275    }
276
277    #[allow(dead_code)]
278    fn default_reranker_is_unset_or_fallback(&self) -> bool {
279        self.defaults
280            .read()
281            .reranker
282            .as_deref()
283            .is_none_or(|name| name == "noop")
284    }
285
286    #[allow(dead_code)]
287    fn default_llm_is_unset_or_fallback(&self) -> bool {
288        self.defaults
289            .read()
290            .llm
291            .as_deref()
292            .is_none_or(|name| name == "mock")
293    }
294
295    #[allow(unused_variables)]
296    fn populate_from_env(reg: &Self) {
297        // Override with real providers based on env vars.
298        #[cfg(feature = "openai")]
299        if let Ok(key) = std::env::var("OPENAI_API_KEY") {
300            Self::register_openai_from_key(
301                reg,
302                key,
303                |api_key| {
304                    hirn_provider::OpenAIEmbedder::new(api_key, "text-embedding-3-small", 1536)
305                        .map(|embedder| Arc::new(embedder) as Arc<dyn Embedder>)
306                },
307                |api_key| {
308                    hirn_provider::OpenAILlmProvider::new(api_key, "gpt-4o-mini")
309                        .map(|provider| Arc::new(provider) as Arc<dyn LlmProvider>)
310                },
311            );
312        }
313
314        #[cfg(feature = "ollama")]
315        {
316            let host = std::env::var("OLLAMA_HOST")
317                .unwrap_or_else(|_| "http://localhost:11434".to_owned());
318            if std::env::var("OLLAMA_HOST").is_ok() {
319                match hirn_provider::OllamaEmbedder::new("nomic-embed-text", 768) {
320                    Ok(embedder) => match embedder.with_host(&host) {
321                        Ok(embedder) => {
322                            reg.register_embedder("ollama", Arc::new(embedder));
323                            if reg.defaults.read().embedder.as_deref() != Some("openai") {
324                                let _ = reg.set_default_embedder("ollama");
325                            }
326                        }
327                        Err(err) => {
328                            tracing::warn!(error = %err, provider = "ollama", "failed to validate optional ollama embedder host from environment");
329                        }
330                    },
331                    Err(err) => {
332                        tracing::warn!(error = %err, provider = "ollama", "failed to initialize optional ollama embedder from environment");
333                    }
334                }
335
336                match hirn_provider::OllamaLlmProvider::new("llama3.1") {
337                    Ok(provider) => match provider.with_host(&host) {
338                        Ok(provider) => {
339                            reg.register_llm("ollama", Arc::new(provider));
340                            if reg.defaults.read().llm.as_deref() != Some("openai") {
341                                let _ = reg.set_default_llm("ollama");
342                            }
343                        }
344                        Err(err) => {
345                            tracing::warn!(error = %err, provider = "ollama", "failed to validate optional ollama llm host from environment");
346                        }
347                    },
348                    Err(err) => {
349                        tracing::warn!(error = %err, provider = "ollama", "failed to initialize optional ollama llm from environment");
350                    }
351                }
352            }
353        }
354
355        #[cfg(feature = "cohere")]
356        match hirn_provider::CohereReranker::from_env() {
357            Ok(Some(cohere_reranker)) => {
358                reg.register_reranker("cohere", Arc::new(cohere_reranker));
359                let _ = reg.set_default_reranker("cohere");
360            }
361            Ok(None) => {}
362            Err(err) => {
363                tracing::warn!(error = %err, provider = "cohere", "failed to initialize optional cohere reranker from environment");
364            }
365        }
366
367        #[cfg(feature = "cohere")]
368        match hirn_provider::CohereEmbedder::from_env() {
369            Ok(Some(cohere_embedder)) => {
370                reg.register_embedder("cohere", Arc::new(cohere_embedder));
371                if reg.default_embedder_is_unset_or_fallback() {
372                    let _ = reg.set_default_embedder("cohere");
373                }
374            }
375            Ok(None) => {}
376            Err(err) => {
377                tracing::warn!(error = %err, provider = "cohere", "failed to initialize optional cohere embedder from environment");
378            }
379        }
380
381        #[cfg(feature = "voyage")]
382        match hirn_provider::VoyageEmbedder::from_env() {
383            Ok(Some(voyage_embedder)) => {
384                reg.register_embedder("voyage", Arc::new(voyage_embedder));
385                if reg.default_embedder_is_unset_or_fallback() {
386                    let _ = reg.set_default_embedder("voyage");
387                }
388            }
389            Ok(None) => {}
390            Err(err) => {
391                tracing::warn!(error = %err, provider = "voyage", "failed to initialize optional voyage embedder from environment");
392            }
393        }
394
395        #[cfg(feature = "cross-encoder")]
396        if let Ok(cross_encoder) = hirn_provider::CrossEncoderReranker::default_model() {
397            reg.register_reranker("cross-encoder", Arc::new(cross_encoder));
398            if reg.default_reranker_is_unset_or_fallback() {
399                let _ = reg.set_default_reranker("cross-encoder");
400            }
401        }
402
403        #[cfg(feature = "anthropic")]
404        if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
405            match hirn_provider::AnthropicProvider::new(key) {
406                Ok(provider) => {
407                    reg.register_llm("anthropic", Arc::new(provider));
408                    if reg.default_llm_is_unset_or_fallback() {
409                        let _ = reg.set_default_llm("anthropic");
410                    }
411                }
412                Err(err) => {
413                    tracing::warn!(error = %err, provider = "anthropic", "failed to initialize optional anthropic llm from environment");
414                }
415            }
416        }
417
418        #[cfg(feature = "hf-tokenizer")]
419        if let Ok(model_id) = std::env::var("HF_TOKENIZER_MODEL") {
420            if let Ok(hf_tok) = hirn_provider::HuggingFaceTokenizer::from_pretrained(&model_id) {
421                reg.register_tokenizer("huggingface", Arc::new(hf_tok));
422                let _ = reg.set_default_tokenizer("huggingface");
423            }
424        }
425    }
426
427    #[cfg(feature = "openai")]
428    fn register_openai_from_key<FEmbed, FLlm>(
429        reg: &Self,
430        key: String,
431        make_embedder: FEmbed,
432        make_llm: FLlm,
433    ) where
434        FEmbed: FnOnce(String) -> HirnResult<Arc<dyn Embedder>>,
435        FLlm: FnOnce(String) -> HirnResult<Arc<dyn LlmProvider>>,
436    {
437        match make_embedder(key.clone()) {
438            Ok(embedder) => {
439                reg.register_embedder("openai", embedder);
440                let _ = reg.set_default_embedder("openai");
441            }
442            Err(err) => {
443                tracing::warn!(error = %err, provider = "openai", "failed to initialize optional openai embedder from environment");
444            }
445        }
446
447        match make_llm(key) {
448            Ok(provider) => {
449                reg.register_llm("openai", provider);
450                let _ = reg.set_default_llm("openai");
451            }
452            Err(err) => {
453                tracing::warn!(error = %err, provider = "openai", "failed to initialize optional openai llm from environment");
454            }
455        }
456    }
457
458    // ── Embedder ─────────────────────────────────────────────────────
459
460    /// Register a named embedder.
461    pub fn register_embedder(&self, name: &str, embedder: Arc<dyn Embedder>) {
462        self.embedders.write().insert(name.to_owned(), embedder);
463    }
464
465    /// Set the default embedder name. Returns error if the name is not registered.
466    pub fn set_default_embedder(&self, name: &str) -> HirnResult<()> {
467        if !self.embedders.read().contains_key(name) {
468            return Err(HirnError::config(format!(
469                "embedder '{name}' not registered"
470            )));
471        }
472        self.defaults.write().embedder = Some(name.to_owned());
473        Ok(())
474    }
475
476    /// Get the default embedder, if one is configured.
477    pub fn embedder(&self) -> Option<Arc<dyn Embedder>> {
478        let defaults = self.defaults.read();
479        let name = defaults.embedder.as_deref()?;
480        self.embedders.read().get(name).cloned()
481    }
482
483    /// Look up an embedder by name.
484    pub fn embedder_by_name(&self, name: &str) -> Option<Arc<dyn Embedder>> {
485        self.embedders.read().get(name).cloned()
486    }
487
488    // ── Tokenizer ────────────────────────────────────────────────────
489
490    /// Register a named tokenizer.
491    pub fn register_tokenizer(&self, name: &str, tokenizer: Arc<dyn Tokenizer>) {
492        self.tokenizers.write().insert(name.to_owned(), tokenizer);
493    }
494
495    /// Set the default tokenizer name.
496    pub fn set_default_tokenizer(&self, name: &str) -> HirnResult<()> {
497        if !self.tokenizers.read().contains_key(name) {
498            return Err(HirnError::config(format!(
499                "tokenizer '{name}' not registered"
500            )));
501        }
502        self.defaults.write().tokenizer = Some(name.to_owned());
503        Ok(())
504    }
505
506    /// Get the default tokenizer.
507    pub fn tokenizer(&self) -> Option<Arc<dyn Tokenizer>> {
508        let defaults = self.defaults.read();
509        let name = defaults.tokenizer.as_deref()?;
510        self.tokenizers.read().get(name).cloned()
511    }
512
513    /// Look up a tokenizer by name.
514    pub fn tokenizer_by_name(&self, name: &str) -> Option<Arc<dyn Tokenizer>> {
515        self.tokenizers.read().get(name).cloned()
516    }
517
518    // ── Reranker ─────────────────────────────────────────────────────
519
520    /// Register a named reranker.
521    pub fn register_reranker(&self, name: &str, reranker: Arc<dyn Reranker>) {
522        self.rerankers.write().insert(name.to_owned(), reranker);
523    }
524
525    /// Set the default reranker name.
526    pub fn set_default_reranker(&self, name: &str) -> HirnResult<()> {
527        if !self.rerankers.read().contains_key(name) {
528            return Err(HirnError::config(format!(
529                "reranker '{name}' not registered"
530            )));
531        }
532        self.defaults.write().reranker = Some(name.to_owned());
533        Ok(())
534    }
535
536    /// Get the default reranker.
537    pub fn reranker(&self) -> Option<Arc<dyn Reranker>> {
538        let defaults = self.defaults.read();
539        let name = defaults.reranker.as_deref()?;
540        self.rerankers.read().get(name).cloned()
541    }
542
543    /// Look up a reranker by name.
544    pub fn reranker_by_name(&self, name: &str) -> Option<Arc<dyn Reranker>> {
545        self.rerankers.read().get(name).cloned()
546    }
547
548    // ── LLM ──────────────────────────────────────────────────────────
549
550    /// Register a named LLM provider.
551    pub fn register_llm(&self, name: &str, llm: Arc<dyn LlmProvider>) {
552        self.llms.write().insert(name.to_owned(), llm);
553    }
554
555    /// Set the default LLM name.
556    pub fn set_default_llm(&self, name: &str) -> HirnResult<()> {
557        if !self.llms.read().contains_key(name) {
558            return Err(HirnError::config(format!(
559                "llm provider '{name}' not registered"
560            )));
561        }
562        self.defaults.write().llm = Some(name.to_owned());
563        Ok(())
564    }
565
566    /// Get the default LLM provider.
567    pub fn llm(&self) -> Option<Arc<dyn LlmProvider>> {
568        let defaults = self.defaults.read();
569        let name = defaults.llm.as_deref()?;
570        self.llms.read().get(name).cloned()
571    }
572
573    /// Look up an LLM provider by name.
574    pub fn llm_by_name(&self, name: &str) -> Option<Arc<dyn LlmProvider>> {
575        self.llms.read().get(name).cloned()
576    }
577
578    // ── Environment discovery ────────────────────────────────────────
579
580    /// Auto-discover providers from environment variables.
581    ///
582    /// Recognized variables:
583    /// - `OPENAI_API_KEY` → registers OpenAI embedder + LLM (if `openai` features enabled)
584    /// - `OLLAMA_HOST` → registers Ollama embedder + LLM (if `ollama` features enabled)
585    /// - `ANTHROPIC_API_KEY` → registers Anthropic LLM (if `anthropic` feature enabled)
586    ///
587    /// Falls back to `PseudoEmbedder` + provider-default tokenizer + `MockLlmProvider`
588    /// when no keys are found.
589    pub fn from_env() -> Self {
590        let reg = Self::with_fallbacks();
591        Self::populate_from_env(&reg);
592
593        reg
594    }
595
596    /// Auto-discover providers from environment variables without registering
597    /// pseudo/mock/noop fallbacks.
598    pub fn from_env_strict() -> Self {
599        let reg = Self::new();
600        Self::populate_from_env(&reg);
601
602        reg
603    }
604
605    // ── Config-driven construction ───────────────────────────────────
606
607    /// Parse a TOML string into a [`ProviderConfig`] and build a registry.
608    ///
609    /// Environment variable references (`{ env = "VAR" }`) are resolved at
610    /// call time.
611    ///
612    /// # Errors
613    ///
614    /// Returns an error if the TOML is invalid, a provider type is unknown,
615    /// or an environment variable reference cannot be resolved.
616    pub fn from_toml(toml_str: &str) -> HirnResult<Self> {
617        let config: ProviderConfig = toml::from_str(toml_str)
618            .map_err(|e| HirnError::config(format!("invalid provider TOML: {e}")))?;
619        Self::from_config(&config)
620    }
621
622    /// Build a registry from a [`ProviderConfig`].
623    ///
624    /// Each provider entry is constructed according to its `type` field.
625    /// Environment variable references are resolved at call time.
626    /// Fallback providers (pseudo, estimating, noop, mock) are always
627    /// registered; config entries override them.
628    ///
629    /// # Errors
630    ///
631    /// Returns an error if:
632    /// - A provider `type` is unknown or not enabled via feature flag
633    /// - A required field is missing (e.g. `api_key` for remote providers)
634    /// - An environment variable reference cannot be resolved
635    /// - A default name references a provider that was not configured
636    pub fn from_config(config: &ProviderConfig) -> HirnResult<Self> {
637        let reg = Self::with_fallbacks();
638
639        // ── Embedders ────────────────────────────────────────────────
640        for (name, cfg) in &config.providers.embedder {
641            let embedder: Arc<dyn Embedder> = Self::build_embedder(name, cfg)?;
642            reg.register_embedder(name, embedder);
643        }
644
645        // ── LLMs ─────────────────────────────────────────────────────
646        for (name, cfg) in &config.providers.llm {
647            let llm: Arc<dyn LlmProvider> = Self::build_llm(name, cfg)?;
648            reg.register_llm(name, llm);
649        }
650
651        // ── Rerankers ────────────────────────────────────────────────
652        for (name, cfg) in &config.providers.reranker {
653            let reranker: Arc<dyn Reranker> = Self::build_reranker(name, cfg)?;
654            reg.register_reranker(name, reranker);
655        }
656
657        // ── Tokenizers ───────────────────────────────────────────────
658        for (name, cfg) in &config.providers.tokenizer {
659            let tokenizer: Arc<dyn Tokenizer> = Self::build_tokenizer(name, cfg)?;
660            reg.register_tokenizer(name, tokenizer);
661        }
662
663        // ── Defaults ─────────────────────────────────────────────────
664        if let Some(ref name) = config.defaults.embedder {
665            reg.set_default_embedder(name)?;
666        }
667        if let Some(ref name) = config.defaults.tokenizer {
668            reg.set_default_tokenizer(name)?;
669        }
670        if let Some(ref name) = config.defaults.reranker {
671            reg.set_default_reranker(name)?;
672        }
673        if let Some(ref name) = config.defaults.llm {
674            reg.set_default_llm(name)?;
675        }
676
677        Ok(reg)
678    }
679
680    /// Apply a [`ProviderConfig`] on top of an existing registry.
681    ///
682    /// Providers from the config are registered (overriding any with the same
683    /// name). Defaults from the config override existing defaults.
684    pub fn apply_config(&self, config: &ProviderConfig) -> HirnResult<()> {
685        for (name, cfg) in &config.providers.embedder {
686            self.register_embedder(name, Self::build_embedder(name, cfg)?);
687        }
688        for (name, cfg) in &config.providers.llm {
689            self.register_llm(name, Self::build_llm(name, cfg)?);
690        }
691        for (name, cfg) in &config.providers.reranker {
692            self.register_reranker(name, Self::build_reranker(name, cfg)?);
693        }
694        for (name, cfg) in &config.providers.tokenizer {
695            self.register_tokenizer(name, Self::build_tokenizer(name, cfg)?);
696        }
697        if let Some(ref name) = config.defaults.embedder {
698            self.set_default_embedder(name)?;
699        }
700        if let Some(ref name) = config.defaults.tokenizer {
701            self.set_default_tokenizer(name)?;
702        }
703        if let Some(ref name) = config.defaults.reranker {
704            self.set_default_reranker(name)?;
705        }
706        if let Some(ref name) = config.defaults.llm {
707            self.set_default_llm(name)?;
708        }
709        Ok(())
710    }
711
712    // ── Provider builders (private) ──────────────────────────────────
713
714    #[cfg(feature = "openai")]
715    fn build_openai_embedder_with<F>(
716        name: &str,
717        cfg: &EmbedderConfig,
718        constructor: F,
719    ) -> HirnResult<Arc<dyn Embedder>>
720    where
721        F: FnOnce(String, &str, usize) -> HirnResult<hirn_provider::OpenAIEmbedder>,
722    {
723        let api_key = cfg
724            .api_key
725            .as_ref()
726            .ok_or_else(|| {
727                HirnError::config(format!("embedder '{name}': 'api_key' required for openai"))
728            })?
729            .resolve()?;
730        let model = cfg.model.as_deref().unwrap_or("text-embedding-3-small");
731        let dims = cfg.dimensions.unwrap_or(1536);
732        let mut embedder = constructor(api_key, model, dims).map_err(|err| {
733            HirnError::config(format!(
734                "embedder '{name}': failed to initialize openai client: {err}"
735            ))
736        })?;
737        if let Some(ref url) = cfg.base_url {
738            embedder = embedder.with_base_url(url).map_err(|err| {
739                HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
740            })?;
741        }
742        Ok(Arc::new(embedder))
743    }
744
745    fn build_embedder(name: &str, cfg: &EmbedderConfig) -> HirnResult<Arc<dyn Embedder>> {
746        match cfg.r#type.as_str() {
747            "pseudo" => {
748                let dims = cfg.dimensions.unwrap_or(384);
749                Ok(Arc::new(hirn_provider::PseudoEmbedder::new(dims)))
750            }
751            #[cfg(feature = "openai")]
752            "openai" => Self::build_openai_embedder_with(name, cfg, |api_key, model, dims| {
753                hirn_provider::OpenAIEmbedder::new(api_key, model, dims)
754            }),
755            #[cfg(feature = "ollama")]
756            "ollama" => {
757                let model = cfg.model.as_deref().unwrap_or("nomic-embed-text");
758                let dims = cfg.dimensions.unwrap_or(768);
759                let mut embedder =
760                    hirn_provider::OllamaEmbedder::new(model, dims).map_err(|err| {
761                        HirnError::config(format!(
762                            "embedder '{name}': failed to initialize ollama client: {err}"
763                        ))
764                    })?;
765                if let Some(ref url) = cfg.base_url {
766                    embedder = embedder.with_host(url).map_err(|err| {
767                        HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
768                    })?;
769                }
770                Ok(Arc::new(embedder))
771            }
772            #[cfg(feature = "cohere")]
773            "cohere" => {
774                let api_key = cfg
775                    .api_key
776                    .as_ref()
777                    .ok_or_else(|| {
778                        HirnError::config(format!(
779                            "embedder '{name}': 'api_key' required for cohere"
780                        ))
781                    })?
782                    .resolve()?;
783                let model = cfg.model.as_deref().unwrap_or("embed-english-v3.0");
784                let dims = cfg.dimensions.unwrap_or(1024);
785                let mut embedder = hirn_provider::CohereEmbedder::new(api_key, model, dims)
786                    .map_err(|err| {
787                        HirnError::config(format!(
788                            "embedder '{name}': failed to initialize cohere client: {err}"
789                        ))
790                    })?;
791                if let Some(ref url) = cfg.base_url {
792                    embedder = embedder.with_base_url(url).map_err(|err| {
793                        HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
794                    })?;
795                }
796                Ok(Arc::new(embedder))
797            }
798            #[cfg(feature = "voyage")]
799            "voyage" => {
800                let api_key = cfg
801                    .api_key
802                    .as_ref()
803                    .ok_or_else(|| {
804                        HirnError::config(format!(
805                            "embedder '{name}': 'api_key' required for voyage"
806                        ))
807                    })?
808                    .resolve()?;
809                let model = cfg.model.as_deref().unwrap_or("voyage-3");
810                let dims = cfg.dimensions.unwrap_or(1024);
811                let mut embedder = hirn_provider::VoyageEmbedder::new(api_key, model, dims)
812                    .map_err(|err| {
813                        HirnError::config(format!(
814                            "embedder '{name}': failed to initialize voyage client: {err}"
815                        ))
816                    })?;
817                if let Some(ref url) = cfg.base_url {
818                    embedder = embedder.with_base_url(url).map_err(|err| {
819                        HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
820                    })?;
821                }
822                Ok(Arc::new(embedder))
823            }
824            other => Err(HirnError::config(format!(
825                "embedder '{name}': unknown type '{other}'"
826            ))),
827        }
828    }
829
830    fn build_llm(name: &str, cfg: &LlmConfig) -> HirnResult<Arc<dyn LlmProvider>> {
831        match cfg.r#type.as_str() {
832            "mock" => Ok(Arc::new(hirn_provider::MockLlmProvider::new(name))),
833            #[cfg(feature = "openai")]
834            "openai" => {
835                let api_key = cfg
836                    .api_key
837                    .as_ref()
838                    .ok_or_else(|| {
839                        HirnError::config(format!("llm '{name}': 'api_key' required for openai"))
840                    })?
841                    .resolve()?;
842                let model = cfg.model.as_deref().unwrap_or("gpt-4o-mini");
843                let mut provider =
844                    hirn_provider::OpenAILlmProvider::new(api_key, model).map_err(|err| {
845                        HirnError::config(format!(
846                            "llm '{name}': failed to initialize openai client: {err}"
847                        ))
848                    })?;
849                if let Some(ref url) = cfg.base_url {
850                    provider = provider.with_base_url(url).map_err(|err| {
851                        HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
852                    })?;
853                }
854                Ok(Arc::new(provider))
855            }
856            #[cfg(feature = "ollama")]
857            "ollama" => {
858                let model = cfg.model.as_deref().unwrap_or("llama3.1");
859                let mut provider = hirn_provider::OllamaLlmProvider::new(model).map_err(|err| {
860                    HirnError::config(format!(
861                        "llm '{name}': failed to initialize ollama client: {err}"
862                    ))
863                })?;
864                if let Some(ref url) = cfg.base_url {
865                    provider = provider.with_host(url).map_err(|err| {
866                        HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
867                    })?;
868                }
869                Ok(Arc::new(provider))
870            }
871            #[cfg(feature = "anthropic")]
872            "anthropic" => {
873                let api_key = cfg
874                    .api_key
875                    .as_ref()
876                    .ok_or_else(|| {
877                        HirnError::config(format!("llm '{name}': 'api_key' required for anthropic"))
878                    })?
879                    .resolve()?;
880                let mut provider =
881                    hirn_provider::AnthropicProvider::new(api_key).map_err(|err| {
882                        HirnError::config(format!(
883                            "llm '{name}': failed to initialize anthropic client: {err}"
884                        ))
885                    })?;
886                if let Some(ref model) = cfg.model {
887                    provider = provider.with_model(model);
888                }
889                if let Some(ref url) = cfg.base_url {
890                    provider = provider.with_base_url(url).map_err(|err| {
891                        HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
892                    })?;
893                }
894                Ok(Arc::new(provider))
895            }
896            other => Err(HirnError::config(format!(
897                "llm '{name}': unknown type '{other}'"
898            ))),
899        }
900    }
901
902    fn build_reranker(name: &str, cfg: &RerankerConfig) -> HirnResult<Arc<dyn Reranker>> {
903        match cfg.r#type.as_str() {
904            "noop" => Ok(Arc::new(hirn_core::embed::NoopReranker)),
905            #[cfg(feature = "cohere")]
906            "cohere" => {
907                let api_key = cfg
908                    .api_key
909                    .as_ref()
910                    .ok_or_else(|| {
911                        HirnError::config(format!(
912                            "reranker '{name}': 'api_key' required for cohere"
913                        ))
914                    })?
915                    .resolve()?;
916                let mut reranker = hirn_provider::CohereReranker::new(api_key).map_err(|err| {
917                    HirnError::config(format!(
918                        "reranker '{name}': failed to initialize cohere client: {err}"
919                    ))
920                })?;
921                if let Some(ref model) = cfg.model {
922                    reranker = reranker.with_model(model);
923                }
924                if let Some(ref url) = cfg.base_url {
925                    reranker = reranker.with_base_url(url).map_err(|err| {
926                        HirnError::config(format!("reranker '{name}': invalid base_url: {err}"))
927                    })?;
928                }
929                Ok(Arc::new(reranker))
930            }
931            #[cfg(feature = "cross-encoder")]
932            "cross-encoder" => {
933                let reranker =
934                    hirn_provider::CrossEncoderReranker::default_model().map_err(|e| {
935                        HirnError::config(format!(
936                            "reranker '{name}': failed to load cross-encoder: {e}"
937                        ))
938                    })?;
939                Ok(Arc::new(reranker))
940            }
941            other => Err(HirnError::config(format!(
942                "reranker '{name}': unknown type '{other}'"
943            ))),
944        }
945    }
946
947    fn build_tokenizer(name: &str, cfg: &TokenizerConfig) -> HirnResult<Arc<dyn Tokenizer>> {
948        hirn_provider::build_tokenizer(&cfg.r#type, cfg.model.as_deref(), cfg.max_length)
949            .map_err(|e| HirnError::config(format!("tokenizer '{name}': {e}")))
950    }
951}
952
953impl Default for ProviderRegistry {
954    fn default() -> Self {
955        Self::new()
956    }
957}
958
959// Satisfy Send + Sync requirement (parking_lot RwLock is Send + Sync).
960// The trait objects are Send + Sync by trait bounds.
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965
966    #[test]
967    fn register_and_lookup_embedder() {
968        let reg = ProviderRegistry::new();
969        reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(64)));
970        assert!(reg.embedder_by_name("pseudo").is_some());
971        assert!(reg.embedder_by_name("unknown").is_none());
972    }
973
974    #[test]
975    fn default_embedder_requires_registration() {
976        let reg = ProviderRegistry::new();
977        assert!(reg.set_default_embedder("missing").is_err());
978    }
979
980    #[test]
981    fn default_embedder_lookup() {
982        let reg = ProviderRegistry::new();
983        reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(64)));
984        reg.set_default_embedder("pseudo").unwrap();
985        assert!(reg.embedder().is_some());
986    }
987
988    #[test]
989    fn no_default_embedder_returns_none() {
990        let reg = ProviderRegistry::new();
991        assert!(reg.embedder().is_none());
992    }
993
994    #[test]
995    fn register_and_lookup_llm() {
996        let reg = ProviderRegistry::new();
997        reg.register_llm(
998            "mock",
999            Arc::new(hirn_provider::MockLlmProvider::new("test")),
1000        );
1001        assert!(reg.llm_by_name("mock").is_some());
1002    }
1003
1004    #[test]
1005    fn hot_swap_embedder() {
1006        let reg = ProviderRegistry::new();
1007        let e1 = Arc::new(hirn_provider::PseudoEmbedder::new(64));
1008        let e2 = Arc::new(hirn_provider::PseudoEmbedder::new(128));
1009        reg.register_embedder("e", e1);
1010        reg.set_default_embedder("e").unwrap();
1011        assert_eq!(reg.embedder().unwrap().dimensions(), 64);
1012        // Hot-swap
1013        reg.register_embedder("e", e2);
1014        assert_eq!(reg.embedder().unwrap().dimensions(), 128);
1015    }
1016
1017    #[test]
1018    fn from_env_creates_fallbacks() {
1019        // In CI/test without OPENAI_API_KEY, should get fallbacks.
1020        let reg = ProviderRegistry::from_env();
1021        assert!(reg.embedder().is_some());
1022        assert!(reg.tokenizer().is_some());
1023        assert!(reg.reranker().is_some());
1024        assert!(reg.llm().is_some());
1025    }
1026
1027    #[test]
1028    fn from_env_strict_omits_fallback_embedder_when_no_real_embedder_is_configured() {
1029        if [
1030            "OPENAI_API_KEY",
1031            "OLLAMA_HOST",
1032            "COHERE_API_KEY",
1033            "VOYAGE_API_KEY",
1034        ]
1035        .iter()
1036        .any(|key| std::env::var(key).is_ok())
1037        {
1038            return;
1039        }
1040
1041        let reg = ProviderRegistry::from_env_strict();
1042        assert!(reg.embedder().is_none());
1043    }
1044
1045    #[test]
1046    fn registry_is_send_sync() {
1047        fn assert_send_sync<T: Send + Sync>() {}
1048        assert_send_sync::<ProviderRegistry>();
1049    }
1050
1051    #[cfg(feature = "openai")]
1052    #[test]
1053    fn openai_auto_discovery_continues_when_embedder_init_fails() {
1054        let reg = ProviderRegistry::with_fallbacks();
1055
1056        ProviderRegistry::register_openai_from_key(
1057            &reg,
1058            "sk-test".into(),
1059            |_api_key| Err(HirnError::provider("synthetic openai embedder failure")),
1060            |_api_key| Ok(Arc::new(hirn_provider::MockLlmProvider::new("openai"))),
1061        );
1062
1063        assert_eq!(reg.defaults.read().embedder.as_deref(), Some("pseudo"));
1064        assert_eq!(reg.embedder().unwrap().dimensions(), 384);
1065        assert!(reg.embedder_by_name("openai").is_none());
1066        assert_eq!(reg.defaults.read().llm.as_deref(), Some("openai"));
1067        assert!(reg.llm_by_name("openai").is_some());
1068    }
1069
1070    #[cfg(feature = "openai")]
1071    #[test]
1072    fn openai_config_constructor_failure_returns_structured_error() {
1073        let cfg = EmbedderConfig {
1074            r#type: "openai".into(),
1075            model: Some("text-embedding-3-small".into()),
1076            dimensions: Some(1536),
1077            api_key: Some(ApiKeySource::Literal("sk-test".into())),
1078            base_url: None,
1079        };
1080
1081        let err = ProviderRegistry::build_openai_embedder_with(
1082            "broken-openai",
1083            &cfg,
1084            |_api_key, _model, _dims| Err(HirnError::provider("synthetic constructor failure")),
1085        );
1086
1087        let err = match err {
1088            Ok(_) => panic!("expected constructor failure"),
1089            Err(err) => err,
1090        };
1091
1092        match err {
1093            HirnError::InvalidInput(message) => {
1094                assert!(message.contains("embedder 'broken-openai'"));
1095                assert!(message.contains("failed to initialize openai client"));
1096                assert!(message.contains("synthetic constructor failure"));
1097            }
1098            other => panic!("expected invalid input, got {other:?}"),
1099        }
1100    }
1101
1102    #[test]
1103    fn register_and_lookup_reranker() {
1104        let reg = ProviderRegistry::new();
1105        reg.register_reranker("noop", Arc::new(hirn_core::embed::NoopReranker));
1106        reg.set_default_reranker("noop").unwrap();
1107        assert!(reg.reranker().is_some());
1108    }
1109
1110    #[test]
1111    fn register_and_lookup_tokenizer() {
1112        let reg = ProviderRegistry::new();
1113        reg.register_tokenizer("est", Arc::new(EstimatingTokenizer));
1114        reg.set_default_tokenizer("est").unwrap();
1115        assert!(reg.tokenizer().is_some());
1116    }
1117
1118    // ── Config-driven tests ──────────────────────────────────────────
1119
1120    #[test]
1121    fn from_toml_pseudo_and_estimating() {
1122        let toml = r#"
1123[providers.embedder.my_embed]
1124type = "pseudo"
1125dimensions = 256
1126
1127[providers.tokenizer.my_tok]
1128type = "estimating"
1129
1130[providers.llm.my_llm]
1131type = "mock"
1132
1133[providers.reranker.my_reranker]
1134type = "noop"
1135
1136[defaults]
1137embedder = "my_embed"
1138tokenizer = "my_tok"
1139llm = "my_llm"
1140reranker = "my_reranker"
1141"#;
1142        let reg = ProviderRegistry::from_toml(toml).unwrap();
1143        assert_eq!(reg.embedder().unwrap().dimensions(), 256);
1144        assert!(reg.tokenizer().is_some());
1145        assert!(reg.llm().is_some());
1146        assert!(reg.reranker().is_some());
1147    }
1148
1149    #[test]
1150    fn from_toml_unknown_embedder_type_error() {
1151        let toml = r#"
1152[providers.embedder.bad]
1153type = "nonexistent_provider"
1154"#;
1155        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1156        let msg = err.to_string();
1157        assert!(
1158            msg.contains("unknown type") && msg.contains("nonexistent_provider"),
1159            "should mention unknown type: {msg}"
1160        );
1161    }
1162
1163    #[test]
1164    fn from_toml_unknown_llm_type_error() {
1165        let toml = r#"
1166[providers.llm.bad]
1167type = "gpt-magic"
1168"#;
1169        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1170        assert!(err.to_string().contains("unknown type"));
1171    }
1172
1173    #[test]
1174    fn from_toml_unknown_reranker_type_error() {
1175        let toml = r#"
1176[providers.reranker.bad]
1177type = "magic-reranker"
1178"#;
1179        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1180        assert!(err.to_string().contains("unknown type"));
1181    }
1182
1183    #[test]
1184    fn from_toml_unknown_tokenizer_type_error() {
1185        let toml = r#"
1186[providers.tokenizer.bad]
1187type = "magic-tokenizer"
1188"#;
1189        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1190        assert!(err.to_string().contains("unknown tokenizer type"));
1191    }
1192
1193    #[test]
1194    fn from_toml_invalid_toml_syntax_error() {
1195        let toml = "this is not [valid toml";
1196        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1197        assert!(
1198            err.to_string().contains("invalid provider TOML"),
1199            "error: {}",
1200            err,
1201        );
1202    }
1203
1204    #[test]
1205    fn from_toml_env_var_literal_key() {
1206        // Test that literal API keys work in config (no env var needed).
1207        let toml = r#"
1208[providers.embedder.pseudo_env]
1209type = "pseudo"
1210dimensions = 128
1211"#;
1212        let reg = ProviderRegistry::from_toml(toml).unwrap();
1213        assert!(reg.embedder_by_name("pseudo_env").is_some());
1214    }
1215
1216    #[test]
1217    fn missing_env_var_error() {
1218        // Use a var name that is very unlikely to be set.
1219        let source = ApiKeySource::Env {
1220            env: "HIRN_NONEXISTENT_VAR_42_TEST".into(),
1221        };
1222        let err = source.resolve().unwrap_err();
1223        assert!(
1224            err.to_string().contains("HIRN_NONEXISTENT_VAR_42_TEST"),
1225            "error should name the variable: {err}"
1226        );
1227    }
1228
1229    #[test]
1230    fn api_key_source_literal_resolves() {
1231        let source = ApiKeySource::Literal("my-key".into());
1232        assert_eq!(source.resolve().unwrap(), "my-key");
1233    }
1234
1235    #[test]
1236    fn api_key_source_env_resolves() {
1237        // Use HOME which is always set on macOS/Linux.
1238        let source = ApiKeySource::Env { env: "HOME".into() };
1239        let resolved = source.resolve().unwrap();
1240        assert!(
1241            !resolved.is_empty(),
1242            "HOME should resolve to a non-empty string"
1243        );
1244    }
1245
1246    #[test]
1247    fn api_key_source_deserialize_literal() {
1248        #[derive(serde::Deserialize)]
1249        struct W {
1250            key: ApiKeySource,
1251        }
1252        let w: W = toml::from_str(r#"key = "my-literal-key""#).unwrap();
1253        assert_eq!(w.key, ApiKeySource::Literal("my-literal-key".into()));
1254    }
1255
1256    #[test]
1257    fn api_key_source_deserialize_env() {
1258        #[derive(serde::Deserialize)]
1259        struct W {
1260            key: ApiKeySource,
1261        }
1262        let w: W = toml::from_str(r#"key = { env = "MY_VAR" }"#).unwrap();
1263        assert_eq!(
1264            w.key,
1265            ApiKeySource::Env {
1266                env: "MY_VAR".into()
1267            }
1268        );
1269    }
1270
1271    #[test]
1272    fn from_toml_default_references_unregistered_provider_error() {
1273        let toml = r#"
1274[defaults]
1275embedder = "nonexistent"
1276"#;
1277        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1278        assert!(err.to_string().contains("not registered"), "error: {}", err);
1279    }
1280
1281    #[cfg(feature = "tiktoken")]
1282    #[test]
1283    fn from_toml_tiktoken_tokenizer() {
1284        let toml = r#"
1285[providers.tokenizer.tiktoken]
1286type = "tiktoken"
1287model = "cl100k_base"
1288
1289[defaults]
1290tokenizer = "tiktoken"
1291"#;
1292        let reg = ProviderRegistry::from_toml(toml).unwrap();
1293        let tok = reg.tokenizer().unwrap();
1294        assert!(tok.count_tokens("hello world") > 0);
1295    }
1296
1297    #[cfg(feature = "tiktoken")]
1298    #[test]
1299    fn from_toml_tiktoken_invalid_model_error() {
1300        let toml = r#"
1301[providers.tokenizer.bad]
1302type = "tiktoken"
1303model = "gpt-99-turbo"
1304"#;
1305        let err = ProviderRegistry::from_toml(toml).unwrap_err();
1306        assert!(err.to_string().contains("unknown tiktoken model"));
1307    }
1308
1309    #[test]
1310    fn from_toml_empty_config_uses_fallbacks() {
1311        let reg = ProviderRegistry::from_toml("").unwrap();
1312        // Fallbacks should be registered.
1313        assert!(reg.embedder().is_some());
1314        assert!(reg.tokenizer().is_some());
1315        assert!(reg.reranker().is_some());
1316        assert!(reg.llm().is_some());
1317    }
1318
1319    #[test]
1320    fn from_config_and_from_env_combined() {
1321        // from_env creates a registry with fallbacks.
1322        let reg = ProviderRegistry::from_env();
1323        assert!(reg.embedder().is_some());
1324
1325        // Apply config on top — add a custom pseudo embedder.
1326        let config = ProviderConfig {
1327            providers: ProvidersSection {
1328                embedder: {
1329                    let mut m = HashMap::new();
1330                    m.insert(
1331                        "custom".into(),
1332                        EmbedderConfig {
1333                            r#type: "pseudo".into(),
1334                            model: None,
1335                            dimensions: Some(999),
1336                            api_key: None,
1337                            base_url: None,
1338                        },
1339                    );
1340                    m
1341                },
1342                ..Default::default()
1343            },
1344            defaults: DefaultsConfig {
1345                embedder: Some("custom".into()),
1346                ..Default::default()
1347            },
1348        };
1349        reg.apply_config(&config).unwrap();
1350        assert_eq!(reg.embedder().unwrap().dimensions(), 999);
1351    }
1352
1353    #[test]
1354    fn from_toml_multiple_embedders() {
1355        let toml = r#"
1356[providers.embedder.small]
1357type = "pseudo"
1358dimensions = 128
1359
1360[providers.embedder.large]
1361type = "pseudo"
1362dimensions = 2048
1363
1364[defaults]
1365embedder = "large"
1366"#;
1367        let reg = ProviderRegistry::from_toml(toml).unwrap();
1368        assert_eq!(reg.embedder().unwrap().dimensions(), 2048);
1369        assert_eq!(reg.embedder_by_name("small").unwrap().dimensions(), 128);
1370    }
1371
1372    #[test]
1373    fn provider_config_deserialize_full_example() {
1374        let toml = r#"
1375[providers.embedder.openai]
1376type = "openai"
1377model = "text-embedding-3-small"
1378api_key = { env = "OPENAI_API_KEY" }
1379dimensions = 1536
1380
1381[providers.embedder.local]
1382type = "pseudo"
1383dimensions = 384
1384
1385[providers.llm.claude]
1386type = "anthropic"
1387model = "claude-sonnet-4-20250514"
1388api_key = { env = "ANTHROPIC_API_KEY" }
1389
1390[providers.llm.fallback]
1391type = "mock"
1392
1393[providers.reranker.noop]
1394type = "noop"
1395
1396[providers.tokenizer.default]
1397type = "estimating"
1398
1399[providers.tokenizer.tiktoken]
1400type = "tiktoken"
1401model = "cl100k_base"
1402
1403[defaults]
1404embedder = "local"
1405llm = "fallback"
1406reranker = "noop"
1407tokenizer = "default"
1408"#;
1409        // Parse only — don't resolve env vars (they may not be set).
1410        let config: ProviderConfig = toml::from_str(toml).unwrap();
1411        assert_eq!(config.providers.embedder.len(), 2);
1412        assert_eq!(config.providers.llm.len(), 2);
1413        assert_eq!(config.providers.reranker.len(), 1);
1414        assert_eq!(config.providers.tokenizer.len(), 2);
1415        assert_eq!(config.defaults.embedder.as_deref(), Some("local"));
1416        assert_eq!(config.defaults.llm.as_deref(), Some("fallback"));
1417    }
1418}