stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! gRPC transport — `Tunnel.Pipe` bidi `BytesFrame` stream (prost + tonic).

use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::{Mutex, mpsc};
use tonic::transport::{Endpoint, Identity, Server, ServerTlsConfig};
use tonic::{Request, Response, Status, Streaming};

pub mod proto {
    tonic::include_proto!("srx.transport.v1");
}

use proto::tunnel_server::{Tunnel, TunnelServer};

/// gRPC client bridge: each [`Transport::send`] is one `BytesFrame`; each [`Transport::recv`] waits for one frame.
pub struct GrpcTransport {
    out_tx: Mutex<Option<mpsc::Sender<proto::BytesFrame>>>,
    in_rx: Mutex<mpsc::Receiver<Bytes>>,
    _jh: Arc<tokio::task::JoinHandle<()>>,
}

impl GrpcTransport {
    /// Connect with TLS-first policy.
    ///
    /// Requires an `https://` endpoint with client TLS configured.
    /// For plaintext development/testing, use [`Self::connect_insecure`].
    pub async fn connect(endpoint: Endpoint) -> crate::error::Result<Self> {
        let uri = endpoint.uri().to_string();
        if !uri.starts_with("https://") {
            return Err(crate::error::SrxError::Transport(
                crate::error::TransportError::ConnectionFailed(
                    "gRPC secure connect requires https:// endpoint; use connect_insecure for plaintext"
                        .into(),
                ),
            ));
        }
        Self::connect_inner(endpoint).await
    }

    /// Explicit insecure/plaintext gRPC connect (`http://`) for local development/testing.
    pub async fn connect_insecure(endpoint: Endpoint) -> crate::error::Result<Self> {
        let uri = endpoint.uri().to_string();
        if !uri.starts_with("http://") {
            return Err(crate::error::SrxError::Transport(
                crate::error::TransportError::ConnectionFailed(
                    "connect_insecure expects http:// endpoint".into(),
                ),
            ));
        }
        Self::connect_inner(endpoint).await
    }

    async fn connect_inner(endpoint: Endpoint) -> crate::error::Result<Self> {
        let mut client = proto::tunnel_client::TunnelClient::connect(endpoint)
            .await
            .map_err(|e| {
                crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
                    e.to_string(),
                ))
            })?;

        let (out_tx, out_rx) = mpsc::channel::<proto::BytesFrame>(64);
        let (in_tx, in_rx) = mpsc::channel(64);

        let out_stream = tokio_stream::wrappers::ReceiverStream::new(out_rx);
        let response = client.pipe(Request::new(out_stream)).await.map_err(|e| {
            crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
                e.to_string(),
            ))
        })?;
        let mut inbound: Streaming<proto::BytesFrame> = response.into_inner();

        let jh = tokio::spawn(async move {
            loop {
                match inbound.message().await {
                    Ok(Some(msg)) => {
                        if in_tx.send(Bytes::from(msg.payload)).await.is_err() {
                            break;
                        }
                    }
                    Ok(None) => break,
                    Err(_) => break,
                }
            }
        });

        Ok(Self {
            out_tx: Mutex::new(Some(out_tx)),
            in_rx: Mutex::new(in_rx),
            _jh: Arc::new(jh),
        })
    }
}

#[async_trait]
impl super::Transport for GrpcTransport {
    fn kind(&self) -> super::TransportKind {
        super::TransportKind::Grpc
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let g = self.out_tx.lock().await;
        let tx = g.as_ref().ok_or(crate::error::SrxError::Transport(
            crate::error::TransportError::ChannelClosed,
        ))?;
        tx.send(proto::BytesFrame {
            payload: data.to_vec(),
        })
        .await
        .map_err(|_| {
            crate::error::SrxError::Transport(crate::error::TransportError::ChannelClosed)
        })?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        self.in_rx
            .lock()
            .await
            .recv()
            .await
            .ok_or(crate::error::SrxError::Transport(
                crate::error::TransportError::ChannelClosed,
            ))
    }

    async fn is_healthy(&self) -> bool {
        self.out_tx.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        self.out_tx.lock().await.take();
        Ok(())
    }
}

/// Echo server: streams each `BytesFrame` back (for tests / local tunnel).
#[derive(Clone, Copy, Default)]
pub struct TunnelEcho;

