apr_cli/federation/
policy.rs

1//! Routing policies for federation
2//!
3//! Policies determine HOW nodes are selected for inference requests.
4//! Multiple policies can be composed (scored, weighted, chained).
5
6use super::traits::*;
7use std::time::Duration;
8
9// ============================================================================
10// Selection Criteria
11// ============================================================================
12
13/// Criteria for selecting nodes
14#[derive(Debug, Clone)]
15pub struct SelectionCriteria {
16    /// Required capability
17    pub capability: Capability,
18    /// Minimum health state
19    pub min_health: HealthState,
20    /// Maximum latency
21    pub max_latency: Option<Duration>,
22    /// Required privacy level
23    pub min_privacy: PrivacyLevel,
24    /// Preferred regions (in order)
25    pub preferred_regions: Vec<RegionId>,
26    /// Excluded nodes
27    pub excluded_nodes: Vec<NodeId>,
28}
29
30impl Default for SelectionCriteria {
31    fn default() -> Self {
32        Self {
33            capability: Capability::Generate,
34            min_health: HealthState::Degraded,
35            max_latency: None,
36            min_privacy: PrivacyLevel::Public,
37            preferred_regions: vec![],
38            excluded_nodes: vec![],
39        }
40    }
41}
42
43// ============================================================================
44// Concrete Routing Policies
45// ============================================================================
46
47/// Latency-based routing policy
48///
49/// Scores nodes inversely proportional to their latency.
50/// Lower latency = higher score.
51pub struct LatencyPolicy {
52    /// Weight for this policy in composite scoring
53    pub weight: f64,
54    /// Maximum acceptable latency (nodes above this get score 0)
55    pub max_latency: Duration,
56}
57
58impl Default for LatencyPolicy {
59    fn default() -> Self {
60        Self {
61            weight: 1.0,
62            max_latency: Duration::from_secs(5),
63        }
64    }
65}
66
67impl RoutingPolicyTrait for LatencyPolicy {
68    fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
69        let latency_ms = candidate.target.estimated_latency.as_millis() as f64;
70        let max_ms = self.max_latency.as_millis() as f64;
71
72        if latency_ms >= max_ms {
73            return 0.0;
74        }
75
76        // Score: 1.0 at 0ms, 0.0 at max_latency
77        let score = 1.0 - (latency_ms / max_ms);
78        score * self.weight
79    }
80
81    fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
82        candidate.target.estimated_latency <= self.max_latency
83    }
84
85    fn name(&self) -> &'static str {
86        "latency"
87    }
88}
89
90/// Locality-based routing policy
91///
92/// Prefers nodes in the same region as the request origin.
93/// Useful for data sovereignty and latency.
94pub struct LocalityPolicy {
95    /// Weight for this policy
96    pub weight: f64,
97    /// Score boost for same-region
98    pub same_region_boost: f64,
99    /// Score penalty for cross-region
100    pub cross_region_penalty: f64,
101}
102
103impl Default for LocalityPolicy {
104    fn default() -> Self {
105        Self {
106            weight: 1.0,
107            same_region_boost: 0.3,
108            cross_region_penalty: 0.1,
109        }
110    }
111}
112
113impl RoutingPolicyTrait for LocalityPolicy {
114    fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
115        // Check if request has tenant locality preference
116        let base_score = 0.5;
117
118        // For now, use locality score from candidate
119        let score = base_score + candidate.scores.locality_score * self.same_region_boost;
120        score * self.weight
121    }
122
123    fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
124        true // Locality is a preference, not a hard requirement
125    }
126
127    fn name(&self) -> &'static str {
128        "locality"
129    }
130}
131
132/// Privacy-based routing policy
133///
134/// Enforces data sovereignty by filtering nodes based on privacy level.
135#[derive(Default)]
136pub struct PrivacyPolicy {
137    /// Region privacy levels
138    pub region_privacy: std::collections::HashMap<RegionId, PrivacyLevel>,
139}
140
141impl PrivacyPolicy {
142    /// Add a region with its privacy level
143    #[must_use]
144    pub fn with_region(mut self, region: RegionId, level: PrivacyLevel) -> Self {
145        self.region_privacy.insert(region, level);
146        self
147    }
148}
149
150impl RoutingPolicyTrait for PrivacyPolicy {
151    fn score(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
152        1.0 // Privacy is binary: eligible or not
153    }
154
155    fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
156        let region_level = self
157            .region_privacy
158            .get(&candidate.target.region_id)
159            .copied()
160            // Default to Internal - unknown regions can handle internal traffic
161            .unwrap_or(PrivacyLevel::Internal);
162
163        // Region must meet or exceed request's privacy requirement
164        region_level >= request.qos.privacy
165    }
166
167    fn name(&self) -> &'static str {
168        "privacy"
169    }
170}
171
172/// Cost-based routing policy
173///
174/// Balances cost vs performance based on user tolerance.
175pub struct CostPolicy {
176    /// Weight for this policy
177    pub weight: f64,
178    /// Cost per region (0.0 = cheapest, 1.0 = most expensive)
179    pub region_costs: std::collections::HashMap<RegionId, f64>,
180}
181
182impl Default for CostPolicy {
183    fn default() -> Self {
184        Self {
185            weight: 1.0,
186            region_costs: std::collections::HashMap::new(),
187        }
188    }
189}
190
191impl CostPolicy {
192    #[must_use]
193    pub fn with_region_cost(mut self, region: RegionId, cost: f64) -> Self {
194        self.region_costs.insert(region, cost.clamp(0.0, 1.0));
195        self
196    }
197}
198
199impl RoutingPolicyTrait for CostPolicy {
200    fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
201        let region_cost = self
202            .region_costs
203            .get(&candidate.target.region_id)
204            .copied()
205            .unwrap_or(0.5);
206
207        let cost_tolerance = request.qos.cost_tolerance as f64 / 100.0;
208
209        // High tolerance = prefer fast (expensive)
210        // Low tolerance = prefer cheap
211        let score = if cost_tolerance > 0.5 {
212            // User tolerates cost, score performance
213            candidate.scores.throughput_score
214        } else {
215            // User wants cheap, invert cost
216            1.0 - region_cost
217        };
218
219        score * self.weight
220    }
221
222    fn is_eligible(&self, _candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
223        true
224    }
225
226    fn name(&self) -> &'static str {
227        "cost"
228    }
229}
230
231/// Health-based routing policy
232///
233/// Strongly penalizes unhealthy or degraded nodes.
234pub struct HealthPolicy {
235    /// Weight for this policy
236    pub weight: f64,
237    /// Score for healthy nodes
238    pub healthy_score: f64,
239    /// Score for degraded nodes
240    pub degraded_score: f64,
241}
242
243impl Default for HealthPolicy {
244    fn default() -> Self {
245        Self {
246            weight: 2.0, // Health is important!
247            healthy_score: 1.0,
248            degraded_score: 0.3,
249        }
250    }
251}
252
253impl RoutingPolicyTrait for HealthPolicy {
254    fn score(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> f64 {
255        candidate.scores.health_score * self.weight
256    }
257
258    fn is_eligible(&self, candidate: &RouteCandidate, _request: &InferenceRequest) -> bool {
259        // Must have some health score
260        candidate.scores.health_score > 0.0
261    }
262
263    fn name(&self) -> &'static str {
264        "health"
265    }
266}
267
268// ============================================================================
269// Composite Policy
270// ============================================================================
271
272/// Combines multiple policies with weighted scoring
273pub struct CompositePolicy {
274    policies: Vec<Box<dyn RoutingPolicyTrait>>,
275}
276
277impl CompositePolicy {
278    pub fn new() -> Self {
279        Self { policies: vec![] }
280    }
281
282    #[must_use]
283    pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
284        self.policies.push(Box::new(policy));
285        self
286    }
287
288    /// Create default enterprise policy
289    pub fn enterprise_default() -> Self {
290        Self::new()
291            .with_policy(HealthPolicy::default())
292            .with_policy(LatencyPolicy::default())
293            .with_policy(PrivacyPolicy::default())
294            .with_policy(LocalityPolicy::default())
295            .with_policy(CostPolicy::default())
296    }
297}
298
299impl Default for CompositePolicy {
300    fn default() -> Self {
301        Self::enterprise_default()
302    }
303}
304
305impl RoutingPolicyTrait for CompositePolicy {
306    fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64 {
307        if self.policies.is_empty() {
308            return 1.0;
309        }
310
311        let total: f64 = self
312            .policies
313            .iter()
314            .map(|p| p.score(candidate, request))
315            .sum();
316
317        total / self.policies.len() as f64
318    }
319
320    fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool {
321        // Must pass ALL policies
322        self.policies
323            .iter()
324            .all(|p| p.is_eligible(candidate, request))
325    }
326
327    fn name(&self) -> &'static str {
328        "composite"
329    }
330}
331
332// ============================================================================
333// Wrapper type for export
334// ============================================================================
335
336/// Routing policy configuration
337pub struct RoutingPolicy {
338    #[allow(dead_code)]
339    inner: Box<dyn RoutingPolicyTrait>,
340}
341
342impl RoutingPolicy {
343    pub fn latency() -> Self {
344        Self {
345            inner: Box::new(LatencyPolicy::default()),
346        }
347    }
348
349    pub fn locality() -> Self {
350        Self {
351            inner: Box::new(LocalityPolicy::default()),
352        }
353    }
354
355    pub fn privacy() -> Self {
356        Self {
357            inner: Box::new(PrivacyPolicy::default()),
358        }
359    }
360
361    pub fn cost() -> Self {
362        Self {
363            inner: Box::new(CostPolicy::default()),
364        }
365    }
366
367    pub fn health() -> Self {
368        Self {
369            inner: Box::new(HealthPolicy::default()),
370        }
371    }
372
373    pub fn enterprise() -> Self {
374        Self {
375            inner: Box::new(CompositePolicy::enterprise_default()),
376        }
377    }
378}
379
380// ============================================================================
381// Tests
382// ============================================================================
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    fn mock_candidate(latency_ms: u64, health_score: f64) -> RouteCandidate {
389        RouteCandidate {
390            target: RouteTarget {
391                node_id: NodeId("node1".to_string()),
392                region_id: RegionId("us-west".to_string()),
393                endpoint: "http://node1:8080".to_string(),
394                estimated_latency: Duration::from_millis(latency_ms),
395                score: 0.0,
396            },
397            scores: RouteScores {
398                latency_score: 1.0 - (latency_ms as f64 / 5000.0),
399                throughput_score: 0.8,
400                cost_score: 0.5,
401                locality_score: 0.7,
402                health_score,
403                total: 0.0,
404            },
405            eligible: true,
406            rejection_reason: None,
407        }
408    }
409
410    fn mock_request() -> InferenceRequest {
411        InferenceRequest {
412            capability: Capability::Generate,
413            input: vec![],
414            qos: QoSRequirements::default(),
415            request_id: "req-1".to_string(),
416            tenant_id: None,
417        }
418    }
419
420    #[test]
421    fn test_latency_policy_scoring() {
422        let policy = LatencyPolicy::default();
423        let request = mock_request();
424
425        // Fast node should score higher
426        let fast = mock_candidate(100, 1.0);
427        let slow = mock_candidate(4000, 1.0);
428
429        let fast_score = policy.score(&fast, &request);
430        let slow_score = policy.score(&slow, &request);
431
432        assert!(fast_score > slow_score);
433        assert!(fast_score > 0.9); // 100ms out of 5000ms max
434    }
435
436    #[test]
437    fn test_latency_policy_eligibility() {
438        let policy = LatencyPolicy {
439            max_latency: Duration::from_secs(2),
440            ..Default::default()
441        };
442        let request = mock_request();
443
444        let fast = mock_candidate(1000, 1.0);
445        let slow = mock_candidate(3000, 1.0);
446
447        assert!(policy.is_eligible(&fast, &request));
448        assert!(!policy.is_eligible(&slow, &request));
449    }
450
451    #[test]
452    fn test_health_policy_scoring() {
453        let policy = HealthPolicy::default();
454        let request = mock_request();
455
456        let healthy = mock_candidate(100, 1.0);
457        let degraded = mock_candidate(100, 0.3);
458
459        let healthy_score = policy.score(&healthy, &request);
460        let degraded_score = policy.score(&degraded, &request);
461
462        assert!(healthy_score > degraded_score);
463    }
464
465    #[test]
466    fn test_composite_policy() {
467        let policy = CompositePolicy::enterprise_default();
468        let request = mock_request();
469
470        let good = mock_candidate(100, 1.0);
471        let bad = mock_candidate(4000, 0.2);
472
473        let good_score = policy.score(&good, &request);
474        let bad_score = policy.score(&bad, &request);
475
476        assert!(good_score > bad_score);
477    }
478
479    #[test]
480    fn test_privacy_policy_eligibility() {
481        let policy = PrivacyPolicy::default()
482            .with_region(RegionId("eu-west".to_string()), PrivacyLevel::Confidential)
483            .with_region(RegionId("us-east".to_string()), PrivacyLevel::Public);
484
485        let mut request = mock_request();
486        request.qos.privacy = PrivacyLevel::Confidential;
487
488        let eu_candidate = RouteCandidate {
489            target: RouteTarget {
490                node_id: NodeId("node-eu".to_string()),
491                region_id: RegionId("eu-west".to_string()),
492                endpoint: "http://eu:8080".to_string(),
493                estimated_latency: Duration::from_millis(100),
494                score: 0.0,
495            },
496            scores: RouteScores::default(),
497            eligible: true,
498            rejection_reason: None,
499        };
500
501        let us_candidate = RouteCandidate {
502            target: RouteTarget {
503                node_id: NodeId("node-us".to_string()),
504                region_id: RegionId("us-east".to_string()),
505                endpoint: "http://us:8080".to_string(),
506                estimated_latency: Duration::from_millis(50),
507                score: 0.0,
508            },
509            scores: RouteScores::default(),
510            eligible: true,
511            rejection_reason: None,
512        };
513
514        // EU meets confidential requirement
515        assert!(policy.is_eligible(&eu_candidate, &request));
516        // US is public, doesn't meet confidential
517        assert!(!policy.is_eligible(&us_candidate, &request));
518    }
519}
520
521impl Default for RouteScores {
522    fn default() -> Self {
523        Self {
524            latency_score: 0.5,
525            throughput_score: 0.5,
526            cost_score: 0.5,
527            locality_score: 0.5,
528            health_score: 1.0,
529            total: 0.5,
530        }
531    }
532}