use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::{Mutex, mpsc};
use tonic::transport::{Endpoint, Identity, Server, ServerTlsConfig};
use tonic::{Request, Response, Status, Streaming};
pub mod proto {
tonic::include_proto!("srx.transport.v1");
}
use proto::tunnel_server::{Tunnel, TunnelServer};
pub struct GrpcTransport {
out_tx: Mutex<Option<mpsc::Sender<proto::BytesFrame>>>,
in_rx: Mutex<mpsc::Receiver<Bytes>>,
_jh: Arc<tokio::task::JoinHandle<()>>,
}
impl GrpcTransport {
pub async fn connect(endpoint: Endpoint) -> crate::error::Result<Self> {
let uri = endpoint.uri().to_string();
if !uri.starts_with("https://") {
return Err(crate::error::SrxError::Transport(
crate::error::TransportError::ConnectionFailed(
"gRPC secure connect requires https:// endpoint; use connect_insecure for plaintext"
.into(),
),
));
}
Self::connect_inner(endpoint).await
}
pub async fn connect_insecure(endpoint: Endpoint) -> crate::error::Result<Self> {
let uri = endpoint.uri().to_string();
if !uri.starts_with("http://") {
return Err(crate::error::SrxError::Transport(
crate::error::TransportError::ConnectionFailed(
"connect_insecure expects http:// endpoint".into(),
),
));
}
Self::connect_inner(endpoint).await
}
async fn connect_inner(endpoint: Endpoint) -> crate::error::Result<Self> {
let mut client = proto::tunnel_client::TunnelClient::connect(endpoint)
.await
.map_err(|e| {
crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
e.to_string(),
))
})?;
let (out_tx, out_rx) = mpsc::channel::<proto::BytesFrame>(64);
let (in_tx, in_rx) = mpsc::channel(64);
let out_stream = tokio_stream::wrappers::ReceiverStream::new(out_rx);
let response = client.pipe(Request::new(out_stream)).await.map_err(|e| {
crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
e.to_string(),
))
})?;
let mut inbound: Streaming<proto::BytesFrame> = response.into_inner();
let jh = tokio::spawn(async move {
loop {
match inbound.message().await {
Ok(Some(msg)) => {
if in_tx.send(Bytes::from(msg.payload)).await.is_err() {
break;
}
}
Ok(None) => break,
Err(_) => break,
}
}
});
Ok(Self {
out_tx: Mutex::new(Some(out_tx)),
in_rx: Mutex::new(in_rx),
_jh: Arc::new(jh),
})
}
}
#[async_trait]
impl super::Transport for GrpcTransport {
fn kind(&self) -> super::TransportKind {
super::TransportKind::Grpc
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let g = self.out_tx.lock().await;
let tx = g.as_ref().ok_or(crate::error::SrxError::Transport(
crate::error::TransportError::ChannelClosed,
))?;
tx.send(proto::BytesFrame {
payload: data.to_vec(),
})
.await
.map_err(|_| {
crate::error::SrxError::Transport(crate::error::TransportError::ChannelClosed)
})?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
self.in_rx
.lock()
.await
.recv()
.await
.ok_or(crate::error::SrxError::Transport(
crate::error::TransportError::ChannelClosed,
))
}
async fn is_healthy(&self) -> bool {
self.out_tx.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
self.out_tx.lock().await.take();
Ok(())
}
}
#[derive(Clone, Copy, Default)]
pub struct TunnelEcho;
#[tonic::async_trait]
impl Tunnel for TunnelEcho {
type PipeStream =
Pin<Box<dyn tokio_stream::Stream<Item = Result<proto::BytesFrame, Status>> + Send>>;
async fn pipe(
&self,
request: Request<Streaming<proto::BytesFrame>>,
) -> Result<Response<Self::PipeStream>, Status> {
let mut inbound = request.into_inner();
let s = async_stream::stream! {
while let Ok(Some(m)) = inbound.message().await {
yield Ok(m);
}
};
Ok(Response::new(Box::pin(s)))
}
}
pub fn serve_tunnel_echo(listener: tokio::net::TcpListener) -> tokio::task::JoinHandle<()> {
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
let svc = TunnelServer::new(TunnelEcho);
tokio::spawn(async move {
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming(incoming)
.await;
})
}
pub fn serve_tunnel_echo_tls(
listener: tokio::net::TcpListener,
identity: Identity,
) -> crate::error::Result<tokio::task::JoinHandle<()>> {
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
let svc = TunnelServer::new(TunnelEcho);
let server = Server::builder()
.tls_config(ServerTlsConfig::new().identity(identity))
.map_err(|e| {
crate::error::SrxError::Transport(crate::error::TransportError::ConnectionFailed(
e.to_string(),
))
})?
.add_service(svc);
Ok(tokio::spawn(async move {
let _ = server.serve_with_incoming(incoming).await;
}))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::Transport;
use tonic::transport::{Certificate, ClientTlsConfig};
fn localhost_grpc_tls_identity() -> (Identity, Certificate) {
let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_pem = ck.cert.pem();
let key_pem = ck.signing_key.serialize_pem();
let identity = Identity::from_pem(cert_pem.as_bytes(), key_pem.as_bytes());
let ca = Certificate::from_pem(cert_pem.as_bytes());
(identity, ca)
}
#[tokio::test]
async fn grpc_echo_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let serve = serve_tunnel_echo(listener);
let uri = format!("http://{}", addr);
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
let client = GrpcTransport::connect_insecure(Endpoint::from_shared(uri).unwrap())
.await
.unwrap();
client.send(Bytes::from_static(b"grpc-ping")).await.unwrap();
let got = client.recv().await.unwrap();
assert_eq!(got.as_ref(), b"grpc-ping");
client.close().await.unwrap();
serve.abort();
}
#[tokio::test]
async fn grpc_echo_roundtrip_tls() {
let (identity, ca_cert) = localhost_grpc_tls_identity();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let serve = serve_tunnel_echo_tls(listener, identity).unwrap();
let uri = format!("https://{}", addr);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let endpoint = Endpoint::from_shared(uri)
.unwrap()
.tls_config(
ClientTlsConfig::new()
.ca_certificate(ca_cert)
.domain_name("localhost"),
)
.unwrap();
let client = GrpcTransport::connect(endpoint).await.unwrap();
client
.send(Bytes::from_static(b"grpc-tls-ping"))
.await
.unwrap();
let got = client.recv().await.unwrap();
assert_eq!(got.as_ref(), b"grpc-tls-ping");
client.close().await.unwrap();
serve.abort();
}
#[tokio::test]
async fn grpc_connect_rejects_insecure_endpoint() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let serve = serve_tunnel_echo(listener);
let uri = format!("http://{}", addr);
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
let endpoint = Endpoint::from_shared(uri).unwrap();
let res = GrpcTransport::connect(endpoint).await;
assert!(res.is_err(), "connect() must reject plaintext endpoint");
serve.abort();
}
}