1pub mod commands;
12pub mod protocol;
13pub mod queries;
14
15use crate::config::ControlConfig;
16use protocol::{Request, Response};
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::sync::{mpsc, oneshot};
19use tracing::{debug, info, warn};
20
21const MAX_REQUEST_SIZE: usize = 4096;
23
24const IO_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
26
27pub type ControlMessage = (Request, oneshot::Sender<Response>);
29
30async fn handle_connection_generic<S>(
35 stream: S,
36 control_tx: mpsc::Sender<ControlMessage>,
37) -> Result<(), Box<dyn std::error::Error>>
38where
39 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
40{
41 let (reader, mut writer) = tokio::io::split(stream);
42 let mut buf_reader = BufReader::new(reader);
43 let mut line = String::new();
44
45 let read_result = tokio::time::timeout(IO_TIMEOUT, async {
47 let mut total = 0usize;
48 loop {
49 let n = buf_reader.read_line(&mut line).await?;
50 if n == 0 {
51 break; }
53 total += n;
54 if total > MAX_REQUEST_SIZE {
55 return Err(std::io::Error::new(
56 std::io::ErrorKind::InvalidData,
57 "request too large",
58 ));
59 }
60 if line.ends_with('\n') {
61 break;
62 }
63 }
64 Ok(())
65 })
66 .await;
67
68 let response = match read_result {
69 Ok(Ok(())) if line.is_empty() => Response::error("empty request"),
70 Ok(Ok(())) => {
71 match serde_json::from_str::<Request>(line.trim()) {
73 Ok(request) => {
74 let (resp_tx, resp_rx) = oneshot::channel();
76 if control_tx.send((request, resp_tx)).await.is_err() {
77 Response::error("node shutting down")
78 } else {
79 match tokio::time::timeout(IO_TIMEOUT, resp_rx).await {
80 Ok(Ok(resp)) => resp,
81 Ok(Err(_)) => Response::error("response channel closed"),
82 Err(_) => Response::error("query timeout"),
83 }
84 }
85 }
86 Err(e) => Response::error(format!("invalid request: {}", e)),
87 }
88 }
89 Ok(Err(e)) => Response::error(format!("read error: {}", e)),
90 Err(_) => Response::error("read timeout"),
91 };
92
93 let json = serde_json::to_string(&response)?;
95 let write_result = tokio::time::timeout(IO_TIMEOUT, async {
96 writer.write_all(json.as_bytes()).await?;
97 writer.write_all(b"\n").await?;
98 writer.shutdown().await?;
99 Ok::<_, std::io::Error>(())
100 })
101 .await;
102
103 if let Err(_) | Ok(Err(_)) = write_result {
104 debug!("Control socket write failed or timed out");
105 }
106
107 Ok(())
108}
109
110#[cfg(unix)]
115mod unix_impl {
116 use super::*;
117 use std::path::{Path, PathBuf};
118 use tokio::net::UnixListener;
119
120 pub struct ControlSocket {
124 listener: UnixListener,
125 socket_path: PathBuf,
126 }
127
128 impl ControlSocket {
129 pub fn bind(config: &ControlConfig) -> Result<Self, std::io::Error> {
134 let socket_path = PathBuf::from(&config.socket_path);
135
136 if let Some(parent) = socket_path.parent()
138 && !parent.exists()
139 {
140 std::fs::create_dir_all(parent)?;
141 debug!(path = %parent.display(), "Created control socket directory");
142 }
143
144 if socket_path.exists() {
146 Self::remove_stale_socket(&socket_path)?;
147 }
148
149 let listener = UnixListener::bind(&socket_path)?;
150
151 use std::os::unix::fs::PermissionsExt;
154 std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o770))?;
155 Self::chown_to_fips_group(&socket_path);
156 if let Some(parent) = socket_path.parent() {
157 Self::chown_to_fips_group(parent);
158 }
159
160 info!(path = %socket_path.display(), "Control socket listening");
161
162 Ok(Self {
163 listener,
164 socket_path,
165 })
166 }
167
168 fn remove_stale_socket(path: &Path) -> Result<(), std::io::Error> {
173 match std::os::unix::net::UnixStream::connect(path) {
175 Ok(_) => {
176 Err(std::io::Error::new(
178 std::io::ErrorKind::AddrInUse,
179 format!("control socket already in use: {}", path.display()),
180 ))
181 }
182 Err(_) => {
183 debug!(path = %path.display(), "Removing stale control socket");
185 std::fs::remove_file(path)?;
186 Ok(())
187 }
188 }
189 }
190
191 fn chown_to_fips_group(path: &Path) {
193 use std::ffi::CString;
194 use std::os::unix::ffi::OsStrExt;
195
196 let group_name = CString::new("fips").unwrap();
198 let grp = unsafe { libc::getgrnam(group_name.as_ptr()) };
199 if grp.is_null() {
200 debug!(
201 "'fips' group not found, skipping chown for {}",
202 path.display()
203 );
204 return;
205 }
206 let gid = unsafe { (*grp).gr_gid };
207
208 let c_path = match CString::new(path.as_os_str().as_bytes()) {
209 Ok(p) => p,
210 Err(_) => return,
211 };
212 let ret = unsafe { libc::chown(c_path.as_ptr(), u32::MAX, gid) };
213 if ret != 0 {
214 warn!(
215 path = %path.display(),
216 error = %std::io::Error::last_os_error(),
217 "Failed to chown control socket to 'fips' group"
218 );
219 }
220 }
221
222 pub async fn accept_loop(self, control_tx: mpsc::Sender<ControlMessage>) {
231 loop {
232 let (stream, _addr) = match self.listener.accept().await {
233 Ok(conn) => conn,
234 Err(e) => {
235 warn!(error = %e, "Control socket accept failed");
236 continue;
237 }
238 };
239
240 let tx = control_tx.clone();
241 tokio::spawn(async move {
242 if let Err(e) = handle_connection_generic(stream, tx).await {
243 debug!(error = %e, "Control connection error");
244 }
245 });
246 }
247 }
248
249 pub fn socket_path(&self) -> &Path {
251 &self.socket_path
252 }
253
254 fn cleanup(&self) {
256 if self.socket_path.exists() {
257 if let Err(e) = std::fs::remove_file(&self.socket_path) {
258 warn!(
259 path = %self.socket_path.display(),
260 error = %e,
261 "Failed to remove control socket"
262 );
263 } else {
264 debug!(path = %self.socket_path.display(), "Control socket removed");
265 }
266 }
267 }
268 }
269
270 impl Drop for ControlSocket {
271 fn drop(&mut self) {
272 self.cleanup();
273 }
274 }
275}
276
277#[cfg(windows)]
282mod windows_impl {
283 use super::*;
284 use tokio::net::TcpListener;
285
286 const DEFAULT_CONTROL_PORT: u16 = 21210;
288
289 pub struct ControlSocket {
299 listener: TcpListener,
300 port: u16,
301 }
302
303 impl ControlSocket {
304 pub fn bind(config: &ControlConfig) -> Result<Self, std::io::Error> {
310 let port: u16 = match config.socket_path.parse() {
311 Ok(p) => p,
312 Err(e) => {
313 warn!(
314 path = %config.socket_path,
315 error = %e,
316 default = DEFAULT_CONTROL_PORT,
317 "Invalid control port, using default"
318 );
319 DEFAULT_CONTROL_PORT
320 }
321 };
322
323 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
324 let std_listener = std::net::TcpListener::bind(addr)?;
325 std_listener.set_nonblocking(true)?;
326 let listener = TcpListener::from_std(std_listener)?;
327
328 info!(port = port, "Control socket listening on localhost");
329
330 Ok(Self { listener, port })
331 }
332
333 pub fn port(&self) -> u16 {
335 self.port
336 }
337
338 pub async fn accept_loop(self, control_tx: mpsc::Sender<ControlMessage>) {
343 loop {
344 let (stream, addr) = match self.listener.accept().await {
345 Ok(conn) => conn,
346 Err(e) => {
347 warn!(error = %e, "Control socket accept failed");
348 continue;
349 }
350 };
351
352 if !addr.ip().is_loopback() {
354 warn!(addr = %addr, "Rejected non-localhost control connection");
355 continue;
356 }
357
358 let tx = control_tx.clone();
359 tokio::spawn(async move {
360 if let Err(e) = handle_connection_generic(stream, tx).await {
361 debug!(error = %e, "Control connection error");
362 }
363 });
364 }
365 }
366 }
367}
368
369#[cfg(unix)]
371pub use unix_impl::ControlSocket;
372#[cfg(windows)]
373pub use windows_impl::ControlSocket;
374
375#[cfg(test)]
376mod tests {
377 #[cfg(windows)]
378 use super::*;
379
380 #[cfg(windows)]
381 #[tokio::test]
382 async fn test_tcp_control_socket_bind() {
383 let config = ControlConfig {
384 enabled: true,
385 socket_path: "0".to_string(), };
387
388 let _socket = ControlSocket::bind(&config).expect("failed to bind control socket");
390 }
391
392 #[cfg(windows)]
393 #[tokio::test]
394 async fn test_tcp_control_socket_invalid_port_uses_default() {
395 let config = ControlConfig {
396 enabled: true,
397 socket_path: "not-a-port".to_string(),
398 };
399
400 let result = ControlSocket::bind(&config);
403 if let Ok(socket) = result {
405 assert_eq!(socket.port(), 21210);
406 }
407 }
408}