Skip to main content

simple_agents_router/
latency.rs

1//! Latency-based routing implementation.
2//!
3//! Routes requests to provider with lowest observed latency.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderHealth, Result,
7    SimpleAgentsError,
8};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13/// Configuration for latency-based routing.
14#[derive(Debug, Clone)]
15pub struct LatencyRouterConfig {
16    /// Exponential moving average factor (0.0-1.0).
17    pub alpha: f64,
18    /// Threshold after which providers are marked degraded.
19    pub slow_threshold: Duration,
20}
21
22impl Default for LatencyRouterConfig {
23    fn default() -> Self {
24        Self {
25            alpha: 0.2,
26            slow_threshold: Duration::from_secs(2),
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug)]
32struct LatencyStats {
33    avg_latency_ms: f64,
34    samples: u64,
35    health: ProviderHealth,
36}
37
38impl LatencyStats {
39    fn new() -> Self {
40        Self {
41            avg_latency_ms: 0.0,
42            samples: 0,
43            health: ProviderHealth::Healthy,
44        }
45    }
46
47    fn record(&mut self, latency: Duration, alpha: f64, slow_threshold: Duration) {
48        let latency_ms = latency.as_secs_f64() * 1000.0;
49        if self.samples == 0 {
50            self.avg_latency_ms = latency_ms;
51        } else {
52            let previous = self.avg_latency_ms;
53            self.avg_latency_ms = (alpha * latency_ms) + ((1.0 - alpha) * previous);
54        }
55        self.samples = self.samples.saturating_add(1);
56
57        let threshold_ms = slow_threshold.as_secs_f64() * 1000.0;
58        self.health = if self.avg_latency_ms >= threshold_ms {
59            ProviderHealth::Degraded
60        } else {
61            ProviderHealth::Healthy
62        };
63    }
64}
65
66/// Router that selects providers based on observed latency.
67pub struct LatencyRouter {
68    providers: Vec<Arc<dyn Provider>>,
69    stats: Mutex<Vec<LatencyStats>>,
70    counter: AtomicUsize,
71    config: LatencyRouterConfig,
72}
73
74impl LatencyRouter {
75    /// Create a latency router with default configuration.
76    ///
77    /// # Errors
78    /// Returns a routing error if no providers are supplied.
79    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
80        Self::with_config(providers, LatencyRouterConfig::default())
81    }
82
83    /// Create a latency router with explicit configuration.
84    ///
85    /// # Errors
86    /// Returns a routing error if no providers are supplied.
87    pub fn with_config(
88        providers: Vec<Arc<dyn Provider>>,
89        config: LatencyRouterConfig,
90    ) -> Result<Self> {
91        if providers.is_empty() {
92            return Err(SimpleAgentsError::Routing(
93                "no providers configured".to_string(),
94            ));
95        }
96
97        let stats = vec![LatencyStats::new(); providers.len()];
98        Ok(Self {
99            providers,
100            stats: Mutex::new(stats),
101            counter: AtomicUsize::new(0),
102            config,
103        })
104    }
105
106    /// Return the number of configured providers.
107    pub fn provider_count(&self) -> usize {
108        self.providers.len()
109    }
110
111    /// Execute a completion request using latency-based selection.
112    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
113        let index = self.select_provider_index()?;
114        let provider = &self.providers[index];
115        let start = Instant::now();
116        let provider_request = provider.transform_request(request)?;
117        let provider_response = provider.execute(provider_request).await?;
118        let response = provider.transform_response(provider_response)?;
119        self.record_latency(index, start.elapsed());
120        Ok(response)
121    }
122
123    /// Execute a streaming request using latency-based selection.
124    pub async fn stream(
125        &self,
126        request: &CompletionRequest,
127    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
128        let index = self.select_provider_index()?;
129        let provider = &self.providers[index];
130        let provider_request = provider.transform_request(request)?;
131        provider.execute_stream(provider_request).await
132    }
133
134    fn select_provider_index(&self) -> Result<usize> {
135        let len = self.providers.len();
136        if len == 0 {
137            return Err(SimpleAgentsError::Routing(
138                "no providers configured".to_string(),
139            ));
140        }
141
142        let stats = self.stats.lock().expect("latency stats lock poisoned");
143        let mut best_index: Option<usize> = None;
144        let mut best_latency = f64::MAX;
145        let mut has_samples = false;
146        let mut has_healthy = false;
147
148        for stat in stats.iter() {
149            if stat.samples == 0 {
150                continue;
151            }
152            has_samples = true;
153            if stat.health == ProviderHealth::Healthy {
154                has_healthy = true;
155            }
156        }
157
158        if has_samples {
159            for (index, stat) in stats.iter().enumerate() {
160                if stat.samples == 0 {
161                    continue;
162                }
163                if has_healthy && stat.health != ProviderHealth::Healthy {
164                    continue;
165                }
166                if stat.avg_latency_ms < best_latency {
167                    best_latency = stat.avg_latency_ms;
168                    best_index = Some(index);
169                }
170            }
171        }
172
173        if let Some(index) = best_index {
174            return Ok(index);
175        }
176
177        let index = self.counter.fetch_add(1, Ordering::Relaxed);
178        Ok(index % len)
179    }
180
181    fn record_latency(&self, index: usize, latency: Duration) {
182        let mut stats = self.stats.lock().expect("latency stats lock poisoned");
183        if let Some(stat) = stats.get_mut(index) {
184            stat.record(latency, self.config.alpha, self.config.slow_threshold);
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use async_trait::async_trait;
193    use simple_agent_type::prelude::*;
194
195    struct MockProvider {
196        name: &'static str,
197    }
198
199    impl MockProvider {
200        fn new(name: &'static str) -> Self {
201            Self { name }
202        }
203    }
204
205    #[async_trait]
206    impl Provider for MockProvider {
207        fn name(&self) -> &str {
208            self.name
209        }
210
211        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
212            Ok(ProviderRequest::new("http://example.com"))
213        }
214
215        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
216            Ok(ProviderResponse::new(200, serde_json::Value::Null))
217        }
218
219        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
220            Ok(CompletionResponse {
221                id: "resp_test".to_string(),
222                model: "test-model".to_string(),
223                choices: vec![CompletionChoice {
224                    index: 0,
225                    message: Message::assistant("ok"),
226                    finish_reason: FinishReason::Stop,
227                    logprobs: None,
228                }],
229                usage: Usage::new(1, 1),
230                created: None,
231                provider: Some(self.name().to_string()),
232                healing_metadata: None,
233            })
234        }
235    }
236
237    fn build_request() -> CompletionRequest {
238        CompletionRequest::builder()
239            .model("test-model")
240            .message(Message::user("hello"))
241            .build()
242            .unwrap()
243    }
244
245    #[test]
246    fn empty_router_returns_error() {
247        let result = LatencyRouter::new(Vec::new());
248        match result {
249            Ok(_) => panic!("expected error, got Ok"),
250            Err(SimpleAgentsError::Routing(message)) => {
251                assert_eq!(message, "no providers configured");
252            }
253            Err(_) => panic!("unexpected error type"),
254        }
255    }
256
257    #[test]
258    fn selects_lowest_latency_provider() {
259        let router = LatencyRouter::new(vec![
260            Arc::new(MockProvider::new("p1")),
261            Arc::new(MockProvider::new("p2")),
262        ])
263        .unwrap();
264
265        router.record_latency(0, Duration::from_millis(250));
266        router.record_latency(1, Duration::from_millis(50));
267
268        let index = router.select_provider_index().unwrap();
269        assert_eq!(index, 1);
270    }
271
272    #[test]
273    fn prefers_healthy_over_degraded() {
274        let config = LatencyRouterConfig {
275            alpha: 1.0,
276            slow_threshold: Duration::from_millis(100),
277        };
278        let router = LatencyRouter::with_config(
279            vec![
280                Arc::new(MockProvider::new("p1")),
281                Arc::new(MockProvider::new("p2")),
282            ],
283            config,
284        )
285        .unwrap();
286
287        router.record_latency(0, Duration::from_millis(400));
288        router.record_latency(1, Duration::from_millis(80));
289
290        let index = router.select_provider_index().unwrap();
291        assert_eq!(index, 1);
292    }
293
294    #[test]
295    fn round_robin_when_no_metrics() {
296        let router = LatencyRouter::new(vec![
297            Arc::new(MockProvider::new("p1")),
298            Arc::new(MockProvider::new("p2")),
299        ])
300        .unwrap();
301
302        let first = router.select_provider_index().unwrap();
303        let second = router.select_provider_index().unwrap();
304
305        assert_eq!(first, 0);
306        assert_eq!(second, 1);
307    }
308
309    #[tokio::test]
310    async fn records_latency_on_success() {
311        let router = LatencyRouter::new(vec![Arc::new(MockProvider::new("p1"))]).unwrap();
312        let request = build_request();
313
314        let _ = router.complete(&request).await.unwrap();
315        let stats = router.stats.lock().expect("latency stats lock poisoned");
316        assert_eq!(stats[0].samples, 1);
317    }
318}