1use std::collections::HashMap;
13use std::sync::Arc;
14
15use parking_lot::RwLock;
16
17use hirn_core::embed::{Embedder, LlmProvider, Reranker};
18use hirn_core::tokenizer::{EstimatingTokenizer, Tokenizer};
19use hirn_core::{HirnError, HirnResult};
20
21#[derive(Debug, Clone, Default)]
23pub struct ProviderDefaults {
24 pub embedder: Option<String>,
25 pub tokenizer: Option<String>,
26 pub reranker: Option<String>,
27 pub llm: Option<String>,
28}
29
30#[derive(Debug, Clone, serde::Deserialize, PartialEq)]
41#[serde(untagged)]
42pub enum ApiKeySource {
43 Env {
45 env: String,
47 },
48 Literal(String),
50}
51
52impl ApiKeySource {
53 pub fn resolve(&self) -> HirnResult<String> {
58 match self {
59 Self::Literal(key) => Ok(key.clone()),
60 Self::Env { env } => std::env::var(env).map_err(|_| {
61 HirnError::config(format!(
62 "environment variable '{env}' not set (required by provider config)"
63 ))
64 }),
65 }
66 }
67}
68
69#[derive(Debug, Clone, serde::Deserialize)]
71pub struct EmbedderConfig {
72 pub r#type: String,
74 pub model: Option<String>,
76 pub dimensions: Option<usize>,
78 pub api_key: Option<ApiKeySource>,
80 pub base_url: Option<String>,
82}
83
84#[derive(Debug, Clone, serde::Deserialize)]
86pub struct LlmConfig {
87 pub r#type: String,
89 pub model: Option<String>,
91 pub api_key: Option<ApiKeySource>,
93 pub base_url: Option<String>,
95}
96
97#[derive(Debug, Clone, serde::Deserialize)]
99pub struct RerankerConfig {
100 pub r#type: String,
102 pub model: Option<String>,
104 pub api_key: Option<ApiKeySource>,
106 pub base_url: Option<String>,
108}
109
110#[derive(Debug, Clone, serde::Deserialize)]
112pub struct TokenizerConfig {
113 pub r#type: String,
115 pub model: Option<String>,
117 pub max_length: Option<usize>,
119}
120
121#[derive(Debug, Clone, Default, serde::Deserialize)]
123pub struct DefaultsConfig {
124 pub embedder: Option<String>,
125 pub tokenizer: Option<String>,
126 pub reranker: Option<String>,
127 pub llm: Option<String>,
128}
129
130#[derive(Debug, Clone, Default, serde::Deserialize)]
161pub struct ProviderConfig {
162 #[serde(default)]
164 pub providers: ProvidersSection,
165 #[serde(default)]
167 pub defaults: DefaultsConfig,
168}
169
170#[derive(Debug, Clone, Default, serde::Deserialize)]
172pub struct ProvidersSection {
173 #[serde(default)]
174 pub embedder: HashMap<String, EmbedderConfig>,
175 #[serde(default)]
176 pub llm: HashMap<String, LlmConfig>,
177 #[serde(default)]
178 pub reranker: HashMap<String, RerankerConfig>,
179 #[serde(default)]
180 pub tokenizer: HashMap<String, TokenizerConfig>,
181}
182
183pub struct ProviderRegistry {
201 embedders: RwLock<HashMap<String, Arc<dyn Embedder>>>,
202 tokenizers: RwLock<HashMap<String, Arc<dyn Tokenizer>>>,
203 rerankers: RwLock<HashMap<String, Arc<dyn Reranker>>>,
204 llms: RwLock<HashMap<String, Arc<dyn LlmProvider>>>,
205 defaults: RwLock<ProviderDefaults>,
206}
207
208impl std::fmt::Debug for ProviderRegistry {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 let defaults = self.defaults.read();
211 f.debug_struct("ProviderRegistry")
212 .field(
213 "embedders",
214 &self.embedders.read().keys().collect::<Vec<_>>(),
215 )
216 .field(
217 "tokenizers",
218 &self.tokenizers.read().keys().collect::<Vec<_>>(),
219 )
220 .field(
221 "rerankers",
222 &self.rerankers.read().keys().collect::<Vec<_>>(),
223 )
224 .field("llms", &self.llms.read().keys().collect::<Vec<_>>())
225 .field("defaults", &*defaults)
226 .finish()
227 }
228}
229
230impl ProviderRegistry {
231 pub fn new() -> Self {
233 Self {
234 embedders: RwLock::new(HashMap::new()),
235 tokenizers: RwLock::new(HashMap::new()),
236 rerankers: RwLock::new(HashMap::new()),
237 llms: RwLock::new(HashMap::new()),
238 defaults: RwLock::new(ProviderDefaults::default()),
239 }
240 }
241
242 fn with_fallbacks() -> Self {
243 let reg = Self::new();
244
245 reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(384)));
246 reg.register_tokenizer("estimating", Arc::new(EstimatingTokenizer));
247 reg.register_reranker("noop", Arc::new(hirn_core::embed::NoopReranker));
248 reg.register_llm(
249 "mock",
250 Arc::new(hirn_provider::MockLlmProvider::new("mock")),
251 );
252
253 let _ = reg.set_default_embedder("pseudo");
254 let _ = reg.set_default_tokenizer("estimating");
255 let _ = reg.set_default_reranker("noop");
256 let _ = reg.set_default_llm("mock");
257
258 #[cfg(feature = "tiktoken")]
259 if let Ok(tokenizer) = hirn_provider::build_tokenizer("tiktoken", Some("cl100k_base"), None)
260 {
261 reg.register_tokenizer("tiktoken", tokenizer);
262 let _ = reg.set_default_tokenizer("tiktoken");
263 }
264
265 reg
266 }
267
268 #[allow(dead_code)]
269 fn default_embedder_is_unset_or_fallback(&self) -> bool {
270 self.defaults
271 .read()
272 .embedder
273 .as_deref()
274 .is_none_or(|name| name == "pseudo")
275 }
276
277 #[allow(dead_code)]
278 fn default_reranker_is_unset_or_fallback(&self) -> bool {
279 self.defaults
280 .read()
281 .reranker
282 .as_deref()
283 .is_none_or(|name| name == "noop")
284 }
285
286 #[allow(dead_code)]
287 fn default_llm_is_unset_or_fallback(&self) -> bool {
288 self.defaults
289 .read()
290 .llm
291 .as_deref()
292 .is_none_or(|name| name == "mock")
293 }
294
295 #[allow(unused_variables)]
296 fn populate_from_env(reg: &Self) {
297 #[cfg(feature = "openai")]
299 if let Ok(key) = std::env::var("OPENAI_API_KEY") {
300 Self::register_openai_from_key(
301 reg,
302 key,
303 |api_key| {
304 hirn_provider::OpenAIEmbedder::new(api_key, "text-embedding-3-small", 1536)
305 .map(|embedder| Arc::new(embedder) as Arc<dyn Embedder>)
306 },
307 |api_key| {
308 hirn_provider::OpenAILlmProvider::new(api_key, "gpt-4o-mini")
309 .map(|provider| Arc::new(provider) as Arc<dyn LlmProvider>)
310 },
311 );
312 }
313
314 #[cfg(feature = "ollama")]
315 {
316 let host = std::env::var("OLLAMA_HOST")
317 .unwrap_or_else(|_| "http://localhost:11434".to_owned());
318 if std::env::var("OLLAMA_HOST").is_ok() {
319 match hirn_provider::OllamaEmbedder::new("nomic-embed-text", 768) {
320 Ok(embedder) => match embedder.with_host(&host) {
321 Ok(embedder) => {
322 reg.register_embedder("ollama", Arc::new(embedder));
323 if reg.defaults.read().embedder.as_deref() != Some("openai") {
324 let _ = reg.set_default_embedder("ollama");
325 }
326 }
327 Err(err) => {
328 tracing::warn!(error = %err, provider = "ollama", "failed to validate optional ollama embedder host from environment");
329 }
330 },
331 Err(err) => {
332 tracing::warn!(error = %err, provider = "ollama", "failed to initialize optional ollama embedder from environment");
333 }
334 }
335
336 match hirn_provider::OllamaLlmProvider::new("llama3.1") {
337 Ok(provider) => match provider.with_host(&host) {
338 Ok(provider) => {
339 reg.register_llm("ollama", Arc::new(provider));
340 if reg.defaults.read().llm.as_deref() != Some("openai") {
341 let _ = reg.set_default_llm("ollama");
342 }
343 }
344 Err(err) => {
345 tracing::warn!(error = %err, provider = "ollama", "failed to validate optional ollama llm host from environment");
346 }
347 },
348 Err(err) => {
349 tracing::warn!(error = %err, provider = "ollama", "failed to initialize optional ollama llm from environment");
350 }
351 }
352 }
353 }
354
355 #[cfg(feature = "cohere")]
356 match hirn_provider::CohereReranker::from_env() {
357 Ok(Some(cohere_reranker)) => {
358 reg.register_reranker("cohere", Arc::new(cohere_reranker));
359 let _ = reg.set_default_reranker("cohere");
360 }
361 Ok(None) => {}
362 Err(err) => {
363 tracing::warn!(error = %err, provider = "cohere", "failed to initialize optional cohere reranker from environment");
364 }
365 }
366
367 #[cfg(feature = "cohere")]
368 match hirn_provider::CohereEmbedder::from_env() {
369 Ok(Some(cohere_embedder)) => {
370 reg.register_embedder("cohere", Arc::new(cohere_embedder));
371 if reg.default_embedder_is_unset_or_fallback() {
372 let _ = reg.set_default_embedder("cohere");
373 }
374 }
375 Ok(None) => {}
376 Err(err) => {
377 tracing::warn!(error = %err, provider = "cohere", "failed to initialize optional cohere embedder from environment");
378 }
379 }
380
381 #[cfg(feature = "voyage")]
382 match hirn_provider::VoyageEmbedder::from_env() {
383 Ok(Some(voyage_embedder)) => {
384 reg.register_embedder("voyage", Arc::new(voyage_embedder));
385 if reg.default_embedder_is_unset_or_fallback() {
386 let _ = reg.set_default_embedder("voyage");
387 }
388 }
389 Ok(None) => {}
390 Err(err) => {
391 tracing::warn!(error = %err, provider = "voyage", "failed to initialize optional voyage embedder from environment");
392 }
393 }
394
395 #[cfg(feature = "cross-encoder")]
396 if let Ok(cross_encoder) = hirn_provider::CrossEncoderReranker::default_model() {
397 reg.register_reranker("cross-encoder", Arc::new(cross_encoder));
398 if reg.default_reranker_is_unset_or_fallback() {
399 let _ = reg.set_default_reranker("cross-encoder");
400 }
401 }
402
403 #[cfg(feature = "anthropic")]
404 if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
405 match hirn_provider::AnthropicProvider::new(key) {
406 Ok(provider) => {
407 reg.register_llm("anthropic", Arc::new(provider));
408 if reg.default_llm_is_unset_or_fallback() {
409 let _ = reg.set_default_llm("anthropic");
410 }
411 }
412 Err(err) => {
413 tracing::warn!(error = %err, provider = "anthropic", "failed to initialize optional anthropic llm from environment");
414 }
415 }
416 }
417
418 #[cfg(feature = "hf-tokenizer")]
419 if let Ok(model_id) = std::env::var("HF_TOKENIZER_MODEL") {
420 if let Ok(hf_tok) = hirn_provider::HuggingFaceTokenizer::from_pretrained(&model_id) {
421 reg.register_tokenizer("huggingface", Arc::new(hf_tok));
422 let _ = reg.set_default_tokenizer("huggingface");
423 }
424 }
425 }
426
427 #[cfg(feature = "openai")]
428 fn register_openai_from_key<FEmbed, FLlm>(
429 reg: &Self,
430 key: String,
431 make_embedder: FEmbed,
432 make_llm: FLlm,
433 ) where
434 FEmbed: FnOnce(String) -> HirnResult<Arc<dyn Embedder>>,
435 FLlm: FnOnce(String) -> HirnResult<Arc<dyn LlmProvider>>,
436 {
437 match make_embedder(key.clone()) {
438 Ok(embedder) => {
439 reg.register_embedder("openai", embedder);
440 let _ = reg.set_default_embedder("openai");
441 }
442 Err(err) => {
443 tracing::warn!(error = %err, provider = "openai", "failed to initialize optional openai embedder from environment");
444 }
445 }
446
447 match make_llm(key) {
448 Ok(provider) => {
449 reg.register_llm("openai", provider);
450 let _ = reg.set_default_llm("openai");
451 }
452 Err(err) => {
453 tracing::warn!(error = %err, provider = "openai", "failed to initialize optional openai llm from environment");
454 }
455 }
456 }
457
458 pub fn register_embedder(&self, name: &str, embedder: Arc<dyn Embedder>) {
462 self.embedders.write().insert(name.to_owned(), embedder);
463 }
464
465 pub fn set_default_embedder(&self, name: &str) -> HirnResult<()> {
467 if !self.embedders.read().contains_key(name) {
468 return Err(HirnError::config(format!(
469 "embedder '{name}' not registered"
470 )));
471 }
472 self.defaults.write().embedder = Some(name.to_owned());
473 Ok(())
474 }
475
476 pub fn embedder(&self) -> Option<Arc<dyn Embedder>> {
478 let defaults = self.defaults.read();
479 let name = defaults.embedder.as_deref()?;
480 self.embedders.read().get(name).cloned()
481 }
482
483 pub fn embedder_by_name(&self, name: &str) -> Option<Arc<dyn Embedder>> {
485 self.embedders.read().get(name).cloned()
486 }
487
488 pub fn register_tokenizer(&self, name: &str, tokenizer: Arc<dyn Tokenizer>) {
492 self.tokenizers.write().insert(name.to_owned(), tokenizer);
493 }
494
495 pub fn set_default_tokenizer(&self, name: &str) -> HirnResult<()> {
497 if !self.tokenizers.read().contains_key(name) {
498 return Err(HirnError::config(format!(
499 "tokenizer '{name}' not registered"
500 )));
501 }
502 self.defaults.write().tokenizer = Some(name.to_owned());
503 Ok(())
504 }
505
506 pub fn tokenizer(&self) -> Option<Arc<dyn Tokenizer>> {
508 let defaults = self.defaults.read();
509 let name = defaults.tokenizer.as_deref()?;
510 self.tokenizers.read().get(name).cloned()
511 }
512
513 pub fn tokenizer_by_name(&self, name: &str) -> Option<Arc<dyn Tokenizer>> {
515 self.tokenizers.read().get(name).cloned()
516 }
517
518 pub fn register_reranker(&self, name: &str, reranker: Arc<dyn Reranker>) {
522 self.rerankers.write().insert(name.to_owned(), reranker);
523 }
524
525 pub fn set_default_reranker(&self, name: &str) -> HirnResult<()> {
527 if !self.rerankers.read().contains_key(name) {
528 return Err(HirnError::config(format!(
529 "reranker '{name}' not registered"
530 )));
531 }
532 self.defaults.write().reranker = Some(name.to_owned());
533 Ok(())
534 }
535
536 pub fn reranker(&self) -> Option<Arc<dyn Reranker>> {
538 let defaults = self.defaults.read();
539 let name = defaults.reranker.as_deref()?;
540 self.rerankers.read().get(name).cloned()
541 }
542
543 pub fn reranker_by_name(&self, name: &str) -> Option<Arc<dyn Reranker>> {
545 self.rerankers.read().get(name).cloned()
546 }
547
548 pub fn register_llm(&self, name: &str, llm: Arc<dyn LlmProvider>) {
552 self.llms.write().insert(name.to_owned(), llm);
553 }
554
555 pub fn set_default_llm(&self, name: &str) -> HirnResult<()> {
557 if !self.llms.read().contains_key(name) {
558 return Err(HirnError::config(format!(
559 "llm provider '{name}' not registered"
560 )));
561 }
562 self.defaults.write().llm = Some(name.to_owned());
563 Ok(())
564 }
565
566 pub fn llm(&self) -> Option<Arc<dyn LlmProvider>> {
568 let defaults = self.defaults.read();
569 let name = defaults.llm.as_deref()?;
570 self.llms.read().get(name).cloned()
571 }
572
573 pub fn llm_by_name(&self, name: &str) -> Option<Arc<dyn LlmProvider>> {
575 self.llms.read().get(name).cloned()
576 }
577
578 pub fn from_env() -> Self {
590 let reg = Self::with_fallbacks();
591 Self::populate_from_env(®);
592
593 reg
594 }
595
596 pub fn from_env_strict() -> Self {
599 let reg = Self::new();
600 Self::populate_from_env(®);
601
602 reg
603 }
604
605 pub fn from_toml(toml_str: &str) -> HirnResult<Self> {
617 let config: ProviderConfig = toml::from_str(toml_str)
618 .map_err(|e| HirnError::config(format!("invalid provider TOML: {e}")))?;
619 Self::from_config(&config)
620 }
621
622 pub fn from_config(config: &ProviderConfig) -> HirnResult<Self> {
637 let reg = Self::with_fallbacks();
638
639 for (name, cfg) in &config.providers.embedder {
641 let embedder: Arc<dyn Embedder> = Self::build_embedder(name, cfg)?;
642 reg.register_embedder(name, embedder);
643 }
644
645 for (name, cfg) in &config.providers.llm {
647 let llm: Arc<dyn LlmProvider> = Self::build_llm(name, cfg)?;
648 reg.register_llm(name, llm);
649 }
650
651 for (name, cfg) in &config.providers.reranker {
653 let reranker: Arc<dyn Reranker> = Self::build_reranker(name, cfg)?;
654 reg.register_reranker(name, reranker);
655 }
656
657 for (name, cfg) in &config.providers.tokenizer {
659 let tokenizer: Arc<dyn Tokenizer> = Self::build_tokenizer(name, cfg)?;
660 reg.register_tokenizer(name, tokenizer);
661 }
662
663 if let Some(ref name) = config.defaults.embedder {
665 reg.set_default_embedder(name)?;
666 }
667 if let Some(ref name) = config.defaults.tokenizer {
668 reg.set_default_tokenizer(name)?;
669 }
670 if let Some(ref name) = config.defaults.reranker {
671 reg.set_default_reranker(name)?;
672 }
673 if let Some(ref name) = config.defaults.llm {
674 reg.set_default_llm(name)?;
675 }
676
677 Ok(reg)
678 }
679
680 pub fn apply_config(&self, config: &ProviderConfig) -> HirnResult<()> {
685 for (name, cfg) in &config.providers.embedder {
686 self.register_embedder(name, Self::build_embedder(name, cfg)?);
687 }
688 for (name, cfg) in &config.providers.llm {
689 self.register_llm(name, Self::build_llm(name, cfg)?);
690 }
691 for (name, cfg) in &config.providers.reranker {
692 self.register_reranker(name, Self::build_reranker(name, cfg)?);
693 }
694 for (name, cfg) in &config.providers.tokenizer {
695 self.register_tokenizer(name, Self::build_tokenizer(name, cfg)?);
696 }
697 if let Some(ref name) = config.defaults.embedder {
698 self.set_default_embedder(name)?;
699 }
700 if let Some(ref name) = config.defaults.tokenizer {
701 self.set_default_tokenizer(name)?;
702 }
703 if let Some(ref name) = config.defaults.reranker {
704 self.set_default_reranker(name)?;
705 }
706 if let Some(ref name) = config.defaults.llm {
707 self.set_default_llm(name)?;
708 }
709 Ok(())
710 }
711
712 #[cfg(feature = "openai")]
715 fn build_openai_embedder_with<F>(
716 name: &str,
717 cfg: &EmbedderConfig,
718 constructor: F,
719 ) -> HirnResult<Arc<dyn Embedder>>
720 where
721 F: FnOnce(String, &str, usize) -> HirnResult<hirn_provider::OpenAIEmbedder>,
722 {
723 let api_key = cfg
724 .api_key
725 .as_ref()
726 .ok_or_else(|| {
727 HirnError::config(format!("embedder '{name}': 'api_key' required for openai"))
728 })?
729 .resolve()?;
730 let model = cfg.model.as_deref().unwrap_or("text-embedding-3-small");
731 let dims = cfg.dimensions.unwrap_or(1536);
732 let mut embedder = constructor(api_key, model, dims).map_err(|err| {
733 HirnError::config(format!(
734 "embedder '{name}': failed to initialize openai client: {err}"
735 ))
736 })?;
737 if let Some(ref url) = cfg.base_url {
738 embedder = embedder.with_base_url(url).map_err(|err| {
739 HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
740 })?;
741 }
742 Ok(Arc::new(embedder))
743 }
744
745 fn build_embedder(name: &str, cfg: &EmbedderConfig) -> HirnResult<Arc<dyn Embedder>> {
746 match cfg.r#type.as_str() {
747 "pseudo" => {
748 let dims = cfg.dimensions.unwrap_or(384);
749 Ok(Arc::new(hirn_provider::PseudoEmbedder::new(dims)))
750 }
751 #[cfg(feature = "openai")]
752 "openai" => Self::build_openai_embedder_with(name, cfg, |api_key, model, dims| {
753 hirn_provider::OpenAIEmbedder::new(api_key, model, dims)
754 }),
755 #[cfg(feature = "ollama")]
756 "ollama" => {
757 let model = cfg.model.as_deref().unwrap_or("nomic-embed-text");
758 let dims = cfg.dimensions.unwrap_or(768);
759 let mut embedder =
760 hirn_provider::OllamaEmbedder::new(model, dims).map_err(|err| {
761 HirnError::config(format!(
762 "embedder '{name}': failed to initialize ollama client: {err}"
763 ))
764 })?;
765 if let Some(ref url) = cfg.base_url {
766 embedder = embedder.with_host(url).map_err(|err| {
767 HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
768 })?;
769 }
770 Ok(Arc::new(embedder))
771 }
772 #[cfg(feature = "cohere")]
773 "cohere" => {
774 let api_key = cfg
775 .api_key
776 .as_ref()
777 .ok_or_else(|| {
778 HirnError::config(format!(
779 "embedder '{name}': 'api_key' required for cohere"
780 ))
781 })?
782 .resolve()?;
783 let model = cfg.model.as_deref().unwrap_or("embed-english-v3.0");
784 let dims = cfg.dimensions.unwrap_or(1024);
785 let mut embedder = hirn_provider::CohereEmbedder::new(api_key, model, dims)
786 .map_err(|err| {
787 HirnError::config(format!(
788 "embedder '{name}': failed to initialize cohere client: {err}"
789 ))
790 })?;
791 if let Some(ref url) = cfg.base_url {
792 embedder = embedder.with_base_url(url).map_err(|err| {
793 HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
794 })?;
795 }
796 Ok(Arc::new(embedder))
797 }
798 #[cfg(feature = "voyage")]
799 "voyage" => {
800 let api_key = cfg
801 .api_key
802 .as_ref()
803 .ok_or_else(|| {
804 HirnError::config(format!(
805 "embedder '{name}': 'api_key' required for voyage"
806 ))
807 })?
808 .resolve()?;
809 let model = cfg.model.as_deref().unwrap_or("voyage-3");
810 let dims = cfg.dimensions.unwrap_or(1024);
811 let mut embedder = hirn_provider::VoyageEmbedder::new(api_key, model, dims)
812 .map_err(|err| {
813 HirnError::config(format!(
814 "embedder '{name}': failed to initialize voyage client: {err}"
815 ))
816 })?;
817 if let Some(ref url) = cfg.base_url {
818 embedder = embedder.with_base_url(url).map_err(|err| {
819 HirnError::config(format!("embedder '{name}': invalid base_url: {err}"))
820 })?;
821 }
822 Ok(Arc::new(embedder))
823 }
824 other => Err(HirnError::config(format!(
825 "embedder '{name}': unknown type '{other}'"
826 ))),
827 }
828 }
829
830 fn build_llm(name: &str, cfg: &LlmConfig) -> HirnResult<Arc<dyn LlmProvider>> {
831 match cfg.r#type.as_str() {
832 "mock" => Ok(Arc::new(hirn_provider::MockLlmProvider::new(name))),
833 #[cfg(feature = "openai")]
834 "openai" => {
835 let api_key = cfg
836 .api_key
837 .as_ref()
838 .ok_or_else(|| {
839 HirnError::config(format!("llm '{name}': 'api_key' required for openai"))
840 })?
841 .resolve()?;
842 let model = cfg.model.as_deref().unwrap_or("gpt-4o-mini");
843 let mut provider =
844 hirn_provider::OpenAILlmProvider::new(api_key, model).map_err(|err| {
845 HirnError::config(format!(
846 "llm '{name}': failed to initialize openai client: {err}"
847 ))
848 })?;
849 if let Some(ref url) = cfg.base_url {
850 provider = provider.with_base_url(url).map_err(|err| {
851 HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
852 })?;
853 }
854 Ok(Arc::new(provider))
855 }
856 #[cfg(feature = "ollama")]
857 "ollama" => {
858 let model = cfg.model.as_deref().unwrap_or("llama3.1");
859 let mut provider = hirn_provider::OllamaLlmProvider::new(model).map_err(|err| {
860 HirnError::config(format!(
861 "llm '{name}': failed to initialize ollama client: {err}"
862 ))
863 })?;
864 if let Some(ref url) = cfg.base_url {
865 provider = provider.with_host(url).map_err(|err| {
866 HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
867 })?;
868 }
869 Ok(Arc::new(provider))
870 }
871 #[cfg(feature = "anthropic")]
872 "anthropic" => {
873 let api_key = cfg
874 .api_key
875 .as_ref()
876 .ok_or_else(|| {
877 HirnError::config(format!("llm '{name}': 'api_key' required for anthropic"))
878 })?
879 .resolve()?;
880 let mut provider =
881 hirn_provider::AnthropicProvider::new(api_key).map_err(|err| {
882 HirnError::config(format!(
883 "llm '{name}': failed to initialize anthropic client: {err}"
884 ))
885 })?;
886 if let Some(ref model) = cfg.model {
887 provider = provider.with_model(model);
888 }
889 if let Some(ref url) = cfg.base_url {
890 provider = provider.with_base_url(url).map_err(|err| {
891 HirnError::config(format!("llm '{name}': invalid base_url: {err}"))
892 })?;
893 }
894 Ok(Arc::new(provider))
895 }
896 other => Err(HirnError::config(format!(
897 "llm '{name}': unknown type '{other}'"
898 ))),
899 }
900 }
901
902 fn build_reranker(name: &str, cfg: &RerankerConfig) -> HirnResult<Arc<dyn Reranker>> {
903 match cfg.r#type.as_str() {
904 "noop" => Ok(Arc::new(hirn_core::embed::NoopReranker)),
905 #[cfg(feature = "cohere")]
906 "cohere" => {
907 let api_key = cfg
908 .api_key
909 .as_ref()
910 .ok_or_else(|| {
911 HirnError::config(format!(
912 "reranker '{name}': 'api_key' required for cohere"
913 ))
914 })?
915 .resolve()?;
916 let mut reranker = hirn_provider::CohereReranker::new(api_key).map_err(|err| {
917 HirnError::config(format!(
918 "reranker '{name}': failed to initialize cohere client: {err}"
919 ))
920 })?;
921 if let Some(ref model) = cfg.model {
922 reranker = reranker.with_model(model);
923 }
924 if let Some(ref url) = cfg.base_url {
925 reranker = reranker.with_base_url(url).map_err(|err| {
926 HirnError::config(format!("reranker '{name}': invalid base_url: {err}"))
927 })?;
928 }
929 Ok(Arc::new(reranker))
930 }
931 #[cfg(feature = "cross-encoder")]
932 "cross-encoder" => {
933 let reranker =
934 hirn_provider::CrossEncoderReranker::default_model().map_err(|e| {
935 HirnError::config(format!(
936 "reranker '{name}': failed to load cross-encoder: {e}"
937 ))
938 })?;
939 Ok(Arc::new(reranker))
940 }
941 other => Err(HirnError::config(format!(
942 "reranker '{name}': unknown type '{other}'"
943 ))),
944 }
945 }
946
947 fn build_tokenizer(name: &str, cfg: &TokenizerConfig) -> HirnResult<Arc<dyn Tokenizer>> {
948 hirn_provider::build_tokenizer(&cfg.r#type, cfg.model.as_deref(), cfg.max_length)
949 .map_err(|e| HirnError::config(format!("tokenizer '{name}': {e}")))
950 }
951}
952
953impl Default for ProviderRegistry {
954 fn default() -> Self {
955 Self::new()
956 }
957}
958
959#[cfg(test)]
963mod tests {
964 use super::*;
965
966 #[test]
967 fn register_and_lookup_embedder() {
968 let reg = ProviderRegistry::new();
969 reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(64)));
970 assert!(reg.embedder_by_name("pseudo").is_some());
971 assert!(reg.embedder_by_name("unknown").is_none());
972 }
973
974 #[test]
975 fn default_embedder_requires_registration() {
976 let reg = ProviderRegistry::new();
977 assert!(reg.set_default_embedder("missing").is_err());
978 }
979
980 #[test]
981 fn default_embedder_lookup() {
982 let reg = ProviderRegistry::new();
983 reg.register_embedder("pseudo", Arc::new(hirn_provider::PseudoEmbedder::new(64)));
984 reg.set_default_embedder("pseudo").unwrap();
985 assert!(reg.embedder().is_some());
986 }
987
988 #[test]
989 fn no_default_embedder_returns_none() {
990 let reg = ProviderRegistry::new();
991 assert!(reg.embedder().is_none());
992 }
993
994 #[test]
995 fn register_and_lookup_llm() {
996 let reg = ProviderRegistry::new();
997 reg.register_llm(
998 "mock",
999 Arc::new(hirn_provider::MockLlmProvider::new("test")),
1000 );
1001 assert!(reg.llm_by_name("mock").is_some());
1002 }
1003
1004 #[test]
1005 fn hot_swap_embedder() {
1006 let reg = ProviderRegistry::new();
1007 let e1 = Arc::new(hirn_provider::PseudoEmbedder::new(64));
1008 let e2 = Arc::new(hirn_provider::PseudoEmbedder::new(128));
1009 reg.register_embedder("e", e1);
1010 reg.set_default_embedder("e").unwrap();
1011 assert_eq!(reg.embedder().unwrap().dimensions(), 64);
1012 reg.register_embedder("e", e2);
1014 assert_eq!(reg.embedder().unwrap().dimensions(), 128);
1015 }
1016
1017 #[test]
1018 fn from_env_creates_fallbacks() {
1019 let reg = ProviderRegistry::from_env();
1021 assert!(reg.embedder().is_some());
1022 assert!(reg.tokenizer().is_some());
1023 assert!(reg.reranker().is_some());
1024 assert!(reg.llm().is_some());
1025 }
1026
1027 #[test]
1028 fn from_env_strict_omits_fallback_embedder_when_no_real_embedder_is_configured() {
1029 if [
1030 "OPENAI_API_KEY",
1031 "OLLAMA_HOST",
1032 "COHERE_API_KEY",
1033 "VOYAGE_API_KEY",
1034 ]
1035 .iter()
1036 .any(|key| std::env::var(key).is_ok())
1037 {
1038 return;
1039 }
1040
1041 let reg = ProviderRegistry::from_env_strict();
1042 assert!(reg.embedder().is_none());
1043 }
1044
1045 #[test]
1046 fn registry_is_send_sync() {
1047 fn assert_send_sync<T: Send + Sync>() {}
1048 assert_send_sync::<ProviderRegistry>();
1049 }
1050
1051 #[cfg(feature = "openai")]
1052 #[test]
1053 fn openai_auto_discovery_continues_when_embedder_init_fails() {
1054 let reg = ProviderRegistry::with_fallbacks();
1055
1056 ProviderRegistry::register_openai_from_key(
1057 ®,
1058 "sk-test".into(),
1059 |_api_key| Err(HirnError::provider("synthetic openai embedder failure")),
1060 |_api_key| Ok(Arc::new(hirn_provider::MockLlmProvider::new("openai"))),
1061 );
1062
1063 assert_eq!(reg.defaults.read().embedder.as_deref(), Some("pseudo"));
1064 assert_eq!(reg.embedder().unwrap().dimensions(), 384);
1065 assert!(reg.embedder_by_name("openai").is_none());
1066 assert_eq!(reg.defaults.read().llm.as_deref(), Some("openai"));
1067 assert!(reg.llm_by_name("openai").is_some());
1068 }
1069
1070 #[cfg(feature = "openai")]
1071 #[test]
1072 fn openai_config_constructor_failure_returns_structured_error() {
1073 let cfg = EmbedderConfig {
1074 r#type: "openai".into(),
1075 model: Some("text-embedding-3-small".into()),
1076 dimensions: Some(1536),
1077 api_key: Some(ApiKeySource::Literal("sk-test".into())),
1078 base_url: None,
1079 };
1080
1081 let err = ProviderRegistry::build_openai_embedder_with(
1082 "broken-openai",
1083 &cfg,
1084 |_api_key, _model, _dims| Err(HirnError::provider("synthetic constructor failure")),
1085 );
1086
1087 let err = match err {
1088 Ok(_) => panic!("expected constructor failure"),
1089 Err(err) => err,
1090 };
1091
1092 match err {
1093 HirnError::InvalidInput(message) => {
1094 assert!(message.contains("embedder 'broken-openai'"));
1095 assert!(message.contains("failed to initialize openai client"));
1096 assert!(message.contains("synthetic constructor failure"));
1097 }
1098 other => panic!("expected invalid input, got {other:?}"),
1099 }
1100 }
1101
1102 #[test]
1103 fn register_and_lookup_reranker() {
1104 let reg = ProviderRegistry::new();
1105 reg.register_reranker("noop", Arc::new(hirn_core::embed::NoopReranker));
1106 reg.set_default_reranker("noop").unwrap();
1107 assert!(reg.reranker().is_some());
1108 }
1109
1110 #[test]
1111 fn register_and_lookup_tokenizer() {
1112 let reg = ProviderRegistry::new();
1113 reg.register_tokenizer("est", Arc::new(EstimatingTokenizer));
1114 reg.set_default_tokenizer("est").unwrap();
1115 assert!(reg.tokenizer().is_some());
1116 }
1117
1118 #[test]
1121 fn from_toml_pseudo_and_estimating() {
1122 let toml = r#"
1123[providers.embedder.my_embed]
1124type = "pseudo"
1125dimensions = 256
1126
1127[providers.tokenizer.my_tok]
1128type = "estimating"
1129
1130[providers.llm.my_llm]
1131type = "mock"
1132
1133[providers.reranker.my_reranker]
1134type = "noop"
1135
1136[defaults]
1137embedder = "my_embed"
1138tokenizer = "my_tok"
1139llm = "my_llm"
1140reranker = "my_reranker"
1141"#;
1142 let reg = ProviderRegistry::from_toml(toml).unwrap();
1143 assert_eq!(reg.embedder().unwrap().dimensions(), 256);
1144 assert!(reg.tokenizer().is_some());
1145 assert!(reg.llm().is_some());
1146 assert!(reg.reranker().is_some());
1147 }
1148
1149 #[test]
1150 fn from_toml_unknown_embedder_type_error() {
1151 let toml = r#"
1152[providers.embedder.bad]
1153type = "nonexistent_provider"
1154"#;
1155 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1156 let msg = err.to_string();
1157 assert!(
1158 msg.contains("unknown type") && msg.contains("nonexistent_provider"),
1159 "should mention unknown type: {msg}"
1160 );
1161 }
1162
1163 #[test]
1164 fn from_toml_unknown_llm_type_error() {
1165 let toml = r#"
1166[providers.llm.bad]
1167type = "gpt-magic"
1168"#;
1169 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1170 assert!(err.to_string().contains("unknown type"));
1171 }
1172
1173 #[test]
1174 fn from_toml_unknown_reranker_type_error() {
1175 let toml = r#"
1176[providers.reranker.bad]
1177type = "magic-reranker"
1178"#;
1179 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1180 assert!(err.to_string().contains("unknown type"));
1181 }
1182
1183 #[test]
1184 fn from_toml_unknown_tokenizer_type_error() {
1185 let toml = r#"
1186[providers.tokenizer.bad]
1187type = "magic-tokenizer"
1188"#;
1189 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1190 assert!(err.to_string().contains("unknown tokenizer type"));
1191 }
1192
1193 #[test]
1194 fn from_toml_invalid_toml_syntax_error() {
1195 let toml = "this is not [valid toml";
1196 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1197 assert!(
1198 err.to_string().contains("invalid provider TOML"),
1199 "error: {}",
1200 err,
1201 );
1202 }
1203
1204 #[test]
1205 fn from_toml_env_var_literal_key() {
1206 let toml = r#"
1208[providers.embedder.pseudo_env]
1209type = "pseudo"
1210dimensions = 128
1211"#;
1212 let reg = ProviderRegistry::from_toml(toml).unwrap();
1213 assert!(reg.embedder_by_name("pseudo_env").is_some());
1214 }
1215
1216 #[test]
1217 fn missing_env_var_error() {
1218 let source = ApiKeySource::Env {
1220 env: "HIRN_NONEXISTENT_VAR_42_TEST".into(),
1221 };
1222 let err = source.resolve().unwrap_err();
1223 assert!(
1224 err.to_string().contains("HIRN_NONEXISTENT_VAR_42_TEST"),
1225 "error should name the variable: {err}"
1226 );
1227 }
1228
1229 #[test]
1230 fn api_key_source_literal_resolves() {
1231 let source = ApiKeySource::Literal("my-key".into());
1232 assert_eq!(source.resolve().unwrap(), "my-key");
1233 }
1234
1235 #[test]
1236 fn api_key_source_env_resolves() {
1237 let source = ApiKeySource::Env { env: "HOME".into() };
1239 let resolved = source.resolve().unwrap();
1240 assert!(
1241 !resolved.is_empty(),
1242 "HOME should resolve to a non-empty string"
1243 );
1244 }
1245
1246 #[test]
1247 fn api_key_source_deserialize_literal() {
1248 #[derive(serde::Deserialize)]
1249 struct W {
1250 key: ApiKeySource,
1251 }
1252 let w: W = toml::from_str(r#"key = "my-literal-key""#).unwrap();
1253 assert_eq!(w.key, ApiKeySource::Literal("my-literal-key".into()));
1254 }
1255
1256 #[test]
1257 fn api_key_source_deserialize_env() {
1258 #[derive(serde::Deserialize)]
1259 struct W {
1260 key: ApiKeySource,
1261 }
1262 let w: W = toml::from_str(r#"key = { env = "MY_VAR" }"#).unwrap();
1263 assert_eq!(
1264 w.key,
1265 ApiKeySource::Env {
1266 env: "MY_VAR".into()
1267 }
1268 );
1269 }
1270
1271 #[test]
1272 fn from_toml_default_references_unregistered_provider_error() {
1273 let toml = r#"
1274[defaults]
1275embedder = "nonexistent"
1276"#;
1277 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1278 assert!(err.to_string().contains("not registered"), "error: {}", err);
1279 }
1280
1281 #[cfg(feature = "tiktoken")]
1282 #[test]
1283 fn from_toml_tiktoken_tokenizer() {
1284 let toml = r#"
1285[providers.tokenizer.tiktoken]
1286type = "tiktoken"
1287model = "cl100k_base"
1288
1289[defaults]
1290tokenizer = "tiktoken"
1291"#;
1292 let reg = ProviderRegistry::from_toml(toml).unwrap();
1293 let tok = reg.tokenizer().unwrap();
1294 assert!(tok.count_tokens("hello world") > 0);
1295 }
1296
1297 #[cfg(feature = "tiktoken")]
1298 #[test]
1299 fn from_toml_tiktoken_invalid_model_error() {
1300 let toml = r#"
1301[providers.tokenizer.bad]
1302type = "tiktoken"
1303model = "gpt-99-turbo"
1304"#;
1305 let err = ProviderRegistry::from_toml(toml).unwrap_err();
1306 assert!(err.to_string().contains("unknown tiktoken model"));
1307 }
1308
1309 #[test]
1310 fn from_toml_empty_config_uses_fallbacks() {
1311 let reg = ProviderRegistry::from_toml("").unwrap();
1312 assert!(reg.embedder().is_some());
1314 assert!(reg.tokenizer().is_some());
1315 assert!(reg.reranker().is_some());
1316 assert!(reg.llm().is_some());
1317 }
1318
1319 #[test]
1320 fn from_config_and_from_env_combined() {
1321 let reg = ProviderRegistry::from_env();
1323 assert!(reg.embedder().is_some());
1324
1325 let config = ProviderConfig {
1327 providers: ProvidersSection {
1328 embedder: {
1329 let mut m = HashMap::new();
1330 m.insert(
1331 "custom".into(),
1332 EmbedderConfig {
1333 r#type: "pseudo".into(),
1334 model: None,
1335 dimensions: Some(999),
1336 api_key: None,
1337 base_url: None,
1338 },
1339 );
1340 m
1341 },
1342 ..Default::default()
1343 },
1344 defaults: DefaultsConfig {
1345 embedder: Some("custom".into()),
1346 ..Default::default()
1347 },
1348 };
1349 reg.apply_config(&config).unwrap();
1350 assert_eq!(reg.embedder().unwrap().dimensions(), 999);
1351 }
1352
1353 #[test]
1354 fn from_toml_multiple_embedders() {
1355 let toml = r#"
1356[providers.embedder.small]
1357type = "pseudo"
1358dimensions = 128
1359
1360[providers.embedder.large]
1361type = "pseudo"
1362dimensions = 2048
1363
1364[defaults]
1365embedder = "large"
1366"#;
1367 let reg = ProviderRegistry::from_toml(toml).unwrap();
1368 assert_eq!(reg.embedder().unwrap().dimensions(), 2048);
1369 assert_eq!(reg.embedder_by_name("small").unwrap().dimensions(), 128);
1370 }
1371
1372 #[test]
1373 fn provider_config_deserialize_full_example() {
1374 let toml = r#"
1375[providers.embedder.openai]
1376type = "openai"
1377model = "text-embedding-3-small"
1378api_key = { env = "OPENAI_API_KEY" }
1379dimensions = 1536
1380
1381[providers.embedder.local]
1382type = "pseudo"
1383dimensions = 384
1384
1385[providers.llm.claude]
1386type = "anthropic"
1387model = "claude-sonnet-4-20250514"
1388api_key = { env = "ANTHROPIC_API_KEY" }
1389
1390[providers.llm.fallback]
1391type = "mock"
1392
1393[providers.reranker.noop]
1394type = "noop"
1395
1396[providers.tokenizer.default]
1397type = "estimating"
1398
1399[providers.tokenizer.tiktoken]
1400type = "tiktoken"
1401model = "cl100k_base"
1402
1403[defaults]
1404embedder = "local"
1405llm = "fallback"
1406reranker = "noop"
1407tokenizer = "default"
1408"#;
1409 let config: ProviderConfig = toml::from_str(toml).unwrap();
1411 assert_eq!(config.providers.embedder.len(), 2);
1412 assert_eq!(config.providers.llm.len(), 2);
1413 assert_eq!(config.providers.reranker.len(), 1);
1414 assert_eq!(config.providers.tokenizer.len(), 2);
1415 assert_eq!(config.defaults.embedder.as_deref(), Some("local"));
1416 assert_eq!(config.defaults.llm.as_deref(), Some("fallback"));
1417 }
1418}