nntp_proxy/router/
mod.rs

1//! Backend server selection and load balancing
2//!
3//! This module handles selecting backend servers using round-robin
4//! with simple load tracking for monitoring.
5//!
6//! # Overview
7//!
8//! The `BackendSelector` provides thread-safe backend selection for routing
9//! NNTP commands across multiple backend servers. It uses a lock-free
10//! round-robin algorithm with atomic operations for concurrent access.
11//!
12//! # Usage
13//!
14//! ```no_run
15//! use nntp_proxy::router::BackendSelector;
16//! use nntp_proxy::types::{BackendId, ClientId};
17//! # use nntp_proxy::pool::DeadpoolConnectionProvider;
18//!
19//! let mut selector = BackendSelector::new();
20//! # let provider = DeadpoolConnectionProvider::new(
21//! #     "localhost".to_string(), 119, "test".to_string(), 10, None, None
22//! # );
23//! selector.add_backend(BackendId::from_index(0), "server1".to_string(), provider);
24//!
25//! // Route a command
26//! let client_id = ClientId::new();
27//! let backend_id = selector.route_command_sync(client_id, "LIST").unwrap();
28//!
29//! // After command completes
30//! selector.complete_command_sync(backend_id);
31//! ```
32
33use anyhow::Result;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicUsize, Ordering};
36use tracing::{debug, info};
37
38use crate::pool::DeadpoolConnectionProvider;
39use crate::types::{BackendId, ClientId};
40
41/// Backend connection information
42#[derive(Debug, Clone)]
43struct BackendInfo {
44    /// Backend identifier
45    id: BackendId,
46    /// Server name for logging
47    name: String,
48    /// Connection provider for this backend
49    provider: DeadpoolConnectionProvider,
50    /// Number of pending requests on this backend (for load balancing)
51    pending_count: Arc<AtomicUsize>,
52}
53
54/// Selects backend servers using round-robin with load tracking
55///
56/// # Thread Safety
57///
58/// This struct is designed for concurrent access across multiple threads.
59/// The round-robin counter and pending counts use atomic operations for
60/// lock-free performance.
61///
62/// # Load Balancing
63///
64/// - **Strategy**: Round-robin rotation through available backends
65/// - **Tracking**: Atomic counters track pending commands per backend
66/// - **Monitoring**: Load statistics available via `backend_load()`
67///
68/// # Examples
69///
70/// ```no_run
71/// # use nntp_proxy::router::BackendSelector;
72/// # use nntp_proxy::types::{BackendId, ClientId};
73/// # use nntp_proxy::pool::DeadpoolConnectionProvider;
74/// let mut selector = BackendSelector::new();
75///
76/// # let provider = DeadpoolConnectionProvider::new(
77/// #     "localhost".to_string(), 119, "test".to_string(), 10, None, None
78/// # );
79/// selector.add_backend(
80///     BackendId::from_index(0),
81///     "backend-1".to_string(),
82///     provider,
83/// );
84///
85/// // Route commands
86/// let backend = selector.route_command_sync(ClientId::new(), "LIST")?;
87/// # Ok::<(), anyhow::Error>(())
88/// ```
89#[derive(Debug)]
90pub struct BackendSelector {
91    /// Backend connection providers
92    backends: Vec<BackendInfo>,
93    /// Current backend index for round-robin selection
94    current_backend: AtomicUsize,
95}
96
97impl Default for BackendSelector {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl BackendSelector {
104    /// Create a new backend selector
105    #[must_use]
106    pub fn new() -> Self {
107        Self {
108            // Pre-allocate for typical number of backend servers (most setups have 2-8)
109            backends: Vec::with_capacity(4),
110            current_backend: AtomicUsize::new(0),
111        }
112    }
113
114    /// Add a backend server to the router
115    pub fn add_backend(
116        &mut self,
117        backend_id: BackendId,
118        name: String,
119        provider: DeadpoolConnectionProvider,
120    ) {
121        info!("Added backend {:?} ({})", backend_id, name);
122        self.backends.push(BackendInfo {
123            id: backend_id,
124            name,
125            provider,
126            pending_count: Arc::new(AtomicUsize::new(0)),
127        });
128    }
129
130    /// Select the next backend using round-robin strategy
131    fn select_backend(&self) -> Option<&BackendInfo> {
132        if self.backends.is_empty() {
133            return None;
134        }
135
136        let index = self.current_backend.fetch_add(1, Ordering::Relaxed) % self.backends.len();
137        Some(&self.backends[index])
138    }
139
140    /// Select a backend for the given command using round-robin
141    /// Returns the backend ID to use for this command
142    pub fn route_command_sync(&self, _client_id: ClientId, _command: &str) -> Result<BackendId> {
143        let backend = self.select_backend().ok_or_else(|| {
144            anyhow::anyhow!(
145                "No backends available for routing (total backends: {})",
146                self.backends.len()
147            )
148        })?;
149
150        // Increment pending count for load tracking
151        backend.pending_count.fetch_add(1, Ordering::Relaxed);
152
153        debug!(
154            "Selected backend {:?} ({}) for command",
155            backend.id, backend.name
156        );
157
158        Ok(backend.id)
159    }
160
161    /// Mark a command as complete, decrementing the pending count
162    pub fn complete_command_sync(&self, backend_id: BackendId) {
163        if let Some(backend) = self.backends.iter().find(|b| b.id == backend_id) {
164            backend.pending_count.fetch_sub(1, Ordering::Relaxed);
165        }
166    }
167
168    /// Get the connection provider for a backend
169    #[must_use]
170    pub fn get_backend_provider(
171        &self,
172        backend_id: BackendId,
173    ) -> Option<&DeadpoolConnectionProvider> {
174        self.backends
175            .iter()
176            .find(|b| b.id == backend_id)
177            .map(|b| &b.provider)
178    }
179
180    /// Get the number of backends
181    #[must_use]
182    #[inline]
183    pub fn backend_count(&self) -> usize {
184        self.backends.len()
185    }
186
187    /// Get backend load (pending requests) for monitoring
188    #[must_use]
189    pub fn backend_load(&self, backend_id: BackendId) -> Option<usize> {
190        self.backends
191            .iter()
192            .find(|b| b.id == backend_id)
193            .map(|b| b.pending_count.load(Ordering::Relaxed))
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn create_test_provider() -> DeadpoolConnectionProvider {
202        DeadpoolConnectionProvider::new(
203            "localhost".to_string(),
204            9999,
205            "test".to_string(),
206            2,
207            None,
208            None,
209        )
210    }
211
212    #[test]
213    fn test_router_creation() {
214        let router = BackendSelector::new();
215        assert_eq!(router.backend_count(), 0);
216    }
217
218    #[test]
219    fn test_add_backend() {
220        let mut router = BackendSelector::new();
221        let backend_id = BackendId::from_index(0);
222        let provider = create_test_provider();
223
224        router.add_backend(backend_id, "test-backend".to_string(), provider);
225
226        assert_eq!(router.backend_count(), 1);
227    }
228
229    #[test]
230    fn test_add_multiple_backends() {
231        let mut router = BackendSelector::new();
232
233        for i in 0..3 {
234            let backend_id = BackendId::from_index(i);
235            let provider = create_test_provider();
236            router.add_backend(backend_id, format!("backend-{}", i), provider);
237        }
238
239        assert_eq!(router.backend_count(), 3);
240    }
241
242    #[test]
243    fn test_no_backends_fails() {
244        let router = BackendSelector::new();
245        let client_id = ClientId::new();
246        let result = router.route_command_sync(client_id, "LIST\r\n");
247
248        assert!(result.is_err());
249    }
250
251    #[test]
252    fn test_round_robin_selection() {
253        let mut router = BackendSelector::new();
254        let client_id = ClientId::new();
255
256        // Add 3 backends
257        for i in 0..3 {
258            let backend_id = BackendId::from_index(i);
259            let provider = create_test_provider();
260            router.add_backend(backend_id, format!("backend-{}", i), provider);
261        }
262
263        // Route 6 commands and verify round-robin
264        let backend1 = router.route_command_sync(client_id, "LIST\r\n").unwrap();
265        let backend2 = router.route_command_sync(client_id, "DATE\r\n").unwrap();
266        let backend3 = router.route_command_sync(client_id, "HELP\r\n").unwrap();
267        let backend4 = router.route_command_sync(client_id, "LIST\r\n").unwrap();
268        let backend5 = router.route_command_sync(client_id, "DATE\r\n").unwrap();
269        let backend6 = router.route_command_sync(client_id, "HELP\r\n").unwrap();
270
271        // Should cycle through backends in order
272        assert_eq!(backend1.as_index(), 0);
273        assert_eq!(backend2.as_index(), 1);
274        assert_eq!(backend3.as_index(), 2);
275        assert_eq!(backend4.as_index(), 0); // Wraps around
276        assert_eq!(backend5.as_index(), 1);
277        assert_eq!(backend6.as_index(), 2);
278    }
279
280    #[test]
281    fn test_backend_load_tracking() {
282        let mut router = BackendSelector::new();
283        let client_id = ClientId::new();
284        let backend_id = BackendId::from_index(0);
285        let provider = create_test_provider();
286
287        router.add_backend(backend_id, "test".to_string(), provider);
288
289        // Initially no load
290        assert_eq!(router.backend_load(backend_id), Some(0));
291
292        // Route a command
293        router.route_command_sync(client_id, "LIST\r\n").unwrap();
294        assert_eq!(router.backend_load(backend_id), Some(1));
295
296        // Route another
297        router.route_command_sync(client_id, "DATE\r\n").unwrap();
298        assert_eq!(router.backend_load(backend_id), Some(2));
299
300        // Complete one
301        router.complete_command_sync(backend_id);
302        assert_eq!(router.backend_load(backend_id), Some(1));
303
304        // Complete the other
305        router.complete_command_sync(backend_id);
306        assert_eq!(router.backend_load(backend_id), Some(0));
307    }
308
309    #[test]
310    fn test_get_backend_provider() {
311        let mut router = BackendSelector::new();
312        let backend_id = BackendId::from_index(0);
313        let provider = create_test_provider();
314
315        router.add_backend(backend_id, "test".to_string(), provider);
316
317        let retrieved = router.get_backend_provider(backend_id);
318        assert!(retrieved.is_some());
319
320        let fake_id = BackendId::from_index(999);
321        assert!(router.get_backend_provider(fake_id).is_none());
322    }
323
324    #[test]
325    fn test_load_balancing_fairness() {
326        let mut router = BackendSelector::new();
327        let client_id = ClientId::new();
328
329        // Add 3 backends
330        for i in 0..3 {
331            router.add_backend(
332                BackendId::from_index(i),
333                format!("backend-{}", i),
334                create_test_provider(),
335            );
336        }
337
338        // Route 9 commands
339        let mut backend_counts = vec![0, 0, 0];
340        for _ in 0..9 {
341            let backend_id = router.route_command_sync(client_id, "LIST\r\n").unwrap();
342            backend_counts[backend_id.as_index()] += 1;
343        }
344
345        // Each backend should get 3 commands (perfect round-robin)
346        assert_eq!(backend_counts, vec![3, 3, 3]);
347    }
348}