network_protocol/transport/
local.rs1use futures::{SinkExt, StreamExt};
2#[cfg(unix)]
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6#[cfg(unix)]
7use tokio::net::{UnixListener, UnixStream};
8use tokio::sync::{mpsc, Mutex};
9use tokio_util::codec::Framed;
10use tracing::{debug, error, info, instrument, warn};
11
12use crate::core::codec::PacketCodec;
13use crate::error::Result;
14#[cfg(windows)]
15use std::net::SocketAddr;
16#[cfg(windows)]
17use tokio::net::{TcpListener, TcpStream};
18
19#[cfg(unix)]
24#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
25pub async fn start_server<P: AsRef<Path>>(path: P) -> Result<()> {
26 let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
28
29 let shutdown_tx_clone = shutdown_tx.clone();
31 tokio::spawn(async move {
32 if let Ok(()) = tokio::signal::ctrl_c().await {
33 info!("Received CTRL+C signal, shutting down");
34 let _ = shutdown_tx_clone.send(()).await;
35 }
36 });
37
38 start_server_with_shutdown(path, shutdown_rx).await
40}
41
42#[cfg(unix)]
44#[instrument(skip(path, shutdown_rx), fields(socket_path = %path.as_ref().display()))]
45pub async fn start_server_with_shutdown<P: AsRef<Path>>(
46 path: P,
47 mut shutdown_rx: mpsc::Receiver<()>,
48) -> Result<()> {
49 if path.as_ref().exists() {
50 tokio::fs::remove_file(&path).await.ok();
51 }
52
53 let path_string = path.as_ref().to_string_lossy().to_string();
55
56 let listener = UnixListener::bind(&path)?;
57 info!(path = %path_string, "Listening on unix socket");
58
59 let active_connections = Arc::new(Mutex::new(0u32));
61
62 loop {
64 tokio::select! {
65 _ = shutdown_rx.recv() => {
67 info!("Shutting down server. Waiting for connections to close...");
68
69 let timeout = tokio::time::sleep(Duration::from_secs(10));
71 tokio::pin!(timeout);
72
73 loop {
74 tokio::select! {
75 _ = &mut timeout => {
76 warn!("Shutdown timeout reached, forcing exit");
77 break;
78 }
79 _ = tokio::time::sleep(Duration::from_millis(500)) => {
80 let connections = *active_connections.lock().await;
81 info!(connections = %connections, "Waiting for connections to close");
82 if connections == 0 {
83 info!("All connections closed, shutting down");
84 break;
85 }
86 }
87 }
88 }
89
90 if Path::new(&path_string).exists() {
92 if let Err(e) = tokio::fs::remove_file(&path_string).await {
93 error!(error = %e, path = %path_string, "Failed to remove socket file");
94 } else {
95 info!(path = %path_string, "Removed socket file");
96 }
97 }
98
99 return Ok(());
100 }
101
102 accept_result = listener.accept() => {
104 match accept_result {
105 Ok((stream, _)) => {
106 let active_connections = active_connections.clone();
107
108 {
110 let mut count = active_connections.lock().await;
111 *count += 1;
112 }
113
114 tokio::spawn(async move {
115 let mut framed = Framed::new(stream, PacketCodec);
116
117 while let Some(Ok(packet)) = framed.next().await {
118 debug!("Received packet of {} bytes", packet.payload.len());
119
120 let _ = framed.send(packet).await;
122 }
123
124 let mut count = active_connections.lock().await;
126 *count -= 1;
127 });
128 }
129 Err(e) => {
130 error!(error = %e, "Error accepting connection");
131 }
132 }
133 }
134 }
135 }
136}
137
138#[cfg(windows)]
140#[instrument(skip(path))]
141pub async fn start_server<S: AsRef<str>>(path: S) -> Result<()> {
142 let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
145
146 let listener = TcpListener::bind(&addr).await?;
147 info!(address = %addr, "Listening (Windows compatibility mode)");
148
149 let active_connections = Arc::new(Mutex::new(0u32));
151
152 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
154
155 let shutdown_tx_clone = shutdown_tx.clone();
157 tokio::spawn(async move {
158 if let Ok(()) = tokio::signal::ctrl_c().await {
159 info!("Received shutdown signal, initiating graceful shutdown");
160 let _ = shutdown_tx_clone.send(()).await;
161 }
162 });
163
164 loop {
166 tokio::select! {
167 _ = shutdown_rx.recv() => {
169 info!("Shutting down server. Waiting for connections to close...");
170
171 let timeout = tokio::time::sleep(Duration::from_secs(10));
173 tokio::pin!(timeout);
174
175 loop {
176 tokio::select! {
177 _ = &mut timeout => {
178 warn!("Shutdown timeout reached, forcing exit");
179 break;
180 }
181 _ = tokio::time::sleep(Duration::from_millis(500)) => {
182 let connections = *active_connections.lock().await;
183 info!(connections = %connections, "Waiting for connections to close");
184 if connections == 0 {
185 info!("All connections closed, shutting down");
186 break;
187 }
188 }
189 }
190 }
191
192 return Ok(());
193 }
194
195 accept_result = listener.accept() => {
197 match accept_result {
198 Ok((stream, addr)) => {
199 info!(peer = %addr, "New connection established");
200 let active_connections = active_connections.clone();
201
202 {
204 let mut count = active_connections.lock().await;
205 *count += 1;
206 }
207
208 tokio::spawn(async move {
209 let mut framed = Framed::new(stream, PacketCodec);
210
211 while let Some(Ok(packet)) = framed.next().await {
212 debug!(bytes = packet.payload.len(), "Packet received");
213
214 let _ = framed.send(packet).await;
216 }
217
218 let mut count = active_connections.lock().await;
220 *count -= 1;
221 info!(peer = %addr, "Connection closed");
222 });
223 }
224 Err(e) => {
225 error!(error = %e, "Error accepting connection");
226 }
227 }
228 }
229 }
230 }
231}
232
233#[cfg(unix)]
238#[instrument(skip(path), fields(socket_path = %path.as_ref().display()))]
239pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Framed<UnixStream, PacketCodec>> {
240 let stream = UnixStream::connect(path).await?;
241 Ok(Framed::new(stream, PacketCodec))
242}
243
244#[cfg(windows)]
245#[instrument(skip(path))]
246pub async fn connect<S: AsRef<str>>(path: S) -> Result<Framed<TcpStream, PacketCodec>> {
247 let addr = format!("127.0.0.1:{}", extract_port_or_default(path.as_ref()));
249
250 let stream = TcpStream::connect(&addr).await?;
251 Ok(Framed::new(stream, PacketCodec))
252}
253
254#[cfg(windows)]
255fn extract_port_or_default(path: &str) -> u16 {
256 path.split('/')
259 .last()
260 .and_then(|s| s.parse::<u16>().ok())
261 .unwrap_or(8080)
262}