use std::sync::Arc;
use mnem_core::rerank::{RerankError, Reranker};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase", tag = "provider")]
pub enum ProviderConfig {
Cohere(CohereConfig),
Voyage(VoyageConfig),
Jina(JinaConfig),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CohereConfig {
pub model: String,
#[serde(default = "default_cohere_env")]
pub api_key_env: String,
#[serde(default = "default_cohere_base")]
pub base_url: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
impl Default for CohereConfig {
fn default() -> Self {
Self {
model: "rerank-v3.5".into(),
api_key_env: default_cohere_env(),
base_url: default_cohere_base(),
timeout_secs: default_timeout(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct VoyageConfig {
pub model: String,
#[serde(default = "default_voyage_env")]
pub api_key_env: String,
#[serde(default = "default_voyage_base")]
pub base_url: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
impl Default for VoyageConfig {
fn default() -> Self {
Self {
model: "rerank-2.5".into(),
api_key_env: default_voyage_env(),
base_url: default_voyage_base(),
timeout_secs: default_timeout(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct JinaConfig {
pub model: String,
#[serde(default = "default_jina_env")]
pub api_key_env: String,
#[serde(default = "default_jina_base")]
pub base_url: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
impl Default for JinaConfig {
fn default() -> Self {
Self {
model: "jina-reranker-v3".into(),
api_key_env: default_jina_env(),
base_url: default_jina_base(),
timeout_secs: default_timeout(),
}
}
}
fn default_cohere_env() -> String {
"COHERE_API_KEY".into()
}
fn default_cohere_base() -> String {
"https://api.cohere.com".into()
}
fn default_voyage_env() -> String {
"VOYAGE_API_KEY".into()
}
fn default_voyage_base() -> String {
"https://api.voyageai.com".into()
}
fn default_jina_env() -> String {
"JINA_API_KEY".into()
}
fn default_jina_base() -> String {
"https://api.jina.ai".into()
}
const fn default_timeout() -> u64 {
30
}
pub fn open(cfg: &ProviderConfig) -> Result<Arc<dyn Reranker>, RerankError> {
match cfg {
#[cfg(feature = "cohere")]
ProviderConfig::Cohere(c) => {
let r = crate::cohere::CohereReranker::from_config(c)?;
Ok(Arc::new(r))
}
#[cfg(not(feature = "cohere"))]
ProviderConfig::Cohere(_) => Err(RerankError::Config(
"this mnem-rerank-providers build was compiled without the `cohere` feature".into(),
)),
#[cfg(feature = "voyage")]
ProviderConfig::Voyage(c) => {
let r = crate::voyage::VoyageReranker::from_config(c)?;
Ok(Arc::new(r))
}
#[cfg(not(feature = "voyage"))]
ProviderConfig::Voyage(_) => Err(RerankError::Config(
"this mnem-rerank-providers build was compiled without the `voyage` feature".into(),
)),
#[cfg(feature = "jina")]
ProviderConfig::Jina(c) => {
let r = crate::jina::JinaReranker::from_config(c)?;
Ok(Arc::new(r))
}
#[cfg(not(feature = "jina"))]
ProviderConfig::Jina(_) => Err(RerankError::Config(
"this mnem-rerank-providers build was compiled without the `jina` feature".into(),
)),
}
}
pub(crate) fn read_api_key(var: &str) -> Result<String, RerankError> {
std::env::var(var)
.map_err(|_| RerankError::Config(format!("environment variable {var} is not set")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cohere_config_toml_round_trip() {
let cfg = ProviderConfig::Cohere(CohereConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("provider = \"cohere\""));
assert!(s.contains("rerank-v3.5"));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn voyage_config_toml_round_trip() {
let cfg = ProviderConfig::Voyage(VoyageConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("provider = \"voyage\""));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn jina_config_toml_round_trip() {
let cfg = ProviderConfig::Jina(JinaConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("provider = \"jina\""));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn read_api_key_missing_is_config_error() {
let var = "MNEM_TEST_RERANK_KEY_NEVER_SET_a9e1c3d5f7b9";
let e = read_api_key(var).unwrap_err();
match e {
RerankError::Config(msg) => assert!(msg.contains(var)),
other => panic!("expected Config error, got {other:?}"),
}
}
}