Skip to main content

praxis_core/config/cluster/
endpoint.rs

1// SPDX-License-Identifier: LGPL-3.0-only
2// Copyright (c) 2024 Shane Utt
3
4//! Upstream endpoint definition with optional weighting.
5
6use serde::{Deserialize, Serialize};
7
8// -----------------------------------------------------------------------------
9// Endpoint
10// -----------------------------------------------------------------------------
11
12/// A single upstream endpoint, with an optional forwarding weight.
13///
14/// Accepts either a plain `"host:port"` string (weight defaults to 1) or an
15/// object with an explicit `weight` field:
16///
17/// ```yaml
18/// endpoints:
19///   - "10.0.0.1:8080"
20///   - address: "10.0.0.2:8080"
21///     weight: 3
22/// ```
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(untagged)]
25pub enum Endpoint {
26    /// Plain `host:port` string; weight is implicitly 1.
27    Simple(String),
28
29    /// Endpoint with an explicit address and forwarding weight.
30    Weighted {
31        /// Socket address as `host:port`.
32        address: String,
33
34        /// Relative forwarding weight. Higher values receive proportionally more
35        /// traffic. Defaults to 1.
36        #[serde(default = "default_weight")]
37        weight: u32,
38    },
39}
40
41/// Serde default for [`Endpoint::Weighted::weight`].
42fn default_weight() -> u32 {
43    1
44}
45
46impl Endpoint {
47    /// Returns the `host:port` address string.
48    ///
49    /// ```
50    /// use praxis_core::config::Endpoint;
51    ///
52    /// let simple: Endpoint = "10.0.0.1:8080".into();
53    /// assert_eq!(simple.address(), "10.0.0.1:8080");
54    /// ```
55    pub fn address(&self) -> &str {
56        match self {
57            Self::Simple(s) => s,
58            Self::Weighted { address, .. } => address,
59        }
60    }
61
62    /// Returns the forwarding weight (1 for `Simple` endpoints).
63    ///
64    /// ```
65    /// use praxis_core::config::Endpoint;
66    ///
67    /// let simple: Endpoint = "10.0.0.1:8080".into();
68    /// assert_eq!(simple.weight(), 1);
69    /// ```
70    pub fn weight(&self) -> u32 {
71        match self {
72            Self::Simple(_) => 1,
73            Self::Weighted { weight, .. } => *weight,
74        }
75    }
76}
77
78impl From<String> for Endpoint {
79    fn from(s: String) -> Self {
80        Self::Simple(s)
81    }
82}
83
84impl From<&str> for Endpoint {
85    fn from(s: &str) -> Self {
86        Self::Simple(s.to_owned())
87    }
88}
89
90// -----------------------------------------------------------------------------
91// Tests
92// -----------------------------------------------------------------------------
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn simple_endpoint_has_weight_one() {
100        let ep: Endpoint = "10.0.0.1:8080".into();
101        assert_eq!(ep.address(), "10.0.0.1:8080", "simple endpoint address mismatch");
102        assert_eq!(ep.weight(), 1, "simple endpoint should default to weight 1");
103    }
104
105    #[test]
106    fn weighted_endpoint_preserves_weight() {
107        let yaml = r#"
108address: "10.0.0.2:8080"
109weight: 3
110"#;
111        let ep: Endpoint = serde_yaml::from_str(yaml).unwrap();
112        assert_eq!(ep.address(), "10.0.0.2:8080", "weighted endpoint address mismatch");
113        assert_eq!(ep.weight(), 3, "weighted endpoint should preserve configured weight");
114    }
115
116    #[test]
117    fn weighted_endpoint_defaults_weight_to_one() {
118        let yaml = "address: \"10.0.0.1:80\"";
119        let ep: Endpoint = serde_yaml::from_str(yaml).unwrap();
120        assert_eq!(ep.weight(), 1, "omitted weight should default to 1");
121    }
122
123    #[test]
124    fn from_string() {
125        let ep = Endpoint::from("10.0.0.1:80".to_owned());
126        assert_eq!(ep.address(), "10.0.0.1:80", "From<String> should preserve address");
127    }
128
129    #[test]
130    fn parse_mixed_list() {
131        let yaml = r#"
132- "10.0.0.1:8080"
133- address: "10.0.0.2:8080"
134  weight: 3
135"#;
136        let eps: Vec<Endpoint> = serde_yaml::from_str(yaml).unwrap();
137        assert_eq!(eps.len(), 2, "mixed list should parse two endpoints");
138        assert_eq!(eps[0].weight(), 1, "simple entry should have weight 1");
139        assert_eq!(eps[1].weight(), 3, "weighted entry should have weight 3");
140    }
141}