use super::traits::*;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct SelectionCriteria {
pub capability: Capability,
pub min_health: HealthState,
pub max_latency: Option<Duration>,
pub min_privacy: PrivacyLevel,
pub preferred_regions: Vec<RegionId>,
pub excluded_nodes: Vec<NodeId>,
}
impl Default for SelectionCriteria {
fn default() -> Self {
Self {
capability: Capability::Generate,
min_health: HealthState::Degraded,
max_latency: None,
min_privacy: PrivacyLevel::Public,
preferred_regions: vec![],
excluded_nodes: vec![],
}
}
}
pub struct LatencyPolicy {
pub weight: f64,
pub max_latency: Duration,
}
impl Default for LatencyPolicy {
fn default() -> Self {
Self {
weight: 1.0,
max_latency: Duration::from_secs(5),
}
}
}
impl RoutingPolicyTrait for LatencyPolicy {
fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
let latency_ms = candidate.target.estimated_latency.as_millis() as f64;
let max_ms = self.max_latency.as_millis() as f64;
if latency_ms >= max_ms {
return 0.0;
}
let score = 1.0 - (latency_ms / max_ms);
score * self.weight
}
fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
candidate.target.estimated_latency <= self.max_latency
}
fn name(&self) -> &'static str {
"latency"
}
}
pub struct LocalityPolicy {
pub weight: f64,
pub same_region_boost: f64,
pub cross_region_penalty: f64,
}
impl Default for LocalityPolicy {
fn default() -> Self {
Self {
weight: 1.0,
same_region_boost: 0.3,
cross_region_penalty: 0.1,
}
}
}
impl RoutingPolicyTrait for LocalityPolicy {
fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
let base_score = 0.5;
let score = base_score + candidate.scores.locality_score * self.same_region_boost;
score * self.weight
}
fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
true }
fn name(&self) -> &'static str {
"locality"
}
}
#[derive(Default)]
pub struct PrivacyPolicy {
pub region_privacy: std::collections::HashMap<RegionId, PrivacyLevel>,
}
impl PrivacyPolicy {
#[must_use]
pub fn with_region(mut self, region: RegionId, level: PrivacyLevel) -> Self {
self.region_privacy.insert(region, level);
self
}
}
impl RoutingPolicyTrait for PrivacyPolicy {
fn score(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
1.0 }
fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
let region_level = self
.region_privacy
.get(&candidate.target.region_id)
.copied()
.unwrap_or(PrivacyLevel::Internal);
region_level >= request.qos.privacy
}
fn name(&self) -> &'static str {
"privacy"
}
}
pub struct CostPolicy {
pub weight: f64,
pub region_costs: std::collections::HashMap<RegionId, f64>,
}
impl Default for CostPolicy {
fn default() -> Self {
Self {
weight: 1.0,
region_costs: std::collections::HashMap::new(),
}
}
}
impl CostPolicy {
#[must_use]
pub fn with_region_cost(mut self, region: RegionId, cost: f64) -> Self {
self.region_costs.insert(region, cost.clamp(0.0, 1.0));
self
}
}
impl RoutingPolicyTrait for CostPolicy {
fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
let region_cost = self
.region_costs
.get(&candidate.target.region_id)
.copied()
.unwrap_or(0.5);
let cost_tolerance = request.qos.cost_tolerance as f64 / 100.0;
let score = if cost_tolerance > 0.5 {
candidate.scores.throughput_score
} else {
1.0 - region_cost
};
score * self.weight
}
fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
true
}
fn name(&self) -> &'static str {
"cost"
}
}
pub struct HealthPolicy {
pub weight: f64,
pub healthy_score: f64,
pub degraded_score: f64,
}
impl Default for HealthPolicy {
fn default() -> Self {
Self {
weight: 2.0, healthy_score: 1.0,
degraded_score: 0.3,
}
}
}
impl RoutingPolicyTrait for HealthPolicy {
fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
candidate.scores.health_score * self.weight
}
fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
candidate.scores.health_score > 0.0
}
fn name(&self) -> &'static str {
"health"
}
}
pub struct CompositePolicy {
policies: Vec<Box<dyn RoutingPolicyTrait>>,
}
impl CompositePolicy {
pub fn new() -> Self {
Self { policies: vec![] }
}
#[must_use]
pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
self.policies.push(Box::new(policy));
self
}
pub fn enterprise_default() -> Self {
Self::new()
.with_policy(HealthPolicy::default())
.with_policy(LatencyPolicy::default())
.with_policy(PrivacyPolicy::default())
.with_policy(LocalityPolicy::default())
.with_policy(CostPolicy::default())
}
}
impl Default for CompositePolicy {
fn default() -> Self {
Self::enterprise_default()
}
}
impl RoutingPolicyTrait for CompositePolicy {
fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
if self.policies.is_empty() {
return 1.0;
}
let total: f64 = self
.policies
.iter()
.map(|p| p.score(candidate, request))
.sum();
total / self.policies.len() as f64
}
fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
self.policies
.iter()
.all(|p| p.is_eligible(candidate, request))
}
fn name(&self) -> &'static str {
"composite"
}
}
pub struct RoutingPolicy {
#[allow(dead_code)]
inner: Box<dyn RoutingPolicyTrait>,
}
impl RoutingPolicy {
pub fn latency() -> Self {
Self {
inner: Box::new(LatencyPolicy::default()),
}
}
pub fn locality() -> Self {
Self {
inner: Box::new(LocalityPolicy::default()),
}
}
pub fn privacy() -> Self {
Self {
inner: Box::new(PrivacyPolicy::default()),
}
}
pub fn cost() -> Self {
Self {
inner: Box::new(CostPolicy::default()),
}
}
pub fn health() -> Self {
Self {
inner: Box::new(HealthPolicy::default()),
}
}
pub fn enterprise() -> Self {
Self {
inner: Box::new(CompositePolicy::enterprise_default()),
}
}
}
#[cfg(test)]
#[path = "policy_tests.rs"]
mod tests;