#[cfg(test)]
use super::*;
use crate::core::providers::unified_provider::ProviderError;
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
use crate::core::types::model::ProviderCapability;
fn create_test_config() -> WatsonxConfig {
WatsonxConfig {
api_key: Some("test-api-key".to_string()),
api_base: Some("https://us-south.ml.cloud.ibm.com".to_string()),
project_id: Some("test-project-id".to_string()),
..Default::default()
}
}
#[tokio::test]
async fn test_provider_creation() {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await;
assert!(provider.is_ok());
}
#[tokio::test]
async fn test_provider_with_credentials() {
let provider = WatsonxProvider::with_credentials(
"test-key",
"test-project",
Some("https://us-south.ml.cloud.ibm.com".to_string()),
)
.await;
assert!(provider.is_ok());
}
#[test]
fn test_provider_name() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
assert_eq!(provider.name(), "watsonx");
});
}
#[test]
fn test_provider_capabilities() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let capabilities = provider.capabilities();
assert!(capabilities.contains(&ProviderCapability::ChatCompletion));
assert!(capabilities.contains(&ProviderCapability::ChatCompletionStream));
assert!(capabilities.contains(&ProviderCapability::ToolCalling));
});
}
#[test]
fn test_model_info() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let models = provider.models();
assert!(!models.is_empty());
let granite_models: Vec<_> = models.iter().filter(|m| m.id.contains("granite")).collect();
assert!(!granite_models.is_empty());
let llama_models: Vec<_> = models.iter().filter(|m| m.id.contains("llama")).collect();
assert!(!llama_models.is_empty());
});
}
#[test]
fn test_model_pricing() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let models = provider.models();
for model in models {
assert!(model.input_cost_per_1k_tokens.is_some());
assert!(model.output_cost_per_1k_tokens.is_some());
assert_eq!(model.currency, "USD");
}
});
}
#[test]
fn test_supported_openai_params() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let params = provider.get_supported_openai_params("ibm/granite-13b-chat-v2");
assert!(params.contains(&"temperature"));
assert!(params.contains(&"max_tokens"));
assert!(params.contains(&"tools"));
assert!(params.contains(&"response_format"));
let params_no_tools = provider.get_supported_openai_params("ibm/granite-3b-code-instruct");
assert!(params_no_tools.contains(&"temperature"));
assert!(params_no_tools.contains(&"max_tokens"));
assert!(!params_no_tools.contains(&"tools"));
});
}
#[tokio::test]
async fn test_cost_calculation() {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let cost = provider
.calculate_cost("ibm/granite-13b-chat-v2", 1000, 1000)
.await;
assert!(cost.is_ok());
let total_cost = cost.unwrap();
assert!(total_cost > 0.0);
let unknown_cost = provider.calculate_cost("unknown-model", 1000, 1000).await;
assert!(unknown_cost.is_err());
}
#[test]
fn test_error_mapping() {
use crate::core::traits::error_mapper::trait_def::ErrorMapper;
let mapper = error::WatsonxErrorMapper;
let auth_error = mapper.map_http_error(401, "Unauthorized");
assert!(matches!(auth_error, ProviderError::Authentication { .. }));
let rate_error = mapper.map_http_error(429, "Too many requests");
assert!(matches!(rate_error, ProviderError::RateLimit { .. }));
let not_found = mapper.map_http_error(404, "Not found");
assert!(matches!(not_found, ProviderError::ModelNotFound { .. }));
}
#[test]
fn test_error_retryability() {
let rate_error = ProviderError::rate_limit("watsonx", Some(60));
assert!(rate_error.is_retryable());
assert!(rate_error.retry_delay().is_some());
let rate_error_no_delay = ProviderError::rate_limit("watsonx", None);
assert!(rate_error_no_delay.is_retryable());
let service_error = ProviderError::provider_unavailable("watsonx", "Service down");
assert!(service_error.is_retryable());
assert!(service_error.retry_delay().is_some());
let network_error = ProviderError::network("watsonx", "Connection failed");
assert!(network_error.is_retryable());
assert!(network_error.retry_delay().is_some());
let auth_error = ProviderError::authentication("watsonx", "Bad key");
assert!(!auth_error.is_retryable());
assert!(auth_error.retry_delay().is_none());
}
#[test]
fn test_model_capabilities() {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let models = provider.models();
for model in models.iter().filter(|m| m.supports_tools) {
assert!(
model
.capabilities
.contains(&ProviderCapability::ToolCalling),
"Model {} should have ToolCalling capability",
model.id
);
}
let vision_models: Vec<_> = models.iter().filter(|m| m.id.contains("vision")).collect();
for model in vision_models {
assert!(
model.supports_multimodal,
"Vision model {} should support multimodal",
model.id
);
}
});
}
#[test]
fn test_config_build_url() {
let config = WatsonxConfig {
api_base: Some("https://us-south.ml.cloud.ibm.com".to_string()),
api_version: "2024-05-31".to_string(),
..Default::default()
};
let url = config.build_url("/ml/v1/text/chat", false).unwrap();
assert!(url.contains("https://us-south.ml.cloud.ibm.com"));
assert!(url.contains("/ml/v1/text/chat"));
assert!(url.contains("version=2024-05-31"));
let stream_url = config.build_url("/ml/v1/text/chat_stream", true).unwrap();
assert!(stream_url.contains("/ml/v1/text/chat_stream"));
assert!(stream_url.contains("version=2024-05-31"));
}
#[test]
fn test_model_info_accessors() {
let model = model_info::get_model_info("ibm/granite-13b-chat-v2");
assert!(model.is_some());
let models = model_info::get_available_models();
assert!(!models.is_empty());
let ibm_models = model_info::get_models_by_provider("ibm");
assert!(!ibm_models.is_empty());
let meta_models = model_info::get_models_by_provider("meta");
assert!(!meta_models.is_empty());
assert!(model_info::supports_chat("ibm/granite-13b-chat-v2"));
assert!(model_info::supports_tools("ibm/granite-13b-chat-v2"));
assert!(!model_info::supports_tools("ibm/granite-3b-code-instruct"));
}
#[tokio::test]
async fn test_map_openai_params() {
let config = create_test_config();
let provider = WatsonxProvider::new(config).await.unwrap();
let mut params = std::collections::HashMap::new();
params.insert("max_tokens".to_string(), serde_json::json!(100));
params.insert("temperature".to_string(), serde_json::json!(0.7));
params.insert("frequency_penalty".to_string(), serde_json::json!(0.5));
params.insert("stop".to_string(), serde_json::json!(["END"]));
params.insert("seed".to_string(), serde_json::json!(42));
let mapped = provider
.map_openai_params(params, "ibm/granite-13b-chat-v2")
.await
.unwrap();
assert!(mapped.contains_key("max_new_tokens"));
assert!(mapped.contains_key("temperature"));
assert!(mapped.contains_key("repetition_penalty"));
assert!(mapped.contains_key("stop_sequences"));
assert!(mapped.contains_key("random_seed"));
}
#[test]
fn test_config_validation() {
use crate::core::traits::provider::ProviderConfig;
let valid_config = WatsonxConfig {
api_key: Some("test-key".to_string()),
project_id: Some("test-project".to_string()),
..Default::default()
};
assert!(valid_config.validate().is_ok());
let token_config = WatsonxConfig {
token: Some("test-token".to_string()),
project_id: Some("test-project".to_string()),
..Default::default()
};
assert!(token_config.validate().is_ok());
let zen_config = WatsonxConfig {
zen_api_key: Some("zen-key".to_string()),
project_id: Some("test-project".to_string()),
..Default::default()
};
assert!(zen_config.validate().is_ok());
let space_config = WatsonxConfig {
api_key: Some("test-key".to_string()),
space_id: Some("test-space".to_string()),
..Default::default()
};
assert!(space_config.validate().is_ok());
let invalid_timeout = WatsonxConfig {
api_key: Some("test-key".to_string()),
project_id: Some("test-project".to_string()),
timeout: 0,
..Default::default()
};
assert!(invalid_timeout.validate().is_err());
}