use std::cell::RefCell;
use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use serde_json::json;
use crate::value::{ErrorCategory, VmError, VmValue};
use super::api::{LlmCallOptions, LlmRoutePolicy};
use super::cost::{calculate_cost_for_provider, peek_total_cost, LlmBudgetEnvelope};
const ROUTING_POLICY_TAG: &str = "__routing_policy__";
const HANDLE_KEY: &str = "__handle__";
const DEFAULT_RACE_PRIMARY_TIMEOUT_MS: u64 = 120_000;
const DEFAULT_FAILOVER_STATUSES: &[u16] = &[408, 429, 500, 502, 503, 504];
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum BudgetExceedAction {
Abort,
Skip,
Warn,
}
impl BudgetExceedAction {
fn parse(value: &str) -> Result<Self, VmError> {
match value.trim().to_ascii_lowercase().as_str() {
"abort" | "" => Ok(Self::Abort),
"skip" | "downgrade" => Ok(Self::Skip),
"warn" => Ok(Self::Warn),
other => Err(runtime_error(format!(
"routing_policy.budget.on_exceed: expected one of abort|skip|warn, got {other:?}"
))),
}
}
fn as_str(self) -> &'static str {
match self {
Self::Abort => "abort",
Self::Skip => "skip",
Self::Warn => "warn",
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct ChainLink {
pub provider: String,
pub model: String,
pub timeout_ms: Option<u64>,
pub label: Option<String>,
}
impl ChainLink {
pub(crate) fn display_label(&self) -> String {
self.label
.clone()
.unwrap_or_else(|| format!("{}:{}", self.provider, self.model))
}
}
#[derive(Clone, Debug, Default)]
pub(crate) struct FailoverRules {
pub on_status: Vec<u16>,
pub on_timeout_ms: Option<u64>,
pub on_error_kinds: Vec<String>,
pub max_attempts: Option<usize>,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct LatencyRules {
pub target_p95_ms: Option<u64>,
pub race_after_ms: Option<u64>,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct BudgetRules {
pub per_call_usd: Option<f64>,
pub session_usd: Option<f64>,
pub on_exceed: Option<BudgetExceedAction>,
}
impl BudgetRules {
pub(crate) fn on_exceed_or_abort(&self) -> BudgetExceedAction {
self.on_exceed.unwrap_or(BudgetExceedAction::Abort)
}
pub(crate) fn envelope(&self) -> Option<LlmBudgetEnvelope> {
let envelope = LlmBudgetEnvelope {
max_cost_usd: self.per_call_usd,
total_budget_usd: self.session_usd,
max_input_tokens: None,
max_output_tokens: None,
};
if envelope.max_cost_usd.is_none() && envelope.total_budget_usd.is_none() {
None
} else {
Some(envelope)
}
}
}
#[derive(Clone, Debug, Default)]
pub(crate) struct ObserveRules {
pub emit_event: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingPolicyConfig {
pub chain: Vec<ChainLink>,
pub failover: FailoverRules,
pub latency: LatencyRules,
pub budget: BudgetRules,
pub observe: ObserveRules,
pub label: String,
}
impl RoutingPolicyConfig {
pub(crate) fn dispatch_label(&self) -> String {
self.observe
.emit_event
.clone()
.unwrap_or_else(|| "llm.routing".to_string())
}
}
thread_local! {
static POLICY_REGISTRY: RefCell<BTreeMap<u64, Rc<RoutingPolicyConfig>>> =
const { RefCell::new(BTreeMap::new()) };
}
static POLICY_COUNTER: AtomicU64 = AtomicU64::new(1);
fn intern_policy(policy: RoutingPolicyConfig) -> u64 {
let handle = POLICY_COUNTER.fetch_add(1, Ordering::SeqCst);
POLICY_REGISTRY.with(|registry| {
registry.borrow_mut().insert(handle, Rc::new(policy));
});
handle
}
fn lookup_policy(handle: u64) -> Option<Rc<RoutingPolicyConfig>> {
POLICY_REGISTRY.with(|registry| registry.borrow().get(&handle).cloned())
}
#[cfg(test)]
pub(crate) fn clear_policy_registry() {
POLICY_REGISTRY.with(|registry| registry.borrow_mut().clear());
}
fn runtime_error(message: String) -> VmError {
VmError::Thrown(VmValue::String(Rc::from(message)))
}
fn parse_label(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<String, VmError> {
match dict.get(key) {
Some(VmValue::String(s)) => Ok(s.to_string()),
Some(VmValue::Nil) | None => Ok(String::new()),
Some(other) => Err(runtime_error(format!(
"routing_policy.{key}: expected a string, got {}",
other.type_name()
))),
}
}
fn parse_pos_u64(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<Option<u64>, VmError> {
match dict.get(key) {
Some(VmValue::Nil) | None => Ok(None),
Some(VmValue::Int(n)) if *n >= 0 => Ok(Some(*n as u64)),
Some(VmValue::Float(f)) if f.is_finite() && *f >= 0.0 => Ok(Some(*f as u64)),
Some(other) => Err(runtime_error(format!(
"routing_policy.{key}: expected a non-negative integer (got {})",
other.type_name()
))),
}
}
fn parse_pos_usize(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<Option<usize>, VmError> {
parse_pos_u64(dict, key).map(|opt| opt.map(|v| v as usize))
}
fn parse_pos_f64(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<Option<f64>, VmError> {
match dict.get(key) {
Some(VmValue::Nil) | None => Ok(None),
Some(VmValue::Int(n)) if *n >= 0 => Ok(Some(*n as f64)),
Some(VmValue::Float(f)) if f.is_finite() && *f >= 0.0 => Ok(Some(*f)),
Some(other) => Err(runtime_error(format!(
"routing_policy.{key}: expected a non-negative number (got {})",
other.type_name()
))),
}
}
fn parse_string_list(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<Vec<String>, VmError> {
match dict.get(key) {
Some(VmValue::Nil) | None => Ok(Vec::new()),
Some(VmValue::List(items)) => {
let mut out = Vec::with_capacity(items.len());
for item in items.iter() {
let text = item.display();
let trimmed = text.trim();
if !trimmed.is_empty() && !out.iter().any(|existing: &String| existing == trimmed) {
out.push(trimmed.to_string());
}
}
Ok(out)
}
Some(VmValue::String(s)) => Ok(s
.split(',')
.map(str::trim)
.filter(|chunk| !chunk.is_empty())
.map(str::to_string)
.collect()),
Some(other) => Err(runtime_error(format!(
"routing_policy.{key}: expected a list of strings (got {})",
other.type_name()
))),
}
}
fn parse_status_list(dict: &BTreeMap<String, VmValue>, key: &str) -> Result<Vec<u16>, VmError> {
let Some(value) = dict.get(key) else {
return Ok(Vec::new());
};
let items = match value {
VmValue::Nil => return Ok(Vec::new()),
VmValue::List(items) => items.clone(),
_ => {
return Err(runtime_error(format!(
"routing_policy.failover.{key}: expected a list of HTTP status codes"
)));
}
};
let mut out = Vec::with_capacity(items.len());
for item in items.iter() {
let code = item.as_int().ok_or_else(|| {
runtime_error(format!(
"routing_policy.failover.{key}: expected integer status codes (got {})",
item.type_name()
))
})?;
if !(100..=599).contains(&code) {
return Err(runtime_error(format!(
"routing_policy.failover.{key}: {code} is not a valid HTTP status (100..=599)"
)));
}
out.push(code as u16);
}
Ok(out)
}
fn split_target(target: &str) -> Option<(String, String)> {
let target = target.trim();
let (provider, model) = target.split_once(':')?;
let provider = provider.trim();
let model = model.trim();
if provider.is_empty() || model.is_empty() {
None
} else {
Some((provider.to_string(), model.to_string()))
}
}
fn parse_chain_link(value: &VmValue, idx: usize) -> Result<ChainLink, VmError> {
let dict = match value {
VmValue::Dict(dict) => dict.clone(),
VmValue::String(target) => {
let (provider, model) = split_target(target).ok_or_else(|| {
runtime_error(format!(
"routing_policy.chain[{idx}]: expected {{provider, model}} dict or \"provider:model\" string, got {value:?}"
))
})?;
return Ok(ChainLink {
provider,
model,
timeout_ms: None,
label: None,
});
}
other => {
return Err(runtime_error(format!(
"routing_policy.chain[{idx}]: expected dict or string, got {}",
other.type_name()
)));
}
};
let provider = dict
.get("provider")
.map(|v| v.display())
.unwrap_or_default()
.trim()
.to_string();
let model = dict
.get("model")
.map(|v| v.display())
.unwrap_or_default()
.trim()
.to_string();
if provider.is_empty() || model.is_empty() {
return Err(runtime_error(format!(
"routing_policy.chain[{idx}]: both provider and model are required (got provider={provider:?}, model={model:?})"
)));
}
let timeout_ms = parse_pos_u64(&dict, "timeout_ms")?;
let label_text = parse_label(&dict, "label")?;
let label = if label_text.is_empty() {
None
} else {
Some(label_text)
};
Ok(ChainLink {
provider,
model,
timeout_ms,
label,
})
}
fn parse_failover(value: Option<&VmValue>) -> Result<FailoverRules, VmError> {
let Some(value) = value else {
return Ok(FailoverRules::default());
};
let dict = match value {
VmValue::Nil => return Ok(FailoverRules::default()),
VmValue::Dict(dict) => dict.clone(),
other => {
return Err(runtime_error(format!(
"routing_policy.failover: expected dict, got {}",
other.type_name()
)));
}
};
Ok(FailoverRules {
on_status: parse_status_list(&dict, "on_status")?,
on_timeout_ms: parse_pos_u64(&dict, "on_timeout_ms")?,
on_error_kinds: parse_string_list(&dict, "on_error_kinds")?,
max_attempts: parse_pos_usize(&dict, "max_attempts")?,
})
}
fn parse_latency(value: Option<&VmValue>) -> Result<LatencyRules, VmError> {
let Some(value) = value else {
return Ok(LatencyRules::default());
};
let dict = match value {
VmValue::Nil => return Ok(LatencyRules::default()),
VmValue::Dict(dict) => dict.clone(),
other => {
return Err(runtime_error(format!(
"routing_policy.latency: expected dict, got {}",
other.type_name()
)));
}
};
Ok(LatencyRules {
target_p95_ms: parse_pos_u64(&dict, "target_p95_ms")?,
race_after_ms: parse_pos_u64(&dict, "race_after_ms")?,
})
}
fn parse_budget(value: Option<&VmValue>) -> Result<BudgetRules, VmError> {
let Some(value) = value else {
return Ok(BudgetRules::default());
};
let dict = match value {
VmValue::Nil => return Ok(BudgetRules::default()),
VmValue::Dict(dict) => dict.clone(),
other => {
return Err(runtime_error(format!(
"routing_policy.budget: expected dict, got {}",
other.type_name()
)));
}
};
let on_exceed = match dict.get("on_exceed") {
Some(VmValue::Nil) | None => None,
Some(VmValue::String(s)) => Some(BudgetExceedAction::parse(s)?),
Some(other) => {
return Err(runtime_error(format!(
"routing_policy.budget.on_exceed: expected a string, got {}",
other.type_name()
)));
}
};
Ok(BudgetRules {
per_call_usd: parse_pos_f64(&dict, "per_call_usd")?,
session_usd: parse_pos_f64(&dict, "session_usd")?,
on_exceed,
})
}
fn parse_observe(value: Option<&VmValue>) -> Result<ObserveRules, VmError> {
let Some(value) = value else {
return Ok(ObserveRules::default());
};
let dict = match value {
VmValue::Nil => return Ok(ObserveRules::default()),
VmValue::Dict(dict) => dict.clone(),
other => {
return Err(runtime_error(format!(
"routing_policy.observe: expected dict, got {}",
other.type_name()
)));
}
};
let emit_event = match dict.get("emit_event") {
Some(VmValue::String(s)) => {
let text = s.trim();
if text.is_empty() {
None
} else {
Some(text.to_string())
}
}
Some(VmValue::Nil) | None => None,
Some(other) => {
return Err(runtime_error(format!(
"routing_policy.observe.emit_event: expected a string, got {}",
other.type_name()
)));
}
};
Ok(ObserveRules { emit_event })
}
pub(crate) fn build_routing_policy(config: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
let chain_value = config.get("chain").ok_or_else(|| {
runtime_error("routing_policy: `chain` is required (list of {provider, model})".to_string())
})?;
let chain_items = match chain_value {
VmValue::List(items) => items.clone(),
other => {
return Err(runtime_error(format!(
"routing_policy.chain: expected a list, got {}",
other.type_name()
)));
}
};
if chain_items.is_empty() {
return Err(runtime_error(
"routing_policy.chain: at least one {provider, model} entry is required".to_string(),
));
}
let mut chain = Vec::with_capacity(chain_items.len());
for (idx, item) in chain_items.iter().enumerate() {
chain.push(parse_chain_link(item, idx)?);
}
let failover = parse_failover(config.get("failover"))?;
let latency = parse_latency(config.get("latency"))?;
let budget = parse_budget(config.get("budget"))?;
let observe = parse_observe(config.get("observe"))?;
let label_text = parse_label(config, "label")?;
let label = if label_text.is_empty() {
format!("routing_policy(chain={})", chain.len())
} else {
label_text
};
let mut summary = BTreeMap::new();
summary.insert(ROUTING_POLICY_TAG.to_string(), VmValue::Bool(true));
summary.insert(
"label".to_string(),
VmValue::String(Rc::from(label.clone())),
);
summary.insert("chain".to_string(), chain_summary_value(&chain));
if budget.envelope().is_some() {
summary.insert("budget".to_string(), budget_value(&budget));
}
summary.insert("failover".to_string(), failover_value(&failover));
summary.insert("latency".to_string(), latency_value(&latency));
summary.insert("observe".to_string(), observe_value(&observe));
let parsed = RoutingPolicyConfig {
chain,
failover,
latency,
budget,
observe,
label,
};
let handle = intern_policy(parsed);
summary.insert(HANDLE_KEY.to_string(), VmValue::Int(handle as i64));
Ok(VmValue::Dict(Rc::new(summary)))
}
fn chain_summary_value(chain: &[ChainLink]) -> VmValue {
let items: Vec<VmValue> = chain
.iter()
.map(|link| {
let mut dict = BTreeMap::new();
dict.insert(
"provider".to_string(),
VmValue::String(Rc::from(link.provider.clone())),
);
dict.insert(
"model".to_string(),
VmValue::String(Rc::from(link.model.clone())),
);
if let Some(timeout) = link.timeout_ms {
dict.insert("timeout_ms".to_string(), VmValue::Int(timeout as i64));
}
if let Some(label) = &link.label {
dict.insert(
"label".to_string(),
VmValue::String(Rc::from(label.clone())),
);
}
VmValue::Dict(Rc::new(dict))
})
.collect();
VmValue::List(Rc::new(items))
}
fn failover_value(failover: &FailoverRules) -> VmValue {
let mut dict = BTreeMap::new();
let statuses: Vec<VmValue> = failover
.on_status
.iter()
.map(|s| VmValue::Int(*s as i64))
.collect();
dict.insert("on_status".to_string(), VmValue::List(Rc::new(statuses)));
let kinds: Vec<VmValue> = failover
.on_error_kinds
.iter()
.map(|s| VmValue::String(Rc::from(s.clone())))
.collect();
dict.insert("on_error_kinds".to_string(), VmValue::List(Rc::new(kinds)));
if let Some(ms) = failover.on_timeout_ms {
dict.insert("on_timeout_ms".to_string(), VmValue::Int(ms as i64));
}
if let Some(max) = failover.max_attempts {
dict.insert("max_attempts".to_string(), VmValue::Int(max as i64));
}
VmValue::Dict(Rc::new(dict))
}
fn latency_value(latency: &LatencyRules) -> VmValue {
let mut dict = BTreeMap::new();
if let Some(ms) = latency.target_p95_ms {
dict.insert("target_p95_ms".to_string(), VmValue::Int(ms as i64));
}
if let Some(ms) = latency.race_after_ms {
dict.insert("race_after_ms".to_string(), VmValue::Int(ms as i64));
}
VmValue::Dict(Rc::new(dict))
}
fn budget_value(budget: &BudgetRules) -> VmValue {
let mut dict = BTreeMap::new();
if let Some(v) = budget.per_call_usd {
dict.insert("per_call_usd".to_string(), VmValue::Float(v));
}
if let Some(v) = budget.session_usd {
dict.insert("session_usd".to_string(), VmValue::Float(v));
}
dict.insert(
"on_exceed".to_string(),
VmValue::String(Rc::from(budget.on_exceed_or_abort().as_str())),
);
VmValue::Dict(Rc::new(dict))
}
fn observe_value(observe: &ObserveRules) -> VmValue {
let mut dict = BTreeMap::new();
if let Some(event) = &observe.emit_event {
dict.insert(
"emit_event".to_string(),
VmValue::String(Rc::from(event.clone())),
);
}
VmValue::Dict(Rc::new(dict))
}
pub(crate) fn extract_routing_policy(
options: Option<&BTreeMap<String, VmValue>>,
) -> Result<Option<Rc<RoutingPolicyConfig>>, VmError> {
let Some(opts) = options else {
return Ok(None);
};
let Some(value) = opts.get("routing") else {
return Ok(None);
};
let dict = match value {
VmValue::Nil | VmValue::Bool(false) => return Ok(None),
VmValue::Dict(dict) => dict,
other => {
return Err(runtime_error(format!(
"llm_call(... routing: ...): expected a routing_policy(...) value, got {}",
other.type_name()
)));
}
};
match dict.get(ROUTING_POLICY_TAG) {
Some(VmValue::Bool(true)) => {}
_ => {
return Err(runtime_error(
"llm_call(... routing: ...): pass the result of routing_policy({...}); the routing key does not accept a bare dict".to_string(),
));
}
}
let handle = dict
.get(HANDLE_KEY)
.and_then(|v| v.as_int())
.ok_or_else(|| {
runtime_error(
"llm_call(... routing: ...): routing policy handle missing — re-create it with routing_policy({...})".to_string(),
)
})?;
let policy = lookup_policy(handle as u64).ok_or_else(|| {
runtime_error(
"llm_call(... routing: ...): routing policy handle expired — re-create it with routing_policy({...})".to_string(),
)
})?;
Ok(Some(policy))
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingTrace {
pub label: String,
pub attempts: Vec<RoutingAttempt>,
pub selected: Option<usize>,
pub session_cost_usd: f64,
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingAttempt {
pub index: usize,
pub provider: String,
pub model: String,
pub label: String,
pub status: AttemptStatus,
pub duration_ms: u64,
pub cost_usd: Option<f64>,
pub error: Option<RoutingErrorSnapshot>,
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum AttemptStatus {
Succeeded,
Failed,
Skipped,
RaceLost,
}
impl AttemptStatus {
fn as_str(self) -> &'static str {
match self {
Self::Succeeded => "succeeded",
Self::Failed => "failed",
Self::Skipped => "skipped",
Self::RaceLost => "race_lost",
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingErrorSnapshot {
pub category: String,
pub message: String,
pub status: Option<u16>,
}
fn matches_failover(rules: &FailoverRules, error: &VmError) -> (bool, RoutingErrorSnapshot) {
let category = crate::value::error_to_category(error);
let message = match error {
VmError::CategorizedError { message, .. } => message.clone(),
VmError::Thrown(VmValue::String(s)) => s.to_string(),
VmError::Thrown(VmValue::Dict(d)) => d
.get("message")
.map(|v| v.display())
.unwrap_or_else(|| error.to_string()),
_ => error.to_string(),
};
let status = extract_status_code(error);
let snapshot = RoutingErrorSnapshot {
category: category.as_str().to_string(),
message: message.clone(),
status,
};
if let Some(code) = status {
if rules.on_status.contains(&code) {
return (true, snapshot);
}
}
if matches!(category, ErrorCategory::Timeout) {
return (true, snapshot);
}
let category_label = category.as_str();
let kind_match = rules.on_error_kinds.iter().any(|kind| {
let normalized = kind.trim().to_ascii_lowercase();
if normalized == category_label {
return true;
}
matches!(
(normalized.as_str(), category.clone()),
("rate_limit", ErrorCategory::RateLimit)
| ("overloaded", ErrorCategory::Overloaded)
| ("transient", ErrorCategory::TransientNetwork)
| ("transient_network", ErrorCategory::TransientNetwork)
| ("network", ErrorCategory::TransientNetwork)
| ("timeout", ErrorCategory::Timeout)
| ("schema_validation", ErrorCategory::SchemaValidation)
| ("auth", ErrorCategory::Auth)
| ("provider_error", ErrorCategory::ServerError)
| ("server_error", ErrorCategory::ServerError)
| ("provider_5xx", ErrorCategory::ServerError)
| ("generic", ErrorCategory::Generic)
| ("budget_exceeded", ErrorCategory::BudgetExceeded)
| ("circuit_open", ErrorCategory::CircuitOpen)
| ("egress_blocked", ErrorCategory::EgressBlocked)
| ("cancelled", ErrorCategory::Cancelled)
| ("tool_error", ErrorCategory::ToolError)
| ("tool_rejected", ErrorCategory::ToolRejected)
| ("not_found", ErrorCategory::NotFound)
)
});
if kind_match {
return (true, snapshot);
}
let defaults_active = rules.on_status.is_empty()
&& rules.on_error_kinds.is_empty()
&& rules.on_timeout_ms.is_none();
if defaults_active {
let by_status = status
.map(|code| DEFAULT_FAILOVER_STATUSES.contains(&code))
.unwrap_or(false);
let by_category = matches!(
category,
ErrorCategory::RateLimit
| ErrorCategory::Overloaded
| ErrorCategory::TransientNetwork
| ErrorCategory::Timeout
| ErrorCategory::ServerError
);
if by_status || by_category {
return (true, snapshot);
}
}
(false, snapshot)
}
fn extract_status_code(error: &VmError) -> Option<u16> {
let message = error.to_string();
extract_status_from_text(&message)
}
fn extract_status_from_text(message: &str) -> Option<u16> {
let lowered = message.to_ascii_lowercase();
let needles = ["http ", "status_code: ", "status: ", "status "];
for needle in needles.iter() {
if let Some(idx) = lowered.find(needle) {
let tail = &message[idx + needle.len()..];
if let Some(code) = parse_leading_status(tail) {
return Some(code);
}
}
}
parse_leading_status(message)
}
fn parse_leading_status(text: &str) -> Option<u16> {
let text = text.trim_start();
let digits: String = text.chars().take_while(|c| c.is_ascii_digit()).collect();
if digits.is_empty() {
return None;
}
digits
.parse::<u16>()
.ok()
.filter(|code| (100..=599).contains(code))
}
fn emit_routing_event(
dispatch: &str,
event: &str,
metadata: serde_json::Map<String, serde_json::Value>,
) {
let category = format!("{}.{}", dispatch, event);
let mut meta: BTreeMap<String, serde_json::Value> = metadata.into_iter().collect();
meta.entry("event".to_string())
.or_insert_with(|| serde_json::Value::String(event.to_string()));
crate::events::log_info_meta(&category, "", meta);
}
fn budget_overrun_snapshot(
cap: f64,
projected: f64,
session: f64,
kind: &str,
) -> RoutingErrorSnapshot {
RoutingErrorSnapshot {
category: "budget_exceeded".to_string(),
message: format!(
"{kind} budget exceeded (cap=${cap:.6}, projected=${projected:.6}, session=${session:.6})"
),
status: None,
}
}
fn link_options(
base: &LlmCallOptions,
policy: &RoutingPolicyConfig,
link: &ChainLink,
) -> LlmCallOptions {
let mut opts = base.clone();
opts.provider = link.provider.clone();
opts.model = link.model.clone();
opts.api_key = String::new();
opts.route_policy = LlmRoutePolicy::Always(format!("{}:{}", link.provider, link.model));
opts.fallback_chain = Vec::new();
opts.route_fallbacks = Vec::new();
opts.routing_decision = None;
opts.routing_policy = None;
if let Some(timeout_ms) = link.timeout_ms.or(policy.failover.on_timeout_ms) {
let secs = (timeout_ms / 1000).max(1);
opts.timeout = Some(secs);
}
if let Some(envelope) = policy.budget.envelope() {
let mut merged = opts.budget.clone().unwrap_or_default();
if envelope.max_cost_usd.is_some() {
merged.max_cost_usd = envelope.max_cost_usd;
}
if envelope.total_budget_usd.is_some() {
merged.total_budget_usd = envelope.total_budget_usd;
}
opts.budget = Some(merged);
}
if let Ok(key) = super::helpers::resolve_api_key(&link.provider) {
opts.api_key = key;
}
opts
}
fn check_link_budget(
policy: &RoutingPolicyConfig,
opts: &LlmCallOptions,
dispatch: &str,
attempt_idx: usize,
link_label: &str,
trace_attempts: &mut Vec<RoutingAttempt>,
) -> Result<bool, (VmError, RoutingErrorSnapshot)> {
let Some(rules_envelope) = policy.budget.envelope() else {
return Ok(true);
};
let session_cost = peek_total_cost();
let projection = super::cost::project_llm_call_cost(opts, session_cost);
let action = policy.budget.on_exceed_or_abort();
let mut breach = None::<(super::cost::BudgetLimitKind, f64, &'static str)>;
if let Some(max) = rules_envelope.max_cost_usd {
if projection.projected_cost_usd > max {
breach = Some((super::cost::BudgetLimitKind::PerCallCost, max, "per_call"));
}
}
if breach.is_none() {
if let Some(max) = rules_envelope.total_budget_usd {
if session_cost + projection.projected_cost_usd > max {
breach = Some((super::cost::BudgetLimitKind::TotalCost, max, "session"));
}
}
}
let Some((limit_kind, limit_value, kind_label)) = breach else {
return Ok(true);
};
let snapshot = budget_overrun_snapshot(
limit_value,
projection.projected_cost_usd,
session_cost,
kind_label,
);
let mut meta = serde_json::Map::new();
meta.insert("policy".to_string(), json!(policy.label.clone()));
meta.insert("attempt".to_string(), json!(attempt_idx));
meta.insert("provider".to_string(), json!(opts.provider.clone()));
meta.insert("model".to_string(), json!(opts.model.clone()));
meta.insert("link_label".to_string(), json!(link_label));
meta.insert("kind".to_string(), json!(kind_label));
meta.insert("limit_usd".to_string(), json!(limit_value));
meta.insert(
"projected_cost_usd".to_string(),
json!(projection.projected_cost_usd),
);
meta.insert("session_cost_usd".to_string(), json!(session_cost));
meta.insert("on_exceed".to_string(), json!(action.as_str()));
emit_routing_event(dispatch, "budget_exceeded", meta);
match action {
BudgetExceedAction::Abort => Err((
super::cost::budget_exceeded_error(&projection, limit_kind, limit_value),
snapshot,
)),
BudgetExceedAction::Skip => {
trace_attempts.push(RoutingAttempt {
index: attempt_idx,
provider: opts.provider.clone(),
model: opts.model.clone(),
label: link_label.to_string(),
status: AttemptStatus::Skipped,
duration_ms: 0,
cost_usd: None,
error: Some(snapshot),
});
Ok(false)
}
BudgetExceedAction::Warn => Ok(true),
}
}
fn project_link_cost_usd(result: &super::api::LlmResult) -> f64 {
calculate_cost_for_provider(
&result.provider,
&result.model,
result.input_tokens,
result.output_tokens,
)
}
fn duration_ms(elapsed: Duration) -> u64 {
elapsed.as_millis().try_into().unwrap_or(u64::MAX)
}
async fn execute_link(
opts: &LlmCallOptions,
bridge: Option<&Rc<crate::bridge::HostBridge>>,
) -> Result<super::api::LlmResult, VmError> {
let retry_config = super::agent_observe::LlmRetryConfig {
retries: 0,
backoff_ms: 0,
};
super::agent_observe::observed_llm_call(
opts,
None,
bridge,
&retry_config,
None,
false,
bridge.is_some(),
None,
)
.await
}
fn pending_attempt_record(
attempt_no: usize,
link: &ChainLink,
label: &str,
elapsed: Duration,
) -> RoutingAttempt {
RoutingAttempt {
index: attempt_no,
provider: link.provider.clone(),
model: link.model.clone(),
label: label.to_string(),
status: AttemptStatus::Failed,
duration_ms: duration_ms(elapsed),
cost_usd: None,
error: None,
}
}
pub(crate) async fn execute_with_routing(
policy: &RoutingPolicyConfig,
base_opts: LlmCallOptions,
bridge: Option<&Rc<crate::bridge::HostBridge>>,
) -> Result<(super::api::LlmResult, RoutingTrace), VmError> {
let dispatch = policy.dispatch_label();
let mut trace = RoutingTrace {
label: policy.label.clone(),
attempts: Vec::new(),
selected: None,
session_cost_usd: peek_total_cost(),
};
let max_attempts = policy.failover.max_attempts.unwrap_or(policy.chain.len());
if max_attempts == 0 {
return Err(runtime_error(
"routing_policy.failover.max_attempts: must be >= 1".to_string(),
));
}
let mut last_error: Option<VmError> = None;
let mut last_snapshot: Option<RoutingErrorSnapshot> = None;
let mut attempts_used: usize = 0;
let mut decision_meta = serde_json::Map::new();
decision_meta.insert("policy".to_string(), json!(policy.label.clone()));
decision_meta.insert("chain_length".to_string(), json!(policy.chain.len()));
decision_meta.insert("max_attempts".to_string(), json!(max_attempts));
decision_meta.insert(
"chain".to_string(),
serde_json::Value::Array(
policy
.chain
.iter()
.map(|link| {
json!({
"provider": link.provider,
"model": link.model,
"label": link.display_label(),
})
})
.collect(),
),
);
emit_routing_event(&dispatch, "decision", decision_meta);
let mut idx = 0usize;
while idx < policy.chain.len() && attempts_used < max_attempts {
let link = policy.chain[idx].clone();
let opts = link_options(&base_opts, policy, &link);
let link_label = link.display_label();
let mut local_attempts: Vec<RoutingAttempt> = Vec::new();
match check_link_budget(
policy,
&opts,
&dispatch,
attempts_used + 1,
&link_label,
&mut local_attempts,
) {
Ok(true) => {}
Ok(false) => {
trace.attempts.extend(local_attempts);
idx += 1;
attempts_used += 1;
continue;
}
Err((err, snapshot)) => {
trace.attempts.extend(local_attempts);
last_error = Some(err);
last_snapshot = Some(snapshot);
break;
}
}
trace.attempts.extend(local_attempts);
let attempt_no = attempts_used + 1;
let start = std::time::Instant::now();
let mut attempt_meta = serde_json::Map::new();
attempt_meta.insert("policy".to_string(), json!(policy.label.clone()));
attempt_meta.insert("attempt".to_string(), json!(attempt_no));
attempt_meta.insert("provider".to_string(), json!(link.provider.clone()));
attempt_meta.insert("model".to_string(), json!(link.model.clone()));
attempt_meta.insert("link_label".to_string(), json!(link_label.clone()));
emit_routing_event(&dispatch, "attempt", attempt_meta);
let race_after_ms = policy.latency.race_after_ms;
let primary_timeout_ms = link
.timeout_ms
.or(policy.failover.on_timeout_ms)
.unwrap_or(DEFAULT_RACE_PRIMARY_TIMEOUT_MS);
let race_outcome = if let Some(race_after) = race_after_ms {
if idx + 1 < policy.chain.len() && attempts_used + 2 <= max_attempts {
let backup_link = policy.chain[idx + 1].clone();
let backup_opts = link_options(&base_opts, policy, &backup_link);
let backup_label = backup_link.display_label();
Some(
run_race(
&dispatch,
policy,
attempts_used,
&link,
&link_label,
&opts,
bridge,
race_after,
primary_timeout_ms,
backup_label,
backup_opts,
)
.await,
)
} else {
None
}
} else {
None
};
let (result, mut attempt_records) = if let Some(outcome) = race_outcome {
outcome
} else {
let result = execute_link(&opts, bridge).await;
(
result,
vec![pending_attempt_record(
attempt_no,
&link,
&link_label,
start.elapsed(),
)],
)
};
let consumed = attempt_records.len().max(1);
match result {
Ok(value) => {
if let Some(record) = attempt_records
.iter_mut()
.find(|rec| matches!(rec.status, AttemptStatus::Failed) && rec.error.is_none())
{
record.status = AttemptStatus::Succeeded;
record.cost_usd = Some(project_link_cost_usd(&value));
}
let starting_len = trace.attempts.len();
trace.attempts.extend(attempt_records);
trace.selected = trace
.attempts
.iter()
.enumerate()
.skip(starting_len)
.find(|(_, a)| {
matches!(a.status, AttemptStatus::Succeeded)
&& a.provider == value.provider
&& a.model == value.model
})
.map(|(idx, _)| idx);
trace.session_cost_usd = peek_total_cost();
return Ok((value, trace));
}
Err(err) => {
let (eligible, snapshot) = matches_failover(&policy.failover, &err);
if let Some(record) = attempt_records
.iter_mut()
.find(|rec| matches!(rec.status, AttemptStatus::Failed) && rec.error.is_none())
{
record.error = Some(snapshot.clone());
}
trace.attempts.extend(attempt_records);
last_snapshot = Some(snapshot);
attempts_used += consumed;
if !eligible {
last_error = Some(err);
break;
}
last_error = Some(err);
idx += consumed;
continue;
}
}
}
let err = last_error.unwrap_or_else(|| {
runtime_error("routing_policy: chain exhausted with no attempts (empty chain?)".to_string())
});
let mut meta = serde_json::Map::new();
meta.insert("policy".to_string(), json!(policy.label.clone()));
meta.insert("attempts".to_string(), json!(trace.attempts.len()));
if let Some(snapshot) = last_snapshot {
meta.insert("last_error_category".to_string(), json!(snapshot.category));
meta.insert("last_error_message".to_string(), json!(snapshot.message));
if let Some(status) = snapshot.status {
meta.insert("last_error_status".to_string(), json!(status));
}
}
emit_routing_event(&dispatch, "exhausted", meta);
Err(err)
}
#[allow(clippy::too_many_arguments)]
async fn run_race(
dispatch: &str,
policy: &RoutingPolicyConfig,
attempts_used: usize,
link: &ChainLink,
link_label: &str,
opts: &LlmCallOptions,
bridge: Option<&Rc<crate::bridge::HostBridge>>,
race_after_ms: u64,
primary_timeout_ms: u64,
backup_label: String,
backup_opts: LlmCallOptions,
) -> (Result<super::api::LlmResult, VmError>, Vec<RoutingAttempt>) {
let primary_start = std::time::Instant::now();
let primary_attempt_no = attempts_used + 1;
let backup_attempt_no = attempts_used + 2;
let primary_link = link.clone();
let primary_label = link_label.to_string();
let primary_opts = opts.clone();
let mut primary_future = Box::pin(async move {
let res = execute_link(&primary_opts, bridge).await;
(res, primary_start.elapsed())
});
tokio::select! {
biased;
primary = &mut primary_future => {
let (res, elapsed) = primary;
let mut record = pending_attempt_record(
primary_attempt_no,
&primary_link,
&primary_label,
elapsed,
);
if let Ok(ref v) = res {
record.status = AttemptStatus::Succeeded;
record.cost_usd = Some(project_link_cost_usd(v));
}
(res, vec![record])
}
_ = crate::clock_mock::sleep(Duration::from_millis(race_after_ms)) => {
let mut race_meta = serde_json::Map::new();
race_meta.insert("policy".to_string(), json!(policy.label.clone()));
race_meta.insert("race_after_ms".to_string(), json!(race_after_ms));
race_meta.insert("primary_label".to_string(), json!(primary_label.clone()));
race_meta.insert("backup_label".to_string(), json!(backup_label.clone()));
emit_routing_event(dispatch, "race_started", race_meta);
let backup_start = std::time::Instant::now();
let backup_link_clone = ChainLink {
provider: backup_opts.provider.clone(),
model: backup_opts.model.clone(),
timeout_ms: link.timeout_ms,
label: Some(backup_label.clone()),
};
let mut backup_future = Box::pin({
let backup_opts = backup_opts.clone();
async move {
let res = execute_link(&backup_opts, bridge).await;
(res, backup_start.elapsed())
}
});
let primary_deadline = primary_timeout_ms.saturating_add(race_after_ms);
tokio::select! {
biased;
primary = &mut primary_future => {
let (res, elapsed) = primary;
let mut primary_record = pending_attempt_record(
primary_attempt_no,
&primary_link,
&primary_label,
elapsed,
);
if let Ok(ref v) = res {
primary_record.status = AttemptStatus::Succeeded;
primary_record.cost_usd = Some(project_link_cost_usd(v));
}
let mut backup_record = pending_attempt_record(
backup_attempt_no,
&backup_link_clone,
&backup_label,
backup_start.elapsed(),
);
backup_record.status = AttemptStatus::RaceLost;
let mut meta = serde_json::Map::new();
meta.insert("policy".to_string(), json!(policy.label.clone()));
meta.insert("winner".to_string(), json!(primary_label));
meta.insert("loser".to_string(), json!(backup_label));
emit_routing_event(dispatch, "race_won", meta.clone());
let mut lost_meta = meta;
lost_meta.insert("reason".to_string(), json!("primary_finished_first"));
emit_routing_event(dispatch, "race_lost", lost_meta);
(res, vec![primary_record, backup_record])
}
backup = &mut backup_future => {
let (res, elapsed) = backup;
let mut backup_record = pending_attempt_record(
backup_attempt_no,
&backup_link_clone,
&backup_label,
elapsed,
);
if let Ok(ref v) = res {
backup_record.status = AttemptStatus::Succeeded;
backup_record.cost_usd = Some(project_link_cost_usd(v));
}
let mut primary_record = pending_attempt_record(
primary_attempt_no,
&primary_link,
&primary_label,
primary_start.elapsed(),
);
primary_record.status = AttemptStatus::RaceLost;
let mut meta = serde_json::Map::new();
meta.insert("policy".to_string(), json!(policy.label.clone()));
meta.insert("winner".to_string(), json!(backup_label));
meta.insert("loser".to_string(), json!(primary_label));
emit_routing_event(dispatch, "race_won", meta.clone());
let mut lost_meta = meta;
lost_meta.insert("reason".to_string(), json!("backup_finished_first"));
emit_routing_event(dispatch, "race_lost", lost_meta);
(res, vec![primary_record, backup_record])
}
_ = crate::clock_mock::sleep(Duration::from_millis(primary_deadline)) => {
let primary_record = pending_attempt_record(
primary_attempt_no,
&primary_link,
&primary_label,
Duration::from_millis(primary_deadline),
);
let backup_record = pending_attempt_record(
backup_attempt_no,
&backup_link_clone,
&backup_label,
Duration::from_millis(primary_deadline),
);
(
Err(runtime_error(
"routing_policy: race exhausted both primary and backup attempts".to_string(),
)),
vec![primary_record, backup_record],
)
}
}
}
}
}
pub(crate) fn trace_to_decision(
trace: &RoutingTrace,
policy: &RoutingPolicyConfig,
) -> super::api::LlmRoutingDecision {
use super::api::{LlmRouteAlternative, LlmRoutingDecision};
let mut alternatives = Vec::with_capacity(trace.attempts.len());
for (idx, attempt) in trace.attempts.iter().enumerate() {
let selected = trace.selected == Some(idx);
let reason = match attempt.status {
AttemptStatus::Succeeded => "selected".to_string(),
AttemptStatus::Failed => attempt
.error
.as_ref()
.map(|e| format!("failed:{}", e.category))
.unwrap_or_else(|| "failed".to_string()),
AttemptStatus::Skipped => "skipped:budget".to_string(),
AttemptStatus::RaceLost => "race_lost".to_string(),
};
let quality_tier = crate::llm_config::model_tier(&attempt.model);
let pricing = super::cost::pricing_per_1k_for(&attempt.provider, &attempt.model);
alternatives.push(LlmRouteAlternative {
available: true,
cost_per_1k_in: pricing.map(|p| p.0),
cost_per_1k_out: pricing.map(|p| p.1),
latency_p50_ms: super::cost::latency_p50_ms_for(&attempt.provider),
provider: attempt.provider.clone(),
model: attempt.model.clone(),
quality_tier,
selected,
reason,
});
}
let selected_idx = trace.selected.unwrap_or(0);
let (selected_provider, selected_model) = trace
.attempts
.get(selected_idx)
.map(|a| (a.provider.clone(), a.model.clone()))
.unwrap_or_else(|| {
policy
.chain
.first()
.map(|link| (link.provider.clone(), link.model.clone()))
.unwrap_or_default()
});
LlmRoutingDecision {
policy: format!("routing_policy({})", policy.label),
requested_quality: None,
selected_provider,
selected_model,
alternatives,
}
}
pub(crate) fn trace_to_vm_attempts(trace: &RoutingTrace) -> VmValue {
let items: Vec<VmValue> = trace
.attempts
.iter()
.map(|attempt| {
let mut dict = BTreeMap::new();
dict.insert("index".to_string(), VmValue::Int(attempt.index as i64));
dict.insert(
"provider".to_string(),
VmValue::String(Rc::from(attempt.provider.clone())),
);
dict.insert(
"model".to_string(),
VmValue::String(Rc::from(attempt.model.clone())),
);
dict.insert(
"label".to_string(),
VmValue::String(Rc::from(attempt.label.clone())),
);
dict.insert(
"status".to_string(),
VmValue::String(Rc::from(attempt.status.as_str())),
);
dict.insert(
"duration_ms".to_string(),
VmValue::Int(attempt.duration_ms as i64),
);
if let Some(cost) = attempt.cost_usd {
dict.insert("cost_usd".to_string(), VmValue::Float(cost));
}
if let Some(error) = &attempt.error {
let mut err_dict = BTreeMap::new();
err_dict.insert(
"category".to_string(),
VmValue::String(Rc::from(error.category.clone())),
);
err_dict.insert(
"message".to_string(),
VmValue::String(Rc::from(error.message.clone())),
);
if let Some(status) = error.status {
err_dict.insert("status".to_string(), VmValue::Int(status as i64));
}
dict.insert("error".to_string(), VmValue::Dict(Rc::new(err_dict)));
}
VmValue::Dict(Rc::new(dict))
})
.collect();
VmValue::List(Rc::new(items))
}
#[cfg(test)]
mod tests {
use super::*;
fn dict(items: &[(&str, VmValue)]) -> BTreeMap<String, VmValue> {
items
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect()
}
#[test]
fn build_routing_policy_validates_chain() {
clear_policy_registry();
let config = dict(&[
(
"chain",
VmValue::List(Rc::new(vec![
VmValue::String(Rc::from("mock:mock")),
VmValue::Dict(Rc::new(dict(&[
("provider", VmValue::String(Rc::from("mock"))),
("model", VmValue::String(Rc::from("mock-2"))),
]))),
])),
),
(
"failover",
VmValue::Dict(Rc::new(dict(&[
(
"on_status",
VmValue::List(Rc::new(vec![VmValue::Int(429), VmValue::Int(500)])),
),
("max_attempts", VmValue::Int(2)),
]))),
),
(
"budget",
VmValue::Dict(Rc::new(dict(&[
("per_call_usd", VmValue::Float(0.5)),
("on_exceed", VmValue::String(Rc::from("abort"))),
]))),
),
]);
let tagged = build_routing_policy(&config).expect("validates");
let inner = tagged.as_dict().expect("dict");
assert!(matches!(
inner.get(ROUTING_POLICY_TAG),
Some(VmValue::Bool(true))
));
assert!(inner.contains_key(HANDLE_KEY));
let handle = inner.get(HANDLE_KEY).and_then(|v| v.as_int()).unwrap();
let policy = lookup_policy(handle as u64).expect("policy registered");
assert_eq!(policy.chain.len(), 2);
assert_eq!(policy.failover.on_status, vec![429, 500]);
}
#[test]
fn build_rejects_empty_chain() {
clear_policy_registry();
let config = dict(&[("chain", VmValue::List(Rc::new(Vec::new())))]);
let err = build_routing_policy(&config).unwrap_err();
let message = match err {
VmError::Thrown(VmValue::String(s)) => s.to_string(),
other => panic!("unexpected error: {other:?}"),
};
assert!(message.contains("at least one"));
}
#[test]
fn build_rejects_invalid_status_code() {
clear_policy_registry();
let config = dict(&[
(
"chain",
VmValue::List(Rc::new(vec![VmValue::String(Rc::from("mock:mock"))])),
),
(
"failover",
VmValue::Dict(Rc::new(dict(&[(
"on_status",
VmValue::List(Rc::new(vec![VmValue::Int(42)])),
)]))),
),
]);
let err = build_routing_policy(&config).unwrap_err();
let message = match err {
VmError::Thrown(VmValue::String(s)) => s.to_string(),
other => panic!("unexpected error: {other:?}"),
};
assert!(message.contains("not a valid HTTP status"));
}
#[test]
fn matches_failover_default_status() {
let rules = FailoverRules::default();
let err = VmError::Runtime("HTTP 429 rate limit".to_string());
let (eligible, snap) = matches_failover(&rules, &err);
assert!(eligible);
assert_eq!(snap.status, Some(429));
}
#[test]
fn matches_failover_explicit_kind() {
let rules = FailoverRules {
on_error_kinds: vec!["rate_limit".to_string()],
..Default::default()
};
let err = VmError::CategorizedError {
message: "throttled".to_string(),
category: ErrorCategory::RateLimit,
};
let (eligible, _) = matches_failover(&rules, &err);
assert!(eligible);
}
#[test]
fn rejects_non_failover_error_by_default() {
let rules = FailoverRules::default();
let err = VmError::CategorizedError {
message: "schema mismatch".to_string(),
category: ErrorCategory::SchemaValidation,
};
let (eligible, _) = matches_failover(&rules, &err);
assert!(!eligible);
}
#[test]
fn budget_envelope_round_trips() {
let budget = BudgetRules {
per_call_usd: Some(0.25),
session_usd: Some(5.0),
on_exceed: Some(BudgetExceedAction::Skip),
};
let envelope = budget.envelope().unwrap();
assert_eq!(envelope.max_cost_usd, Some(0.25));
assert_eq!(envelope.total_budget_usd, Some(5.0));
}
}