use crate::core::traits::provider::ProviderConfig;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TritonConfig {
pub server_url: Option<String>,
pub model_name: Option<String>,
pub model_version: Option<String>,
#[serde(default = "default_timeout_ms")]
pub timeout_ms: u64,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub debug: bool,
#[serde(default)]
pub headers: std::collections::HashMap<String, String>,
}
impl Default for TritonConfig {
fn default() -> Self {
Self {
server_url: None,
model_name: None,
model_version: None,
timeout_ms: default_timeout_ms(),
max_retries: default_max_retries(),
debug: false,
headers: std::collections::HashMap::new(),
}
}
}
impl ProviderConfig for TritonConfig {
fn validate(&self) -> Result<(), String> {
if self.server_url.is_none() && std::env::var("TRITON_SERVER_URL").is_err() {
return Err(
"Triton server URL not provided and TRITON_SERVER_URL environment variable not set"
.to_string(),
);
}
if self.timeout_ms == 0 {
return Err("Timeout must be greater than 0".to_string());
}
Ok(())
}
fn api_key(&self) -> Option<&str> {
None
}
fn api_base(&self) -> Option<&str> {
self.server_url.as_deref()
}
fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.timeout_ms)
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
impl TritonConfig {
pub fn new(server_url: impl Into<String>) -> Self {
Self {
server_url: Some(server_url.into()),
..Default::default()
}
}
pub fn with_model(server_url: impl Into<String>, model_name: impl Into<String>) -> Self {
Self {
server_url: Some(server_url.into()),
model_name: Some(model_name.into()),
..Default::default()
}
}
pub fn get_server_url(&self) -> String {
self.server_url
.clone()
.or_else(|| std::env::var("TRITON_SERVER_URL").ok())
.unwrap_or_else(|| "http://localhost:8000".to_string())
}
pub fn get_model_name(&self) -> Option<String> {
self.model_name
.clone()
.or_else(|| std::env::var("TRITON_MODEL_NAME").ok())
}
pub fn get_model_version(&self) -> Option<String> {
self.model_version
.clone()
.or_else(|| std::env::var("TRITON_MODEL_VERSION").ok())
}
pub fn model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = Some(name.into());
self
}
pub fn model_version(mut self, version: impl Into<String>) -> Self {
self.model_version = Some(version.into());
self
}
pub fn timeout_ms(mut self, timeout: u64) -> Self {
self.timeout_ms = timeout;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
}
fn default_timeout_ms() -> u64 {
30000 }
fn default_max_retries() -> u32 {
3
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_triton_config_default() {
let config = TritonConfig::default();
assert!(config.server_url.is_none());
assert!(config.model_name.is_none());
assert!(config.model_version.is_none());
assert_eq!(config.timeout_ms, 30000);
assert_eq!(config.max_retries, 3);
assert!(!config.debug);
}
#[test]
fn test_triton_config_new() {
let config = TritonConfig::new("http://localhost:8000");
assert_eq!(config.server_url, Some("http://localhost:8000".to_string()));
}
#[test]
fn test_triton_config_with_model() {
let config = TritonConfig::with_model("http://localhost:8000", "my-model");
assert_eq!(config.server_url, Some("http://localhost:8000".to_string()));
assert_eq!(config.model_name, Some("my-model".to_string()));
}
#[test]
fn test_triton_config_builder() {
let config = TritonConfig::new("http://localhost:8000")
.model_name("llama-7b")
.model_version("1")
.timeout_ms(60000)
.header("Authorization", "Bearer token");
assert_eq!(config.server_url, Some("http://localhost:8000".to_string()));
assert_eq!(config.model_name, Some("llama-7b".to_string()));
assert_eq!(config.model_version, Some("1".to_string()));
assert_eq!(config.timeout_ms, 60000);
assert_eq!(
config.headers.get("Authorization"),
Some(&"Bearer token".to_string())
);
}
#[test]
fn test_triton_config_get_server_url_default() {
let config = TritonConfig::default();
assert_eq!(config.get_server_url(), "http://localhost:8000");
}
#[test]
fn test_triton_config_get_server_url_custom() {
let config = TritonConfig::new("http://triton.example.com:8000");
assert_eq!(config.get_server_url(), "http://triton.example.com:8000");
}
#[test]
fn test_triton_config_validation_no_url() {
unsafe { std::env::remove_var("TRITON_SERVER_URL") };
let config = TritonConfig::default();
assert!(config.validate().is_err());
}
#[test]
fn test_triton_config_validation_with_url() {
let config = TritonConfig::new("http://localhost:8000");
assert!(config.validate().is_ok());
}
#[test]
fn test_triton_config_validation_zero_timeout() {
let config = TritonConfig {
server_url: Some("http://localhost:8000".to_string()),
timeout_ms: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_triton_config_provider_config_trait() {
let config = TritonConfig {
server_url: Some("http://localhost:8000".to_string()),
timeout_ms: 60000,
max_retries: 5,
..Default::default()
};
assert_eq!(config.api_key(), None); assert_eq!(config.api_base(), Some("http://localhost:8000"));
assert_eq!(config.timeout(), std::time::Duration::from_millis(60000));
assert_eq!(config.max_retries(), 5);
}
#[test]
fn test_triton_config_serialization() {
let config = TritonConfig {
server_url: Some("http://localhost:8000".to_string()),
model_name: Some("llama-7b".to_string()),
model_version: Some("1".to_string()),
timeout_ms: 45000,
max_retries: 2,
debug: true,
headers: std::collections::HashMap::new(),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["server_url"], "http://localhost:8000");
assert_eq!(json["model_name"], "llama-7b");
assert_eq!(json["model_version"], "1");
assert_eq!(json["timeout_ms"], 45000);
assert_eq!(json["max_retries"], 2);
assert_eq!(json["debug"], true);
}
#[test]
fn test_triton_config_deserialization() {
let json = r#"{
"server_url": "http://triton.local:8000",
"model_name": "gpt-j",
"timeout_ms": 60000,
"debug": true
}"#;
let config: TritonConfig = serde_json::from_str(json).unwrap();
assert_eq!(
config.server_url,
Some("http://triton.local:8000".to_string())
);
assert_eq!(config.model_name, Some("gpt-j".to_string()));
assert_eq!(config.timeout_ms, 60000);
assert!(config.debug);
}
}