network_protocol/transport/
local.rs1#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
2use futures::{SinkExt, StreamExt};
3#[cfg(unix)]
4use std::path::Path;
5#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
6use std::sync::Arc;
7#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
8use std::time::Duration;
9#[cfg(unix)]
10use tokio::net::{UnixListener, UnixStream};
11#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
12use tokio::sync::{mpsc, Mutex};
13use tokio_util::codec::Framed;
14#[cfg(any(unix, all(windows, feature = "use-tcp-on-windows")))]
15use tracing::{debug, error, warn};
16use tracing::{info, instrument};
17
18use crate::core::codec::PacketCodec;
19use crate::error::Result;
20
21#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
24use crate::transport::windows_pipe;
25
26#[cfg(all(windows, feature = "use-tcp-on-windows"))]
27use std::net::SocketAddr;
28#[cfg(all(windows, feature = "use-tcp-on-windows"))]
29use tokio::net::{TcpListener, TcpStream};
30
31#[cfg(unix)]
36#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
37pub async fn start_server<P: AsRef<Path>>(path: P) -> Result<()> {
38 let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
40
41 let shutdown_tx_clone = shutdown_tx.clone();
43 tokio::spawn(async move {
44 if let Ok(()) = tokio::signal::ctrl_c().await {
45 info!("Received CTRL+C signal, shutting down");
46 let _ = shutdown_tx_clone.send(()).await;
47 }
48 });
49
50 start_server_with_shutdown(path, shutdown_rx).await
52}
53
54#[cfg(unix)]
56#[instrument(skip(path, shutdown_rx), fields(socket_path = %path.as_ref().display()))]
57pub async fn start_server_with_shutdown<P: AsRef<Path>>(
58 path: P,
59 mut shutdown_rx: mpsc::Receiver<()>,
60) -> Result<()> {
61 if path.as_ref().exists() {
62 tokio::fs::remove_file(&path).await.ok();
63 }
64
65 let path_string = path.as_ref().to_string_lossy().to_string();
67
68 let listener = UnixListener::bind(&path)?;
69 info!(path = %path_string, "Listening on unix socket");
70
71 let active_connections = Arc::new(Mutex::new(0u32));
73
74 loop {
76 tokio::select! {
77 _ = shutdown_rx.recv() => {
79 info!("Shutting down server. Waiting for connections to close...");
80
81 let timeout = tokio::time::sleep(Duration::from_secs(10));
83 tokio::pin!(timeout);
84
85 loop {
86 tokio::select! {
87 _ = &mut timeout => {
88 warn!("Shutdown timeout reached, forcing exit");
89 break;
90 }
91 _ = tokio::time::sleep(Duration::from_millis(500)) => {
92 let connections = *active_connections.lock().await;
93 info!(connections = %connections, "Waiting for connections to close");
94 if connections == 0 {
95 info!("All connections closed, shutting down");
96 break;
97 }
98 }
99 }
100 }
101
102 if Path::new(&path_string).exists() {
104 if let Err(e) = tokio::fs::remove_file(&path_string).await {
105 error!(error = %e, path = %path_string, "Failed to remove socket file");
106 } else {
107 info!(path = %path_string, "Removed socket file");
108 }
109 }
110
111 return Ok(());
112 }
113
114 accept_result = listener.accept() => {
116 match accept_result {
117 Ok((stream, _)) => {
118 let active_connections = active_connections.clone();
119
120 {
122 let mut count = active_connections.lock().await;
123 *count += 1;
124 }
125
126 tokio::spawn(async move {
127 let mut framed = Framed::new(stream, PacketCodec);
128
129 while let Some(Ok(packet)) = framed.next().await {
130 debug!("Received packet of {} bytes", packet.payload.len());
131
132 let _ = framed.send(packet).await;
134 }
135
136 let mut count = active_connections.lock().await;
138 *count -= 1;
139 });
140 }
141 Err(e) => {
142 error!(error = %e, "Error accepting connection");
143 }
144 }
145 }
146 }
147 }
148}
149
150#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
155#[instrument(skip(path))]
156pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
157 let pipe_name = convert_to_pipe_name(path.as_ref());
159 info!(pipe = %pipe_name, "Starting Windows Named Pipe server");
160
161 windows_pipe::start_server(&pipe_name).await
162}
163
164#[cfg(all(windows, feature = "use-tcp-on-windows"))]
169#[instrument(skip(path))]
170pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
171 let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
174
175 let listener = TcpListener::bind(&addr).await?;
176 info!(address = %addr, "Listening (Windows compatibility mode)");
177
178 let active_connections = Arc::new(Mutex::new(0u32));
180
181 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
183
184 let shutdown_tx_clone = shutdown_tx.clone();
186 tokio::spawn(async move {
187 if let Ok(()) = tokio::signal::ctrl_c().await {
188 info!("Received shutdown signal, initiating graceful shutdown");
189 let _ = shutdown_tx_clone.send(()).await;
190 }
191 });
192
193 loop {
195 tokio::select! {
196 _ = shutdown_rx.recv() => {
198 info!("Shutting down server. Waiting for connections to close...");
199
200 let timeout = tokio::time::sleep(Duration::from_secs(10));
202 tokio::pin!(timeout);
203
204 loop {
205 tokio::select! {
206 _ = &mut timeout => {
207 warn!("Shutdown timeout reached, forcing exit");
208 break;
209 }
210 _ = tokio::time::sleep(Duration::from_millis(500)) => {
211 let connections = *active_connections.lock().await;
212 info!(connections = %connections, "Waiting for connections to close");
213 if connections == 0 {
214 info!("All connections closed, shutting down");
215 break;
216 }
217 }
218 }
219 }
220
221 return Ok(());
222 }
223
224 accept_result = listener.accept() => {
226 match accept_result {
227 Ok((stream, addr)) => {
228 info!(peer = %addr, "New connection established");
229 let active_connections = active_connections.clone();
230
231 {
233 let mut count = active_connections.lock().await;
234 *count += 1;
235 }
236
237 tokio::spawn(async move {
238 let mut framed = Framed::new(stream, PacketCodec);
239
240 while let Some(Ok(packet)) = framed.next().await {
241 debug!(bytes = packet.payload.len(), "Packet received");
242
243 let _ = framed.send(packet).await;
245 }
246
247 let mut count = active_connections.lock().await;
249 *count -= 1;
250 info!(peer = %addr, "Connection closed");
251 });
252 }
253 Err(e) => {
254 error!(error = %e, "Error accepting connection");
255 }
256 }
257 }
258 }
259 }
260}
261
262#[cfg(unix)]
267#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
268pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Framed<UnixStream, PacketCodec>> {
269 let stream = UnixStream::connect(path).await?;
270 Ok(Framed::new(stream, PacketCodec))
271}
272
273#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
277#[instrument(skip(path))]
278pub async fn connect<S: AsRef<str>>(
279 path: S,
280) -> Result<Framed<tokio::net::windows::named_pipe::NamedPipeClient, PacketCodec>> {
281 let pipe_name = convert_to_pipe_name(path.as_ref());
282 windows_pipe::connect(&pipe_name).await
283}
284
285#[cfg(all(windows, feature = "use-tcp-on-windows"))]
289#[instrument(skip(path))]
290pub async fn connect<S: AsRef<str>>(path: S) -> Result<Framed<TcpStream, PacketCodec>> {
291 let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
293
294 let stream = TcpStream::connect(&addr).await?;
295 Ok(Framed::new(stream, PacketCodec))
296}
297
298#[cfg(all(windows, feature = "use-tcp-on-windows"))]
299fn extract_port_or_default(path: &str) -> u16 {
300 path.split('/')
303 .last()
304 .and_then(|s| s.parse::<u16>().ok())
305 .unwrap_or(8080)
306}
307
308#[cfg(all(windows, not(feature = "use-tcp-on-windows")))]
313fn convert_to_pipe_name(path: &str) -> String {
314 if path.starts_with("\\\\.\\pipe\\") {
316 return path.to_string();
317 }
318
319 let name = path
321 .trim_start_matches('/')
322 .replace('/', "_")
323 .replace('\\', "_");
324
325 let name = if name.is_empty() {
327 "network_protocol"
328 } else {
329 &name
330 };
331
332 format!("\\\\.\\pipe\\{}", name)
333}