mod common;
use common::create_config_with_providers;
use lc::config::Config;
use std::collections::HashMap;
#[cfg(test)]
mod config_set_tests {
use super::*;
#[test]
fn test_config_set_default_provider() {
let mut config = create_config_with_providers();
config.default_provider = Some("anthropic".to_string());
assert_eq!(config.default_provider, Some("anthropic".to_string()));
}
#[test]
fn test_config_set_default_model() {
let mut config = create_config_with_providers();
config.default_model = Some("gpt-4".to_string());
assert_eq!(config.default_model, Some("gpt-4".to_string()));
}
#[test]
fn test_config_set_system_prompt() {
let mut config = create_config_with_providers();
config.system_prompt = Some("You are a helpful assistant.".to_string());
assert_eq!(
config.system_prompt,
Some("You are a helpful assistant.".to_string())
);
}
#[test]
fn test_config_set_max_tokens() {
let mut config = create_config_with_providers();
let parsed_tokens = Config::parse_max_tokens("2k").unwrap();
assert_eq!(parsed_tokens, 2000);
config.max_tokens = Some(parsed_tokens);
assert_eq!(config.max_tokens, Some(2000));
}
#[test]
fn test_config_set_temperature() {
let mut config = create_config_with_providers();
let parsed_temp = Config::parse_temperature("0.7").unwrap();
assert_eq!(parsed_temp, 0.7);
config.temperature = Some(parsed_temp);
assert_eq!(config.temperature, Some(0.7));
}
}
#[cfg(test)]
mod config_get_tests {
use super::*;
#[test]
fn test_config_get_existing_values() {
let mut config = create_config_with_providers();
config.default_provider = Some("openai".to_string());
config.default_model = Some("gpt-4".to_string());
config.system_prompt = Some("Test prompt".to_string());
config.max_tokens = Some(1000);
config.temperature = Some(0.5);
assert_eq!(config.default_provider, Some("openai".to_string()));
assert_eq!(config.default_model, Some("gpt-4".to_string()));
assert_eq!(config.system_prompt, Some("Test prompt".to_string()));
assert_eq!(config.max_tokens, Some(1000));
assert_eq!(config.temperature, Some(0.5));
}
#[test]
fn test_config_get_unset_values() {
let config = Config {
providers: HashMap::new(),
default_provider: None,
default_model: None,
aliases: HashMap::new(),
system_prompt: None,
templates: HashMap::new(),
max_tokens: None,
temperature: None,
stream: None,
};
assert!(config.default_provider.is_none());
assert!(config.default_model.is_none());
assert!(config.system_prompt.is_none());
assert!(config.max_tokens.is_none());
assert!(config.temperature.is_none());
}
}
#[cfg(test)]
mod config_delete_tests {
use super::*;
#[test]
fn test_config_delete_values() {
let mut config = create_config_with_providers();
config.default_provider = Some("openai".to_string());
config.default_model = Some("gpt-4".to_string());
config.system_prompt = Some("Test prompt".to_string());
config.max_tokens = Some(1000);
config.temperature = Some(0.5);
config.default_provider = None;
config.default_model = None;
config.system_prompt = None;
config.max_tokens = None;
config.temperature = None;
assert!(config.default_provider.is_none());
assert!(config.default_model.is_none());
assert!(config.system_prompt.is_none());
assert!(config.max_tokens.is_none());
assert!(config.temperature.is_none());
}
}
#[cfg(test)]
mod config_validation_tests {
use super::*;
#[test]
fn test_max_tokens_parsing() {
assert_eq!(Config::parse_max_tokens("1000").unwrap(), 1000);
assert_eq!(Config::parse_max_tokens("2k").unwrap(), 2000);
assert_eq!(Config::parse_max_tokens("1.5k").unwrap(), 1500);
assert!(Config::parse_max_tokens("invalid").is_err());
assert!(Config::parse_max_tokens("").is_err());
}
#[test]
fn test_temperature_parsing() {
assert_eq!(Config::parse_temperature("0.0").unwrap(), 0.0);
assert_eq!(Config::parse_temperature("0.7").unwrap(), 0.7);
assert_eq!(Config::parse_temperature("1.0").unwrap(), 1.0);
assert_eq!(Config::parse_temperature("2.0").unwrap(), 2.0);
assert!(Config::parse_temperature("invalid").is_err());
assert!(Config::parse_temperature("").is_err());
}
#[test]
fn test_template_resolution() {
let mut config = create_config_with_providers();
config.templates.insert(
"helpful".to_string(),
"You are a helpful assistant.".to_string(),
);
let resolved = config.resolve_template_or_prompt("t:helpful");
assert_eq!(resolved, "You are a helpful assistant.");
let resolved = config.resolve_template_or_prompt("Regular prompt");
assert_eq!(resolved, "Regular prompt");
let resolved = config.resolve_template_or_prompt("t:nonexistent");
assert_eq!(resolved, "t:nonexistent");
}
}
#[cfg(test)]
mod provider_config_url_tests {
use lc::config::ProviderConfig;
use std::collections::HashMap;
#[test]
fn test_get_chat_url_full_url_with_model_and_vars() {
let mut pc = ProviderConfig {
endpoint: "https://aiplatform.googleapis.com".to_string(),
api_key: None,
models: vec![],
models_path: "/v1beta1/{project}/locations/{location}/models".to_string(),
chat_path: "https://aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent".to_string(),
images_path: None,
embeddings_path: None,
headers: HashMap::new(),
token_url: Some("https://oauth2.googleapis.com/token".to_string()),
cached_token: None,
auth_type: Some("google_sa_jwt".to_string()),
vars: HashMap::new(),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
};
pc.vars.insert("project".to_string(), "my-proj".to_string());
pc.vars
.insert("location".to_string(), "us-central1".to_string());
let url = pc.get_chat_url("gemini-1.5-pro");
assert_eq!(
url,
"https://aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/publishers/google/models/gemini-1.5-pro:streamGenerateContent"
);
pc.chat_path = "https://aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/models/{model_name}:generateContent".to_string();
let url2 = pc.get_chat_url("gemini-1.5-flash");
assert_eq!(
url2,
"https://aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/models/gemini-1.5-flash:generateContent"
);
}
#[test]
fn test_get_chat_url_non_full_path_concatenation() {
let pc = ProviderConfig {
endpoint: "https://api.openai.com".to_string(),
api_key: None,
models: vec![],
models_path: "/v1/models".to_string(),
chat_path: "/v1/chat/completions".to_string(),
images_path: None,
embeddings_path: None,
headers: HashMap::new(),
token_url: None,
cached_token: None,
auth_type: None,
vars: HashMap::new(),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
};
let url = pc.get_chat_url("gpt-4o");
assert_eq!(url, "https://api.openai.com/v1/chat/completions");
}
}