network_protocol/transport/
local.rs

1use futures::{SinkExt, StreamExt};
2#[cfg(unix)]
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6#[cfg(unix)]
7use tokio::net::{UnixListener, UnixStream};
8use tokio::sync::{mpsc, Mutex};
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::core::codec::PacketCodec;
13use crate::error::Result;
14#[cfg(windows)]
15use std::net::SocketAddr;
16#[cfg(windows)]
17use tokio::net::{TcpListener, TcpStream};
18
19/// Start a local server for IPC
20///
21/// On Unix systems, this uses Unix Domain Sockets
22/// On Windows, this falls back to TCP localhost connections
23#[cfg(unix)]
24#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
25pub async fn start_server<P: AsRef<Path>>(path: P) -> Result<()> {
26    // Create internal shutdown channel
27    let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
28
29    // Set up ctrl-c handler that sends to our internal shutdown channel
30    let shutdown_tx_clone = shutdown_tx.clone();
31    tokio::spawn(async move {
32        if let Ok(()) = tokio::signal::ctrl_c().await {
33            info!("Received CTRL+C signal, shutting down");
34            let _ = shutdown_tx_clone.send(()).await;
35        }
36    });
37
38    // Start with our internal shutdown channel
39    start_server_with_shutdown(path, shutdown_rx).await
40}
41
42/// Start a Unix domain socket server with an external shutdown channel
43#[cfg(unix)]
44#[instrument(skip(path, shutdown_rx), fields(socket_path = %path.as_ref().display()))]
45pub async fn start_server_with_shutdown<P: AsRef<Path>>(
46    path: P,
47    mut shutdown_rx: mpsc::Receiver<()>,
48) -> Result<()> {
49    if path.as_ref().exists() {
50        tokio::fs::remove_file(&path).await.ok();
51    }
52
53    // Store path for cleanup on shutdown
54    let path_string = path.as_ref().to_string_lossy().to_string();
55
56    let listener = UnixListener::bind(&path)?;
57    info!(path = %path_string, "Listening on unix socket");
58
59    // Track active connections
60    let active_connections = Arc::new(Mutex::new(0u32));
61
62    // Server main loop with graceful shutdown
63    loop {
64        tokio::select! {
65            // Check for shutdown signal from the provided shutdown_rx channel
66            _ = shutdown_rx.recv() => {
67                info!("Shutting down server. Waiting for connections to close...");
68
69                // Wait for active connections to close (with timeout)
70                let timeout = tokio::time::sleep(Duration::from_secs(10));
71                tokio::pin!(timeout);
72
73                loop {
74                    tokio::select! {
75                        _ = &mut timeout => {
76                            warn!("Shutdown timeout reached, forcing exit");
77                            break;
78                        }
79                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
80                            let connections = *active_connections.lock().await;
81                            info!(connections = %connections, "Waiting for connections to close");
82                            if connections == 0 {
83                                info!("All connections closed, shutting down");
84                                break;
85                            }
86                        }
87                    }
88                }
89
90                // Clean up socket file
91                if Path::new(&path_string).exists() {
92                    if let Err(e) = tokio::fs::remove_file(&path_string).await {
93                        error!(error = %e, path = %path_string, "Failed to remove socket file");
94                    } else {
95                        info!(path = %path_string, "Removed socket file");
96                    }
97                }
98
99                return Ok(());
100            }
101
102            // Accept new connections
103            accept_result = listener.accept() => {
104                match accept_result {
105                    Ok((stream, _)) => {
106                        let active_connections = active_connections.clone();
107
108                        // Increment active connections counter
109                        {
110                            let mut count = active_connections.lock().await;
111                            *count += 1;
112                        }
113
114                        tokio::spawn(async move {
115                            let mut framed = Framed::new(stream, PacketCodec);
116
117                            while let Some(Ok(packet)) = framed.next().await {
118                                debug!("Received packet of {} bytes", packet.payload.len());
119
120                                // Echo it back
121                                let _ = framed.send(packet).await;
122                            }
123
124                            // Decrement connection counter when connection closes
125                            let mut count = active_connections.lock().await;
126                            *count -= 1;
127                        });
128                    }
129                    Err(e) => {
130                        error!(error = %e, "Error accepting connection");
131                    }
132                }
133            }
134        }
135    }
136}
137
138/// Windows implementation using TCP on localhost instead of Unix sockets
139#[cfg(windows)]
140#[instrument(skip(path))]
141pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
142    // On Windows, interpret the path as a port number on localhost
143    // Extract just the port number or use a default
144    let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
145
146    let listener = TcpListener::bind(&addr).await?;
147    info!(address = %addr, "Listening (Windows compatibility mode)");
148
149    // Track active connections
150    let active_connections = Arc::new(Mutex::new(0u32));
151
152    // Create shutdown channel
153    let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
154
155    // Spawn ctrl-c handler
156    let shutdown_tx_clone = shutdown_tx.clone();
157    tokio::spawn(async move {
158        if let Ok(()) = tokio::signal::ctrl_c().await {
159            info!("Received shutdown signal, initiating graceful shutdown");
160            let _ = shutdown_tx_clone.send(()).await;
161        }
162    });
163
164    // Server main loop with graceful shutdown
165    loop {
166        tokio::select! {
167            // Check for shutdown signal
168            _ = shutdown_rx.recv() => {
169                info!("Shutting down server. Waiting for connections to close...");
170
171                // Wait for active connections to close (with timeout)
172                let timeout = tokio::time::sleep(Duration::from_secs(10));
173                tokio::pin!(timeout);
174
175                loop {
176                    tokio::select! {
177                        _ = &mut timeout => {
178                            warn!("Shutdown timeout reached, forcing exit");
179                            break;
180                        }
181                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
182                            let connections = *active_connections.lock().await;
183                            info!(connections = %connections, "Waiting for connections to close");
184                            if connections == 0 {
185                                info!("All connections closed, shutting down");
186                                break;
187                            }
188                        }
189                    }
190                }
191
192                return Ok(());
193            }
194
195            // Accept new connections
196            accept_result = listener.accept() => {
197                match accept_result {
198                    Ok((stream, addr)) => {
199                        info!(peer = %addr, "New connection established");
200                        let active_connections = active_connections.clone();
201
202                        // Increment active connections counter
203                        {
204                            let mut count = active_connections.lock().await;
205                            *count += 1;
206                        }
207
208                        tokio::spawn(async move {
209                            let mut framed = Framed::new(stream, PacketCodec);
210
211                            while let Some(Ok(packet)) = framed.next().await {
212                                debug!(bytes = packet.payload.len(), "Packet received");
213
214                                // Echo it back
215                                let _ = framed.send(packet).await;
216                            }
217
218                            // Decrement connection counter when connection closes
219                            let mut count = active_connections.lock().await;
220                            *count -= 1;
221                            info!(peer = %addr, "Connection closed");
222                        });
223                    }
224                    Err(e) => {
225                        error!(error = %e, "Error accepting connection");
226                    }
227                }
228            }
229        }
230    }
231}
232
233/// Connect to a local IPC socket
234///
235/// On Unix systems, this uses Unix Domain Sockets
236/// On Windows, this falls back to TCP localhost connections
237#[cfg(unix)]
238#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
239pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Framed<UnixStream, PacketCodec>> {
240    let stream = UnixStream::connect(path).await?;
241    Ok(Framed::new(stream, PacketCodec))
242}
243
244#[cfg(windows)]
245#[instrument(skip(path))]
246pub async fn connect<S: AsRef<str>>(path: S) -> Result<Framed<TcpStream, PacketCodec>> {
247    // On Windows, interpret the path as a port number on localhost
248    let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
249
250    let stream = TcpStream::connect(&addr).await?;
251    Ok(Framed::new(stream, PacketCodec))
252}
253
254#[cfg(windows)]
255fn extract_port_or_default(path: &str) -> u16 {
256    // Try to extract a port number from the path string
257    // Default to 8080 if we can't parse anything
258    path.split('/')
259        .last()
260        .and_then(|s| s.parse::<u16>().ok())
261        .unwrap_or(8080)
262}