mod aliyun;
mod zhipu;
use async_trait::async_trait;
use std::time::Duration;
use crate::client::HttpClient;
use crate::config::Provider;
use crate::config::ProviderConfig;
use crate::error::{Error, Result};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug, Clone)]
pub struct RerankItem {
pub index: usize,
pub score: f64,
}
#[async_trait]
pub trait RerankProvider: Send + Sync {
async fn rerank(
&self,
query: &str,
documents: &[&str],
top_n: Option<usize>,
) -> Result<Vec<RerankItem>>;
}
fn http_client(config: &ProviderConfig) -> Result<HttpClient> {
HttpClient::new(config.timeout.unwrap_or(DEFAULT_TIMEOUT))
}
pub(crate) fn create(config: &ProviderConfig) -> Result<Box<dyn RerankProvider>> {
match config.provider {
#[cfg(all(feature = "aliyun", feature = "rerank"))]
Provider::Aliyun => Ok(Box::new(aliyun::AliyunRerank::new(
config,
http_client(config)?,
))),
#[cfg(all(feature = "zhipu", feature = "rerank"))]
Provider::Zhipu => Ok(Box::new(zhipu::ZhipuRerank::new(
config,
http_client(config)?,
))),
#[cfg(not(feature = "aliyun"))]
Provider::Aliyun => Err(Error::ProviderDisabled("aliyun".to_string())),
#[cfg(not(feature = "zhipu"))]
Provider::Zhipu => Err(Error::ProviderDisabled("zhipu".to_string())),
#[cfg(feature = "anthropic")]
Provider::Anthropic => Err(Error::Unsupported {
provider: config.provider.to_string(),
capability: "rerank",
}),
#[cfg(not(feature = "anthropic"))]
Provider::Anthropic => Err(Error::ProviderDisabled("anthropic".to_string())),
#[cfg(feature = "google")]
Provider::Google => Err(Error::Unsupported {
provider: config.provider.to_string(),
capability: "rerank",
}),
#[cfg(not(feature = "google"))]
Provider::Google => Err(Error::ProviderDisabled("google".to_string())),
Provider::OpenAI | Provider::Ollama => Err(Error::Unsupported {
provider: config.provider.to_string(),
capability: "rerank",
}),
}
}
#[cfg(test)]
mod factory_tests {
use super::create;
use crate::config::{Provider, ProviderConfig};
use crate::error::Error;
#[cfg(feature = "openai")]
#[test]
fn openai_is_unsupported() {
let cfg = ProviderConfig::new(Provider::OpenAI, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::Unsupported {
provider,
capability,
}) => {
assert_eq!(provider, "openai");
assert_eq!(capability, "rerank");
}
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected Unsupported, got {:?}", e),
}
}
#[cfg(feature = "ollama")]
#[test]
fn ollama_is_unsupported() {
let cfg = ProviderConfig::new(Provider::Ollama, "k", "http://localhost/v1", "m");
match create(&cfg) {
Err(Error::Unsupported {
provider,
capability,
}) => {
assert_eq!(provider, "ollama");
assert_eq!(capability, "rerank");
}
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected Unsupported, got {:?}", e),
}
}
#[cfg(feature = "anthropic")]
#[test]
fn anthropic_is_unsupported() {
let cfg = ProviderConfig::new(Provider::Anthropic, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::Unsupported {
provider,
capability,
}) => {
assert_eq!(provider, "anthropic");
assert_eq!(capability, "rerank");
}
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected Unsupported, got {:?}", e),
}
}
#[cfg(not(feature = "anthropic"))]
#[test]
fn anthropic_disabled_without_anthropic_feature() {
let cfg = ProviderConfig::new(Provider::Anthropic, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::ProviderDisabled(s)) => assert_eq!(s, "anthropic"),
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected ProviderDisabled, got {:?}", e),
}
}
#[cfg(feature = "google")]
#[test]
fn google_is_unsupported() {
let cfg = ProviderConfig::new(Provider::Google, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::Unsupported {
provider,
capability,
}) => {
assert_eq!(provider, "google");
assert_eq!(capability, "rerank");
}
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected Unsupported, got {:?}", e),
}
}
#[cfg(not(feature = "google"))]
#[test]
fn google_disabled_without_google_feature() {
let cfg = ProviderConfig::new(Provider::Google, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::ProviderDisabled(s)) => assert_eq!(s, "google"),
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected ProviderDisabled, got {:?}", e),
}
}
#[cfg(not(feature = "aliyun"))]
#[test]
fn aliyun_disabled_without_aliyun_feature() {
let cfg = ProviderConfig::new(Provider::Aliyun, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::ProviderDisabled(s)) => assert_eq!(s, "aliyun"),
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected ProviderDisabled, got {:?}", e),
}
}
#[cfg(not(feature = "zhipu"))]
#[test]
fn zhipu_disabled_without_zhipu_feature() {
let cfg = ProviderConfig::new(Provider::Zhipu, "k", "https://x/v1", "m");
match create(&cfg) {
Err(Error::ProviderDisabled(s)) => assert_eq!(s, "zhipu"),
Ok(_) => panic!("expected error"),
Err(e) => panic!("expected ProviderDisabled, got {:?}", e),
}
}
}