use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::core::providers::base::BaseConfig;
use crate::core::traits::provider::ProviderConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAILikeConfig {
#[serde(flatten)]
pub base: BaseConfig,
#[serde(default = "default_provider_name")]
pub provider_name: String,
#[serde(default)]
pub custom_headers: HashMap<String, String>,
#[serde(default)]
pub skip_api_key: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_prefix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_model: Option<String>,
#[serde(default = "default_pass_through")]
pub pass_through_params: bool,
}
fn default_provider_name() -> String {
"openai_like".to_string()
}
fn default_pass_through() -> bool {
true
}
impl Default for OpenAILikeConfig {
fn default() -> Self {
Self {
base: BaseConfig {
api_key: None,
api_base: None, timeout: 60,
max_retries: 3,
headers: HashMap::new(),
organization: None,
api_version: None,
},
provider_name: default_provider_name(),
custom_headers: HashMap::new(),
skip_api_key: false,
model_prefix: None,
default_model: None,
pass_through_params: true,
}
}
}
impl OpenAILikeConfig {
pub fn new(api_base: impl Into<String>) -> Self {
let mut config = Self::default();
config.base.api_base = Some(api_base.into());
config
}
pub fn with_api_key(api_base: impl Into<String>, api_key: impl Into<String>) -> Self {
let mut config = Self::new(api_base);
config.base.api_key = Some(api_key.into());
config
}
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(api_base) = std::env::var("OPENAI_LIKE_API_BASE") {
config.base.api_base = Some(api_base);
} else if let Ok(api_base) = std::env::var("OPENAI_API_BASE") {
config.base.api_base = Some(api_base);
}
if let Ok(api_key) = std::env::var("OPENAI_LIKE_API_KEY") {
config.base.api_key = Some(api_key);
} else if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config.base.api_key = Some(api_key);
}
if let Ok(timeout_str) = std::env::var("OPENAI_LIKE_TIMEOUT")
&& let Ok(timeout) = timeout_str.parse::<u64>()
{
config.base.timeout = timeout;
}
if let Ok(skip) = std::env::var("OPENAI_LIKE_SKIP_API_KEY") {
config.skip_api_key = skip.to_lowercase() == "true" || skip == "1";
}
if let Ok(name) = std::env::var("OPENAI_LIKE_PROVIDER_NAME") {
config.provider_name = name;
}
config
}
pub fn with_provider_name(mut self, name: impl Into<String>) -> Self {
self.provider_name = name.into();
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.custom_headers.insert(key.into(), value.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.custom_headers.extend(headers);
self
}
pub fn with_skip_api_key(mut self, skip: bool) -> Self {
self.skip_api_key = skip;
self
}
pub fn with_model_prefix(mut self, prefix: impl Into<String>) -> Self {
self.model_prefix = Some(prefix.into());
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = Some(model.into());
self
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.base.timeout = timeout_secs;
self
}
pub fn validate(&self) -> Result<(), String> {
if self.base.api_base.is_none() {
return Err("api_base is required for openai_like provider".to_string());
}
if !self.skip_api_key && self.base.api_key.is_none() {
return Err(
"api_key is required for openai_like provider (set skip_api_key=true to skip)"
.to_string(),
);
}
if self.base.timeout == 0 {
return Err("Timeout must be greater than 0".to_string());
}
if self.base.max_retries > 10 {
return Err("Max retries should not exceed 10".to_string());
}
Ok(())
}
pub fn get_api_base(&self) -> String {
self.base
.api_base
.clone()
.unwrap_or_else(|| "http://localhost:8000/v1".to_string())
}
pub fn get_effective_model(&self, model: &str) -> String {
if let Some(prefix) = &self.model_prefix
&& model.starts_with(prefix)
{
return model[prefix.len()..].to_string();
}
model.to_string()
}
}
impl ProviderConfig for OpenAILikeConfig {
fn validate(&self) -> Result<(), String> {
self.validate()
}
fn api_key(&self) -> Option<&str> {
self.base.api_key.as_deref()
}
fn api_base(&self) -> Option<&str> {
self.base.api_base.as_deref()
}
fn timeout(&self) -> Duration {
Duration::from_secs(self.base.timeout)
}
fn max_retries(&self) -> u32 {
self.base.max_retries
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = OpenAILikeConfig::default();
assert!(config.base.api_base.is_none());
assert!(config.base.api_key.is_none());
assert!(!config.skip_api_key);
assert!(config.pass_through_params);
}
#[test]
fn test_new_with_api_base() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1");
assert_eq!(
config.base.api_base,
Some("http://localhost:8000/v1".to_string())
);
}
#[test]
fn test_with_api_key() {
let config = OpenAILikeConfig::with_api_key("http://localhost:8000/v1", "sk-test123");
assert_eq!(
config.base.api_base,
Some("http://localhost:8000/v1".to_string())
);
assert_eq!(config.base.api_key, Some("sk-test123".to_string()));
}
#[test]
fn test_validation_missing_api_base() {
let config = OpenAILikeConfig::default();
assert!(config.validate().is_err());
}
#[test]
fn test_validation_missing_api_key() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1");
assert!(config.validate().is_err());
}
#[test]
fn test_validation_skip_api_key() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1").with_skip_api_key(true);
assert!(config.validate().is_ok());
}
#[test]
fn test_validation_with_api_key() {
let config = OpenAILikeConfig::with_api_key("http://localhost:8000/v1", "sk-test123");
assert!(config.validate().is_ok());
}
#[test]
fn test_get_effective_model_no_prefix() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1");
assert_eq!(config.get_effective_model("gpt-4"), "gpt-4");
}
#[test]
fn test_get_effective_model_with_prefix() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1").with_model_prefix("custom/");
assert_eq!(config.get_effective_model("custom/gpt-4"), "gpt-4");
assert_eq!(config.get_effective_model("gpt-4"), "gpt-4");
}
#[test]
fn test_custom_headers() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1")
.with_header("X-Custom-Header", "value1")
.with_header("X-Another-Header", "value2");
assert_eq!(config.custom_headers.len(), 2);
assert_eq!(
config.custom_headers.get("X-Custom-Header"),
Some(&"value1".to_string())
);
}
#[test]
fn test_builder_chain() {
let config = OpenAILikeConfig::new("http://localhost:8000/v1")
.with_provider_name("my-provider")
.with_timeout(120)
.with_default_model("llama-2-70b")
.with_skip_api_key(true);
assert_eq!(config.provider_name, "my-provider");
assert_eq!(config.base.timeout, 120);
assert_eq!(config.default_model, Some("llama-2-70b".to_string()));
assert!(config.skip_api_key);
}
}