Skip to main content

oxide_mesh/
tcp.rs

1//! JSON-line framed TCP transport — plain and TLS.
2//!
3//! `TcpMesh` is a thin server: it accepts connections, frames messages as one
4//! JSON-encoded [`PeerMessage`] per line, and forwards each into the supplied
5//! handler. Outbound dispatch happens through the same [`LocalMesh`] used by
6//! in-process peers, so a TCP peer behaves identically to a local one once
7//! its `Hello` has been processed.
8//!
9//! # TLS
10//!
11//! Enable the `tls` Cargo feature to unlock [`TcpMesh::serve_tls`] and
12//! [`TcpMesh::connect_tls`]. Both sides authenticate with a certificate.
13//! For mutual TLS (mTLS) supply a `ClientConfig` that includes a client
14//! certificate — the server already requests client certs when built with
15//! [`tls_server_config`].
16
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
21use tokio::net::{TcpListener, TcpStream};
22use tokio::sync::Mutex;
23
24use crate::error::Result;
25use crate::local::LocalMesh;
26use crate::message::PeerMessage;
27
28/// TCP mesh server bound to a [`LocalMesh`] for routing.
29pub struct TcpMesh {
30    local: LocalMesh,
31}
32
33impl TcpMesh {
34    /// Build a TCP wrapper around `local`.
35    pub fn new(local: LocalMesh) -> Self {
36        Self { local }
37    }
38
39    /// Bind to `addr` and serve **plain** TCP connections until cancellation.
40    pub async fn serve(self, addr: SocketAddr) -> Result<()> {
41        let listener = TcpListener::bind(addr).await?;
42        tracing::info!(%addr, "tcp mesh listening (plain)");
43        loop {
44            let (socket, peer) = listener.accept().await?;
45            tracing::debug!(?peer, "tcp peer connected");
46            let local = self.local.clone();
47            tokio::spawn(async move {
48                if let Err(e) = handle_connection(socket, local).await {
49                    tracing::warn!(?peer, ?e, "tcp connection ended");
50                }
51            });
52        }
53    }
54
55    /// Bind to `addr` and serve **TLS** connections.
56    ///
57    /// `acceptor` is a [`tokio_rustls::TlsAcceptor`] built from a
58    /// [`rustls::ServerConfig`]. Use [`tls_server_config`] to build one from
59    /// PEM bytes, or construct it manually for mTLS.
60    #[cfg(feature = "tls")]
61    pub async fn serve_tls(
62        self,
63        addr: SocketAddr,
64        acceptor: tokio_rustls::TlsAcceptor,
65    ) -> Result<()> {
66        let listener = TcpListener::bind(addr).await?;
67        tracing::info!(%addr, "tcp mesh listening (tls)");
68        loop {
69            let (socket, peer) = listener.accept().await?;
70            let acceptor = acceptor.clone();
71            let local = self.local.clone();
72            tokio::spawn(async move {
73                match acceptor.accept(socket).await {
74                    Ok(tls) => {
75                        if let Err(e) = handle_connection(tls, local).await {
76                            tracing::warn!(?peer, ?e, "tls connection ended");
77                        }
78                    }
79                    Err(e) => tracing::warn!(?peer, ?e, "tls handshake failed"),
80                }
81            });
82        }
83    }
84
85    /// Connect to a remote `addr` over **plain** TCP, send `hello`, and
86    /// return a [`TcpClient`].
87    pub async fn connect(addr: SocketAddr, hello: PeerMessage) -> Result<TcpClient> {
88        let mut socket = TcpStream::connect(addr).await?;
89        let line = format!("{}\n", serde_json::to_string(&hello)?);
90        socket.write_all(line.as_bytes()).await?;
91        Ok(TcpClient {
92            inner: Arc::new(Mutex::new(
93                Box::new(socket) as Box<dyn AsyncWrite + Send + Unpin>
94            )),
95        })
96    }
97
98    /// Connect to a remote `addr` over **TLS**, send `hello`, and return a
99    /// [`TcpClient`].
100    ///
101    /// `server_name` must match the CN / SAN in the server certificate.
102    /// `connector` is a [`tokio_rustls::TlsConnector`] built from a
103    /// [`rustls::ClientConfig`] that trusts the server's CA.
104    #[cfg(feature = "tls")]
105    pub async fn connect_tls(
106        addr: SocketAddr,
107        hello: PeerMessage,
108        server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
109        connector: tokio_rustls::TlsConnector,
110    ) -> Result<TcpClient> {
111        let stream = TcpStream::connect(addr).await?;
112        let mut tls = connector.connect(server_name, stream).await?;
113        let line = format!("{}\n", serde_json::to_string(&hello)?);
114        tls.write_all(line.as_bytes()).await?;
115        Ok(TcpClient {
116            inner: Arc::new(Mutex::new(
117                Box::new(tls) as Box<dyn AsyncWrite + Send + Unpin>
118            )),
119        })
120    }
121}
122
123/// Client handle returned by [`TcpMesh::connect`] or [`TcpMesh::connect_tls`].
124#[derive(Clone)]
125pub struct TcpClient {
126    inner: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>>,
127}
128
129impl TcpClient {
130    /// Send `msg` as a single JSON line.
131    pub async fn send(&self, msg: &PeerMessage) -> Result<()> {
132        let line = format!("{}\n", serde_json::to_string(msg)?);
133        self.inner.lock().await.write_all(line.as_bytes()).await?;
134        Ok(())
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Generic connection handler — works over plain TCP and TLS streams.
140// ---------------------------------------------------------------------------
141
142async fn handle_connection<S>(socket: S, local: LocalMesh) -> Result<()>
143where
144    S: AsyncRead + AsyncWrite + Unpin,
145{
146    let mut reader = BufReader::new(socket);
147    let mut line = String::new();
148    let mut sender_id: Option<String> = None;
149    loop {
150        line.clear();
151        let n = reader.read_line(&mut line).await?;
152        if n == 0 {
153            break;
154        }
155        let trimmed = line.trim();
156        if trimmed.is_empty() {
157            continue;
158        }
159        let msg: PeerMessage = match serde_json::from_str(trimmed) {
160            Ok(m) => m,
161            Err(e) => {
162                tracing::warn!(?e, "discarding malformed line");
163                continue;
164            }
165        };
166        if let PeerMessage::Hello { from, capabilities } = &msg {
167            sender_id = Some(from.clone());
168            let _ = local
169                .join(from.clone(), capabilities.clone(), Vec::new())
170                .await;
171        }
172        let sender = sender_id.clone().unwrap_or_else(|| msg.sender().clone());
173        let (_p, handle) = local
174            .join(format!("ephemeral:{sender}"), Vec::new(), Vec::new())
175            .await?;
176        handle.publish(msg).await?;
177        local.leave(&handle.id).await?;
178    }
179    if let Some(id) = sender_id {
180        let _ = local.leave(&id).await;
181    }
182    Ok(())
183}
184
185// ---------------------------------------------------------------------------
186// TLS config helpers
187// ---------------------------------------------------------------------------
188
189/// Build a [`rustls::ServerConfig`] from PEM-encoded certificate + private key.
190///
191/// The returned config does **not** request client certificates; for mTLS
192/// call `.with_client_cert_verifier(...)` on the builder instead.
193#[cfg(feature = "tls")]
194pub fn tls_server_config(cert_pem: &[u8], key_pem: &[u8]) -> anyhow::Result<rustls::ServerConfig> {
195    use rustls::pki_types::{CertificateDer, PrivateKeyDer};
196    use rustls_pemfile::{certs, private_key};
197    use std::io::Cursor;
198
199    let certs: Vec<CertificateDer<'static>> = certs(&mut Cursor::new(cert_pem))
200        .collect::<std::result::Result<_, _>>()
201        .map_err(|e| anyhow::anyhow!("cert parse: {e}"))?;
202
203    let key: PrivateKeyDer<'static> = private_key(&mut Cursor::new(key_pem))
204        .map_err(|e| anyhow::anyhow!("key parse: {e}"))?
205        .ok_or_else(|| anyhow::anyhow!("no private key found"))?;
206
207    let config = rustls::ServerConfig::builder()
208        .with_no_client_auth()
209        .with_single_cert(certs, key)
210        .map_err(|e| anyhow::anyhow!("tls config: {e}"))?;
211    Ok(config)
212}
213
214/// Build a [`rustls::ClientConfig`] that trusts a single CA certificate (PEM).
215///
216/// Suitable for connecting to servers that present a self-signed or
217/// `rcgen`-generated certificate.
218#[cfg(feature = "tls")]
219pub fn tls_client_config(ca_cert_pem: &[u8]) -> anyhow::Result<rustls::ClientConfig> {
220    use rustls::pki_types::CertificateDer;
221    use rustls::RootCertStore;
222    use rustls_pemfile::certs;
223    use std::io::Cursor;
224
225    let mut roots = RootCertStore::empty();
226    for cert in certs(&mut Cursor::new(ca_cert_pem))
227        .collect::<std::result::Result<Vec<CertificateDer<'static>>, _>>()
228        .map_err(|e| anyhow::anyhow!("ca cert parse: {e}"))?
229    {
230        roots
231            .add(cert)
232            .map_err(|e| anyhow::anyhow!("add root: {e}"))?;
233    }
234    let config = rustls::ClientConfig::builder()
235        .with_root_certificates(roots)
236        .with_no_client_auth();
237    Ok(config)
238}
239
240// ---------------------------------------------------------------------------
241// Tests
242// ---------------------------------------------------------------------------
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::message::PeerCapability;
248    use serde_json::json;
249    use std::net::Ipv4Addr;
250
251    fn caps(name: &str) -> Vec<PeerCapability> {
252        vec![PeerCapability {
253            name: name.into(),
254            version: None,
255        }]
256    }
257
258    #[tokio::test]
259    async fn tcp_round_trip_delivers_broadcast() {
260        let local = LocalMesh::new();
261        let (mut listener_handle, _h) = local.join("listener", caps("x"), vec![]).await.unwrap();
262
263        let server = TcpMesh::new(local.clone());
264        let listener = TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
265            .await
266            .unwrap();
267        let addr = listener.local_addr().unwrap();
268        let local_clone = local.clone();
269        let accept_task = tokio::spawn(async move {
270            let _ = server;
271            while let Ok((socket, _)) = listener.accept().await {
272                let local = local_clone.clone();
273                tokio::spawn(async move {
274                    let _ = handle_connection(socket, local).await;
275                });
276            }
277        });
278
279        let hello = PeerMessage::Hello {
280            from: "remote".into(),
281            capabilities: caps("remote"),
282        };
283        let client = TcpMesh::connect(addr, hello).await.unwrap();
284        client
285            .send(&PeerMessage::broadcast("remote", "topic", json!({"v": 1})))
286            .await
287            .unwrap();
288
289        let mut saw_broadcast = false;
290        for _ in 0..6 {
291            let recv = tokio::time::timeout(
292                std::time::Duration::from_millis(400),
293                listener_handle.receiver.recv(),
294            )
295            .await;
296            match recv {
297                Ok(Some(PeerMessage::Broadcast { from, topic, .. })) => {
298                    assert_eq!(from, "remote");
299                    assert_eq!(topic, "topic");
300                    saw_broadcast = true;
301                    break;
302                }
303                Ok(Some(_)) => continue,
304                _ => break,
305            }
306        }
307        assert!(saw_broadcast, "expected a Broadcast to arrive");
308        accept_task.abort();
309    }
310
311    /// End-to-end TLS round-trip using an rcgen self-signed certificate.
312    #[cfg(feature = "tls")]
313    #[tokio::test]
314    async fn tls_round_trip_delivers_broadcast() {
315        use rcgen::generate_simple_self_signed;
316        use rustls::pki_types::ServerName;
317        use std::sync::Arc;
318        use tokio_rustls::{TlsAcceptor, TlsConnector};
319
320        // Install ring as the default crypto provider (idempotent).
321        let _ = rustls::crypto::ring::default_provider().install_default();
322
323        // Generate a self-signed cert for "localhost".
324        let cert = generate_simple_self_signed(vec!["localhost".into()]).unwrap();
325        let cert_pem = cert.cert.pem();
326        let key_pem = cert.key_pair.serialize_pem();
327
328        // Server config.
329        let server_cfg = tls_server_config(cert_pem.as_bytes(), key_pem.as_bytes()).unwrap();
330        let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
331
332        // Client config — trusts the same cert as CA.
333        let client_cfg = tls_client_config(cert_pem.as_bytes()).unwrap();
334        let connector = TlsConnector::from(Arc::new(client_cfg));
335
336        let local = LocalMesh::new();
337        let (mut listener_handle, _h) = local.join("listener", caps("x"), vec![]).await.unwrap();
338
339        // Bind the TLS server manually so we can pick an ephemeral port.
340        let std_listener =
341            std::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).unwrap();
342        std_listener.set_nonblocking(true).unwrap();
343        let tls_listener = TcpListener::from_std(std_listener).unwrap();
344        let addr = tls_listener.local_addr().unwrap();
345
346        let local_clone = local.clone();
347        let accept_task = tokio::spawn(async move {
348            while let Ok((socket, peer)) = tls_listener.accept().await {
349                let acceptor = acceptor.clone();
350                let local = local_clone.clone();
351                tokio::spawn(async move {
352                    match acceptor.accept(socket).await {
353                        Ok(tls) => {
354                            let _ = handle_connection(tls, local).await;
355                        }
356                        Err(e) => tracing::warn!(?peer, ?e, "tls handshake failed in test"),
357                    }
358                });
359            }
360        });
361
362        let hello = PeerMessage::Hello {
363            from: "tls-remote".into(),
364            capabilities: caps("tls-remote"),
365        };
366        let server_name = ServerName::try_from("localhost").unwrap();
367        let client = TcpMesh::connect_tls(addr, hello, server_name, connector)
368            .await
369            .unwrap();
370        client
371            .send(&PeerMessage::broadcast(
372                "tls-remote",
373                "tls-topic",
374                json!({"secure": true}),
375            ))
376            .await
377            .unwrap();
378
379        let mut saw_broadcast = false;
380        for _ in 0..6 {
381            let recv = tokio::time::timeout(
382                std::time::Duration::from_millis(500),
383                listener_handle.receiver.recv(),
384            )
385            .await;
386            match recv {
387                Ok(Some(PeerMessage::Broadcast { from, topic, .. })) => {
388                    assert_eq!(from, "tls-remote");
389                    assert_eq!(topic, "tls-topic");
390                    saw_broadcast = true;
391                    break;
392                }
393                Ok(Some(_)) => continue,
394                _ => break,
395            }
396        }
397        assert!(saw_broadcast, "expected a TLS Broadcast to arrive");
398        accept_task.abort();
399    }
400}