use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::typed_id::{ModelId, ModelRouterId};
#[cfg(feature = "openapi")]
use utoipa::ToSchema;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "lowercase")]
pub enum ModelRouterStatus {
Active,
Archived,
Deleted,
}
impl std::fmt::Display for ModelRouterStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelRouterStatus::Active => write!(f, "active"),
ModelRouterStatus::Archived => write!(f, "archived"),
ModelRouterStatus::Deleted => write!(f, "deleted"),
}
}
}
impl From<&str> for ModelRouterStatus {
fn from(s: &str) -> Self {
match s {
"archived" => ModelRouterStatus::Archived,
"deleted" => ModelRouterStatus::Deleted,
_ => ModelRouterStatus::Active,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum ModelRouterStrategy {
Single,
OrderedFallback,
Weighted,
Rules,
Custom,
}
impl std::fmt::Display for ModelRouterStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelRouterStrategy::Single => write!(f, "single"),
ModelRouterStrategy::OrderedFallback => write!(f, "ordered_fallback"),
ModelRouterStrategy::Weighted => write!(f, "weighted"),
ModelRouterStrategy::Rules => write!(f, "rules"),
ModelRouterStrategy::Custom => write!(f, "custom"),
}
}
}
impl ModelRouterStrategy {
pub fn parse(s: &str) -> Result<Self, String> {
match s {
"single" => Ok(ModelRouterStrategy::Single),
"ordered_fallback" => Ok(ModelRouterStrategy::OrderedFallback),
"weighted" => Ok(ModelRouterStrategy::Weighted),
"rules" => Ok(ModelRouterStrategy::Rules),
"custom" => Ok(ModelRouterStrategy::Custom),
other => Err(format!(
"unknown model router strategy '{other}'; expected one of single, ordered_fallback, weighted, rules, custom"
)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct ModelRouter {
#[serde(rename = "id")]
#[cfg_attr(
feature = "openapi",
schema(value_type = String, example = "mrtr_01933b5a000070008000000000000001")
)]
pub public_id: ModelRouterId,
#[serde(skip, default = "Uuid::nil")]
pub internal_id: Uuid,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default = "default_empty_object")]
pub param_schema: serde_json::Value,
pub status: ModelRouterStatus,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub archived_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub routes: Vec<ModelRouterRoute>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct ModelRouterRoute {
pub id: Uuid,
pub key: String,
pub purpose: String,
pub when_to_use: String,
pub strategy: ModelRouterStrategy,
pub position: i32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub candidates: Vec<ModelRouterCandidate>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct ModelRouterCandidate {
pub id: Uuid,
#[cfg_attr(
feature = "openapi",
schema(value_type = String, example = "model_01933b5a000070008000000000000001")
)]
pub model_id: ModelId,
#[serde(default = "default_empty_object")]
pub request_overrides: serde_json::Value,
#[serde(default = "default_weight")]
pub weight: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rules: Option<serde_json::Value>,
pub position: i32,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
fn default_weight() -> i32 {
1
}
fn default_empty_object() -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::new())
}
pub const MAX_ROUTE_KEY_LEN: usize = 64;
pub fn validate_route_key(key: &str) -> Result<(), String> {
if key.is_empty() {
return Err("route key must not be empty".into());
}
if key.len() > MAX_ROUTE_KEY_LEN {
return Err(format!(
"route key must be at most {MAX_ROUTE_KEY_LEN} characters"
));
}
if !key
.bytes()
.all(|b| b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-')
{
return Err("route key must contain only lowercase letters, digits, and hyphens".into());
}
if key.starts_with('-') || key.ends_with('-') {
return Err("route key must not start or end with a hyphen".into());
}
Ok(())
}
pub fn validate_candidate_shape(
candidate: &ModelRouterCandidate,
strategy: ModelRouterStrategy,
) -> Result<(), String> {
if candidate.weight < 0 {
return Err(format!(
"candidate.weight must be non-negative, got {}",
candidate.weight
));
}
match strategy {
ModelRouterStrategy::Rules => {
if candidate.rules.is_none() {
return Err(
"candidates under a 'rules' strategy must have a rules document set"
.to_string(),
);
}
}
ModelRouterStrategy::Single
| ModelRouterStrategy::OrderedFallback
| ModelRouterStrategy::Weighted
| ModelRouterStrategy::Custom => {
}
}
Ok(())
}
pub fn validate_route_shape(route: &ModelRouterRoute) -> Result<(), String> {
validate_route_key(&route.key)?;
if matches!(route.strategy, ModelRouterStrategy::Single) && route.candidates.len() != 1 {
return Err(format!(
"route '{}' has strategy 'single' but {} candidates; single-strategy routes must have exactly one candidate",
route.key,
route.candidates.len()
));
}
for candidate in &route.candidates {
validate_candidate_shape(candidate, route.strategy)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn now() -> DateTime<Utc> {
DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap()
}
fn candidate(weight: i32, rules: Option<serde_json::Value>) -> ModelRouterCandidate {
ModelRouterCandidate {
id: Uuid::nil(),
model_id: ModelId::from_seed(1),
request_overrides: serde_json::Value::Null,
weight,
rules,
position: 0,
created_at: now(),
updated_at: now(),
}
}
#[test]
fn status_round_trip() {
assert_eq!(ModelRouterStatus::from("active").to_string(), "active");
assert_eq!(ModelRouterStatus::from("archived").to_string(), "archived");
assert_eq!(ModelRouterStatus::from("deleted").to_string(), "deleted");
assert_eq!(ModelRouterStatus::from("unknown").to_string(), "active");
}
#[test]
fn strategy_parse_round_trip() {
for s in ["single", "ordered_fallback", "weighted", "rules", "custom"] {
assert_eq!(ModelRouterStrategy::parse(s).unwrap().to_string(), s);
}
}
#[test]
fn strategy_parse_rejects_unknown() {
let err = ModelRouterStrategy::parse("invalid").unwrap_err();
assert!(err.contains("unknown model router strategy"));
}
#[test]
fn route_key_accepts_canonical_keys() {
for key in ["base", "utility", "analysis", "review", "fast-path", "v1"] {
assert!(validate_route_key(key).is_ok(), "should accept key '{key}'");
}
}
#[test]
fn route_key_rejects_empty() {
assert!(validate_route_key("").is_err());
}
#[test]
fn route_key_rejects_uppercase() {
assert!(validate_route_key("Analysis").is_err());
}
#[test]
fn route_key_rejects_underscore() {
assert!(validate_route_key("fast_path").is_err());
}
#[test]
fn route_key_rejects_leading_hyphen() {
assert!(validate_route_key("-fast").is_err());
}
#[test]
fn route_key_rejects_trailing_hyphen() {
assert!(validate_route_key("fast-").is_err());
}
#[test]
fn route_key_rejects_too_long() {
let key = "a".repeat(MAX_ROUTE_KEY_LEN + 1);
assert!(validate_route_key(&key).is_err());
}
#[test]
fn candidate_shape_rejects_negative_weight() {
let cand = candidate(-1, None);
assert!(validate_candidate_shape(&cand, ModelRouterStrategy::Weighted).is_err());
}
#[test]
fn candidate_shape_rules_strategy_requires_rules_doc() {
let cand = candidate(1, None);
let err = validate_candidate_shape(&cand, ModelRouterStrategy::Rules).unwrap_err();
assert!(err.contains("rules"));
}
#[test]
fn candidate_shape_rules_strategy_accepts_rules_doc() {
let cand = candidate(1, Some(serde_json::json!({ "if": { "tier": "fast" } })));
assert!(validate_candidate_shape(&cand, ModelRouterStrategy::Rules).is_ok());
}
#[test]
fn route_shape_rejects_single_with_multiple_candidates() {
let route = ModelRouterRoute {
id: Uuid::nil(),
key: "base".into(),
purpose: "default route".into(),
when_to_use: "use this when no specific route fits".into(),
strategy: ModelRouterStrategy::Single,
position: 0,
candidates: vec![candidate(1, None), candidate(1, None)],
created_at: now(),
updated_at: now(),
};
let err = validate_route_shape(&route).unwrap_err();
assert!(err.contains("single"));
}
#[test]
fn route_shape_rejects_single_with_zero_candidates() {
let route = ModelRouterRoute {
id: Uuid::nil(),
key: "base".into(),
purpose: "default route".into(),
when_to_use: "use this when no specific route fits".into(),
strategy: ModelRouterStrategy::Single,
position: 0,
candidates: vec![],
created_at: now(),
updated_at: now(),
};
let err = validate_route_shape(&route).unwrap_err();
assert!(err.contains("single"));
}
#[test]
fn route_shape_accepts_single_with_exactly_one_candidate() {
let route = ModelRouterRoute {
id: Uuid::nil(),
key: "base".into(),
purpose: "default route".into(),
when_to_use: "use this when no specific route fits".into(),
strategy: ModelRouterStrategy::Single,
position: 0,
candidates: vec![candidate(1, None)],
created_at: now(),
updated_at: now(),
};
assert!(validate_route_shape(&route).is_ok());
}
#[test]
fn route_shape_accepts_ordered_fallback_with_multiple_candidates() {
let route = ModelRouterRoute {
id: Uuid::nil(),
key: "base".into(),
purpose: "default route".into(),
when_to_use: "use this when no specific route fits".into(),
strategy: ModelRouterStrategy::OrderedFallback,
position: 0,
candidates: vec![candidate(1, None), candidate(1, None)],
created_at: now(),
updated_at: now(),
};
assert!(validate_route_shape(&route).is_ok());
}
#[test]
fn candidate_default_weight_is_one() {
let json = r#"{
"id": "00000000-0000-0000-0000-000000000000",
"model_id": "model_00000000000000000000000000000001",
"position": 0,
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}"#;
let cand: ModelRouterCandidate = serde_json::from_str(json).unwrap();
assert_eq!(cand.weight, 1);
}
#[test]
fn candidate_default_request_overrides_is_empty_object() {
let json = r#"{
"id": "00000000-0000-0000-0000-000000000000",
"model_id": "model_00000000000000000000000000000001",
"position": 0,
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}"#;
let cand: ModelRouterCandidate = serde_json::from_str(json).unwrap();
assert!(
cand.request_overrides.is_object(),
"expected default request_overrides to be a JSON object, got {:?}",
cand.request_overrides
);
assert_eq!(cand.request_overrides.as_object().unwrap().len(), 0);
}
#[test]
fn router_default_param_schema_is_empty_object() {
let json = r#"{
"id": "mrtr_00000000000000000000000000000001",
"name": "default",
"status": "active",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}"#;
let router: ModelRouter = serde_json::from_str(json).unwrap();
assert!(
router.param_schema.is_object(),
"expected default param_schema to be a JSON object, got {:?}",
router.param_schema
);
assert_eq!(router.param_schema.as_object().unwrap().len(), 0);
}
}