Skip to main content

grapsus_proxy/upstream/
sticky_session.rs

1//! Cookie-based sticky session load balancer
2//!
3//! Routes requests to the same backend based on an affinity cookie.
4//! Falls back to a configurable algorithm when no cookie is present
5//! or the target is unavailable.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17use grapsus_common::errors::{GrapsusError, GrapsusResult};
18use grapsus_config::upstreams::StickySessionConfig;
19
20type HmacSha256 = Hmac<Sha256>;
21
22/// Runtime configuration for sticky sessions
23#[derive(Debug, Clone)]
24pub struct StickySessionRuntimeConfig {
25    /// Cookie name for session affinity
26    pub cookie_name: String,
27    /// Cookie TTL in seconds
28    pub cookie_ttl_secs: u64,
29    /// Cookie path
30    pub cookie_path: String,
31    /// Whether to set Secure and HttpOnly flags
32    pub cookie_secure: bool,
33    /// SameSite policy
34    pub cookie_same_site: grapsus_config::upstreams::SameSitePolicy,
35    /// HMAC key for signing cookie values
36    pub hmac_key: [u8; 32],
37}
38
39impl StickySessionRuntimeConfig {
40    /// Create runtime config from parsed config, generating HMAC key
41    pub fn from_config(config: &StickySessionConfig) -> Self {
42        use rand::Rng;
43
44        // Generate random HMAC key
45        let mut hmac_key = [0u8; 32];
46        rand::rng().fill_bytes(&mut hmac_key);
47
48        Self {
49            cookie_name: config.cookie_name.clone(),
50            cookie_ttl_secs: config.cookie_ttl_secs,
51            cookie_path: config.cookie_path.clone(),
52            cookie_secure: config.cookie_secure,
53            cookie_same_site: config.cookie_same_site,
54            hmac_key,
55        }
56    }
57}
58
59/// Cookie-based sticky session load balancer
60///
61/// This balancer wraps a fallback load balancer and adds session affinity
62/// based on cookies. When a client has a valid affinity cookie, requests
63/// are routed to the same backend. Otherwise, the fallback balancer is used
64/// and a new cookie is set.
65pub struct StickySessionBalancer {
66    /// Runtime configuration
67    config: StickySessionRuntimeConfig,
68    /// All upstream targets
69    targets: Vec<UpstreamTarget>,
70    /// Fallback load balancer
71    fallback: Arc<dyn LoadBalancer>,
72    /// Target health status
73    health_status: Arc<RwLock<HashMap<String, bool>>>,
74}
75
76impl StickySessionBalancer {
77    /// Create a new sticky session balancer
78    pub fn new(
79        targets: Vec<UpstreamTarget>,
80        config: StickySessionRuntimeConfig,
81        fallback: Arc<dyn LoadBalancer>,
82    ) -> Self {
83        trace!(
84            target_count = targets.len(),
85            cookie_name = %config.cookie_name,
86            cookie_ttl_secs = config.cookie_ttl_secs,
87            "Creating sticky session balancer"
88        );
89
90        let mut health_status = HashMap::new();
91        for target in &targets {
92            health_status.insert(target.full_address(), true);
93        }
94
95        Self {
96            config,
97            targets,
98            fallback,
99            health_status: Arc::new(RwLock::new(health_status)),
100        }
101    }
102
103    /// Extract and validate sticky cookie from request
104    ///
105    /// Returns the target index if the cookie is valid and properly signed.
106    fn extract_affinity(&self, context: &RequestContext) -> Option<usize> {
107        // Get cookie header
108        let cookie_header = context.headers.get("cookie")?;
109
110        // Parse cookies and find our sticky session cookie
111        let cookie_value = cookie_header.split(';').find_map(|cookie| {
112            let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
113            if parts.len() == 2 && parts[0] == self.config.cookie_name {
114                Some(parts[1].to_string())
115            } else {
116                None
117            }
118        })?;
119
120        // Validate cookie format: "{index}.{signature}"
121        let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
122        if parts.len() != 2 {
123            trace!(
124                cookie_value = %cookie_value,
125                "Invalid sticky cookie format (missing signature)"
126            );
127            return None;
128        }
129
130        let index: usize = parts[0].parse().ok()?;
131        let signature = parts[1];
132
133        // Verify HMAC signature
134        if !self.verify_signature(index, signature) {
135            warn!(
136                cookie_value = %cookie_value,
137                "Invalid sticky cookie signature (possible tampering)"
138            );
139            return None;
140        }
141
142        // Verify index is valid
143        if index >= self.targets.len() {
144            trace!(
145                index = index,
146                target_count = self.targets.len(),
147                "Sticky cookie index out of bounds"
148            );
149            return None;
150        }
151
152        trace!(
153            cookie_name = %self.config.cookie_name,
154            target_index = index,
155            "Extracted valid sticky session affinity"
156        );
157
158        Some(index)
159    }
160
161    /// Generate signed cookie value for target
162    pub fn generate_cookie_value(&self, target_index: usize) -> String {
163        let signature = self.sign_index(target_index);
164        format!("{}.{}", target_index, signature)
165    }
166
167    /// Generate full Set-Cookie header value
168    pub fn generate_set_cookie_header(&self, target_index: usize) -> String {
169        let cookie_value = self.generate_cookie_value(target_index);
170
171        let mut header = format!(
172            "{}={}; Path={}; Max-Age={}",
173            self.config.cookie_name,
174            cookie_value,
175            self.config.cookie_path,
176            self.config.cookie_ttl_secs
177        );
178
179        if self.config.cookie_secure {
180            header.push_str("; HttpOnly; Secure");
181        }
182
183        header.push_str(&format!("; SameSite={}", self.config.cookie_same_site));
184
185        header
186    }
187
188    /// Sign target index with HMAC-SHA256
189    fn sign_index(&self, index: usize) -> String {
190        let mut mac =
191            HmacSha256::new_from_slice(&self.config.hmac_key).expect("HMAC key length is valid");
192        mac.update(index.to_string().as_bytes());
193        let result = mac.finalize();
194        // Use first 8 bytes of signature (16 hex chars) for compactness
195        hex::encode(&result.into_bytes()[..8])
196    }
197
198    /// Verify HMAC signature for target index
199    fn verify_signature(&self, index: usize, signature: &str) -> bool {
200        let expected = self.sign_index(index);
201        // Constant-time comparison
202        expected == signature
203    }
204
205    /// Check if target at index is healthy
206    async fn is_target_healthy(&self, index: usize) -> bool {
207        if index >= self.targets.len() {
208            return false;
209        }
210
211        let target = &self.targets[index];
212        let health = self.health_status.read().await;
213        *health.get(&target.full_address()).unwrap_or(&true)
214    }
215
216    /// Find target index by address
217    fn find_target_index(&self, address: &str) -> Option<usize> {
218        self.targets
219            .iter()
220            .position(|t| t.full_address() == address)
221    }
222
223    /// Get the cookie name
224    pub fn cookie_name(&self) -> &str {
225        &self.config.cookie_name
226    }
227
228    /// Get the config for Set-Cookie header generation
229    pub fn config(&self) -> &StickySessionRuntimeConfig {
230        &self.config
231    }
232}
233
234#[async_trait]
235impl LoadBalancer for StickySessionBalancer {
236    async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
237        trace!(
238            has_context = context.is_some(),
239            cookie_name = %self.config.cookie_name,
240            "Sticky session select called"
241        );
242
243        // Try to extract affinity from cookie
244        if let Some(ctx) = context {
245            if let Some(target_index) = self.extract_affinity(ctx) {
246                // Check if target is healthy
247                if self.is_target_healthy(target_index).await {
248                    let target = &self.targets[target_index];
249
250                    debug!(
251                        target = %target.full_address(),
252                        target_index = target_index,
253                        cookie_name = %self.config.cookie_name,
254                        "Sticky session hit - routing to affinity target"
255                    );
256
257                    return Ok(TargetSelection {
258                        address: target.full_address(),
259                        weight: target.weight,
260                        metadata: {
261                            let mut meta = HashMap::new();
262                            meta.insert("sticky_session_hit".to_string(), "true".to_string());
263                            meta.insert(
264                                "sticky_target_index".to_string(),
265                                target_index.to_string(),
266                            );
267                            meta.insert("algorithm".to_string(), "sticky_session".to_string());
268                            meta
269                        },
270                    });
271                }
272
273                debug!(
274                    target_index = target_index,
275                    cookie_name = %self.config.cookie_name,
276                    "Sticky target unhealthy, falling back to load balancer"
277                );
278            }
279        }
280
281        // No valid cookie or target unavailable - use fallback
282        let mut selection = self.fallback.select(context).await?;
283
284        // Find target index for the selected address
285        let target_index = self.find_target_index(&selection.address);
286
287        if let Some(index) = target_index {
288            // Mark that we need to set a new cookie
289            selection
290                .metadata
291                .insert("sticky_session_new".to_string(), "true".to_string());
292            selection
293                .metadata
294                .insert("sticky_target_index".to_string(), index.to_string());
295            selection.metadata.insert(
296                "sticky_cookie_value".to_string(),
297                self.generate_cookie_value(index),
298            );
299            selection.metadata.insert(
300                "sticky_set_cookie_header".to_string(),
301                self.generate_set_cookie_header(index),
302            );
303
304            debug!(
305                target = %selection.address,
306                target_index = index,
307                cookie_name = %self.config.cookie_name,
308                "New sticky session assignment, will set cookie"
309            );
310        }
311
312        selection
313            .metadata
314            .insert("algorithm".to_string(), "sticky_session".to_string());
315
316        Ok(selection)
317    }
318
319    async fn report_health(&self, address: &str, healthy: bool) {
320        trace!(
321            target = %address,
322            healthy = healthy,
323            algorithm = "sticky_session",
324            "Updating target health status"
325        );
326
327        // Update local health status
328        self.health_status
329            .write()
330            .await
331            .insert(address.to_string(), healthy);
332
333        // Propagate to fallback balancer
334        self.fallback.report_health(address, healthy).await;
335    }
336
337    async fn healthy_targets(&self) -> Vec<String> {
338        // Delegate to fallback balancer for consistency
339        self.fallback.healthy_targets().await
340    }
341
342    async fn release(&self, selection: &TargetSelection) {
343        // Delegate to fallback balancer
344        self.fallback.release(selection).await;
345    }
346
347    async fn report_result(
348        &self,
349        selection: &TargetSelection,
350        success: bool,
351        latency: Option<std::time::Duration>,
352    ) {
353        // Delegate to fallback balancer
354        self.fallback
355            .report_result(selection, success, latency)
356            .await;
357    }
358
359    async fn report_result_with_latency(
360        &self,
361        address: &str,
362        success: bool,
363        latency: Option<std::time::Duration>,
364    ) {
365        // Delegate to fallback balancer
366        self.fallback
367            .report_result_with_latency(address, success, latency)
368            .await;
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
377        (0..count)
378            .map(|i| UpstreamTarget {
379                address: format!("10.0.0.{}", i + 1),
380                port: 8080,
381                weight: 100,
382            })
383            .collect()
384    }
385
386    fn create_test_config() -> StickySessionRuntimeConfig {
387        StickySessionRuntimeConfig {
388            cookie_name: "SERVERID".to_string(),
389            cookie_ttl_secs: 3600,
390            cookie_path: "/".to_string(),
391            cookie_secure: true,
392            cookie_same_site: grapsus_config::upstreams::SameSitePolicy::Lax,
393            hmac_key: [42u8; 32], // Fixed key for testing
394        }
395    }
396
397    #[test]
398    fn test_cookie_generation_and_validation() {
399        let targets = create_test_targets(3);
400        let config = create_test_config();
401
402        // Create a mock fallback balancer
403        struct MockBalancer;
404
405        #[async_trait]
406        impl LoadBalancer for MockBalancer {
407            async fn select(
408                &self,
409                _context: Option<&RequestContext>,
410            ) -> GrapsusResult<TargetSelection> {
411                Ok(TargetSelection {
412                    address: "10.0.0.1:8080".to_string(),
413                    weight: 100,
414                    metadata: HashMap::new(),
415                })
416            }
417            async fn report_health(&self, _address: &str, _healthy: bool) {}
418            async fn healthy_targets(&self) -> Vec<String> {
419                vec![]
420            }
421        }
422
423        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
424
425        // Test cookie value generation
426        let cookie_value = balancer.generate_cookie_value(1);
427        assert!(cookie_value.starts_with("1."));
428        assert_eq!(cookie_value.len(), 2 + 16); // "1." + 16 hex chars
429
430        // Test signature verification
431        let parts: Vec<&str> = cookie_value.splitn(2, '.').collect();
432        assert!(balancer.verify_signature(1, parts[1]));
433
434        // Test invalid signature
435        assert!(!balancer.verify_signature(1, "invalid"));
436        assert!(!balancer.verify_signature(2, parts[1])); // Wrong index
437    }
438
439    #[test]
440    fn test_set_cookie_header_generation() {
441        let targets = create_test_targets(3);
442        let config = create_test_config();
443
444        struct MockBalancer;
445
446        #[async_trait]
447        impl LoadBalancer for MockBalancer {
448            async fn select(
449                &self,
450                _context: Option<&RequestContext>,
451            ) -> GrapsusResult<TargetSelection> {
452                unreachable!()
453            }
454            async fn report_health(&self, _address: &str, _healthy: bool) {}
455            async fn healthy_targets(&self) -> Vec<String> {
456                vec![]
457            }
458        }
459
460        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
461
462        let header = balancer.generate_set_cookie_header(0);
463        assert!(header.starts_with("SERVERID=0."));
464        assert!(header.contains("Path=/"));
465        assert!(header.contains("Max-Age=3600"));
466        assert!(header.contains("HttpOnly"));
467        assert!(header.contains("Secure"));
468        assert!(header.contains("SameSite=Lax"));
469    }
470
471    #[tokio::test]
472    async fn test_sticky_session_hit() {
473        let targets = create_test_targets(3);
474        let config = create_test_config();
475
476        struct MockBalancer;
477
478        #[async_trait]
479        impl LoadBalancer for MockBalancer {
480            async fn select(
481                &self,
482                _context: Option<&RequestContext>,
483            ) -> GrapsusResult<TargetSelection> {
484                // Should not be called when we have valid cookie
485                panic!("Fallback should not be called for sticky hit");
486            }
487            async fn report_health(&self, _address: &str, _healthy: bool) {}
488            async fn healthy_targets(&self) -> Vec<String> {
489                vec![
490                    "10.0.0.1:8080".to_string(),
491                    "10.0.0.2:8080".to_string(),
492                    "10.0.0.3:8080".to_string(),
493                ]
494            }
495        }
496
497        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
498
499        // Generate a valid cookie for target 1
500        let cookie_value = balancer.generate_cookie_value(1);
501
502        // Create context with sticky cookie
503        let mut headers = HashMap::new();
504        headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
505
506        let context = RequestContext {
507            client_ip: None,
508            headers,
509            path: "/".to_string(),
510            method: "GET".to_string(),
511        };
512
513        let selection = balancer.select(Some(&context)).await.unwrap();
514
515        // Should route to target 1 (10.0.0.2:8080)
516        assert_eq!(selection.address, "10.0.0.2:8080");
517        assert_eq!(
518            selection.metadata.get("sticky_session_hit"),
519            Some(&"true".to_string())
520        );
521        assert_eq!(
522            selection.metadata.get("sticky_target_index"),
523            Some(&"1".to_string())
524        );
525    }
526
527    #[tokio::test]
528    async fn test_sticky_session_miss_sets_cookie() {
529        let targets = create_test_targets(3);
530        let config = create_test_config();
531
532        struct MockBalancer;
533
534        #[async_trait]
535        impl LoadBalancer for MockBalancer {
536            async fn select(
537                &self,
538                _context: Option<&RequestContext>,
539            ) -> GrapsusResult<TargetSelection> {
540                Ok(TargetSelection {
541                    address: "10.0.0.2:8080".to_string(),
542                    weight: 100,
543                    metadata: HashMap::new(),
544                })
545            }
546            async fn report_health(&self, _address: &str, _healthy: bool) {}
547            async fn healthy_targets(&self) -> Vec<String> {
548                vec!["10.0.0.2:8080".to_string()]
549            }
550        }
551
552        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
553
554        // Create context without sticky cookie
555        let context = RequestContext {
556            client_ip: None,
557            headers: HashMap::new(),
558            path: "/".to_string(),
559            method: "GET".to_string(),
560        };
561
562        let selection = balancer.select(Some(&context)).await.unwrap();
563
564        // Should use fallback and mark for cookie setting
565        assert_eq!(selection.address, "10.0.0.2:8080");
566        assert_eq!(
567            selection.metadata.get("sticky_session_new"),
568            Some(&"true".to_string())
569        );
570        assert!(selection.metadata.contains_key("sticky_cookie_value"));
571        assert!(selection.metadata.contains_key("sticky_set_cookie_header"));
572    }
573
574    #[tokio::test]
575    async fn test_unhealthy_target_falls_back() {
576        let targets = create_test_targets(3);
577        let config = create_test_config();
578
579        struct MockBalancer;
580
581        #[async_trait]
582        impl LoadBalancer for MockBalancer {
583            async fn select(
584                &self,
585                _context: Option<&RequestContext>,
586            ) -> GrapsusResult<TargetSelection> {
587                Ok(TargetSelection {
588                    address: "10.0.0.3:8080".to_string(), // Different target
589                    weight: 100,
590                    metadata: HashMap::new(),
591                })
592            }
593            async fn report_health(&self, _address: &str, _healthy: bool) {}
594            async fn healthy_targets(&self) -> Vec<String> {
595                vec!["10.0.0.3:8080".to_string()]
596            }
597        }
598
599        let balancer = StickySessionBalancer::new(targets, config, Arc::new(MockBalancer));
600
601        // Mark target 1 as unhealthy
602        balancer.report_health("10.0.0.2:8080", false).await;
603
604        // Generate cookie for unhealthy target 1
605        let cookie_value = balancer.generate_cookie_value(1);
606
607        let mut headers = HashMap::new();
608        headers.insert("cookie".to_string(), format!("SERVERID={}", cookie_value));
609
610        let context = RequestContext {
611            client_ip: None,
612            headers,
613            path: "/".to_string(),
614            method: "GET".to_string(),
615        };
616
617        let selection = balancer.select(Some(&context)).await.unwrap();
618
619        // Should fall back to another target and set new cookie
620        assert_eq!(selection.address, "10.0.0.3:8080");
621        assert_eq!(
622            selection.metadata.get("sticky_session_new"),
623            Some(&"true".to_string())
624        );
625    }
626}