1use std::net::SocketAddr;
18use std::sync::Arc;
19
20use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
21use tokio::net::{TcpListener, TcpStream};
22use tokio::sync::Mutex;
23
24use crate::error::Result;
25use crate::local::LocalMesh;
26use crate::message::PeerMessage;
27
28pub struct TcpMesh {
30 local: LocalMesh,
31}
32
33impl TcpMesh {
34 pub fn new(local: LocalMesh) -> Self {
36 Self { local }
37 }
38
39 pub async fn serve(self, addr: SocketAddr) -> Result<()> {
41 let listener = TcpListener::bind(addr).await?;
42 tracing::info!(%addr, "tcp mesh listening (plain)");
43 loop {
44 let (socket, peer) = listener.accept().await?;
45 tracing::debug!(?peer, "tcp peer connected");
46 let local = self.local.clone();
47 tokio::spawn(async move {
48 if let Err(e) = handle_connection(socket, local).await {
49 tracing::warn!(?peer, ?e, "tcp connection ended");
50 }
51 });
52 }
53 }
54
55 #[cfg(feature = "tls")]
61 pub async fn serve_tls(
62 self,
63 addr: SocketAddr,
64 acceptor: tokio_rustls::TlsAcceptor,
65 ) -> Result<()> {
66 let listener = TcpListener::bind(addr).await?;
67 tracing::info!(%addr, "tcp mesh listening (tls)");
68 loop {
69 let (socket, peer) = listener.accept().await?;
70 let acceptor = acceptor.clone();
71 let local = self.local.clone();
72 tokio::spawn(async move {
73 match acceptor.accept(socket).await {
74 Ok(tls) => {
75 if let Err(e) = handle_connection(tls, local).await {
76 tracing::warn!(?peer, ?e, "tls connection ended");
77 }
78 }
79 Err(e) => tracing::warn!(?peer, ?e, "tls handshake failed"),
80 }
81 });
82 }
83 }
84
85 pub async fn connect(addr: SocketAddr, hello: PeerMessage) -> Result<TcpClient> {
88 let mut socket = TcpStream::connect(addr).await?;
89 let line = format!("{}\n", serde_json::to_string(&hello)?);
90 socket.write_all(line.as_bytes()).await?;
91 Ok(TcpClient {
92 inner: Arc::new(Mutex::new(
93 Box::new(socket) as Box<dyn AsyncWrite + Send + Unpin>
94 )),
95 })
96 }
97
98 #[cfg(feature = "tls")]
105 pub async fn connect_tls(
106 addr: SocketAddr,
107 hello: PeerMessage,
108 server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
109 connector: tokio_rustls::TlsConnector,
110 ) -> Result<TcpClient> {
111 let stream = TcpStream::connect(addr).await?;
112 let mut tls = connector.connect(server_name, stream).await?;
113 let line = format!("{}\n", serde_json::to_string(&hello)?);
114 tls.write_all(line.as_bytes()).await?;
115 Ok(TcpClient {
116 inner: Arc::new(Mutex::new(
117 Box::new(tls) as Box<dyn AsyncWrite + Send + Unpin>
118 )),
119 })
120 }
121}
122
123#[derive(Clone)]
125pub struct TcpClient {
126 inner: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>>,
127}
128
129impl TcpClient {
130 pub async fn send(&self, msg: &PeerMessage) -> Result<()> {
132 let line = format!("{}\n", serde_json::to_string(msg)?);
133 self.inner.lock().await.write_all(line.as_bytes()).await?;
134 Ok(())
135 }
136}
137
138async fn handle_connection<S>(socket: S, local: LocalMesh) -> Result<()>
143where
144 S: AsyncRead + AsyncWrite + Unpin,
145{
146 let mut reader = BufReader::new(socket);
147 let mut line = String::new();
148 let mut sender_id: Option<String> = None;
149 loop {
150 line.clear();
151 let n = reader.read_line(&mut line).await?;
152 if n == 0 {
153 break;
154 }
155 let trimmed = line.trim();
156 if trimmed.is_empty() {
157 continue;
158 }
159 let msg: PeerMessage = match serde_json::from_str(trimmed) {
160 Ok(m) => m,
161 Err(e) => {
162 tracing::warn!(?e, "discarding malformed line");
163 continue;
164 }
165 };
166 if let PeerMessage::Hello { from, capabilities } = &msg {
167 sender_id = Some(from.clone());
168 let _ = local
169 .join(from.clone(), capabilities.clone(), Vec::new())
170 .await;
171 }
172 let sender = sender_id.clone().unwrap_or_else(|| msg.sender().clone());
173 let (_p, handle) = local
174 .join(format!("ephemeral:{sender}"), Vec::new(), Vec::new())
175 .await?;
176 handle.publish(msg).await?;
177 local.leave(&handle.id).await?;
178 }
179 if let Some(id) = sender_id {
180 let _ = local.leave(&id).await;
181 }
182 Ok(())
183}
184
185#[cfg(feature = "tls")]
194pub fn tls_server_config(cert_pem: &[u8], key_pem: &[u8]) -> anyhow::Result<rustls::ServerConfig> {
195 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
196 use rustls_pemfile::{certs, private_key};
197 use std::io::Cursor;
198
199 let certs: Vec<CertificateDer<'static>> = certs(&mut Cursor::new(cert_pem))
200 .collect::<std::result::Result<_, _>>()
201 .map_err(|e| anyhow::anyhow!("cert parse: {e}"))?;
202
203 let key: PrivateKeyDer<'static> = private_key(&mut Cursor::new(key_pem))
204 .map_err(|e| anyhow::anyhow!("key parse: {e}"))?
205 .ok_or_else(|| anyhow::anyhow!("no private key found"))?;
206
207 let config = rustls::ServerConfig::builder()
208 .with_no_client_auth()
209 .with_single_cert(certs, key)
210 .map_err(|e| anyhow::anyhow!("tls config: {e}"))?;
211 Ok(config)
212}
213
214#[cfg(feature = "tls")]
219pub fn tls_client_config(ca_cert_pem: &[u8]) -> anyhow::Result<rustls::ClientConfig> {
220 use rustls::pki_types::CertificateDer;
221 use rustls::RootCertStore;
222 use rustls_pemfile::certs;
223 use std::io::Cursor;
224
225 let mut roots = RootCertStore::empty();
226 for cert in certs(&mut Cursor::new(ca_cert_pem))
227 .collect::<std::result::Result<Vec<CertificateDer<'static>>, _>>()
228 .map_err(|e| anyhow::anyhow!("ca cert parse: {e}"))?
229 {
230 roots
231 .add(cert)
232 .map_err(|e| anyhow::anyhow!("add root: {e}"))?;
233 }
234 let config = rustls::ClientConfig::builder()
235 .with_root_certificates(roots)
236 .with_no_client_auth();
237 Ok(config)
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::message::PeerCapability;
248 use serde_json::json;
249 use std::net::Ipv4Addr;
250
251 fn caps(name: &str) -> Vec<PeerCapability> {
252 vec![PeerCapability {
253 name: name.into(),
254 version: None,
255 }]
256 }
257
258 #[tokio::test]
259 async fn tcp_round_trip_delivers_broadcast() {
260 let local = LocalMesh::new();
261 let (mut listener_handle, _h) = local.join("listener", caps("x"), vec![]).await.unwrap();
262
263 let server = TcpMesh::new(local.clone());
264 let listener = TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
265 .await
266 .unwrap();
267 let addr = listener.local_addr().unwrap();
268 let local_clone = local.clone();
269 let accept_task = tokio::spawn(async move {
270 let _ = server;
271 while let Ok((socket, _)) = listener.accept().await {
272 let local = local_clone.clone();
273 tokio::spawn(async move {
274 let _ = handle_connection(socket, local).await;
275 });
276 }
277 });
278
279 let hello = PeerMessage::Hello {
280 from: "remote".into(),
281 capabilities: caps("remote"),
282 };
283 let client = TcpMesh::connect(addr, hello).await.unwrap();
284 client
285 .send(&PeerMessage::broadcast("remote", "topic", json!({"v": 1})))
286 .await
287 .unwrap();
288
289 let mut saw_broadcast = false;
290 for _ in 0..6 {
291 let recv = tokio::time::timeout(
292 std::time::Duration::from_millis(400),
293 listener_handle.receiver.recv(),
294 )
295 .await;
296 match recv {
297 Ok(Some(PeerMessage::Broadcast { from, topic, .. })) => {
298 assert_eq!(from, "remote");
299 assert_eq!(topic, "topic");
300 saw_broadcast = true;
301 break;
302 }
303 Ok(Some(_)) => continue,
304 _ => break,
305 }
306 }
307 assert!(saw_broadcast, "expected a Broadcast to arrive");
308 accept_task.abort();
309 }
310
311 #[cfg(feature = "tls")]
313 #[tokio::test]
314 async fn tls_round_trip_delivers_broadcast() {
315 use rcgen::generate_simple_self_signed;
316 use rustls::pki_types::ServerName;
317 use std::sync::Arc;
318 use tokio_rustls::{TlsAcceptor, TlsConnector};
319
320 let _ = rustls::crypto::ring::default_provider().install_default();
322
323 let cert = generate_simple_self_signed(vec!["localhost".into()]).unwrap();
325 let cert_pem = cert.cert.pem();
326 let key_pem = cert.key_pair.serialize_pem();
327
328 let server_cfg = tls_server_config(cert_pem.as_bytes(), key_pem.as_bytes()).unwrap();
330 let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
331
332 let client_cfg = tls_client_config(cert_pem.as_bytes()).unwrap();
334 let connector = TlsConnector::from(Arc::new(client_cfg));
335
336 let local = LocalMesh::new();
337 let (mut listener_handle, _h) = local.join("listener", caps("x"), vec![]).await.unwrap();
338
339 let std_listener =
341 std::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).unwrap();
342 std_listener.set_nonblocking(true).unwrap();
343 let tls_listener = TcpListener::from_std(std_listener).unwrap();
344 let addr = tls_listener.local_addr().unwrap();
345
346 let local_clone = local.clone();
347 let accept_task = tokio::spawn(async move {
348 while let Ok((socket, peer)) = tls_listener.accept().await {
349 let acceptor = acceptor.clone();
350 let local = local_clone.clone();
351 tokio::spawn(async move {
352 match acceptor.accept(socket).await {
353 Ok(tls) => {
354 let _ = handle_connection(tls, local).await;
355 }
356 Err(e) => tracing::warn!(?peer, ?e, "tls handshake failed in test"),
357 }
358 });
359 }
360 });
361
362 let hello = PeerMessage::Hello {
363 from: "tls-remote".into(),
364 capabilities: caps("tls-remote"),
365 };
366 let server_name = ServerName::try_from("localhost").unwrap();
367 let client = TcpMesh::connect_tls(addr, hello, server_name, connector)
368 .await
369 .unwrap();
370 client
371 .send(&PeerMessage::broadcast(
372 "tls-remote",
373 "tls-topic",
374 json!({"secure": true}),
375 ))
376 .await
377 .unwrap();
378
379 let mut saw_broadcast = false;
380 for _ in 0..6 {
381 let recv = tokio::time::timeout(
382 std::time::Duration::from_millis(500),
383 listener_handle.receiver.recv(),
384 )
385 .await;
386 match recv {
387 Ok(Some(PeerMessage::Broadcast { from, topic, .. })) => {
388 assert_eq!(from, "tls-remote");
389 assert_eq!(topic, "tls-topic");
390 saw_broadcast = true;
391 break;
392 }
393 Ok(Some(_)) => continue,
394 _ => break,
395 }
396 }
397 assert!(saw_broadcast, "expected a TLS Broadcast to arrive");
398 accept_task.abort();
399 }
400}