Skip to main content

praxis_core/config/cluster/
mod.rs

1// SPDX-License-Identifier: LGPL-3.0-only
2// Copyright (c) 2024 Shane Utt
3
4//! Upstream cluster definitions: endpoints, load-balancing strategies, and timeouts.
5
6mod endpoint;
7mod health_check;
8mod load_balancer_strategy;
9
10use std::sync::Arc;
11
12pub use endpoint::Endpoint;
13pub use health_check::{HealthCheckConfig, HealthCheckType};
14pub use load_balancer_strategy::{ConsistentHashOpts, LoadBalancerStrategy, ParameterisedStrategy, SimpleStrategy};
15use serde::{Deserialize, Serialize};
16
17// -----------------------------------------------------------------------------
18// Cluster
19// -----------------------------------------------------------------------------
20
21/// A named group of upstream endpoints.
22///
23/// ```
24/// # use praxis_core::config::Cluster;
25/// let yaml = r#"
26/// name: "backend"
27/// endpoints: ["10.0.0.1:8080"]
28/// connection_timeout_ms: 5000
29/// idle_timeout_ms: 30000
30/// "#;
31/// let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
32/// assert_eq!(cluster.connection_timeout_ms, Some(5000));
33/// assert_eq!(cluster.idle_timeout_ms, Some(30000));
34/// assert!(cluster.read_timeout_ms.is_none());
35/// assert!(cluster.tls.is_none());
36/// ```
37#[derive(Debug, Clone, Deserialize, Serialize)]
38pub struct Cluster {
39    /// Unique name for the cluster.
40    pub name: Arc<str>,
41
42    /// TCP connection timeout in milliseconds.
43    #[serde(default)]
44    pub connection_timeout_ms: Option<u64>,
45
46    /// List of endpoints for the cluster. Each entry is either a plain
47    /// `"host:port"` string or a `{ address, weight }` object.
48    pub endpoints: Vec<Endpoint>,
49
50    /// Active health check configuration for this cluster.
51    #[serde(default)]
52    pub health_check: Option<HealthCheckConfig>,
53
54    /// Idle connection timeout in milliseconds.
55    #[serde(default)]
56    pub idle_timeout_ms: Option<u64>,
57
58    /// Load-balancing algorithm for this cluster. Defaults to `round_robin`.
59    #[serde(default)]
60    pub load_balancer_strategy: LoadBalancerStrategy,
61
62    /// Read timeout in milliseconds.
63    #[serde(default)]
64    pub read_timeout_ms: Option<u64>,
65
66    /// TLS settings for upstream connections.
67    ///
68    /// Presence implies TLS is enabled. Omit for plaintext HTTP.
69    #[serde(default)]
70    pub tls: Option<praxis_tls::ClusterTls>,
71
72    /// Total connection timeout in milliseconds (TCP + TLS).
73    #[serde(default)]
74    pub total_connection_timeout_ms: Option<u64>,
75
76    /// Write timeout in milliseconds.
77    #[serde(default)]
78    pub write_timeout_ms: Option<u64>,
79}
80
81impl Cluster {
82    /// Build a cluster with only a name and endpoints; all other
83    /// fields use their defaults (no timeouts, no TLS, no health
84    /// check, `round_robin` strategy).
85    ///
86    /// ```
87    /// use praxis_core::config::Cluster;
88    /// use praxis_tls::ClusterTls;
89    ///
90    /// let c = Cluster {
91    ///     tls: Some(ClusterTls::default()),
92    ///     ..Cluster::with_defaults("backend", vec!["10.0.0.1:443".into()])
93    /// };
94    /// assert_eq!(&*c.name, "backend");
95    /// assert!(c.tls.is_some());
96    /// assert!(c.tls.as_ref().unwrap().verify);
97    /// ```
98    pub fn with_defaults(name: &str, endpoints: Vec<Endpoint>) -> Self {
99        Self {
100            connection_timeout_ms: None,
101            endpoints,
102            health_check: None,
103            idle_timeout_ms: None,
104            load_balancer_strategy: LoadBalancerStrategy::default(),
105            name: Arc::from(name),
106            read_timeout_ms: None,
107            tls: None,
108            total_connection_timeout_ms: None,
109            write_timeout_ms: None,
110        }
111    }
112}
113
114// -----------------------------------------------------------------------------
115// Tests
116// -----------------------------------------------------------------------------
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn parse_cluster_minimal() {
124        let yaml = r#"
125name: "backend"
126endpoints: ["10.0.0.1:8080"]
127"#;
128        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
129        assert_eq!(&*cluster.name, "backend", "cluster name mismatch");
130        assert_eq!(
131            cluster.endpoints[0].address(),
132            "10.0.0.1:8080",
133            "endpoint address mismatch"
134        );
135        assert_eq!(cluster.endpoints[0].weight(), 1, "default weight should be 1");
136        assert_eq!(
137            cluster.load_balancer_strategy,
138            LoadBalancerStrategy::default(),
139            "strategy should default"
140        );
141        assert!(
142            cluster.connection_timeout_ms.is_none(),
143            "connection_timeout should default to None"
144        );
145    }
146
147    #[test]
148    fn parse_cluster_with_weights() {
149        let yaml = r#"
150name: "backend"
151endpoints:
152  - "10.0.0.1:8080"
153  - address: "10.0.0.2:8080"
154    weight: 3
155"#;
156        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
157        assert_eq!(cluster.endpoints.len(), 2, "should parse two endpoints");
158        assert_eq!(cluster.endpoints[0].weight(), 1, "simple endpoint weight should be 1");
159        assert_eq!(cluster.endpoints[1].weight(), 3, "weighted endpoint weight should be 3");
160    }
161
162    #[test]
163    fn parse_cluster_with_timeouts() {
164        let yaml = r#"
165name: "backend"
166endpoints: ["10.0.0.1:8080"]
167connection_timeout_ms: 5000
168idle_timeout_ms: 30000
169read_timeout_ms: 10000
170write_timeout_ms: 10000
171"#;
172        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
173        assert_eq!(
174            cluster.connection_timeout_ms,
175            Some(5000),
176            "connection_timeout_ms mismatch"
177        );
178        assert_eq!(cluster.idle_timeout_ms, Some(30000), "idle_timeout_ms mismatch");
179        assert_eq!(cluster.read_timeout_ms, Some(10000), "read_timeout_ms mismatch");
180        assert_eq!(cluster.write_timeout_ms, Some(10000), "write_timeout_ms mismatch");
181    }
182
183    #[test]
184    fn cluster_roundtrips_via_serde() {
185        let cluster = Cluster {
186            connection_timeout_ms: Some(1000),
187            ..Cluster::with_defaults("web", vec!["10.0.0.1:80".into()])
188        };
189        let value = serde_yaml::to_value(&cluster).unwrap();
190        let back: Cluster = serde_yaml::from_value(value).unwrap();
191        assert_eq!(back.name, cluster.name, "name should roundtrip");
192        assert_eq!(back.endpoints, cluster.endpoints, "endpoints should roundtrip");
193        assert_eq!(
194            back.connection_timeout_ms, cluster.connection_timeout_ms,
195            "timeout should roundtrip"
196        );
197    }
198
199    #[test]
200    fn tls_and_sni_parse_correctly() {
201        let yaml = r#"
202name: "backend"
203endpoints: ["10.0.0.1:443"]
204tls:
205  sni: "api.example.com"
206"#;
207        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
208        assert!(cluster.tls.is_some(), "tls should be present");
209        assert_eq!(
210            cluster.tls.as_ref().unwrap().sni.as_deref(),
211            Some("api.example.com"),
212            "sni mismatch"
213        );
214    }
215
216    #[test]
217    fn tls_verify_defaults_to_true() {
218        let yaml = r#"
219name: "backend"
220endpoints: ["10.0.0.1:443"]
221tls: {}
222"#;
223        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
224        assert!(cluster.tls.as_ref().unwrap().verify, "verify should default to true");
225    }
226
227    #[test]
228    fn tls_verify_can_be_disabled() {
229        let yaml = r#"
230name: "backend"
231endpoints: ["10.0.0.1:443"]
232tls:
233  verify: false
234"#;
235        let cluster: Cluster = serde_yaml::from_str(yaml).unwrap();
236        assert!(
237            !cluster.tls.as_ref().unwrap().verify,
238            "verify should be false when explicitly set"
239        );
240    }
241
242    #[test]
243    fn no_tls_by_default() {
244        let cluster = Cluster::with_defaults("web", vec!["10.0.0.1:80".into()]);
245        assert!(cluster.tls.is_none(), "tls should be None by default");
246    }
247}