1use super::traits::*;
7use std::time::Duration;
8
9#[derive(Debug, Clone)]
15pub struct SelectionCriteria {
16 pub capability: Capability,
18 pub min_health: HealthState,
20 pub max_latency: Option<Duration>,
22 pub min_privacy: PrivacyLevel,
24 pub preferred_regions: Vec<RegionId>,
26 pub excluded_nodes: Vec<NodeId>,
28}
29
30impl Default for SelectionCriteria {
31 fn default() -> Self {
32 Self {
33 capability: Capability::Generate,
34 min_health: HealthState::Degraded,
35 max_latency: None,
36 min_privacy: PrivacyLevel::Public,
37 preferred_regions: vec![],
38 excluded_nodes: vec![],
39 }
40 }
41}
42
43pub struct LatencyPolicy {
52 pub weight: f64,
54 pub max_latency: Duration,
56}
57
58impl Default for LatencyPolicy {
59 fn default() -> Self {
60 Self {
61 weight: 1.0,
62 max_latency: Duration::from_secs(5),
63 }
64 }
65}
66
67impl RoutingPolicyTrait for LatencyPolicy {
68 fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
69 let latency_ms = candidate.target.estimated_latency.as_millis() as f64;
70 let max_ms = self.max_latency.as_millis() as f64;
71
72 if latency_ms >= max_ms {
73 return 0.0;
74 }
75
76 let score = 1.0 - (latency_ms / max_ms);
78 score * self.weight
79 }
80
81 fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
82 candidate.target.estimated_latency <= self.max_latency
83 }
84
85 fn name(&self) -> &'static str {
86 "latency"
87 }
88}
89
90pub struct LocalityPolicy {
95 pub weight: f64,
97 pub same_region_boost: f64,
99 pub cross_region_penalty: f64,
101}
102
103impl Default for LocalityPolicy {
104 fn default() -> Self {
105 Self {
106 weight: 1.0,
107 same_region_boost: 0.3,
108 cross_region_penalty: 0.1,
109 }
110 }
111}
112
113impl RoutingPolicyTrait for LocalityPolicy {
114 fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
115 let base_score = 0.5;
117
118 let score = base_score + candidate.scores.locality_score * self.same_region_boost;
120 score * self.weight
121 }
122
123 fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
124 true }
126
127 fn name(&self) -> &'static str {
128 "locality"
129 }
130}
131
132#[derive(Default)]
136pub struct PrivacyPolicy {
137 pub region_privacy: std::collections::HashMap<RegionId, PrivacyLevel>,
139}
140
141impl PrivacyPolicy {
142 #[must_use]
144 pub fn with_region(mut self, region: RegionId, level: PrivacyLevel) -> Self {
145 self.region_privacy.insert(region, level);
146 self
147 }
148}
149
150impl RoutingPolicyTrait for PrivacyPolicy {
151 fn score(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
152 1.0 }
154
155 fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
156 let region_level = self
157 .region_privacy
158 .get(&candidate.target.region_id)
159 .copied()
160 .unwrap_or(PrivacyLevel::Internal);
162
163 region_level >= request.qos.privacy
165 }
166
167 fn name(&self) -> &'static str {
168 "privacy"
169 }
170}
171
172pub struct CostPolicy {
176 pub weight: f64,
178 pub region_costs: std::collections::HashMap<RegionId, f64>,
180}
181
182impl Default for CostPolicy {
183 fn default() -> Self {
184 Self {
185 weight: 1.0,
186 region_costs: std::collections::HashMap::new(),
187 }
188 }
189}
190
191impl CostPolicy {
192 #[must_use]
193 pub fn with_region_cost(mut self, region: RegionId, cost: f64) -> Self {
194 self.region_costs.insert(region, cost.clamp(0.0, 1.0));
195 self
196 }
197}
198
199impl RoutingPolicyTrait for CostPolicy {
200 fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
201 let region_cost = self
202 .region_costs
203 .get(&candidate.target.region_id)
204 .copied()
205 .unwrap_or(0.5);
206
207 let cost_tolerance = request.qos.cost_tolerance as f64 / 100.0;
208
209 let score = if cost_tolerance > 0.5 {
212 candidate.scores.throughput_score
214 } else {
215 1.0 - region_cost
217 };
218
219 score * self.weight
220 }
221
222 fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
223 true
224 }
225
226 fn name(&self) -> &'static str {
227 "cost"
228 }
229}
230
231pub struct HealthPolicy {
235 pub weight: f64,
237 pub healthy_score: f64,
239 pub degraded_score: f64,
241}
242
243impl Default for HealthPolicy {
244 fn default() -> Self {
245 Self {
246 weight: 2.0, healthy_score: 1.0,
248 degraded_score: 0.3,
249 }
250 }
251}
252
253impl RoutingPolicyTrait for HealthPolicy {
254 fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
255 candidate.scores.health_score * self.weight
256 }
257
258 fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
259 candidate.scores.health_score > 0.0
261 }
262
263 fn name(&self) -> &'static str {
264 "health"
265 }
266}
267
268pub struct CompositePolicy {
274 policies: Vec<Box<dyn RoutingPolicyTrait>>,
275}
276
277impl CompositePolicy {
278 pub fn new() -> Self {
279 Self { policies: vec![] }
280 }
281
282 #[must_use]
283 pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
284 self.policies.push(Box::new(policy));
285 self
286 }
287
288 pub fn enterprise_default() -> Self {
290 Self::new()
291 .with_policy(HealthPolicy::default())
292 .with_policy(LatencyPolicy::default())
293 .with_policy(PrivacyPolicy::default())
294 .with_policy(LocalityPolicy::default())
295 .with_policy(CostPolicy::default())
296 }
297}
298
299impl Default for CompositePolicy {
300 fn default() -> Self {
301 Self::enterprise_default()
302 }
303}
304
305impl RoutingPolicyTrait for CompositePolicy {
306 fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
307 if self.policies.is_empty() {
308 return 1.0;
309 }
310
311 let total: f64 = self
312 .policies
313 .iter()
314 .map(|p| p.score(candidate, request))
315 .sum();
316
317 total / self.policies.len() as f64
318 }
319
320 fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
321 self.policies
323 .iter()
324 .all(|p| p.is_eligible(candidate, request))
325 }
326
327 fn name(&self) -> &'static str {
328 "composite"
329 }
330}
331
332pub struct RoutingPolicy {
338 #[allow(dead_code)]
339 inner: Box<dyn RoutingPolicyTrait>,
340}
341
342impl RoutingPolicy {
343 pub fn latency() -> Self {
344 Self {
345 inner: Box::new(LatencyPolicy::default()),
346 }
347 }
348
349 pub fn locality() -> Self {
350 Self {
351 inner: Box::new(LocalityPolicy::default()),
352 }
353 }
354
355 pub fn privacy() -> Self {
356 Self {
357 inner: Box::new(PrivacyPolicy::default()),
358 }
359 }
360
361 pub fn cost() -> Self {
362 Self {
363 inner: Box::new(CostPolicy::default()),
364 }
365 }
366
367 pub fn health() -> Self {
368 Self {
369 inner: Box::new(HealthPolicy::default()),
370 }
371 }
372
373 pub fn enterprise() -> Self {
374 Self {
375 inner: Box::new(CompositePolicy::enterprise_default()),
376 }
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387
388 fn mock_candidate(latency_ms: u64, health_score: f64) -> RouteCandidate {
389 RouteCandidate {
390 target: RouteTarget {
391 node_id: NodeId("node1".to_string()),
392 region_id: RegionId("us-west".to_string()),
393 endpoint: "http://node1:8080".to_string(),
394 estimated_latency: Duration::from_millis(latency_ms),
395 score: 0.0,
396 },
397 scores: RouteScores {
398 latency_score: 1.0 - (latency_ms as f64 / 5000.0),
399 throughput_score: 0.8,
400 cost_score: 0.5,
401 locality_score: 0.7,
402 health_score,
403 total: 0.0,
404 },
405 eligible: true,
406 rejection_reason: None,
407 }
408 }
409
410 fn mock_request() -> InferenceRequest {
411 InferenceRequest {
412 capability: Capability::Generate,
413 input: vec![],
414 qos: QoSRequirements::default(),
415 request_id: "req-1".to_string(),
416 tenant_id: None,
417 }
418 }
419
420 #[test]
421 fn test_latency_policy_scoring() {
422 let policy = LatencyPolicy::default();
423 let request = mock_request();
424
425 let fast = mock_candidate(100, 1.0);
427 let slow = mock_candidate(4000, 1.0);
428
429 let fast_score = policy.score(&fast, &request);
430 let slow_score = policy.score(&slow, &request);
431
432 assert!(fast_score > slow_score);
433 assert!(fast_score > 0.9); }
435
436 #[test]
437 fn test_latency_policy_eligibility() {
438 let policy = LatencyPolicy {
439 max_latency: Duration::from_secs(2),
440 ..Default::default()
441 };
442 let request = mock_request();
443
444 let fast = mock_candidate(1000, 1.0);
445 let slow = mock_candidate(3000, 1.0);
446
447 assert!(policy.is_eligible(&fast, &request));
448 assert!(!policy.is_eligible(&slow, &request));
449 }
450
451 #[test]
452 fn test_health_policy_scoring() {
453 let policy = HealthPolicy::default();
454 let request = mock_request();
455
456 let healthy = mock_candidate(100, 1.0);
457 let degraded = mock_candidate(100, 0.3);
458
459 let healthy_score = policy.score(&healthy, &request);
460 let degraded_score = policy.score(°raded, &request);
461
462 assert!(healthy_score > degraded_score);
463 }
464
465 #[test]
466 fn test_composite_policy() {
467 let policy = CompositePolicy::enterprise_default();
468 let request = mock_request();
469
470 let good = mock_candidate(100, 1.0);
471 let bad = mock_candidate(4000, 0.2);
472
473 let good_score = policy.score(&good, &request);
474 let bad_score = policy.score(&bad, &request);
475
476 assert!(good_score > bad_score);
477 }
478
479 #[test]
480 fn test_privacy_policy_eligibility() {
481 let policy = PrivacyPolicy::default()
482 .with_region(RegionId("eu-west".to_string()), PrivacyLevel::Confidential)
483 .with_region(RegionId("us-east".to_string()), PrivacyLevel::Public);
484
485 let mut request = mock_request();
486 request.qos.privacy = PrivacyLevel::Confidential;
487
488 let eu_candidate = RouteCandidate {
489 target: RouteTarget {
490 node_id: NodeId("node-eu".to_string()),
491 region_id: RegionId("eu-west".to_string()),
492 endpoint: "http://eu:8080".to_string(),
493 estimated_latency: Duration::from_millis(100),
494 score: 0.0,
495 },
496 scores: RouteScores::default(),
497 eligible: true,
498 rejection_reason: None,
499 };
500
501 let us_candidate = RouteCandidate {
502 target: RouteTarget {
503 node_id: NodeId("node-us".to_string()),
504 region_id: RegionId("us-east".to_string()),
505 endpoint: "http://us:8080".to_string(),
506 estimated_latency: Duration::from_millis(50),
507 score: 0.0,
508 },
509 scores: RouteScores::default(),
510 eligible: true,
511 rejection_reason: None,
512 };
513
514 assert!(policy.is_eligible(&eu_candidate, &request));
516 assert!(!policy.is_eligible(&us_candidate, &request));
518 }
519}
520
521impl Default for RouteScores {
522 fn default() -> Self {
523 Self {
524 latency_score: 0.5,
525 throughput_score: 0.5,
526 cost_score: 0.5,
527 locality_score: 0.5,
528 health_score: 1.0,
529 total: 0.5,
530 }
531 }
532}