use serde::{Deserialize, Serialize};
use std::time::Duration;
use crate::core::traits::provider::ProviderConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum CohereApiVersion {
V1,
#[default]
V2,
}
impl CohereApiVersion {
pub fn as_path(&self) -> &'static str {
match self {
CohereApiVersion::V1 => "v1",
CohereApiVersion::V2 => "v2",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohereConfig {
pub api_key: String,
pub api_base: String,
pub api_version: CohereApiVersion,
pub timeout_seconds: u64,
pub max_retries: u32,
pub default_embedding_input_type: String,
}
impl Default for CohereConfig {
fn default() -> Self {
Self {
api_key: String::new(),
api_base: "https://api.cohere.ai".to_string(),
api_version: CohereApiVersion::V2,
timeout_seconds: 60,
max_retries: 3,
default_embedding_input_type: "search_document".to_string(),
}
}
}
impl CohereConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
..Default::default()
}
}
pub fn with_api_version(mut self, version: CohereApiVersion) -> Self {
self.api_version = version;
self
}
pub fn with_api_base(mut self, base: impl Into<String>) -> Self {
self.api_base = base.into();
self
}
pub fn with_timeout(mut self, seconds: u64) -> Self {
self.timeout_seconds = seconds;
self
}
pub fn chat_endpoint(&self) -> String {
format!(
"{}/{}/chat",
self.api_base.trim_end_matches('/'),
self.api_version.as_path()
)
}
pub fn embed_endpoint(&self) -> String {
format!("{}/v2/embed", self.api_base.trim_end_matches('/'))
}
pub fn rerank_endpoint(&self) -> String {
format!("{}/v1/rerank", self.api_base.trim_end_matches('/'))
}
pub fn models_endpoint(&self) -> String {
format!("{}/v1/models", self.api_base.trim_end_matches('/'))
}
pub fn create_headers(&self) -> std::collections::HashMap<String, String> {
let mut headers = std::collections::HashMap::new();
headers.insert(
"Authorization".to_string(),
format!("Bearer {}", self.api_key),
);
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers.insert("Accept".to_string(), "application/json".to_string());
headers.insert("Request-Source".to_string(), "litellm-rs".to_string());
headers
}
}
impl ProviderConfig for CohereConfig {
fn validate(&self) -> Result<(), String> {
self.validate_standard("Cohere")
}
fn api_key(&self) -> Option<&str> {
Some(&self.api_key)
}
fn api_base(&self) -> Option<&str> {
Some(&self.api_base)
}
fn timeout(&self) -> Duration {
Duration::from_secs(self.timeout_seconds)
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = CohereConfig::default();
assert_eq!(config.api_base, "https://api.cohere.ai");
assert_eq!(config.api_version, CohereApiVersion::V2);
assert_eq!(config.timeout_seconds, 60);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_config_with_api_key() {
let config = CohereConfig::new("test-key");
assert_eq!(config.api_key, "test-key");
}
#[test]
fn test_config_validation() {
let config = CohereConfig::default();
assert!(config.validate().is_err());
let config = CohereConfig::new("test-key");
assert!(config.validate().is_ok());
}
#[test]
fn test_endpoints() {
let config = CohereConfig::new("test-key");
assert_eq!(config.chat_endpoint(), "https://api.cohere.ai/v2/chat");
assert_eq!(config.embed_endpoint(), "https://api.cohere.ai/v2/embed");
assert_eq!(config.rerank_endpoint(), "https://api.cohere.ai/v1/rerank");
}
#[test]
fn test_v1_chat_endpoint() {
let config = CohereConfig::new("test-key").with_api_version(CohereApiVersion::V1);
assert_eq!(config.chat_endpoint(), "https://api.cohere.ai/v1/chat");
}
#[test]
fn test_create_headers() {
let config = CohereConfig::new("test-key");
let headers = config.create_headers();
assert_eq!(headers.get("Authorization").unwrap(), "Bearer test-key");
assert_eq!(headers.get("Content-Type").unwrap(), "application/json");
}
}