1use anyhow::Result;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicUsize, Ordering};
36use tracing::{debug, info};
37
38use crate::pool::DeadpoolConnectionProvider;
39use crate::types::{BackendId, ClientId};
40
41#[derive(Debug, Clone)]
43struct BackendInfo {
44 id: BackendId,
46 name: String,
48 provider: DeadpoolConnectionProvider,
50 pending_count: Arc<AtomicUsize>,
52}
53
54#[derive(Debug)]
90pub struct BackendSelector {
91 backends: Vec<BackendInfo>,
93 current_backend: AtomicUsize,
95}
96
97impl Default for BackendSelector {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl BackendSelector {
104 #[must_use]
106 pub fn new() -> Self {
107 Self {
108 backends: Vec::with_capacity(4),
110 current_backend: AtomicUsize::new(0),
111 }
112 }
113
114 pub fn add_backend(
116 &mut self,
117 backend_id: BackendId,
118 name: String,
119 provider: DeadpoolConnectionProvider,
120 ) {
121 info!("Added backend {:?} ({})", backend_id, name);
122 self.backends.push(BackendInfo {
123 id: backend_id,
124 name,
125 provider,
126 pending_count: Arc::new(AtomicUsize::new(0)),
127 });
128 }
129
130 fn select_backend(&self) -> Option<&BackendInfo> {
132 if self.backends.is_empty() {
133 return None;
134 }
135
136 let index = self.current_backend.fetch_add(1, Ordering::Relaxed) % self.backends.len();
137 Some(&self.backends[index])
138 }
139
140 pub fn route_command_sync(&self, _client_id: ClientId, _command: &str) -> Result<BackendId> {
143 let backend = self.select_backend().ok_or_else(|| {
144 anyhow::anyhow!(
145 "No backends available for routing (total backends: {})",
146 self.backends.len()
147 )
148 })?;
149
150 backend.pending_count.fetch_add(1, Ordering::Relaxed);
152
153 debug!(
154 "Selected backend {:?} ({}) for command",
155 backend.id, backend.name
156 );
157
158 Ok(backend.id)
159 }
160
161 pub fn complete_command_sync(&self, backend_id: BackendId) {
163 if let Some(backend) = self.backends.iter().find(|b| b.id == backend_id) {
164 backend.pending_count.fetch_sub(1, Ordering::Relaxed);
165 }
166 }
167
168 #[must_use]
170 pub fn get_backend_provider(
171 &self,
172 backend_id: BackendId,
173 ) -> Option<&DeadpoolConnectionProvider> {
174 self.backends
175 .iter()
176 .find(|b| b.id == backend_id)
177 .map(|b| &b.provider)
178 }
179
180 #[must_use]
182 #[inline]
183 pub fn backend_count(&self) -> usize {
184 self.backends.len()
185 }
186
187 #[must_use]
189 pub fn backend_load(&self, backend_id: BackendId) -> Option<usize> {
190 self.backends
191 .iter()
192 .find(|b| b.id == backend_id)
193 .map(|b| b.pending_count.load(Ordering::Relaxed))
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn create_test_provider() -> DeadpoolConnectionProvider {
202 DeadpoolConnectionProvider::new(
203 "localhost".to_string(),
204 9999,
205 "test".to_string(),
206 2,
207 None,
208 None,
209 )
210 }
211
212 #[test]
213 fn test_router_creation() {
214 let router = BackendSelector::new();
215 assert_eq!(router.backend_count(), 0);
216 }
217
218 #[test]
219 fn test_add_backend() {
220 let mut router = BackendSelector::new();
221 let backend_id = BackendId::from_index(0);
222 let provider = create_test_provider();
223
224 router.add_backend(backend_id, "test-backend".to_string(), provider);
225
226 assert_eq!(router.backend_count(), 1);
227 }
228
229 #[test]
230 fn test_add_multiple_backends() {
231 let mut router = BackendSelector::new();
232
233 for i in 0..3 {
234 let backend_id = BackendId::from_index(i);
235 let provider = create_test_provider();
236 router.add_backend(backend_id, format!("backend-{}", i), provider);
237 }
238
239 assert_eq!(router.backend_count(), 3);
240 }
241
242 #[test]
243 fn test_no_backends_fails() {
244 let router = BackendSelector::new();
245 let client_id = ClientId::new();
246 let result = router.route_command_sync(client_id, "LIST\r\n");
247
248 assert!(result.is_err());
249 }
250
251 #[test]
252 fn test_round_robin_selection() {
253 let mut router = BackendSelector::new();
254 let client_id = ClientId::new();
255
256 for i in 0..3 {
258 let backend_id = BackendId::from_index(i);
259 let provider = create_test_provider();
260 router.add_backend(backend_id, format!("backend-{}", i), provider);
261 }
262
263 let backend1 = router.route_command_sync(client_id, "LIST\r\n").unwrap();
265 let backend2 = router.route_command_sync(client_id, "DATE\r\n").unwrap();
266 let backend3 = router.route_command_sync(client_id, "HELP\r\n").unwrap();
267 let backend4 = router.route_command_sync(client_id, "LIST\r\n").unwrap();
268 let backend5 = router.route_command_sync(client_id, "DATE\r\n").unwrap();
269 let backend6 = router.route_command_sync(client_id, "HELP\r\n").unwrap();
270
271 assert_eq!(backend1.as_index(), 0);
273 assert_eq!(backend2.as_index(), 1);
274 assert_eq!(backend3.as_index(), 2);
275 assert_eq!(backend4.as_index(), 0); assert_eq!(backend5.as_index(), 1);
277 assert_eq!(backend6.as_index(), 2);
278 }
279
280 #[test]
281 fn test_backend_load_tracking() {
282 let mut router = BackendSelector::new();
283 let client_id = ClientId::new();
284 let backend_id = BackendId::from_index(0);
285 let provider = create_test_provider();
286
287 router.add_backend(backend_id, "test".to_string(), provider);
288
289 assert_eq!(router.backend_load(backend_id), Some(0));
291
292 router.route_command_sync(client_id, "LIST\r\n").unwrap();
294 assert_eq!(router.backend_load(backend_id), Some(1));
295
296 router.route_command_sync(client_id, "DATE\r\n").unwrap();
298 assert_eq!(router.backend_load(backend_id), Some(2));
299
300 router.complete_command_sync(backend_id);
302 assert_eq!(router.backend_load(backend_id), Some(1));
303
304 router.complete_command_sync(backend_id);
306 assert_eq!(router.backend_load(backend_id), Some(0));
307 }
308
309 #[test]
310 fn test_get_backend_provider() {
311 let mut router = BackendSelector::new();
312 let backend_id = BackendId::from_index(0);
313 let provider = create_test_provider();
314
315 router.add_backend(backend_id, "test".to_string(), provider);
316
317 let retrieved = router.get_backend_provider(backend_id);
318 assert!(retrieved.is_some());
319
320 let fake_id = BackendId::from_index(999);
321 assert!(router.get_backend_provider(fake_id).is_none());
322 }
323
324 #[test]
325 fn test_load_balancing_fairness() {
326 let mut router = BackendSelector::new();
327 let client_id = ClientId::new();
328
329 for i in 0..3 {
331 router.add_backend(
332 BackendId::from_index(i),
333 format!("backend-{}", i),
334 create_test_provider(),
335 );
336 }
337
338 let mut backend_counts = vec![0, 0, 0];
340 for _ in 0..9 {
341 let backend_id = router.route_command_sync(client_id, "LIST\r\n").unwrap();
342 backend_counts[backend_id.as_index()] += 1;
343 }
344
345 assert_eq!(backend_counts, vec![3, 3, 3]);
347 }
348}