#[tonic::async_trait]
impl Tunnel for TunnelEcho {
    type PipeStream =
        Pin<Box<dyn tokio_stream::Stream<Item = Result<proto::BytesFrame, Status>> + Send>>;

    async fn pipe(
        &self,
        request: Request<Streaming<proto::BytesFrame>>,
    ) -> Result<Response<Self::PipeStream>, Status> {
        let mut inbound = request.into_inner();
        let s = async_stream::stream! {
            while let Ok(Some(m)) = inbound.message().await {
                yield Ok(m);
            }
        };
        Ok(Response::new(Box::pin(s)))
    }
}

/// Serve [`TunnelEcho`] on an existing TCP listener (plaintext HTTP/2).
pub fn serve_tunnel_echo(listener: tokio::net::TcpListener) -> tokio::task::JoinHandle<()> {
    let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
    let svc = TunnelServer::new(TunnelEcho);
    tokio::spawn(async move {
        let _ = Server::builder()
            .add_service(svc)
            .serve_with_incoming(incoming)
            .await;
    })
}

/// Serve [`TunnelEcho`] with TLS (HTTP/2 over TLS). Use PEM from the same CA the client trusts.
pub fn serve_tunnel_echo_tls(
    listener: tokio::net::TcpListener,
    identity: Identity,
) -> crate::error::Result<tokio::task::JoinHandle<()>> {
    let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
    let svc = TunnelServer::new(TunnelEcho);
    let server = Server::builder()
        .tls_config(ServerTlsConfig::new().identity(identity))
        .map_err(|e| {
            crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
                e.to_string(),
            ))
        })?
        .add_service(svc);
    Ok(tokio::spawn(async move {
        let _ = server.serve_with_incoming(incoming).await;
    }))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::transport::Transport;
    use tonic::transport::{Certificate, ClientTlsConfig};

    fn localhost_grpc_tls_identity() -> (Identity, Certificate) {
        let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
        let cert_pem = ck.cert.pem();
        let key_pem = ck.signing_key.serialize_pem();
        let identity = Identity::from_pem(cert_pem.as_bytes(), key_pem.as_bytes());
        let ca = Certificate::from_pem(cert_pem.as_bytes());
        (identity, ca)
    }

    #[tokio::test]
    async fn grpc_echo_roundtrip() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let serve = serve_tunnel_echo(listener);

        let uri = format!("http://{}", addr);
        tokio::time::sleep(std::time::Duration::from_millis(30)).await;
        let client = GrpcTransport::connect_insecure(Endpoint::from_shared(uri).unwrap())
            .await
            .unwrap();

        client.send(Bytes::from_static(b"grpc-ping")).await.unwrap();
        let got = client.recv().await.unwrap();
        assert_eq!(got.as_ref(), b"grpc-ping");
        client.close().await.unwrap();

        serve.abort();
    }

    #[tokio::test]
    async fn grpc_echo_roundtrip_tls() {
        let (identity, ca_cert) = localhost_grpc_tls_identity();
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let serve = serve_tunnel_echo_tls(listener, identity).unwrap();

        let uri = format!("https://{}", addr);
        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
        let endpoint = Endpoint::from_shared(uri)
            .unwrap()
            .tls_config(
                ClientTlsConfig::new()
                    .ca_certificate(ca_cert)
                    .domain_name("localhost"),
            )
            .unwrap();
        let client = GrpcTransport::connect(endpoint).await.unwrap();

        client
            .send(Bytes::from_static(b"grpc-tls-ping"))
            .await
            .unwrap();
        let got = client.recv().await.unwrap();
        assert_eq!(got.as_ref(), b"grpc-tls-ping");
        client.close().await.unwrap();

        serve.abort();
    }

    #[tokio::test]
    async fn grpc_connect_rejects_insecure_endpoint() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let serve = serve_tunnel_echo(listener);

        let uri = format!("http://{}", addr);
        tokio::time::sleep(std::time::Duration::from_millis(30)).await;
        let endpoint = Endpoint::from_shared(uri).unwrap();
        let res = GrpcTransport::connect(endpoint).await;
        assert!(res.is_err(), "connect() must reject plaintext endpoint");

        serve.abort();
    }
}