1pub mod api_key;
2pub mod config;
3pub mod error;
4pub mod http;
5pub mod local;
6pub mod model_files;
7pub mod model_hashes;
8pub mod openai;
9pub mod tokenize;
10pub mod voyage;
11
12use std::sync::Arc;
13
14pub use error::Result;
15
16#[async_trait::async_trait]
17pub trait Embedder: Send + Sync + 'static {
18 fn dimension(&self) -> usize;
19
20 fn model_id(&self) -> &str;
21
22 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
23
24 async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
25 Ok(self.embed(&[query.to_string()]).await?.remove(0))
26 }
27}
28
29pub enum Provider {
30 Local,
31 OpenAi,
32 Voyage,
33}
34
35impl Provider {
36 pub fn default_concurrency(&self) -> usize {
37 match self {
38 Provider::Local => std::thread::available_parallelism()
39 .map(|n| n.get())
40 .unwrap_or(4),
41 Provider::OpenAi => 8,
42 Provider::Voyage => 4,
43 }
44 }
45}
46
47pub fn build(provider: Provider, config: config::EmbedConfig) -> Result<Arc<dyn Embedder>> {
48 match provider {
49 Provider::Local => {
50 let embedder = tokio::runtime::Handle::try_current()
51 .map(|handle| handle.block_on(local::LocalEmbedder::new(config.clone())))
52 .unwrap_or_else(|_| {
53 tokio::runtime::Builder::new_current_thread()
54 .enable_all()
55 .build()
56 .map_err(|e| {
57 error::EmbedError::Config(format!("failed to create runtime: {e}"))
58 })?
59 .block_on(local::LocalEmbedder::new(config))
60 })?;
61 Ok(Arc::new(embedder))
62 }
63 Provider::OpenAi => Ok(Arc::new(openai::OpenAiEmbedder::new(config)?)),
64 Provider::Voyage => Ok(Arc::new(voyage::VoyageEmbedder::new(config)?)),
65 }
66}
67
68pub struct NullEmbedder;
69
70impl Default for NullEmbedder {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl NullEmbedder {
77 pub fn new() -> Self {
78 Self
79 }
80}
81
82#[async_trait::async_trait]
83impl Embedder for NullEmbedder {
84 fn dimension(&self) -> usize {
85 1
86 }
87
88 fn model_id(&self) -> &str {
89 "null"
90 }
91
92 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
93 Ok(texts.iter().map(|_| vec![0.0]).collect())
94 }
95}
96
97#[cfg(test)]
98#[allow(clippy::unwrap_used)]
99mod tests {
100 use super::*;
101 use crate::config::EmbedConfig;
102
103 #[test]
104 fn build_local_not_implemented_or_fails_with_config_err() {
105 let result = build(Provider::Local, EmbedConfig::default());
106 match result {
107 Err(error::EmbedError::Config(_)) => {}
108 Ok(_) => {
109 eprintln!("local provider succeeded (model was cached)");
110 }
111 Err(other) => panic!("expected Config error, got: {other:?}"),
112 }
113 }
114
115 #[test]
116 fn build_voyage_fails_without_api_key() {
117 if std::env::var("VOYAGE_API_KEY").is_ok() {
118 return;
119 }
120 let result = build(Provider::Voyage, EmbedConfig::default());
121 assert!(result.is_err());
122 match result.err().unwrap() {
123 error::EmbedError::Config(msg) => {
124 assert!(msg.contains("VOYAGE_API_KEY"));
125 }
126 other => panic!("expected Config error about API key, got: {other:?}"),
127 }
128 }
129
130 #[test]
131 fn build_openai_fails_without_api_key() {
132 if std::env::var("OPENAI_API_KEY").is_ok() {
133 return;
134 }
135 let result = build(Provider::OpenAi, EmbedConfig::default());
136 assert!(result.is_err());
137 match result.err().unwrap() {
138 error::EmbedError::Config(msg) => {
139 assert!(msg.contains("OPENAI_API_KEY"));
140 }
141 other => panic!("expected Config error about API key, got: {other:?}"),
142 }
143 }
144}