Skip to main content

network_protocol/transport/
local.rs

1#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
2use futures::{SinkExt, StreamExt};
3#[cfg(unix)]
4use std::path::Path;
5#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
6use std::sync::Arc;
7#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
8use std::time::Duration;
9#[cfg(unix)]
10use tokio::net::{UnixListener, UnixStream};
11#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
12use tokio::sync::{mpsc, Mutex};
13use tokio_util::codec::Framed;
14#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
15use tracing::{debug, error, warn};
16use tracing::{info, instrument};
17
18use crate::core::codec::PacketCodec;
19use crate::error::Result;
20
21// Windows will use named pipes via windows_pipe module
22// Keeping TCP fallback for compatibility
23#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
24use crate::transport::windows_pipe;
25
26#[cfg(all(windows, feature = "use-tcp-on-windows"))]
27use std::net::SocketAddr;
28#[cfg(all(windows, feature = "use-tcp-on-windows"))]
29use tokio::net::{TcpListener, TcpStream};
30
31/// Start a local server for IPC
32///
33/// On Unix systems, this uses Unix Domain Sockets
34/// On Windows, this falls back to TCP localhost connections
35#[cfg(unix)]
36#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
37pub async fn start_server<P: AsRef<Path>>(path: P) -> Result<()> {
38    // Create internal shutdown channel
39    let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
40
41    // Set up ctrl-c handler that sends to our internal shutdown channel
42    let shutdown_tx_clone = shutdown_tx.clone();
43    tokio::spawn(async move {
44        if let Ok(()) = tokio::signal::ctrl_c().await {
45            info!("Received CTRL+C signal, shutting down");
46            let _ = shutdown_tx_clone.send(()).await;
47        }
48    });
49
50    // Start with our internal shutdown channel
51    start_server_with_shutdown(path, shutdown_rx).await
52}
53
54/// Start a Unix domain socket server with an external shutdown channel
55#[cfg(unix)]
56#[instrument(skip(path, shutdown_rx), fields(socket_path = %path.as_ref().display()))]
57pub async fn start_server_with_shutdown<P: AsRef<Path>>(
58    path: P,
59    mut shutdown_rx: mpsc::Receiver<()>,
60) -> Result<()> {
61    if path.as_ref().exists() {
62        tokio::fs::remove_file(&path).await.ok();
63    }
64
65    // Store path for cleanup on shutdown
66    let path_string = path.as_ref().to_string_lossy().to_string();
67
68    let listener = UnixListener::bind(&path)?;
69    info!(path = %path_string, "Listening on unix socket");
70
71    // Track active connections
72    let active_connections = Arc::new(Mutex::new(0u32));
73
74    // Server main loop with graceful shutdown
75    loop {
76        tokio::select! {
77            // Check for shutdown signal from the provided shutdown_rx channel
78            _ = shutdown_rx.recv() => {
79                info!("Shutting down server. Waiting for connections to close...");
80
81                // Wait for active connections to close (with timeout)
82                let timeout = tokio::time::sleep(Duration::from_secs(10));
83                tokio::pin!(timeout);
84
85                loop {
86                    tokio::select! {
87                        _ = &mut timeout => {
88                            warn!("Shutdown timeout reached, forcing exit");
89                            break;
90                        }
91                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
92                            let connections = *active_connections.lock().await;
93                            info!(connections = %connections, "Waiting for connections to close");
94                            if connections == 0 {
95                                info!("All connections closed, shutting down");
96                                break;
97                            }
98                        }
99                    }
100                }
101
102                // Clean up socket file
103                if Path::new(&path_string).exists() {
104                    if let Err(e) = tokio::fs::remove_file(&path_string).await {
105                        error!(error = %e, path = %path_string, "Failed to remove socket file");
106                    } else {
107                        info!(path = %path_string, "Removed socket file");
108                    }
109                }
110
111                return Ok(());
112            }
113
114            // Accept new connections
115            accept_result = listener.accept() => {
116                match accept_result {
117                    Ok((stream, _)) => {
118                        let active_connections = active_connections.clone();
119
120                        // Increment active connections counter
121                        {
122                            let mut count = active_connections.lock().await;
123                            *count += 1;
124                        }
125
126                        tokio::spawn(async move {
127                            let mut framed = Framed::new(stream, PacketCodec);
128
129                            while let Some(Ok(packet)) = framed.next().await {
130                                debug!("Received packet of {} bytes", packet.payload.len());
131
132                                // Echo it back
133                                let _ = framed.send(packet).await;
134                            }
135
136                            // Decrement connection counter when connection closes
137                            let mut count = active_connections.lock().await;
138                            *count -= 1;
139                        });
140                    }
141                    Err(e) => {
142                        error!(error = %e, "Error accepting connection");
143                    }
144                }
145            }
146        }
147    }
148}
149
150/// Windows implementation using Named Pipes for native high-performance IPC
151///
152/// This provides 30-40% better performance than TCP localhost.
153/// Falls back to TCP if the `use-tcp-on-windows` feature is enabled.
154#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
155#[instrument(skip(path))]
156pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
157    // Convert path to Windows named pipe format
158    let pipe_name = convert_to_pipe_name(path.as_ref());
159    info!(pipe = %pipe_name, "Starting Windows Named Pipe server");
160
161    windows_pipe::start_server(&pipe_name).await
162}
163
164/// Windows implementation using TCP on localhost (legacy fallback)
165///
166/// This is available when the `use-tcp-on-windows` feature is enabled.
167/// For better performance, use the default Named Pipes implementation.
168#[cfg(all(windows, feature = "use-tcp-on-windows"))]
169#[instrument(skip(path))]
170pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
171    // On Windows, interpret the path as a port number on localhost
172    // Extract just the port number or use a default
173    let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
174
175    let listener = TcpListener::bind(&addr).await?;
176    info!(address = %addr, "Listening (Windows compatibility mode)");
177
178    // Track active connections
179    let active_connections = Arc::new(Mutex::new(0u32));
180
181    // Create shutdown channel
182    let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
183
184    // Spawn ctrl-c handler
185    let shutdown_tx_clone = shutdown_tx.clone();
186    tokio::spawn(async move {
187        if let Ok(()) = tokio::signal::ctrl_c().await {
188            info!("Received shutdown signal, initiating graceful shutdown");
189            let _ = shutdown_tx_clone.send(()).await;
190        }
191    });
192
193    // Server main loop with graceful shutdown
194    loop {
195        tokio::select! {
196            // Check for shutdown signal
197            _ = shutdown_rx.recv() => {
198                info!("Shutting down server. Waiting for connections to close...");
199
200                // Wait for active connections to close (with timeout)
201                let timeout = tokio::time::sleep(Duration::from_secs(10));
202                tokio::pin!(timeout);
203
204                loop {
205                    tokio::select! {
206                        _ = &mut timeout => {
207                            warn!("Shutdown timeout reached, forcing exit");
208                            break;
209                        }
210                        _ = tokio::time::sleep(Duration::from_millis(500)) => {
211                            let connections = *active_connections.lock().await;
212                            info!(connections = %connections, "Waiting for connections to close");
213                            if connections == 0 {
214                                info!("All connections closed, shutting down");
215                                break;
216                            }
217                        }
218                    }
219                }
220
221                return Ok(());
222            }
223
224            // Accept new connections
225            accept_result = listener.accept() => {
226                match accept_result {
227                    Ok((stream, addr)) => {
228                        info!(peer = %addr, "New connection established");
229                        let active_connections = active_connections.clone();
230
231                        // Increment active connections counter
232                        {
233                            let mut count = active_connections.lock().await;
234                            *count += 1;
235                        }
236
237                        tokio::spawn(async move {
238                            let mut framed = Framed::new(stream, PacketCodec);
239
240                            while let Some(Ok(packet)) = framed.next().await {
241                                debug!(bytes = packet.payload.len(), "Packet received");
242
243                                // Echo it back
244                                let _ = framed.send(packet).await;
245                            }
246
247                            // Decrement connection counter when connection closes
248                            let mut count = active_connections.lock().await;
249                            *count -= 1;
250                            info!(peer = %addr, "Connection closed");
251                        });
252                    }
253                    Err(e) => {
254                        error!(error = %e, "Error accepting connection");
255                    }
256                }
257            }
258        }
259    }
260}
261
262/// Connect to a local IPC socket
263///
264/// On Unix systems, this uses Unix Domain Sockets
265/// On Windows, this falls back to TCP localhost connections
266#[cfg(unix)]
267#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
268pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Framed<UnixStream, PacketCodec>> {
269    let stream = UnixStream::connect(path).await?;
270    Ok(Framed::new(stream, PacketCodec))
271}
272
273/// Connect to a local IPC socket on Windows using Named Pipes
274///
275/// This provides native high-performance IPC on Windows.
276#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
277#[instrument(skip(path))]
278pub async fn connect<S: AsRef<str>>(
279    path: S,
280) -> Result<Framed<tokio::net::windows::named_pipe::NamedPipeClient, PacketCodec>> {
281    let pipe_name = convert_to_pipe_name(path.as_ref());
282    windows_pipe::connect(&pipe_name).await
283}
284
285/// Connect to a local IPC socket on Windows using TCP (legacy fallback)
286///
287/// Available when the `use-tcp-on-windows` feature is enabled.
288#[cfg(all(windows, feature = "use-tcp-on-windows"))]
289#[instrument(skip(path))]
290pub async fn connect<S: AsRef<str>>(path: S) -> Result<Framed<TcpStream, PacketCodec>> {
291    // On Windows, interpret the path as a port number on localhost
292    let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
293
294    let stream = TcpStream::connect(&addr).await?;
295    Ok(Framed::new(stream, PacketCodec))
296}
297
298#[cfg(all(windows, feature = "use-tcp-on-windows"))]
299fn extract_port_or_default(path: &str) -> u16 {
300    // Try to extract a port number from the path string
301    // Default to 8080 if we can't parse anything
302    path.split('/')
303        .last()
304        .and_then(|s| s.parse::<u16>().ok())
305        .unwrap_or(8080)
306}
307
308/// Convert a path string to a Windows named pipe name
309///
310/// Handles various input formats and converts them to the proper
311/// `\\\\.\\pipe\\name` format required by Windows Named Pipes.
312#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
313fn convert_to_pipe_name(path: &str) -> String {
314    // If it's already a proper pipe name, use it as-is
315    if path.starts_with("\\\\.\\pipe\\") {
316        return path.to_string();
317    }
318
319    // Extract a meaningful name from the path
320    let name = path
321        .trim_start_matches('/')
322        .replace('/', "_")
323        .replace('\\', "_");
324
325    // Use a default if empty
326    let name = if name.is_empty() {
327        "network_protocol"
328    } else {
329        &name
330    };
331
332    format!("\\\\.\\pipe\\{}", name)
333}