Skip to main content

bairelay_wake_server/
config.rs

1//! Operator-facing configuration for the wake server.
2
3use serde::{Deserialize, Serialize};
4use std::net::IpAddr;
5
6/// Operator-facing config block. Bind IP is inherited from the top-level
7/// `bind_addr`, not duplicated here.
8#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
9#[serde(deny_unknown_fields)]
10pub struct WakeServerConfig {
11	#[serde(default)]
12	pub enable: bool,
13
14	#[serde(default = "default_middleman_port")]
15	pub middleman_port: u16,
16
17	#[serde(default = "default_register_port")]
18	pub register_port: u16,
19
20	#[serde(default = "default_heartbeat_ms")]
21	pub heartbeat_ms: u32,
22
23	#[serde(default = "default_stale_after_ms")]
24	pub stale_after_ms: u64,
25
26	/// Bind IP, populated from the top-level `bind_addr` when the binary
27	/// constructs a `RuntimeConfig` (see [`RuntimeConfig::from_block`]).
28	/// Skipped at TOML parse time so operators never set it directly.
29	#[serde(skip)]
30	pub bind: Option<IpAddr>,
31}
32
33impl Default for WakeServerConfig {
34	fn default() -> Self {
35		Self {
36			enable: false,
37			middleman_port: default_middleman_port(),
38			register_port: default_register_port(),
39			heartbeat_ms: default_heartbeat_ms(),
40			stale_after_ms: default_stale_after_ms(),
41			bind: None,
42		}
43	}
44}
45
46fn default_middleman_port() -> u16 {
47	9999
48}
49fn default_register_port() -> u16 {
50	58200
51}
52fn default_heartbeat_ms() -> u32 {
53	20000
54}
55fn default_stale_after_ms() -> u64 {
56	80000
57}
58
59/// Validated runtime view consumed by `run()`. Constructed by the binary
60/// from a `WakeServerConfig` plus the top-level `bind_addr`.
61#[derive(Debug, Clone)]
62pub struct RuntimeConfig {
63	pub bind: IpAddr,
64	pub middleman_port: u16,
65	pub register_port: u16,
66	pub heartbeat_ms: u32,
67	pub stale_after_ms: u64,
68}
69
70impl RuntimeConfig {
71	/// Build a runtime config from a parsed `[wake_server]` block plus a
72	/// resolved bind IP. Returns `Err(_)` with a human-readable message on
73	/// validation failure.
74	pub fn from_block(block: &WakeServerConfig, bind: IpAddr) -> Result<Self, String> {
75		if block.middleman_port == 0 {
76			return Err("middleman_port must be > 0".into());
77		}
78		if block.register_port == 0 {
79			return Err("register_port must be > 0".into());
80		}
81		if block.middleman_port == block.register_port {
82			return Err(format!(
83				"middleman_port and register_port must differ (both {})",
84				block.middleman_port
85			));
86		}
87		if block.heartbeat_ms < 1000 {
88			return Err(format!(
89				"heartbeat_ms must be >= 1000 (got {})",
90				block.heartbeat_ms
91			));
92		}
93		if block.stale_after_ms < block.heartbeat_ms as u64 {
94			return Err(format!(
95				"stale_after_ms ({}) must be >= heartbeat_ms ({})",
96				block.stale_after_ms, block.heartbeat_ms
97			));
98		}
99		Ok(Self {
100			bind,
101			middleman_port: block.middleman_port,
102			register_port: block.register_port,
103			heartbeat_ms: block.heartbeat_ms,
104			stale_after_ms: block.stale_after_ms,
105		})
106	}
107}
108
109#[cfg(test)]
110mod tests {
111	use super::*;
112	use std::net::Ipv4Addr;
113
114	fn loopback() -> IpAddr {
115		Ipv4Addr::LOCALHOST.into()
116	}
117
118	#[test]
119	fn defaults_match_spec() {
120		let cfg: WakeServerConfig = toml::from_str("").unwrap();
121		assert!(!cfg.enable);
122		assert_eq!(cfg.middleman_port, 9999);
123		assert_eq!(cfg.register_port, 58200);
124		assert_eq!(cfg.heartbeat_ms, 20000);
125		assert_eq!(cfg.stale_after_ms, 80000);
126	}
127
128	#[test]
129	fn deny_unknown_fields() {
130		let result: Result<WakeServerConfig, _> = toml::from_str("totally_made_up_field = 1");
131		assert!(result.is_err());
132	}
133
134	#[test]
135	fn runtime_rejects_zero_ports() {
136		let block = WakeServerConfig {
137			middleman_port: 0,
138			..WakeServerConfig::default()
139		};
140		assert!(RuntimeConfig::from_block(&block, loopback()).is_err());
141	}
142
143	#[test]
144	fn from_block_zero_register_port_message_mentions_register_port() {
145		let block = WakeServerConfig {
146			register_port: 0,
147			..WakeServerConfig::default()
148		};
149		let err = RuntimeConfig::from_block(&block, loopback())
150			.expect_err("zero register_port must error");
151		assert!(
152			err.contains("register_port"),
153			"expected error to mention register_port, got: {err}"
154		);
155	}
156
157	#[test]
158	fn runtime_rejects_equal_ports() {
159		let block = WakeServerConfig {
160			middleman_port: 5000,
161			register_port: 5000,
162			..WakeServerConfig::default()
163		};
164		assert!(RuntimeConfig::from_block(&block, loopback()).is_err());
165	}
166
167	#[test]
168	fn runtime_rejects_low_heartbeat() {
169		let block = WakeServerConfig {
170			heartbeat_ms: 500,
171			..WakeServerConfig::default()
172		};
173		assert!(RuntimeConfig::from_block(&block, loopback()).is_err());
174	}
175
176	#[test]
177	fn runtime_rejects_stale_below_heartbeat() {
178		let block = WakeServerConfig {
179			heartbeat_ms: 5000,
180			stale_after_ms: 1000,
181			..WakeServerConfig::default()
182		};
183		assert!(RuntimeConfig::from_block(&block, loopback()).is_err());
184	}
185
186	#[test]
187	fn runtime_accepts_defaults_with_loopback_bind() {
188		let block = WakeServerConfig::default();
189		let rt = RuntimeConfig::from_block(&block, loopback()).unwrap();
190		assert_eq!(rt.middleman_port, 9999);
191		assert_eq!(rt.register_port, 58200);
192	}
193}