Skip to main content

mvm_core/
routing.rs

1use std::collections::HashSet;
2
3use anyhow::{Result, bail};
4use serde::{Deserialize, Serialize};
5
6/// A routing table for a gateway pool's instances.
7/// Maps inbound traffic (by port, path prefix, or source) to target worker instances.
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct RoutingTable {
10    pub routes: Vec<Route>,
11}
12
13/// A single routing rule: match inbound traffic and forward to a target.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Route {
16    #[serde(default)]
17    pub name: String,
18    pub match_rule: MatchRule,
19    pub target: RouteTarget,
20}
21
22/// Criteria for matching inbound traffic.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct MatchRule {
25    #[serde(default)]
26    pub port: Option<u16>,
27    #[serde(default)]
28    pub path_prefix: Option<String>,
29    #[serde(default)]
30    pub source_cidr: Option<String>,
31}
32
33/// Target for matched traffic.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct RouteTarget {
36    pub pool_id: String,
37    #[serde(default)]
38    pub instance_selector: InstanceSelector,
39    #[serde(default)]
40    pub target_port: Option<u16>,
41}
42
43/// Strategy for selecting an instance within the target pool.
44#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum InstanceSelector {
47    #[default]
48    Any,
49    ByIp(String),
50    LeastConnections,
51}
52
53impl RoutingTable {
54    pub fn from_json(json: &str) -> Result<Self> {
55        let table: Self = serde_json::from_str(json)?;
56        table.validate()?;
57        Ok(table)
58    }
59
60    pub fn to_json(&self) -> Result<String> {
61        Ok(serde_json::to_string_pretty(self)?)
62    }
63
64    pub fn validate(&self) -> Result<()> {
65        let mut seen_ports: HashSet<u16> = HashSet::new();
66
67        for (i, route) in self.routes.iter().enumerate() {
68            if route.match_rule.port.is_none()
69                && route.match_rule.path_prefix.is_none()
70                && route.match_rule.source_cidr.is_none()
71            {
72                bail!(
73                    "Route {} ({}) has no match criteria — at least one of port, path_prefix, or source_cidr must be set",
74                    i,
75                    route.name,
76                );
77            }
78
79            if let Some(port) = route.match_rule.port
80                && !seen_ports.insert(port)
81            {
82                bail!(
83                    "Route {} ({}) has duplicate port {}: another route already matches this port",
84                    i,
85                    route.name,
86                    port,
87                );
88            }
89        }
90
91        Ok(())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_routing_table_serde_roundtrip() {
101        let table = RoutingTable {
102            routes: vec![
103                Route {
104                    name: "slack-webhook".to_string(),
105                    match_rule: MatchRule {
106                        port: Some(8080),
107                        path_prefix: Some("/webhook/slack".to_string()),
108                        source_cidr: None,
109                    },
110                    target: RouteTarget {
111                        pool_id: "workers".to_string(),
112                        instance_selector: InstanceSelector::Any,
113                        target_port: Some(8080),
114                    },
115                },
116                Route {
117                    name: "telegram-bot".to_string(),
118                    match_rule: MatchRule {
119                        port: Some(8443),
120                        path_prefix: None,
121                        source_cidr: Some("149.154.160.0/20".to_string()),
122                    },
123                    target: RouteTarget {
124                        pool_id: "workers".to_string(),
125                        instance_selector: InstanceSelector::ByIp("10.240.3.5".to_string()),
126                        target_port: None,
127                    },
128                },
129            ],
130        };
131
132        let json = table.to_json().unwrap();
133        let parsed = RoutingTable::from_json(&json).unwrap();
134        assert_eq!(parsed.routes.len(), 2);
135        assert_eq!(parsed.routes[0].name, "slack-webhook");
136        assert_eq!(parsed.routes[0].match_rule.port, Some(8080));
137        assert_eq!(parsed.routes[1].target.pool_id, "workers");
138    }
139
140    #[test]
141    fn test_empty_routing_table() {
142        let table = RoutingTable::default();
143        assert!(table.validate().is_ok());
144        let json = table.to_json().unwrap();
145        let parsed = RoutingTable::from_json(&json).unwrap();
146        assert!(parsed.routes.is_empty());
147    }
148
149    #[test]
150    fn test_validation_rejects_empty_match() {
151        let table = RoutingTable {
152            routes: vec![Route {
153                name: "bad-route".to_string(),
154                match_rule: MatchRule {
155                    port: None,
156                    path_prefix: None,
157                    source_cidr: None,
158                },
159                target: RouteTarget {
160                    pool_id: "workers".to_string(),
161                    instance_selector: InstanceSelector::Any,
162                    target_port: None,
163                },
164            }],
165        };
166        let err = table.validate().unwrap_err();
167        assert!(err.to_string().contains("no match criteria"));
168    }
169
170    #[test]
171    fn test_validation_rejects_duplicate_port() {
172        let table = RoutingTable {
173            routes: vec![
174                Route {
175                    name: "first".to_string(),
176                    match_rule: MatchRule {
177                        port: Some(8080),
178                        path_prefix: None,
179                        source_cidr: None,
180                    },
181                    target: RouteTarget {
182                        pool_id: "workers".to_string(),
183                        instance_selector: InstanceSelector::Any,
184                        target_port: None,
185                    },
186                },
187                Route {
188                    name: "second".to_string(),
189                    match_rule: MatchRule {
190                        port: Some(8080),
191                        path_prefix: None,
192                        source_cidr: None,
193                    },
194                    target: RouteTarget {
195                        pool_id: "other".to_string(),
196                        instance_selector: InstanceSelector::Any,
197                        target_port: None,
198                    },
199                },
200            ],
201        };
202        let err = table.validate().unwrap_err();
203        assert!(err.to_string().contains("duplicate port 8080"));
204    }
205
206    #[test]
207    fn test_instance_selector_serde() {
208        let variants = vec![
209            (InstanceSelector::Any, "\"any\""),
210            (
211                InstanceSelector::ByIp("10.0.0.1".to_string()),
212                "{\"by_ip\":\"10.0.0.1\"}",
213            ),
214            (InstanceSelector::LeastConnections, "\"least_connections\""),
215        ];
216
217        for (selector, expected) in &variants {
218            let json = serde_json::to_string(selector).unwrap();
219            assert_eq!(&json, expected);
220            let parsed: InstanceSelector = serde_json::from_str(&json).unwrap();
221            assert_eq!(&parsed, selector);
222        }
223    }
224
225    #[test]
226    fn test_instance_selector_default_is_any() {
227        assert_eq!(InstanceSelector::default(), InstanceSelector::Any);
228    }
229
230    #[test]
231    fn test_route_with_path_prefix_only() {
232        let table = RoutingTable {
233            routes: vec![Route {
234                name: "api".to_string(),
235                match_rule: MatchRule {
236                    port: None,
237                    path_prefix: Some("/api/v1".to_string()),
238                    source_cidr: None,
239                },
240                target: RouteTarget {
241                    pool_id: "workers".to_string(),
242                    instance_selector: InstanceSelector::LeastConnections,
243                    target_port: Some(3000),
244                },
245            }],
246        };
247        assert!(table.validate().is_ok());
248    }
249
250    #[test]
251    fn test_route_with_source_cidr_only() {
252        let table = RoutingTable {
253            routes: vec![Route {
254                name: "trusted".to_string(),
255                match_rule: MatchRule {
256                    port: None,
257                    path_prefix: None,
258                    source_cidr: Some("10.0.0.0/8".to_string()),
259                },
260                target: RouteTarget {
261                    pool_id: "internal".to_string(),
262                    instance_selector: InstanceSelector::Any,
263                    target_port: None,
264                },
265            }],
266        };
267        assert!(table.validate().is_ok());
268    }
269
270    #[test]
271    fn test_backward_compat_no_routing_table() {
272        let json = r#"{"routes": []}"#;
273        let parsed: RoutingTable = serde_json::from_str(json).unwrap();
274        assert!(parsed.routes.is_empty());
275    }
276}