1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17use grapsus_common::errors::{GrapsusError, GrapsusResult};
18use grapsus_config::upstreams::StickySessionConfig;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
24pub struct StickySessionRuntimeConfig {
25 pub cookie_name: String,
27 pub cookie_ttl_secs: u64,
29 pub cookie_path: String,
31 pub cookie_secure: bool,
33 pub cookie_same_site: grapsus_config::upstreams::SameSitePolicy,
35 pub hmac_key: [u8; 32],
37}
38
39impl StickySessionRuntimeConfig {
40 pub fn from_config(config: &StickySessionConfig) -> Self {
42 use rand::Rng;
43
44 let mut hmac_key = [0u8; 32];
46 rand::rng().fill_bytes(&mut hmac_key);
47
48 Self {
49 cookie_name: config.cookie_name.clone(),
50 cookie_ttl_secs: config.cookie_ttl_secs,
51 cookie_path: config.cookie_path.clone(),
52 cookie_secure: config.cookie_secure,
53 cookie_same_site: config.cookie_same_site,
54 hmac_key,
55 }
56 }
57}
58
59pub struct StickySessionBalancer {
66 config: StickySessionRuntimeConfig,
68 targets: Vec<UpstreamTarget>,
70 fallback: Arc<dyn LoadBalancer>,
72 health_status: Arc<RwLock<HashMap<String, bool>>>,
74}
75
76impl StickySessionBalancer {
77 pub fn new(
79 targets: Vec<UpstreamTarget>,
80 config: StickySessionRuntimeConfig,
81 fallback: Arc<dyn LoadBalancer>,
82 ) -> Self {
83 trace!(
84 target_count = targets.len(),
85 cookie_name = %config.cookie_name,
86 cookie_ttl_secs = config.cookie_ttl_secs,
87 "Creating sticky session balancer"
88 );
89
90 let mut health_status = HashMap::new();
91 for target in &targets {
92 health_status.insert(target.full_address(), true);
93 }
94
95 Self {
96 config,
97 targets,
98 fallback,
99 health_status: Arc::new(RwLock::new(health_status)),
100 }
101 }
102
103 fn extract_affinity(&self, context: &RequestContext) -> Option<usize> {
107 let cookie_header = context.headers.get("cookie")?;
109
110 let cookie_value = cookie_header.split(';').find_map(|cookie| {
112 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
113 if parts.len() == 2 && parts[0] == self.config.cookie_name {
114 Some(parts[1].to_string())
115 } else {
116 None
117 }
118 })?;
119
120 let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
122 if parts.len() != 2 {
123 trace!(
124 cookie_value = %cookie_value,
125 "Invalid sticky cookie format (missing signature)"
126 );
127 return None;
128 }
129
130 let index: usize = parts[0].parse().ok()?;
131 let signature = parts[1];
132
133 if !self.verify_signature(index, signature) {
135 warn!(
136 cookie_value = %cookie_value,
137 "Invalid sticky cookie signature (possible tampering)"
138 );
139 return None;
140 }
141
142 if index >= self.targets.len() {
144 trace!(
145 index = index,
146 target_count = self.targets.len(),
147 "Sticky cookie index out of bounds"
148 );
149 return None;
150 }
151
152 trace!(
153 cookie_name = %self.config.cookie_name,
154 target_index = index,
155 "Extracted valid sticky session affinity"
156 );
157
158 Some(index)
159 }
160
161 pub fn generate_cookie_value(&self, target_index: usize) -> String {
163 let signature = self.sign_index(target_index);
164 format!("{}.{}", target_index, signature)
165 }
166
167 pub fn generate_set_cookie_header(&self, target_index: usize) -> String {
169 let cookie_value = self.generate_cookie_value(target_index);
170
171 let mut header = format!(
172 "{}={}; Path={}; Max-Age={}",
173 self.config.cookie_name,
174 cookie_value,
175 self.config.cookie_path,
176 self.config.cookie_ttl_secs
177 );
178
179 if self.config.cookie_secure {
180 header.push_str("; HttpOnly; Secure");
181 }
182
183 header.push_str(&format!("; SameSite={}", self.config.cookie_same_site));
184
185 header
186 }
187
188 fn sign_index(&self, index: usize) -> String {
190 let mut mac =
191 HmacSha256::new_from_slice(&self.config.hmac_key).expect("HMAC key length is valid");
192 mac.update(index.to_string().as_bytes());
193 let result = mac.finalize();
194 hex::encode(&result.into_bytes()[..8])
196 }
197
198 fn verify_signature(&self, index: usize, signature: &str) -> bool {
200 let expected = self.sign_index(index);
201 expected == signature
203 }
204
205 async fn is_target_healthy(&self, index: usize) -> bool {
207 if index >= self.targets.len() {
208 return false;
209 }
210
211 let target = &self.targets[index];
212 let health = self.health_status.read().await;
213 *health.get(&target.full_address()).unwrap_or(&true)
214 }
215
216 fn find_target_index(&self, address: &str) -> Option<usize> {
218 self.targets
219 .iter()
220 .position(|t| t.full_address() == address)
221 }
222
223 pub fn cookie_name(&self) -> &str {
225 &self.config.cookie_name
226 }
227
228 pub fn config(&self) -> &StickySessionRuntimeConfig {
230 &self.config
231 }
232}
233
234#[async_trait]
235impl LoadBalancer for StickySessionBalancer {
236 async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
237 trace!(
238 has_context = context.is_some(),
239 cookie_name = %self.config.cookie_name,
240 "Sticky session select called"
241 );
242
243 if let Some(ctx) = context {
245 if let Some(target_index) = self.extract_affinity(ctx) {
246 if self.is_target_healthy(target_index).await {
248 let target = &self.targets[target_index];
249
250 debug!(
251 target = %target.full_address(),
252 target_index = target_index,
253 cookie_name = %self.config.cookie_name,
254 "Sticky session hit - routing to affinity target"
255 );
256
257 return Ok(TargetSelection {
258 address: target.full_address(),
259 weight: target.weight,
260 metadata: {
261 let mut meta = HashMap::new();
262 meta.insert("sticky_session_hit".to_string(), "true".to_string());
263 meta.insert(
264 "sticky_target_index".to_string(),
265 target_index.to_string(),
266 );
267 meta.insert("algorithm".to_string(), "sticky_session".to_string());
268 meta
269 },
270 });
271 }
272
273 debug!(
274 target_index = target_index,
275 cookie_name = %self.config.cookie_name,
276 "Sticky target unhealthy, falling back to load balancer"
277 );
278 }
279 }
280
281 let mut selection = self.fallback.select(context).await?;
283
284 let target_index = self.find_target_index(&selection.address);
286
287 if let Some(index) = target_index {
288 selection
290 .metadata
291 .insert("sticky_session_new".to_string(), "true".to_string());
292 selection
293 .metadata
294 .insert("sticky_target_index".to_string(), index.to_string());
295 selection.metadata.insert(
296 "sticky_cookie_value".to_string(),
297 self.generate_cookie_value(index),
298 );
299 selection.metadata.insert(
300 "sticky_set_cookie_header".to_string(),
301 self.generate_set_cookie_header(index),
302 );
303
304 debug!(
305 target = %selection.address,
306 target_index = index,
307 cookie_name = %self.config.cookie_name,
308 "New sticky session assignment, will set cookie"
309 );
310 }
311
312 selection
313 .metadata
314 .insert("algorithm".to_string(), "sticky_session".to_string());
315
316 Ok(selection)
317 }
318
319 async fn report_health(&self, address: &str, healthy: bool) {
320 trace!(
321 target = %address,
322 healthy = healthy,
323 algorithm = "sticky_session",
324 "Updating target health status"
325 );
326
327 self.health_status
329 .write()
330 .await
331 .insert(address.to_string(), healthy);
332
333 self.fallback.report_health(address, healthy).await;
335 }
336
337 async fn healthy_targets(&self) -> Vec<String> {
338 self.fallback.healthy_targets().await
340 }
341
342 async fn release(&self, selection: &TargetSelection) {
343 self.fallback.release(selection).await;
345 }
346
347 async fn report_result(
348 &self,
349 selection: &TargetSelection,
350 success: bool,
351 latency: Option<std::time::Duration>,
352 ) {
353 self.fallback
355 .report_result(selection, success, latency)
356 .await;
357 }
358
359 async fn report_result_with_latency(
360 &self,
361 address: &str,
362 success: bool,
363 latency: Option<std::time::Duration>,
364 ) {
365 self.fallback
367 .report_result_with_latency(address, success, latency)
368 .await;
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
377 (0..count)
378 .map(|i| UpstreamTarget {
379 address: format!("10.0.0.{}", i + 1),
380 port: 8080,
381 weight: 100,
382 })
383 .collect()
384 }
385
386 fn create_test_config() -> StickySessionRuntimeConfig {
387 StickySessionRuntimeConfig {
388 cookie_name: "SERVERID".to_string(),
389 cookie_ttl_secs: 3600,
390 cookie_path: "/".to_string(),
391 cookie_secure: true,
392 cookie_same_site: grapsus_config::upstreams::SameSitePolicy::Lax,
393 hmac_key: [42u8; 32], }
395 }
396
397 #[test]
398 fn test_cookie_generation_and_validation() {
399 let targets = create_test_targets(3);
400 let config = create_test_config();
401
402 struct MockBalancer;
404
405 #[async_trait]
406 impl LoadBalancer for MockBalancer {
407 async fn select(
408 &self,
409 _context: Option<&RequestContext>,
410 ) -> GrapsusResult<TargetSelection> {
411 Ok(TargetSelection {
412 address: "10.0.0.1:8080".to_string(),
413 weight: 100,
414 metadata: HashMap::new(),
415 })
416 }
417 async fn report_health(&self, _address: &str, _healthy: bool) {}
418 async fn healthy_targets(&self) -> Vec<String> {
419 vec![]
420 }
421 }
422
423 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
424
425 let cookie_value = balancer.generate_cookie_value(1);
427 assert!(cookie_value.starts_with("1."));
428 assert_eq!(cookie_value.len(), 2 + 16); let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
432 assert!(balancer.verify_signature(1, parts[1]));
433
434 assert!(!balancer.verify_signature(1, "invalid"));
436 assert!(!balancer.verify_signature(2, parts[1])); }
438
439 #[test]
440 fn test_set_cookie_header_generation() {
441 let targets = create_test_targets(3);
442 let config = create_test_config();
443
444 struct MockBalancer;
445
446 #[async_trait]
447 impl LoadBalancer for MockBalancer {
448 async fn select(
449 &self,
450 _context: Option<&RequestContext>,
451 ) -> GrapsusResult<TargetSelection> {
452 unreachable!()
453 }
454 async fn report_health(&self, _address: &str, _healthy: bool) {}
455 async fn healthy_targets(&self) -> Vec<String> {
456 vec![]
457 }
458 }
459
460 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
461
462 let header = balancer.generate_set_cookie_header(0);
463 assert!(header.starts_with("SERVERID=0."));
464 assert!(header.contains("Path=/"));
465 assert!(header.contains("Max-Age=3600"));
466 assert!(header.contains("HttpOnly"));
467 assert!(header.contains("Secure"));
468 assert!(header.contains("SameSite=Lax"));
469 }
470
471 #[tokio::test]
472 async fn test_sticky_session_hit() {
473 let targets = create_test_targets(3);
474 let config = create_test_config();
475
476 struct MockBalancer;
477
478 #[async_trait]
479 impl LoadBalancer for MockBalancer {
480 async fn select(
481 &self,
482 _context: Option<&RequestContext>,
483 ) -> GrapsusResult<TargetSelection> {
484 panic!("Fallback should not be called for sticky hit");
486 }
487 async fn report_health(&self, _address: &str, _healthy: bool) {}
488 async fn healthy_targets(&self) -> Vec<String> {
489 vec![
490 "10.0.0.1:8080".to_string(),
491 "10.0.0.2:8080".to_string(),
492 "10.0.0.3:8080".to_string(),
493 ]
494 }
495 }
496
497 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
498
499 let cookie_value = balancer.generate_cookie_value(1);
501
502 let mut headers = HashMap::new();
504 headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
505
506 let context = RequestContext {
507 client_ip: None,
508 headers,
509 path: "/".to_string(),
510 method: "GET".to_string(),
511 };
512
513 let selection = balancer.select(Some(&context)).await.unwrap();
514
515 assert_eq!(selection.address, "10.0.0.2:8080");
517 assert_eq!(
518 selection.metadata.get("sticky_session_hit"),
519 Some(&"true".to_string())
520 );
521 assert_eq!(
522 selection.metadata.get("sticky_target_index"),
523 Some(&"1".to_string())
524 );
525 }
526
527 #[tokio::test]
528 async fn test_sticky_session_miss_sets_cookie() {
529 let targets = create_test_targets(3);
530 let config = create_test_config();
531
532 struct MockBalancer;
533
534 #[async_trait]
535 impl LoadBalancer for MockBalancer {
536 async fn select(
537 &self,
538 _context: Option<&RequestContext>,
539 ) -> GrapsusResult<TargetSelection> {
540 Ok(TargetSelection {
541 address: "10.0.0.2:8080".to_string(),
542 weight: 100,
543 metadata: HashMap::new(),
544 })
545 }
546 async fn report_health(&self, _address: &str, _healthy: bool) {}
547 async fn healthy_targets(&self) -> Vec<String> {
548 vec!["10.0.0.2:8080".to_string()]
549 }
550 }
551
552 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
553
554 let context = RequestContext {
556 client_ip: None,
557 headers: HashMap::new(),
558 path: "/".to_string(),
559 method: "GET".to_string(),
560 };
561
562 let selection = balancer.select(Some(&context)).await.unwrap();
563
564 assert_eq!(selection.address, "10.0.0.2:8080");
566 assert_eq!(
567 selection.metadata.get("sticky_session_new"),
568 Some(&"true".to_string())
569 );
570 assert!(selection.metadata.contains_key("sticky_cookie_value"));
571 assert!(selection.metadata.contains_key("sticky_set_cookie_header"));
572 }
573
574 #[tokio::test]
575 async fn test_unhealthy_target_falls_back() {
576 let targets = create_test_targets(3);
577 let config = create_test_config();
578
579 struct MockBalancer;
580
581 #[async_trait]
582 impl LoadBalancer for MockBalancer {
583 async fn select(
584 &self,
585 _context: Option<&RequestContext>,
586 ) -> GrapsusResult<TargetSelection> {
587 Ok(TargetSelection {
588 address: "10.0.0.3:8080".to_string(), weight: 100,
590 metadata: HashMap::new(),
591 })
592 }
593 async fn report_health(&self, _address: &str, _healthy: bool) {}
594 async fn healthy_targets(&self) -> Vec<String> {
595 vec!["10.0.0.3:8080".to_string()]
596 }
597 }
598
599 let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
600
601 balancer.report_health("10.0.0.2:8080", false).await;
603
604 let cookie_value = balancer.generate_cookie_value(1);
606
607 let mut headers = HashMap::new();
608 headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
609
610 let context = RequestContext {
611 client_ip: None,
612 headers,
613 path: "/".to_string(),
614 method: "GET".to_string(),
615 };
616
617 let selection = balancer.select(Some(&context)).await.unwrap();
618
619 assert_eq!(selection.address, "10.0.0.3:8080");
621 assert_eq!(
622 selection.metadata.get("sticky_session_new"),
623 Some(&"true".to_string())
624 );
625 }
626}