use super::deployment::DeploymentId;
use std::collections::HashMap;
use std::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FallbackType {
General,
ContextWindow,
ContentPolicy,
RateLimit,
}
#[derive(Debug, Clone)]
pub struct ExecutionResult<T> {
pub result: T,
pub deployment_id: DeploymentId,
pub attempts: u32,
pub model_used: String,
pub used_fallback: bool,
pub latency_us: u64,
}
#[derive(Debug, Default)]
pub struct FallbackConfig {
general: RwLock<HashMap<String, Vec<String>>>,
context_window: RwLock<HashMap<String, Vec<String>>>,
content_policy: RwLock<HashMap<String, Vec<String>>>,
rate_limit: RwLock<HashMap<String, Vec<String>>>,
}
impl FallbackConfig {
pub fn new() -> Self {
Self {
general: RwLock::new(HashMap::new()),
context_window: RwLock::new(HashMap::new()),
content_policy: RwLock::new(HashMap::new()),
rate_limit: RwLock::new(HashMap::new()),
}
}
pub fn add_general(self, model: &str, fallbacks: Vec<String>) -> Self {
self.general
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(model.to_string(), fallbacks);
self
}
pub fn add_context_window(self, model: &str, fallbacks: Vec<String>) -> Self {
self.context_window
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(model.to_string(), fallbacks);
self
}
pub fn add_content_policy(self, model: &str, fallbacks: Vec<String>) -> Self {
self.content_policy
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(model.to_string(), fallbacks);
self
}
pub fn add_rate_limit(self, model: &str, fallbacks: Vec<String>) -> Self {
self.rate_limit
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(model.to_string(), fallbacks);
self
}
#[allow(clippy::type_complexity)]
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
let maps: [(&str, &RwLock<HashMap<String, Vec<String>>>); 4] = [
("general", &self.general),
("context_window", &self.context_window),
("content_policy", &self.content_policy),
("rate_limit", &self.rate_limit),
];
for (label, lock) in &maps {
let map = lock.read().unwrap_or_else(|e| e.into_inner());
for start in map.keys() {
let mut visited = std::collections::HashSet::new();
visited.insert(start.clone());
let mut stack = vec![start.clone()];
while let Some(node) = stack.pop() {
if let Some(targets) = map.get(&node) {
for target in targets {
if !visited.insert(target.clone()) {
errors.push(format!(
"{label}: cycle involving '{start}' -> '{target}'"
));
} else {
stack.push(target.clone());
}
}
}
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
pub fn get_fallbacks_for_type(
&self,
model_name: &str,
fallback_type: FallbackType,
) -> Vec<String> {
let lock = match fallback_type {
FallbackType::General => &self.general,
FallbackType::ContextWindow => &self.context_window,
FallbackType::ContentPolicy => &self.content_policy,
FallbackType::RateLimit => &self.rate_limit,
};
lock.read()
.unwrap_or_else(|e| e.into_inner())
.get(model_name)
.cloned()
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fallback_type_debug() {
assert!(format!("{:?}", FallbackType::General).contains("General"));
assert!(format!("{:?}", FallbackType::ContextWindow).contains("ContextWindow"));
assert!(format!("{:?}", FallbackType::ContentPolicy).contains("ContentPolicy"));
assert!(format!("{:?}", FallbackType::RateLimit).contains("RateLimit"));
}
#[test]
fn test_fallback_type_clone() {
let t = FallbackType::ContextWindow;
let cloned = t;
assert_eq!(cloned, FallbackType::ContextWindow);
}
#[test]
fn test_fallback_type_copy() {
let t = FallbackType::RateLimit;
let copied = t;
assert_eq!(t, copied);
}
#[test]
fn test_fallback_type_eq() {
assert_eq!(FallbackType::General, FallbackType::General);
assert_ne!(FallbackType::General, FallbackType::ContextWindow);
assert_ne!(FallbackType::ContextWindow, FallbackType::ContentPolicy);
assert_ne!(FallbackType::ContentPolicy, FallbackType::RateLimit);
}
#[test]
fn test_execution_result_creation() {
let result = ExecutionResult {
result: "success".to_string(),
deployment_id: "openai-gpt-4".to_string(),
attempts: 1,
model_used: "gpt-4".to_string(),
used_fallback: false,
latency_us: 1500,
};
assert_eq!(result.result, "success");
assert_eq!(result.attempts, 1);
assert_eq!(result.model_used, "gpt-4");
assert!(!result.used_fallback);
assert_eq!(result.latency_us, 1500);
}
#[test]
fn test_execution_result_with_fallback() {
let result = ExecutionResult {
result: 42,
deployment_id: "anthropic-claude-3".to_string(),
attempts: 3,
model_used: "claude-3".to_string(),
used_fallback: true,
latency_us: 5000,
};
assert_eq!(result.result, 42);
assert_eq!(result.attempts, 3);
assert!(result.used_fallback);
}
#[test]
fn test_execution_result_debug() {
let result = ExecutionResult {
result: "test",
deployment_id: "openai-gpt-4".to_string(),
attempts: 1,
model_used: "gpt-4".to_string(),
used_fallback: false,
latency_us: 100,
};
let debug = format!("{:?}", result);
assert!(debug.contains("ExecutionResult"));
assert!(debug.contains("attempts"));
}
#[test]
fn test_execution_result_clone() {
let result = ExecutionResult {
result: vec![1, 2, 3],
deployment_id: "openai-gpt-4".to_string(),
attempts: 2,
model_used: "gpt-4".to_string(),
used_fallback: true,
latency_us: 2000,
};
let cloned = result.clone();
assert_eq!(cloned.result, vec![1, 2, 3]);
assert_eq!(cloned.attempts, 2);
assert!(cloned.used_fallback);
}
#[test]
fn test_fallback_config_new() {
let config = FallbackConfig::new();
assert!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::General)
.is_empty()
);
}
#[test]
fn test_fallback_config_default() {
let config = FallbackConfig::default();
assert!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::General)
.is_empty()
);
}
#[test]
fn test_fallback_config_debug() {
let config = FallbackConfig::new();
let debug = format!("{:?}", config);
assert!(debug.contains("FallbackConfig"));
}
#[test]
fn test_add_general_fallback() {
let config = FallbackConfig::new().add_general(
"gpt-4",
vec!["gpt-3.5-turbo".to_string(), "claude-3".to_string()],
);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
assert_eq!(fallbacks.len(), 2);
assert!(fallbacks.contains(&"gpt-3.5-turbo".to_string()));
assert!(fallbacks.contains(&"claude-3".to_string()));
}
#[test]
fn test_add_context_window_fallback() {
let config =
FallbackConfig::new().add_context_window("gpt-4", vec!["gpt-4-32k".to_string()]);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
assert_eq!(fallbacks.len(), 1);
assert_eq!(fallbacks[0], "gpt-4-32k");
}
#[test]
fn test_add_content_policy_fallback() {
let config =
FallbackConfig::new().add_content_policy("gpt-4", vec!["claude-3".to_string()]);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy);
assert_eq!(fallbacks.len(), 1);
assert_eq!(fallbacks[0], "claude-3");
}
#[test]
fn test_add_rate_limit_fallback() {
let config = FallbackConfig::new().add_rate_limit(
"gpt-4",
vec!["gpt-3.5-turbo".to_string(), "gpt-4-turbo".to_string()],
);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::RateLimit);
assert_eq!(fallbacks.len(), 2);
}
#[test]
fn test_builder_pattern_chaining() {
let config = FallbackConfig::new()
.add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
.add_context_window("gpt-4", vec!["gpt-4-32k".to_string()])
.add_content_policy("gpt-4", vec!["claude-3".to_string()])
.add_rate_limit("gpt-4", vec!["gemini".to_string()]);
assert_eq!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::General)
.len(),
1
);
assert_eq!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow)
.len(),
1
);
assert_eq!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy)
.len(),
1
);
assert_eq!(
config
.get_fallbacks_for_type("gpt-4", FallbackType::RateLimit)
.len(),
1
);
}
#[test]
fn test_multiple_models() {
let config = FallbackConfig::new()
.add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
.add_general("claude-3", vec!["gemini".to_string()]);
let gpt4_fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
let claude_fallbacks = config.get_fallbacks_for_type("claude-3", FallbackType::General);
assert_eq!(gpt4_fallbacks, vec!["gpt-3.5-turbo".to_string()]);
assert_eq!(claude_fallbacks, vec!["gemini".to_string()]);
}
#[test]
fn test_get_fallbacks_nonexistent_model() {
let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
let fallbacks = config.get_fallbacks_for_type("nonexistent", FallbackType::General);
assert!(fallbacks.is_empty());
}
#[test]
fn test_get_fallbacks_wrong_type() {
let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
assert!(fallbacks.is_empty());
}
#[test]
fn test_empty_fallback_list() {
let config = FallbackConfig::new().add_general("gpt-4", vec![]);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
assert!(fallbacks.is_empty());
}
#[test]
fn test_override_fallback() {
let config = FallbackConfig::new()
.add_general("gpt-4", vec!["first".to_string()])
.add_general("gpt-4", vec!["second".to_string()]);
let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
assert_eq!(fallbacks, vec!["second".to_string()]);
}
#[test]
fn test_concurrent_reads() {
use std::sync::Arc;
use std::thread;
let config =
Arc::new(FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5".to_string()]));
let mut handles = vec![];
for _ in 0..10 {
let c = config.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
let _ = c.get_fallbacks_for_type("gpt-4", FallbackType::General);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_special_characters_in_model_name() {
let config =
FallbackConfig::new().add_general("model/v2.0:latest", vec!["backup".to_string()]);
let fallbacks = config.get_fallbacks_for_type("model/v2.0:latest", FallbackType::General);
assert_eq!(fallbacks.len(), 1);
}
#[test]
fn test_unicode_in_model_name() {
let config = FallbackConfig::new().add_general("模型", vec!["备份".to_string()]);
let fallbacks = config.get_fallbacks_for_type("模型", FallbackType::General);
assert_eq!(fallbacks, vec!["备份".to_string()]);
}
#[test]
fn test_empty_model_name() {
let config = FallbackConfig::new().add_general("", vec!["fallback".to_string()]);
let fallbacks = config.get_fallbacks_for_type("", FallbackType::General);
assert_eq!(fallbacks, vec!["fallback".to_string()]);
}
#[test]
fn test_many_fallbacks() {
let fallbacks: Vec<String> = (0..100).map(|i| format!("model_{}", i)).collect();
let config = FallbackConfig::new().add_general("primary", fallbacks.clone());
let result = config.get_fallbacks_for_type("primary", FallbackType::General);
assert_eq!(result.len(), 100);
assert_eq!(result[0], "model_0");
assert_eq!(result[99], "model_99");
}
#[test]
fn test_fallback_type_all_variants() {
let config = FallbackConfig::new()
.add_general("model", vec!["g".to_string()])
.add_context_window("model", vec!["cw".to_string()])
.add_content_policy("model", vec!["cp".to_string()])
.add_rate_limit("model", vec!["rl".to_string()]);
assert_eq!(
config.get_fallbacks_for_type("model", FallbackType::General),
vec!["g"]
);
assert_eq!(
config.get_fallbacks_for_type("model", FallbackType::ContextWindow),
vec!["cw"]
);
assert_eq!(
config.get_fallbacks_for_type("model", FallbackType::ContentPolicy),
vec!["cp"]
);
assert_eq!(
config.get_fallbacks_for_type("model", FallbackType::RateLimit),
vec!["rl"]
);
}
}