pub mod cache;
pub mod store;
pub mod validate;
pub use cache::CachingRoutingStore;
#[cfg(feature = "postgres")]
pub use store::PostgresRoutingStore;
pub use store::{InMemoryRoutingStore, NewRoute, RoutingStore, RoutingStoreError};
pub use validate::{validate_capability, ValidationError};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use tt_shared::{ChatCompletionRequest, RequestContext};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Route {
pub id: Uuid,
pub name: String,
pub priority: u32,
pub enabled: bool,
pub when: RouteConditions,
pub then: RouteAction,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RouteConditions {
#[serde(default)]
pub model_in: Vec<String>,
#[serde(default)]
pub input_tokens_lt: Option<u32>,
#[serde(default)]
pub input_tokens_gt: Option<u32>,
#[serde(default)]
pub tag_equals: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub has_images: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub has_audio: Option<bool>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub prompt_contains_any_of: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub estimated_cost_gt: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub estimated_cost_lt: Option<f64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RouteAction {
pub target_model: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub fallbacks: Vec<String>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub disable_cache: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_cost_usd: Option<f64>,
}
#[derive(Debug, Clone, Default)]
pub struct RoutingEngine {
routes: Vec<Route>,
}
impl RoutingEngine {
pub fn new() -> Self {
Self::default()
}
pub fn with_routes(routes: impl IntoIterator<Item = Route>) -> Self {
let mut v: Vec<Route> = routes.into_iter().collect();
v.sort_by_key(|r| std::cmp::Reverse(r.priority));
Self { routes: v }
}
pub fn add(&mut self, route: Route) {
self.routes.push(route);
self.routes.sort_by_key(|r| std::cmp::Reverse(r.priority));
}
pub fn routes(&self) -> &[Route] {
&self.routes
}
pub fn evaluate(
&self,
req: &ChatCompletionRequest,
ctx: &RequestContext,
input_tokens_estimate: u32,
) -> Option<&Route> {
self.evaluate_with_cost(req, ctx, input_tokens_estimate, None)
}
pub fn evaluate_with_cost(
&self,
req: &ChatCompletionRequest,
ctx: &RequestContext,
input_tokens_estimate: u32,
estimated_cost_usd: Option<f64>,
) -> Option<&Route> {
self.routes
.iter()
.find(|r| r.enabled && matches(r, req, ctx, input_tokens_estimate, estimated_cost_usd))
}
pub fn find_by_name(&self, name: &str) -> Option<&Route> {
self.routes.iter().find(|r| r.enabled && r.name == name)
}
}
fn matches(
r: &Route,
req: &ChatCompletionRequest,
ctx: &RequestContext,
input_tokens: u32,
estimated_cost_usd: Option<f64>,
) -> bool {
let c = &r.when;
if !c.model_in.is_empty() && !c.model_in.iter().any(|m| m == &req.model) {
return false;
}
if let Some(t) = c.input_tokens_lt {
if input_tokens >= t {
return false;
}
}
if let Some(t) = c.input_tokens_gt {
if input_tokens <= t {
return false;
}
}
if let Some(t) = c.estimated_cost_gt {
if !matches!(estimated_cost_usd, Some(cost) if cost > t) {
return false;
}
}
if let Some(t) = c.estimated_cost_lt {
if !matches!(estimated_cost_usd, Some(cost) if cost < t) {
return false;
}
}
if let Some(tag) = &c.tag_equals {
if ctx.tag.as_deref() != Some(tag.as_str()) {
return false;
}
}
if let Some(want) = c.has_images {
if tt_shared::capability_check::request_has_images(req) != want {
return false;
}
}
if let Some(want) = c.has_audio {
if tt_shared::capability_check::request_has_audio(req) != want {
return false;
}
}
if !c.prompt_contains_any_of.is_empty() {
let text = tt_shared::capability_check::request_input_text(req).to_lowercase();
if !c
.prompt_contains_any_of
.iter()
.any(|kw| text.contains(&kw.to_lowercase()))
{
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use tt_shared::{
context::{ProviderCredentials, SecretString},
messages::{ContentPart, ImageUrl, InputAudio},
ChatCompletionRequest, Message, MessageContent,
};
fn make_route(name: &str, priority: u32, model_in: Vec<&str>, target: &str) -> Route {
Route {
id: Uuid::now_v7(),
name: name.into(),
priority,
enabled: true,
when: RouteConditions {
model_in: model_in.into_iter().map(String::from).collect(),
..Default::default()
},
then: RouteAction {
target_model: target.into(),
fallbacks: Vec::new(),
disable_cache: false,
max_cost_usd: None,
},
}
}
fn make_req(model: &str) -> ChatCompletionRequest {
ChatCompletionRequest {
model: model.into(),
messages: vec![Message::User {
content: MessageContent::Text("hi".into()),
name: None,
}],
..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
}
}
fn make_req_with_part(model: &str, part: ContentPart) -> ChatCompletionRequest {
ChatCompletionRequest {
model: model.into(),
messages: vec![Message::User {
content: MessageContent::Parts(vec![part]),
name: None,
}],
..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
}
}
fn image_part() -> ContentPart {
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "data:image/png;base64,abc".into(),
detail: None,
},
}
}
fn audio_part() -> ContentPart {
ContentPart::InputAudio {
input_audio: InputAudio {
data: "abc".into(),
format: "wav".into(),
},
}
}
#[test]
fn find_by_name_matches_enabled_route_by_exact_name() {
let enabled = make_route("alpha", 10, vec!["gpt-4o"], "gpt-4o-mini");
let mut disabled = make_route("beta", 10, vec!["gpt-4o"], "gpt-4o-mini");
disabled.enabled = false;
let eng = RoutingEngine::with_routes(vec![enabled, disabled]);
assert!(eng.find_by_name("alpha").is_some());
assert_eq!(eng.find_by_name("alpha").unwrap().name, "alpha");
assert!(
eng.find_by_name("beta").is_none(),
"disabled route not found"
);
assert!(eng.find_by_name("missing").is_none());
}
#[test]
fn has_images_true_matches_only_image_requests() {
let route = Route {
when: RouteConditions {
has_images: Some(true),
..Default::default()
},
..make_route("vision", 10, vec![], "vision-mini")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(
&make_req_with_part("gpt-4o", image_part()),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_none());
}
#[test]
fn has_images_false_matches_only_non_image_requests() {
let route = Route {
when: RouteConditions {
has_images: Some(false),
..Default::default()
},
..make_route("text", 10, vec![], "cheap")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_some());
assert!(eng
.evaluate(
&make_req_with_part("gpt-4o", image_part()),
&make_ctx(None),
100
)
.is_none());
}
#[test]
fn has_audio_true_matches_only_audio_requests() {
let route = Route {
when: RouteConditions {
has_audio: Some(true),
..Default::default()
},
..make_route("audio", 10, vec![], "audio-model")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(
&make_req_with_part("gpt-4o", audio_part()),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(
&make_req_with_part("gpt-4o", image_part()),
&make_ctx(None),
100
)
.is_none());
}
#[test]
fn modality_anded_with_model_in() {
let route = Route {
when: RouteConditions {
model_in: vec!["gpt-4o".into()],
has_images: Some(true),
..Default::default()
},
..make_route("both", 10, vec!["gpt-4o"], "vision-mini")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(
&make_req_with_part("gpt-4o", image_part()),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_none());
assert!(eng
.evaluate(
&make_req_with_part("other", image_part()),
&make_ctx(None),
100
)
.is_none());
}
fn make_ctx(tag: Option<&str>) -> RequestContext {
RequestContext {
trace_id: Uuid::now_v7(),
org_id: Uuid::nil(),
api_key_id: Uuid::nil(),
credentials: ProviderCredentials {
api_key: SecretString::new(""),
base_url: None,
extra_headers: Vec::new(),
},
tag: tag.map(String::from),
deadline: None,
}
}
#[test]
fn empty_engine_matches_nothing() {
let eng = RoutingEngine::new();
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_none());
}
#[test]
fn model_in_matches() {
let eng = RoutingEngine::with_routes(vec![make_route(
"to-mini",
10,
vec!["gpt-4o"],
"gpt-4o-mini",
)]);
let m = eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.expect("should match");
assert_eq!(m.then.target_model, "gpt-4o-mini");
}
#[test]
fn priority_descending_first_match_wins() {
let eng = RoutingEngine::with_routes(vec![
make_route("low", 1, vec!["gpt-4o"], "low-target"),
make_route("high", 100, vec!["gpt-4o"], "high-target"),
make_route("mid", 50, vec!["gpt-4o"], "mid-target"),
]);
let m = eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.unwrap();
assert_eq!(m.then.target_model, "high-target");
}
#[test]
fn disabled_route_skipped() {
let mut route = make_route("disabled", 100, vec!["gpt-4o"], "never");
route.enabled = false;
let eng = RoutingEngine::with_routes(vec![
route,
make_route("enabled", 10, vec!["gpt-4o"], "winner"),
]);
let m = eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.unwrap();
assert_eq!(m.then.target_model, "winner");
}
#[test]
fn token_lt_filters() {
let route = Route {
when: RouteConditions {
model_in: vec!["gpt-4o".into()],
input_tokens_lt: Some(500),
..Default::default()
},
..make_route("short-only", 10, vec!["gpt-4o"], "gpt-4o-mini")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_some());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 600)
.is_none());
}
#[test]
fn token_gt_filters() {
let route = Route {
when: RouteConditions {
model_in: vec!["gpt-4o".into()],
input_tokens_gt: Some(1000),
..Default::default()
},
..make_route("long-only", 10, vec!["gpt-4o"], "claude-opus-4-7")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 500)
.is_none());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 1500)
.is_some());
}
#[test]
fn tag_equals_filters() {
let route = Route {
when: RouteConditions {
tag_equals: Some("background".into()),
..Default::default()
},
..make_route("bg-only", 10, vec![], "cheap-model")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
.is_none());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(Some("background")), 100)
.is_some());
assert!(eng
.evaluate(&make_req("gpt-4o"), &make_ctx(Some("foreground")), 100)
.is_none());
}
#[test]
fn empty_model_in_matches_any_model() {
let route = make_route("any", 10, vec![], "target");
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(&make_req("claude-sonnet-4-6"), &make_ctx(None), 100)
.is_some());
}
#[test]
fn route_action_minimal_serializes_without_new_fields() {
let a = RouteAction {
target_model: "x".into(),
fallbacks: Vec::new(),
disable_cache: false,
max_cost_usd: None,
};
let json = serde_json::to_string(&a).unwrap();
assert_eq!(
json, r#"{"target_model":"x"}"#,
"empty fallbacks must be omitted from JSON"
);
}
#[test]
fn route_action_backward_compat_deserialize() {
let json = r#"{"target_model":"gpt-4o-mini"}"#;
let a: RouteAction = serde_json::from_str(json).unwrap();
assert_eq!(a.target_model, "gpt-4o-mini");
assert!(a.fallbacks.is_empty(), "fallbacks must default to empty");
}
#[test]
fn route_action_full_round_trip() {
let original = RouteAction {
target_model: "claude-haiku-4-5".into(),
fallbacks: vec!["gpt-4o-mini".into(), "gemini-flash".into()],
disable_cache: false,
max_cost_usd: None,
};
let json = serde_json::to_string(&original).unwrap();
assert!(
json.contains("\"fallbacks\""),
"fallbacks must be present: {json}"
);
let roundtripped: RouteAction = serde_json::from_str(&json).unwrap();
assert_eq!(roundtripped.target_model, original.target_model);
assert_eq!(roundtripped.fallbacks, original.fallbacks);
}
#[test]
fn route_action_disable_cache_defaults_false_and_omits() {
let a = RouteAction {
target_model: "x".into(),
fallbacks: Vec::new(),
disable_cache: false,
max_cost_usd: None,
};
assert_eq!(
serde_json::to_string(&a).unwrap(),
r#"{"target_model":"x"}"#
);
let parsed: RouteAction = serde_json::from_str(r#"{"target_model":"m"}"#).unwrap();
assert!(!parsed.disable_cache);
let b = RouteAction {
disable_cache: true,
..a
};
assert!(serde_json::to_string(&b)
.unwrap()
.contains("\"disable_cache\":true"));
}
#[test]
fn route_action_cross_type_wire_compat() {
let plan_side_json = r#"{"target_model":"claude-3-5-haiku","fallbacks":["gpt-4o-mini"]}"#;
let gateway_action: RouteAction = serde_json::from_str(plan_side_json).unwrap();
assert_eq!(gateway_action.target_model, "claude-3-5-haiku");
assert_eq!(gateway_action.fallbacks, vec!["gpt-4o-mini"]);
let reemitted = serde_json::to_string(&gateway_action).unwrap();
assert_eq!(reemitted, plan_side_json);
}
#[test]
fn route_action_legacy_force_cache_layer_is_ignored() {
let legacy =
r#"{"target_model":"claude-3-5-haiku","fallbacks":["x"],"force_cache_layer":"l1"}"#;
let a: RouteAction = serde_json::from_str(legacy).unwrap();
assert_eq!(a.target_model, "claude-3-5-haiku");
assert_eq!(a.fallbacks, vec!["x"]);
let j = serde_json::to_string(&a).unwrap();
assert!(
!j.contains("force_cache_layer"),
"obsolete key must not be re-emitted: {j}"
);
}
fn make_req_text(model: &str, text: &str) -> ChatCompletionRequest {
ChatCompletionRequest {
model: model.into(),
messages: vec![Message::User {
content: MessageContent::Text(text.into()),
name: None,
}],
..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
}
}
#[test]
fn prompt_contains_matches_case_insensitive_any() {
let route = Route {
when: RouteConditions {
prompt_contains_any_of: vec!["confidential".into(), "salary".into()],
..Default::default()
},
..make_route("topic", 10, vec![], "local")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(
&make_req_text("gpt-4o", "This is a Confidential memo"),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(
&make_req_text("gpt-4o", "my SALARY is"),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(
&make_req_text("gpt-4o", "the weather today"),
&make_ctx(None),
100
)
.is_none());
}
#[test]
fn prompt_contains_anded_with_model_in() {
let route = Route {
when: RouteConditions {
model_in: vec!["gpt-4o".into()],
prompt_contains_any_of: vec!["confidential".into()],
..Default::default()
},
..make_route("both", 10, vec!["gpt-4o"], "local")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate(
&make_req_text("gpt-4o", "confidential"),
&make_ctx(None),
100
)
.is_some());
assert!(eng
.evaluate(&make_req_text("gpt-4o", "hello"), &make_ctx(None), 100)
.is_none());
}
#[test]
fn max_cost_usd_round_trips_and_omits_when_none() {
let mut a = make_route("x", 10, vec![], "gpt-4o-mini").then;
assert!(!serde_json::to_string(&a).unwrap().contains("max_cost_usd"));
a.max_cost_usd = Some(0.1);
let j = serde_json::to_string(&a).unwrap();
assert!(j.contains("\"max_cost_usd\":0.1"));
let back: RouteAction = serde_json::from_str(&j).unwrap();
assert_eq!(back.max_cost_usd, Some(0.1));
}
#[test]
fn cost_gt_matches_above_threshold_only() {
let route = Route {
when: RouteConditions {
estimated_cost_gt: Some(0.02),
..Default::default()
},
..make_route("expensive", 10, vec![], "cheaper")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.03))
.is_some());
assert!(eng
.evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.01))
.is_none());
assert!(eng
.evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, None)
.is_none());
}
#[test]
fn cost_lt_anded_with_model_in() {
let route = Route {
when: RouteConditions {
model_in: vec!["gpt-4o".into()],
estimated_cost_lt: Some(0.05),
..Default::default()
},
..make_route("cheap-small", 10, vec![], "target")
};
let eng = RoutingEngine::with_routes(vec![route]);
assert!(eng
.evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.01))
.is_some());
assert!(eng
.evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.09))
.is_none());
assert!(eng
.evaluate_with_cost(&make_req("claude-x"), &make_ctx(None), 100, Some(0.01))
.is_none());
}
}