Skip to main content

praxis_core/config/cluster/
load_balancer_strategy.rs

1// SPDX-License-Identifier: LGPL-3.0-only
2// Copyright (c) 2024 Shane Utt
3
4//! Load-balancing strategy types for upstream clusters.
5
6use serde::{Deserialize, Serialize};
7
8// -----------------------------------------------------------------------------
9// LoadBalancerStrategy
10// -----------------------------------------------------------------------------
11
12/// Load-balancing algorithm used by a cluster.
13#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
14#[serde(untagged)]
15pub enum LoadBalancerStrategy {
16    /// Plain-string strategies: `"round_robin"` or `"least_connections"`.
17    Simple(SimpleStrategy),
18
19    /// Consistent-hash strategy with an optional hash-key header.
20    Parameterised(ParameterisedStrategy),
21}
22
23impl Default for LoadBalancerStrategy {
24    fn default() -> Self {
25        Self::Simple(SimpleStrategy::RoundRobin)
26    }
27}
28
29/// String-serialisable load-balancing strategies.
30#[derive(Debug, Clone, Default, Deserialize, PartialEq, Eq, Serialize)]
31#[serde(rename_all = "snake_case")]
32pub enum SimpleStrategy {
33    /// Cycle through endpoints in order, respecting weights.
34    #[default]
35    RoundRobin,
36
37    /// Pick the endpoint with the fewest active in-flight requests.
38    LeastConnections,
39}
40
41/// Load-balancing strategies that carry parameters.
42#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
43pub enum ParameterisedStrategy {
44    /// Hash a request attribute to route requests to a stable endpoint.
45    #[serde(rename = "consistent_hash")]
46    ConsistentHash(ConsistentHashOpts),
47}
48
49/// Options for the `consistent_hash` load-balancing strategy.
50#[derive(Debug, Clone, Default, Deserialize, PartialEq, Eq, Serialize)]
51pub struct ConsistentHashOpts {
52    /// Name of the request header to use as the hash key.
53    ///
54    /// Falls back to the request URI path when the header is absent or when this field is `None`.
55    #[serde(default)]
56    pub header: Option<String>,
57}
58
59// -----------------------------------------------------------------------------
60// Tests
61// -----------------------------------------------------------------------------
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn load_balancer_strategy_defaults_to_round_robin() {
69        assert_eq!(
70            LoadBalancerStrategy::default(),
71            LoadBalancerStrategy::Simple(SimpleStrategy::RoundRobin),
72            "default strategy should be round_robin"
73        );
74    }
75
76    #[test]
77    fn load_balancer_strategy_parses_round_robin() {
78        let yaml = "round_robin";
79        let strategy: LoadBalancerStrategy = serde_yaml::from_str(yaml).unwrap();
80        assert_eq!(
81            strategy,
82            LoadBalancerStrategy::Simple(SimpleStrategy::RoundRobin),
83            "should parse 'round_robin' string"
84        );
85    }
86
87    #[test]
88    fn load_balancer_strategy_parses_least_connections() {
89        let yaml = "least_connections";
90        let strategy: LoadBalancerStrategy = serde_yaml::from_str(yaml).unwrap();
91        assert_eq!(
92            strategy,
93            LoadBalancerStrategy::Simple(SimpleStrategy::LeastConnections),
94            "should parse 'least_connections' string"
95        );
96    }
97
98    #[test]
99    fn load_balancer_strategy_parses_consistent_hash() {
100        let yaml = r#"
101consistent_hash:
102  header: "X-User-Id"
103"#;
104        let strategy: LoadBalancerStrategy = serde_yaml::from_str(yaml).unwrap();
105        assert_eq!(
106            strategy,
107            LoadBalancerStrategy::Parameterised(ParameterisedStrategy::ConsistentHash(ConsistentHashOpts {
108                header: Some("X-User-Id".into()),
109            })),
110            "should parse consistent_hash with header"
111        );
112    }
113
114    #[test]
115    fn consistent_hash_without_header() {
116        let yaml = "consistent_hash: {}";
117        let strategy: LoadBalancerStrategy = serde_yaml::from_str(yaml).unwrap();
118        assert_eq!(
119            strategy,
120            LoadBalancerStrategy::Parameterised(ParameterisedStrategy::ConsistentHash(ConsistentHashOpts {
121                header: None,
122            })),
123            "should parse consistent_hash with no header"
124        );
125    }
126}