1use crate::control::protocol::{Request, Response};
9use crate::gateway::pool::{MappingInfo, MappingState, PoolStatus};
10use std::path::{Path, PathBuf};
11use std::time::Instant;
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
13use tokio::net::UnixListener;
14use tokio::sync::watch;
15use tracing::{debug, info, warn};
16
17pub const GATEWAY_SOCKET_PATH: &str = "/run/fips/gateway.sock";
19
20const MAX_REQUEST_SIZE: usize = 4096;
22
23const IO_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
25
26#[derive(Clone)]
28pub struct GatewaySnapshot {
29 pub pool: PoolStatus,
30 pub mappings: Vec<MappingInfo>,
31 pub nat_mappings: usize,
32 pub dns_listen: String,
33 pub uptime_secs: u64,
34 pub pool_cidr: String,
36 pub lan_interface: String,
37 pub dns_upstream: String,
38 pub dns_ttl: u32,
39 pub pool_grace_period: u64,
40}
41
42pub struct GatewayControlSocket {
44 listener: UnixListener,
45 socket_path: PathBuf,
46}
47
48impl GatewayControlSocket {
49 pub fn bind() -> Result<Self, std::io::Error> {
54 let socket_path = PathBuf::from(GATEWAY_SOCKET_PATH);
55
56 if let Some(parent) = socket_path.parent()
58 && !parent.exists()
59 {
60 std::fs::create_dir_all(parent)?;
61 debug!(path = %parent.display(), "Created gateway control socket directory");
62 }
63
64 if socket_path.exists() {
66 Self::remove_stale_socket(&socket_path)?;
67 }
68
69 let listener = UnixListener::bind(&socket_path)?;
70
71 use std::os::unix::fs::PermissionsExt;
73 std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o770))?;
74 Self::chown_to_fips_group(&socket_path);
75 if let Some(parent) = socket_path.parent() {
76 Self::chown_to_fips_group(parent);
77 }
78
79 info!(path = %socket_path.display(), "Gateway control socket listening");
80
81 Ok(Self {
82 listener,
83 socket_path,
84 })
85 }
86
87 fn remove_stale_socket(path: &Path) -> Result<(), std::io::Error> {
89 match std::os::unix::net::UnixStream::connect(path) {
90 Ok(_) => Err(std::io::Error::new(
91 std::io::ErrorKind::AddrInUse,
92 format!("gateway control socket already in use: {}", path.display()),
93 )),
94 Err(_) => {
95 debug!(path = %path.display(), "Removing stale gateway control socket");
96 std::fs::remove_file(path)?;
97 Ok(())
98 }
99 }
100 }
101
102 fn chown_to_fips_group(path: &Path) {
104 use std::ffi::CString;
105 use std::os::unix::ffi::OsStrExt;
106
107 let group_name = CString::new("fips").unwrap();
108 let grp = unsafe { libc::getgrnam(group_name.as_ptr()) };
109 if grp.is_null() {
110 debug!(
111 "'fips' group not found, skipping chown for {}",
112 path.display()
113 );
114 return;
115 }
116 let gid = unsafe { (*grp).gr_gid };
117
118 let c_path = match CString::new(path.as_os_str().as_bytes()) {
119 Ok(p) => p,
120 Err(_) => return,
121 };
122 let ret = unsafe { libc::chown(c_path.as_ptr(), u32::MAX, gid) };
123 if ret != 0 {
124 warn!(
125 path = %path.display(),
126 error = %std::io::Error::last_os_error(),
127 "Failed to chown gateway control socket to 'fips' group"
128 );
129 }
130 }
131
132 pub async fn accept_loop(self, snapshot_rx: watch::Receiver<Option<GatewaySnapshot>>) {
134 loop {
135 let (stream, _addr) = match self.listener.accept().await {
136 Ok(conn) => conn,
137 Err(e) => {
138 warn!(error = %e, "Gateway control socket accept failed");
139 continue;
140 }
141 };
142
143 let rx = snapshot_rx.clone();
144 tokio::spawn(async move {
145 if let Err(e) = Self::handle_connection(stream, rx).await {
146 debug!(error = %e, "Gateway control connection error");
147 }
148 });
149 }
150 }
151
152 async fn handle_connection(
154 stream: tokio::net::UnixStream,
155 snapshot_rx: watch::Receiver<Option<GatewaySnapshot>>,
156 ) -> Result<(), Box<dyn std::error::Error>> {
157 let (reader, mut writer) = stream.into_split();
158 let mut buf_reader = BufReader::new(reader);
159 let mut line = String::new();
160
161 let read_result = tokio::time::timeout(IO_TIMEOUT, async {
163 let mut total = 0usize;
164 loop {
165 let n = buf_reader.read_line(&mut line).await?;
166 if n == 0 {
167 break;
168 }
169 total += n;
170 if total > MAX_REQUEST_SIZE {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::InvalidData,
173 "request too large",
174 ));
175 }
176 if line.ends_with('\n') {
177 break;
178 }
179 }
180 Ok(())
181 })
182 .await;
183
184 let response = match read_result {
185 Ok(Ok(())) if line.is_empty() => Response::error("empty request"),
186 Ok(Ok(())) => match serde_json::from_str::<Request>(line.trim()) {
187 Ok(request) => dispatch_command(&request.command, &snapshot_rx),
188 Err(e) => Response::error(format!("invalid request: {e}")),
189 },
190 Ok(Err(e)) => Response::error(format!("read error: {e}")),
191 Err(_) => Response::error("read timeout"),
192 };
193
194 let json = serde_json::to_string(&response)?;
196 let write_result = tokio::time::timeout(IO_TIMEOUT, async {
197 writer.write_all(json.as_bytes()).await?;
198 writer.write_all(b"\n").await?;
199 writer.shutdown().await?;
200 Ok::<_, std::io::Error>(())
201 })
202 .await;
203
204 if let Err(_) | Ok(Err(_)) = write_result {
205 debug!("Gateway control socket write failed or timed out");
206 }
207
208 Ok(())
209 }
210
211 fn cleanup(&self) {
213 if self.socket_path.exists() {
214 if let Err(e) = std::fs::remove_file(&self.socket_path) {
215 warn!(
216 path = %self.socket_path.display(),
217 error = %e,
218 "Failed to remove gateway control socket"
219 );
220 } else {
221 debug!(path = %self.socket_path.display(), "Gateway control socket removed");
222 }
223 }
224 }
225}
226
227impl Drop for GatewayControlSocket {
228 fn drop(&mut self) {
229 self.cleanup();
230 }
231}
232
233fn dispatch_command(
235 command: &str,
236 snapshot_rx: &watch::Receiver<Option<GatewaySnapshot>>,
237) -> Response {
238 let snapshot = match snapshot_rx.borrow().clone() {
239 Some(s) => s,
240 None => return Response::error("gateway not yet initialized"),
241 };
242
243 match command {
244 "show_gateway" => build_show_gateway(&snapshot),
245 "show_mappings" => build_show_mappings(&snapshot),
246 _ => Response::error(format!("unknown command: {command}")),
247 }
248}
249
250fn build_show_gateway(snapshot: &GatewaySnapshot) -> Response {
252 Response::ok(serde_json::json!({
253 "pool_total": snapshot.pool.total,
254 "pool_allocated": snapshot.pool.allocated,
255 "pool_active": snapshot.pool.active,
256 "pool_draining": snapshot.pool.draining,
257 "pool_free": snapshot.pool.free,
258 "nat_mappings": snapshot.nat_mappings,
259 "dns_listen": snapshot.dns_listen,
260 "uptime_secs": snapshot.uptime_secs,
261 "pool_cidr": snapshot.pool_cidr,
262 "lan_interface": snapshot.lan_interface,
263 "dns_upstream": snapshot.dns_upstream,
264 "dns_ttl": snapshot.dns_ttl,
265 "pool_grace_period": snapshot.pool_grace_period,
266 }))
267}
268
269fn build_show_mappings(snapshot: &GatewaySnapshot) -> Response {
271 let mappings: Vec<serde_json::Value> = snapshot
272 .mappings
273 .iter()
274 .map(|m| {
275 serde_json::json!({
276 "virtual_ip": m.virtual_ip.to_string(),
277 "mesh_addr": m.mesh_addr.to_string(),
278 "node_addr": m.node_addr.to_string(),
279 "dns_name": m.dns_name,
280 "state": mapping_state_str(m.state),
281 "sessions": m.session_count,
282 "age_secs": m.age_secs,
283 "last_ref_secs": m.last_ref_secs,
284 })
285 })
286 .collect();
287
288 Response::ok(serde_json::json!({ "mappings": mappings }))
289}
290
291fn mapping_state_str(state: MappingState) -> &'static str {
293 match state {
294 MappingState::Allocated => "Allocated",
295 MappingState::Active => "Active",
296 MappingState::Draining => "Draining",
297 }
298}
299
300pub struct SnapshotConfig {
302 pub pool_cidr: String,
303 pub lan_interface: String,
304 pub dns_upstream: String,
305 pub dns_listen: String,
306 pub dns_ttl: u32,
307 pub pool_grace_period: u64,
308}
309
310pub fn build_snapshot(
314 pool_status: PoolStatus,
315 mappings: Vec<MappingInfo>,
316 nat_mappings: usize,
317 start_time: Instant,
318 config: &SnapshotConfig,
319) -> GatewaySnapshot {
320 GatewaySnapshot {
321 pool: pool_status,
322 mappings,
323 nat_mappings,
324 dns_listen: config.dns_listen.clone(),
325 uptime_secs: start_time.elapsed().as_secs(),
326 pool_cidr: config.pool_cidr.clone(),
327 lan_interface: config.lan_interface.clone(),
328 dns_upstream: config.dns_upstream.clone(),
329 dns_ttl: config.dns_ttl,
330 pool_grace_period: config.pool_grace_period,
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::NodeAddr;
338
339 fn make_snapshot() -> GatewaySnapshot {
340 GatewaySnapshot {
341 pool: PoolStatus {
342 total: 65535,
343 allocated: 1,
344 active: 0,
345 draining: 0,
346 free: 65534,
347 },
348 mappings: vec![MappingInfo {
349 virtual_ip: "fd01::1".parse().unwrap(),
350 mesh_addr: "fd97:467a::1".parse().unwrap(),
351 node_addr: NodeAddr::from_bytes([0; 16]),
352 dns_name: "npub1test.fips".to_string(),
353 state: MappingState::Active,
354 session_count: 3,
355 age_secs: 120,
356 last_ref_secs: 5,
357 }],
358 nat_mappings: 1,
359 dns_listen: "[fd02::10]:53".to_string(),
360 uptime_secs: 3600,
361 pool_cidr: "fd01::/112".to_string(),
362 lan_interface: "br-lan".to_string(),
363 dns_upstream: "127.0.0.1:5354".to_string(),
364 dns_ttl: 60,
365 pool_grace_period: 60,
366 }
367 }
368
369 #[test]
370 fn test_show_gateway_response() {
371 let snapshot = make_snapshot();
372 let resp = build_show_gateway(&snapshot);
373 assert_eq!(resp.status, "ok");
374 let data = resp.data.unwrap();
375 assert_eq!(data["pool_total"], 65535);
376 assert_eq!(data["pool_free"], 65534);
377 assert_eq!(data["nat_mappings"], 1);
378 assert_eq!(data["dns_listen"], "[fd02::10]:53");
379 assert_eq!(data["uptime_secs"], 3600);
380 }
381
382 #[test]
383 fn test_show_mappings_response() {
384 let snapshot = make_snapshot();
385 let resp = build_show_mappings(&snapshot);
386 assert_eq!(resp.status, "ok");
387 let data = resp.data.unwrap();
388 let mappings = data["mappings"].as_array().unwrap();
389 assert_eq!(mappings.len(), 1);
390 assert_eq!(mappings[0]["state"], "Active");
391 assert_eq!(mappings[0]["sessions"], 3);
392 assert_eq!(mappings[0]["virtual_ip"], "fd01::1");
393 }
394
395 #[test]
396 fn test_unknown_command() {
397 let (tx, rx) = watch::channel(Some(make_snapshot()));
398 let resp = dispatch_command("bogus", &rx);
399 assert_eq!(resp.status, "error");
400 assert!(resp.message.unwrap().contains("unknown command: bogus"));
401 drop(tx);
402 }
403
404 #[test]
405 fn test_not_initialized() {
406 let (tx, rx) = watch::channel::<Option<GatewaySnapshot>>(None);
407 let resp = dispatch_command("show_gateway", &rx);
408 assert_eq!(resp.status, "error");
409 assert!(resp.message.unwrap().contains("not yet initialized"));
410 drop(tx);
411 }
412
413 #[test]
414 fn test_mapping_state_str() {
415 assert_eq!(mapping_state_str(MappingState::Allocated), "Allocated");
416 assert_eq!(mapping_state_str(MappingState::Active), "Active");
417 assert_eq!(mapping_state_str(MappingState::Draining), "Draining");
418 }
419
420 #[test]
421 fn test_empty_mappings() {
422 let snapshot = GatewaySnapshot {
423 pool: PoolStatus {
424 total: 255,
425 allocated: 0,
426 active: 0,
427 draining: 0,
428 free: 255,
429 },
430 mappings: vec![],
431 nat_mappings: 0,
432 dns_listen: "[::1]:53".to_string(),
433 uptime_secs: 0,
434 pool_cidr: "fd01::/112".to_string(),
435 lan_interface: "br-lan".to_string(),
436 dns_upstream: "127.0.0.1:5354".to_string(),
437 dns_ttl: 60,
438 pool_grace_period: 60,
439 };
440 let resp = build_show_mappings(&snapshot);
441 assert_eq!(resp.status, "ok");
442 let data = resp.data.unwrap();
443 let mappings = data["mappings"].as_array().unwrap();
444 assert!(mappings.is_empty());
445 }
446}