use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InputType {
None,
Query,
Document,
}
impl Default for InputType {
fn default() -> Self {
Self::None
}
}
impl InputType {
pub fn as_api_str(&self) -> Option<&'static str> {
match self {
InputType::None => None,
InputType::Query => Some("query"),
InputType::Document => Some("document"),
}
}
pub fn get_prefix(&self) -> Option<&'static str> {
match self {
InputType::None => None,
InputType::Query => Some(crate::constants::QUERY_PREFIX),
InputType::Document => Some(crate::constants::DOCUMENT_PREFIX),
}
}
pub fn apply_prefix(&self, text: &str) -> String {
match self.get_prefix() {
Some(prefix) => format!("{}{}", prefix, text),
None => text.to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingProviderType {
FastEmbed,
Jina,
Voyage,
Google,
HuggingFace,
OpenAI,
}
impl Default for EmbeddingProviderType {
fn default() -> Self {
#[cfg(feature = "fastembed")]
{
Self::FastEmbed
}
#[cfg(not(feature = "fastembed"))]
{
Self::Voyage
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub code_model: String,
pub text_model: String,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
#[cfg(feature = "fastembed")]
{
Self {
code_model: "fastembed:jinaai/jina-embeddings-v2-base-code".to_string(),
text_model: "fastembed:sentence-transformers/all-MiniLM-L6-v2-quantized"
.to_string(),
}
}
#[cfg(not(feature = "fastembed"))]
{
Self {
code_model: "voyage:voyage-code-3".to_string(),
text_model: "voyage:voyage-3.5-lite".to_string(),
}
}
}
}
pub fn parse_provider_model(input: &str) -> (EmbeddingProviderType, String) {
if let Some((provider_str, model)) = input.split_once(':') {
let provider = match provider_str.to_lowercase().as_str() {
"fastembed" => EmbeddingProviderType::FastEmbed,
"jinaai" | "jina" => EmbeddingProviderType::Jina,
"voyageai" | "voyage" => EmbeddingProviderType::Voyage,
"google" => EmbeddingProviderType::Google,
"huggingface" | "hf" => EmbeddingProviderType::HuggingFace,
"openai" => EmbeddingProviderType::OpenAI,
_ => {
#[cfg(feature = "fastembed")]
{
EmbeddingProviderType::FastEmbed
}
#[cfg(not(feature = "fastembed"))]
{
EmbeddingProviderType::Voyage
}
}
};
(provider, model.to_string())
} else {
#[cfg(feature = "fastembed")]
{
(EmbeddingProviderType::FastEmbed, input.to_string())
}
#[cfg(not(feature = "fastembed"))]
{
(EmbeddingProviderType::Voyage, input.to_string())
}
}
}
impl EmbeddingConfig {
pub fn get_active_provider(&self) -> EmbeddingProviderType {
let (provider, _) = parse_provider_model(&self.code_model);
provider
}
pub fn get_api_key(&self, provider: &EmbeddingProviderType) -> Option<String> {
match provider {
EmbeddingProviderType::Jina => std::env::var("JINA_API_KEY").ok(),
EmbeddingProviderType::Voyage => std::env::var("VOYAGE_API_KEY").ok(),
EmbeddingProviderType::Google => std::env::var("GOOGLE_API_KEY").ok(),
_ => None, }
}
pub async fn get_vector_dimension(
&self,
provider: &EmbeddingProviderType,
model: &str,
) -> usize {
match crate::embedding::provider::create_embedding_provider_from_parts(provider, model)
.await
{
Ok(provider_impl) => provider_impl.get_dimension(),
Err(e) => {
panic!(
"Failed to create provider for {:?}:{}: {}. Using fallback dimension.",
provider, model, e
);
}
}
}
pub async fn validate_model(
&self,
provider: &EmbeddingProviderType,
model: &str,
) -> Result<()> {
let provider_impl =
crate::embedding::provider::create_embedding_provider_from_parts(provider, model)
.await?;
if !provider_impl.is_model_supported() {
return Err(anyhow::anyhow!(
"Model {} is not supported by provider {:?}",
model,
provider
));
}
Ok(())
}
}