Skip to main content

argyph_embed/
lib.rs

1pub mod api_key;
2pub mod config;
3pub mod error;
4pub mod http;
5pub mod local;
6pub mod model_files;
7pub mod model_hashes;
8pub mod openai;
9pub mod tokenize;
10pub mod voyage;
11
12use std::sync::Arc;
13
14pub use error::Result;
15
16#[async_trait::async_trait]
17pub trait Embedder: Send + Sync + 'static {
18    fn dimension(&self) -> usize;
19
20    fn model_id(&self) -> &str;
21
22    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
23
24    async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
25        Ok(self.embed(&[query.to_string()]).await?.remove(0))
26    }
27}
28
29pub enum Provider {
30    Local,
31    OpenAi,
32    Voyage,
33}
34
35impl Provider {
36    pub fn default_concurrency(&self) -> usize {
37        match self {
38            Provider::Local => std::thread::available_parallelism()
39                .map(|n| n.get())
40                .unwrap_or(4),
41            Provider::OpenAi => 8,
42            Provider::Voyage => 4,
43        }
44    }
45}
46
47pub fn build(provider: Provider, config: config::EmbedConfig) -> Result<Arc<dyn Embedder>> {
48    match provider {
49        Provider::Local => {
50            let embedder = tokio::runtime::Handle::try_current()
51                .map(|handle| handle.block_on(local::LocalEmbedder::new(config.clone())))
52                .unwrap_or_else(|_| {
53                    tokio::runtime::Builder::new_current_thread()
54                        .enable_all()
55                        .build()
56                        .map_err(|e| {
57                            error::EmbedError::Config(format!("failed to create runtime: {e}"))
58                        })?
59                        .block_on(local::LocalEmbedder::new(config))
60                })?;
61            Ok(Arc::new(embedder))
62        }
63        Provider::OpenAi => Ok(Arc::new(openai::OpenAiEmbedder::new(config)?)),
64        Provider::Voyage => Ok(Arc::new(voyage::VoyageEmbedder::new(config)?)),
65    }
66}
67
68pub struct NullEmbedder;
69
70impl Default for NullEmbedder {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl NullEmbedder {
77    pub fn new() -> Self {
78        Self
79    }
80}
81
82#[async_trait::async_trait]
83impl Embedder for NullEmbedder {
84    fn dimension(&self) -> usize {
85        1
86    }
87
88    fn model_id(&self) -> &str {
89        "null"
90    }
91
92    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
93        Ok(texts.iter().map(|_| vec![0.0]).collect())
94    }
95}
96
97#[cfg(test)]
98#[allow(clippy::unwrap_used)]
99mod tests {
100    use super::*;
101    use crate::config::EmbedConfig;
102
103    #[test]
104    fn build_local_not_implemented_or_fails_with_config_err() {
105        let result = build(Provider::Local, EmbedConfig::default());
106        match result {
107            Err(error::EmbedError::Config(_)) => {}
108            Ok(_) => {
109                eprintln!("local provider succeeded (model was cached)");
110            }
111            Err(other) => panic!("expected Config error, got: {other:?}"),
112        }
113    }
114
115    #[test]
116    fn build_voyage_fails_without_api_key() {
117        if std::env::var("VOYAGE_API_KEY").is_ok() {
118            return;
119        }
120        let result = build(Provider::Voyage, EmbedConfig::default());
121        assert!(result.is_err());
122        match result.err().unwrap() {
123            error::EmbedError::Config(msg) => {
124                assert!(msg.contains("VOYAGE_API_KEY"));
125            }
126            other => panic!("expected Config error about API key, got: {other:?}"),
127        }
128    }
129
130    #[test]
131    fn build_openai_fails_without_api_key() {
132        if std::env::var("OPENAI_API_KEY").is_ok() {
133            return;
134        }
135        let result = build(Provider::OpenAi, EmbedConfig::default());
136        assert!(result.is_err());
137        match result.err().unwrap() {
138            error::EmbedError::Config(msg) => {
139                assert!(msg.contains("OPENAI_API_KEY"));
140            }
141            other => panic!("expected Config error about API key, got: {other:?}"),
142        }
143    }
144}