use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HuggingFaceEmbedderConfig {
#[serde(default = "default_model")]
pub model: String,
pub api_key: Option<String>,
#[serde(default = "default_embedding_dims")]
pub embedding_dims: usize,
#[serde(default = "default_base_url")]
pub base_url: String,
}
fn default_model() -> String {
"BAAI/bge-small-en-v1.5".to_string()
}
fn default_embedding_dims() -> usize {
384
}
fn default_base_url() -> String {
"https://api-inference.huggingface.co/models".to_string()
}
impl Default for HuggingFaceEmbedderConfig {
fn default() -> Self {
Self {
model: default_model(),
api_key: None,
embedding_dims: default_embedding_dims(),
base_url: default_base_url(),
}
}
}
impl HuggingFaceEmbedderConfig {
pub fn get_api_key(&self) -> Option<String> {
self.api_key
.clone()
.or_else(|| std::env::var("HUGGINGFACE_API_KEY").ok())
}
}