1pub mod commands;
12pub mod firewall_state;
13pub mod listening;
14pub mod protocol;
15pub mod queries;
16
17use crate::config::ControlConfig;
18use protocol::{Request, Response};
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::sync::{mpsc, oneshot};
21use tracing::{debug, info, warn};
22
23const MAX_REQUEST_SIZE: usize = 4096;
25
26const IO_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
28
29pub type ControlMessage = (Request, oneshot::Sender<Response>);
31
32async fn handle_connection_generic<S>(
37 stream: S,
38 control_tx: mpsc::Sender<ControlMessage>,
39) -> Result<(), Box<dyn std::error::Error>>
40where
41 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
42{
43 let (reader, mut writer) = tokio::io::split(stream);
44 let mut buf_reader = BufReader::new(reader);
45 let mut line = String::new();
46
47 let read_result = tokio::time::timeout(IO_TIMEOUT, async {
49 let mut total = 0usize;
50 loop {
51 let n = buf_reader.read_line(&mut line).await?;
52 if n == 0 {
53 break; }
55 total += n;
56 if total > MAX_REQUEST_SIZE {
57 return Err(std::io::Error::new(
58 std::io::ErrorKind::InvalidData,
59 "request too large",
60 ));
61 }
62 if line.ends_with('\n') {
63 break;
64 }
65 }
66 Ok(())
67 })
68 .await;
69
70 let response = match read_result {
71 Ok(Ok(())) if line.is_empty() => Response::error("empty request"),
72 Ok(Ok(())) => {
73 match serde_json::from_str::<Request>(line.trim()) {
75 Ok(request) => {
76 let (resp_tx, resp_rx) = oneshot::channel();
78 if control_tx.send((request, resp_tx)).await.is_err() {
79 Response::error("node shutting down")
80 } else {
81 match tokio::time::timeout(IO_TIMEOUT, resp_rx).await {
82 Ok(Ok(resp)) => resp,
83 Ok(Err(_)) => Response::error("response channel closed"),
84 Err(_) => Response::error("query timeout"),
85 }
86 }
87 }
88 Err(e) => Response::error(format!("invalid request: {}", e)),
89 }
90 }
91 Ok(Err(e)) => Response::error(format!("read error: {}", e)),
92 Err(_) => Response::error("read timeout"),
93 };
94
95 let json = serde_json::to_string(&response)?;
97 let write_result = tokio::time::timeout(IO_TIMEOUT, async {
98 writer.write_all(json.as_bytes()).await?;
99 writer.write_all(b"\n").await?;
100 writer.shutdown().await?;
101 Ok::<_, std::io::Error>(())
102 })
103 .await;
104
105 if let Err(_) | Ok(Err(_)) = write_result {
106 debug!("Control socket write failed or timed out");
107 }
108
109 Ok(())
110}
111
112#[cfg(unix)]
117mod unix_impl {
118 use super::*;
119 use std::path::{Path, PathBuf};
120 use tokio::net::UnixListener;
121
122 pub struct ControlSocket {
126 listener: UnixListener,
127 socket_path: PathBuf,
128 }
129
130 impl ControlSocket {
131 pub fn bind(config: &ControlConfig) -> Result<Self, std::io::Error> {
136 let socket_path = PathBuf::from(&config.socket_path);
137
138 if let Some(parent) = socket_path.parent()
140 && !parent.exists()
141 {
142 std::fs::create_dir_all(parent)?;
143 debug!(path = %parent.display(), "Created control socket directory");
144 }
145
146 if socket_path.exists() {
148 Self::remove_stale_socket(&socket_path)?;
149 }
150
151 let listener = UnixListener::bind(&socket_path)?;
152
153 use std::os::unix::fs::PermissionsExt;
156 std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o770))?;
157 Self::chown_to_fips_group(&socket_path);
158 if let Some(parent) = socket_path.parent() {
159 Self::chown_to_fips_group(parent);
160 }
161
162 info!(path = %socket_path.display(), "Control socket listening");
163
164 Ok(Self {
165 listener,
166 socket_path,
167 })
168 }
169
170 fn remove_stale_socket(path: &Path) -> Result<(), std::io::Error> {
175 match std::os::unix::net::UnixStream::connect(path) {
177 Ok(_) => {
178 Err(std::io::Error::new(
180 std::io::ErrorKind::AddrInUse,
181 format!("control socket already in use: {}", path.display()),
182 ))
183 }
184 Err(_) => {
185 debug!(path = %path.display(), "Removing stale control socket");
187 std::fs::remove_file(path)?;
188 Ok(())
189 }
190 }
191 }
192
193 fn chown_to_fips_group(path: &Path) {
195 use std::ffi::CString;
196 use std::os::unix::ffi::OsStrExt;
197
198 let group_name = CString::new("fips").unwrap();
200 let grp = unsafe { libc::getgrnam(group_name.as_ptr()) };
201 if grp.is_null() {
202 debug!(
203 "'fips' group not found, skipping chown for {}",
204 path.display()
205 );
206 return;
207 }
208 let gid = unsafe { (*grp).gr_gid };
209
210 let c_path = match CString::new(path.as_os_str().as_bytes()) {
211 Ok(p) => p,
212 Err(_) => return,
213 };
214 let ret = unsafe { libc::chown(c_path.as_ptr(), u32::MAX, gid) };
215 if ret != 0 {
216 warn!(
217 path = %path.display(),
218 error = %std::io::Error::last_os_error(),
219 "Failed to chown control socket to 'fips' group"
220 );
221 }
222 }
223
224 pub async fn accept_loop(self, control_tx: mpsc::Sender<ControlMessage>) {
233 loop {
234 let (stream, _addr) = match self.listener.accept().await {
235 Ok(conn) => conn,
236 Err(e) => {
237 warn!(error = %e, "Control socket accept failed");
238 continue;
239 }
240 };
241
242 let tx = control_tx.clone();
243 tokio::spawn(async move {
244 if let Err(e) = handle_connection_generic(stream, tx).await {
245 debug!(error = %e, "Control connection error");
246 }
247 });
248 }
249 }
250
251 pub fn socket_path(&self) -> &Path {
253 &self.socket_path
254 }
255
256 fn cleanup(&self) {
258 if self.socket_path.exists() {
259 if let Err(e) = std::fs::remove_file(&self.socket_path) {
260 warn!(
261 path = %self.socket_path.display(),
262 error = %e,
263 "Failed to remove control socket"
264 );
265 } else {
266 debug!(path = %self.socket_path.display(), "Control socket removed");
267 }
268 }
269 }
270 }
271
272 impl Drop for ControlSocket {
273 fn drop(&mut self) {
274 self.cleanup();
275 }
276 }
277}
278
279#[cfg(windows)]
284mod windows_impl {
285 use super::*;
286 use tokio::net::TcpListener;
287
288 const DEFAULT_CONTROL_PORT: u16 = 21210;
290
291 pub struct ControlSocket {
301 listener: TcpListener,
302 port: u16,
303 }
304
305 impl ControlSocket {
306 pub fn bind(config: &ControlConfig) -> Result<Self, std::io::Error> {
312 let port: u16 = match config.socket_path.parse() {
313 Ok(p) => p,
314 Err(e) => {
315 warn!(
316 path = %config.socket_path,
317 error = %e,
318 default = DEFAULT_CONTROL_PORT,
319 "Invalid control port, using default"
320 );
321 DEFAULT_CONTROL_PORT
322 }
323 };
324
325 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
326 let std_listener = std::net::TcpListener::bind(addr)?;
327 std_listener.set_nonblocking(true)?;
328 let listener = TcpListener::from_std(std_listener)?;
329
330 info!(port = port, "Control socket listening on localhost");
331
332 Ok(Self { listener, port })
333 }
334
335 pub fn port(&self) -> u16 {
337 self.port
338 }
339
340 pub async fn accept_loop(self, control_tx: mpsc::Sender<ControlMessage>) {
345 loop {
346 let (stream, addr) = match self.listener.accept().await {
347 Ok(conn) => conn,
348 Err(e) => {
349 warn!(error = %e, "Control socket accept failed");
350 continue;
351 }
352 };
353
354 if !addr.ip().is_loopback() {
356 warn!(addr = %addr, "Rejected non-localhost control connection");
357 continue;
358 }
359
360 let tx = control_tx.clone();
361 tokio::spawn(async move {
362 if let Err(e) = handle_connection_generic(stream, tx).await {
363 debug!(error = %e, "Control connection error");
364 }
365 });
366 }
367 }
368 }
369}
370
371#[cfg(unix)]
373pub use unix_impl::ControlSocket;
374#[cfg(windows)]
375pub use windows_impl::ControlSocket;
376
377#[cfg(test)]
378mod tests {
379 #[cfg(windows)]
380 use super::*;
381
382 #[cfg(windows)]
383 #[tokio::test]
384 async fn test_tcp_control_socket_bind() {
385 let config = ControlConfig {
386 enabled: true,
387 socket_path: "0".to_string(), };
389
390 let _socket = ControlSocket::bind(&config).expect("failed to bind control socket");
392 }
393
394 #[cfg(windows)]
395 #[tokio::test]
396 async fn test_tcp_control_socket_invalid_port_uses_default() {
397 let config = ControlConfig {
398 enabled: true,
399 socket_path: "not-a-port".to_string(),
400 };
401
402 let result = ControlSocket::bind(&config);
405 if let Ok(socket) = result {
407 assert_eq!(socket.port(), 21210);
408 }
409 }
410}