use std::sync::Arc;
use anyhow::{anyhow, Result};
use bytes::Bytes;
use h2::client::ResponseFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::relay::outbound::{InboundStream, Outbound, OutboundContext, OutboundFuture};
use crate::relay::transport::grpc::{
encode_grpc_frame, grpc_to_raw, raw_to_grpc, send_grpc_data, GrpcPool,
};
use crate::vmess::validator::Upstream;
pub struct GrpcOutbound;
impl Outbound for GrpcOutbound {
fn relay(
self: Box<Self>,
inbound: Box<dyn InboundStream>,
ctx: OutboundContext,
) -> OutboundFuture {
Box::pin(async move {
relay_grpc(
inbound,
ctx.upstream,
ctx.runtime.grpc_pool.clone(),
ctx.auth_id,
ctx.peer,
)
.await
})
}
}
async fn relay_grpc(
inbound: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
upstream: Arc<Upstream>,
pool: Arc<GrpcPool>,
auth_id: [u8; 16],
peer: std::net::SocketAddr,
) -> Result<()> {
let GrpcTunnel {
service_name,
tls_sni,
response_future,
mut send_stream,
} = open_grpc_tunnel(upstream.clone(), pool.clone()).await?;
let frame = encode_grpc_frame(&auth_id);
send_grpc_data(&mut send_stream, frame, false).await?;
let (inbound_reader, inbound_writer) = tokio::io::split(inbound);
let upstream_addr = upstream.addr.clone();
let tls_sni2 = tls_sni.clone();
let pool2 = pool.clone();
let t1 = tokio::spawn(async move {
let result = raw_to_grpc(inbound_reader, send_stream).await;
if result.is_err() {
pool2.evict(&upstream_addr, &tls_sni2);
}
result
});
let response = match response_future.await {
Ok(response) => response,
Err(e) => {
pool.evict(&upstream.addr, &tls_sni);
t1.abort();
let _ = t1.await;
return Err(anyhow!("response headers: {}", e));
}
};
tracing::info!(
"{} → {} [grpc/{} sni={}] relaying",
peer,
upstream.addr,
service_name,
tls_sni,
);
let recv_stream = response.into_body();
let t2 = tokio::spawn(async move { grpc_to_raw(recv_stream, inbound_writer).await });
let started = std::time::Instant::now();
crate::relay::transport::grpc::relay_until_one_side_finishes("grpc relay", t1, t2).await;
tracing::info!(
"{} → {} [grpc/{} sni={}] closed ({:.2}s)",
peer,
upstream.addr,
service_name,
tls_sni,
started.elapsed().as_secs_f64(),
);
Ok(())
}
pub(crate) struct GrpcTunnel {
pub(crate) service_name: String,
pub(crate) tls_sni: String,
pub(crate) response_future: ResponseFuture,
pub(crate) send_stream: h2::SendStream<Bytes>,
}
pub(crate) async fn open_grpc_tunnel(
upstream: Arc<Upstream>,
pool: Arc<GrpcPool>,
) -> Result<GrpcTunnel> {
use crate::vmess::validator::Transport;
let (service_name, tls_sni, request_uri) = match &upstream.transport {
Transport::Grpc {
service_name,
tls_sni,
request_uri,
} => (service_name.clone(), tls_sni.clone(), request_uri.clone()),
_ => return Err(anyhow!("open_grpc_tunnel called on non-gRPC upstream")),
};
let mut send_request = pool.get_or_create(&upstream.addr, &tls_sni).await?;
let request = http::Request::builder()
.method("POST")
.uri(request_uri)
.header("content-type", "application/grpc")
.header("user-agent", "grpc-go/1.48.0")
.header("te", "trailers")
.body(())
.map_err(|e| anyhow!("build request: {}", e))?;
if std::future::poll_fn(|cx| send_request.poll_ready(cx))
.await
.is_err()
{
tracing::debug!(
"cached H2 connection dead for {} -- reconnecting",
upstream.addr
);
pool.evict(&upstream.addr, &tls_sni);
send_request = pool.get_or_create(&upstream.addr, &tls_sni).await?;
std::future::poll_fn(|cx| send_request.poll_ready(cx))
.await
.map_err(|e| anyhow!("h2 not ready after reconnect: {}", e))?;
}
let (response_future, send_stream) = send_request
.send_request(request, false)
.map_err(|e| anyhow!("send_request: {}", e))?;
Ok(GrpcTunnel {
service_name,
tls_sni,
response_future,
send_stream,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::relay::transport::grpc::{
decode_gun_payload, read_varint, varint_size, write_varint,
};
use bytes::{Buf, BufMut, BytesMut};
#[test]
fn test_varint_roundtrip() {
for v in [
0u64,
1,
127,
128,
255,
300,
16383,
16384,
65535,
65536,
1 << 21,
u32::MAX as u64,
] {
let mut buf = BytesMut::new();
write_varint(&mut buf, v);
let expected_size = varint_size(v);
assert_eq!(buf.len(), expected_size, "varint_size mismatch for {}", v);
let (decoded, consumed) = read_varint(&buf).unwrap();
assert_eq!(decoded, v, "roundtrip failed for {}", v);
assert_eq!(consumed, expected_size);
}
}
#[test]
fn test_varint_size() {
assert_eq!(varint_size(0), 1);
assert_eq!(varint_size(127), 1);
assert_eq!(varint_size(128), 2);
assert_eq!(varint_size(16383), 2);
assert_eq!(varint_size(16384), 3);
}
#[test]
fn test_encode_grpc_frame_empty() {
let frame = encode_grpc_frame(&[]);
assert_eq!(frame[0], 0); let outer_len = u32::from_be_bytes(frame[1..5].try_into().unwrap());
assert_eq!(outer_len, 2); assert_eq!(frame[5], 0x0A); assert_eq!(frame[6], 0x00); assert_eq!(frame.len(), 7);
}
#[test]
fn test_encode_grpc_frame_data() {
let data = b"hello world";
let frame = encode_grpc_frame(data);
assert_eq!(frame[0], 0);
let outer_len = u32::from_be_bytes(frame[1..5].try_into().unwrap()) as usize;
assert_eq!(outer_len, 1 + 1 + data.len()); assert_eq!(frame[5], 0x0A);
assert_eq!(frame[6], data.len() as u8); assert_eq!(&frame[7..], data);
}
#[test]
fn test_encode_grpc_frame_large() {
let data = vec![0xABu8; 65536];
let frame = encode_grpc_frame(&data);
assert_eq!(frame[0], 0);
let outer_len = u32::from_be_bytes(frame[1..5].try_into().unwrap()) as usize;
assert_eq!(outer_len, 1 + varint_size(65536) + 65536);
assert_eq!(frame[5], 0x0A);
assert_eq!(frame.len(), 5 + outer_len);
}
#[test]
fn test_encode_decode_roundtrip() {
for size in [0usize, 1, 127, 128, 255, 256, 1000, 16384] {
let data: Vec<u8> = (0..size).map(|i| i as u8).collect();
let frame = encode_grpc_frame(&data);
assert_eq!(frame[0], 0);
let outer_len = u32::from_be_bytes(frame[1..5].try_into().unwrap()) as usize;
assert_eq!(frame.len(), 5 + outer_len);
let proto = &frame[5..5 + outer_len];
assert_eq!(proto[0], 0x0A);
let (inner_len, varint_len) = read_varint(&proto[1..]).unwrap();
assert_eq!(inner_len as usize, size);
let decoded = &proto[1 + varint_len..];
assert_eq!(decoded, data.as_slice());
}
}
#[test]
fn test_decode_gun_payload() {
let data = b"test data";
let mut payload = BytesMut::new();
payload.put_u8(0x0A);
write_varint(&mut payload, data.len() as u64);
payload.put_slice(data);
let decoded = decode_gun_payload(&payload).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_decode_gun_payload_empty() {
assert_eq!(decode_gun_payload(&[]).unwrap(), &[] as &[u8]);
}
#[test]
fn test_decode_gun_payload_wrong_tag() {
assert!(decode_gun_payload(&[0x0B, 0x05, 0, 0, 0, 0, 0]).is_none());
}
#[tokio::test]
async fn test_grpc_frame_decode_from_buffer() {
let payload1 = b"first message";
let payload2 = b"second message";
let mut combined = BytesMut::new();
combined.extend_from_slice(&encode_grpc_frame(payload1));
combined.extend_from_slice(&encode_grpc_frame(payload2));
let mut out = Vec::new();
let mut buf = combined;
loop {
if buf.len() < 5 {
break;
}
let outer_len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize;
if buf.len() < 5 + outer_len {
break;
}
let proto = &buf[5..5 + outer_len];
let write_range = if !proto.is_empty() && proto[0] == 0x0A {
read_varint(&proto[1..]).and_then(|(inner_len, varint_len)| {
let data_start = 5 + 1 + varint_len;
let data_end = data_start + inner_len as usize;
if data_end <= 5 + outer_len {
Some(data_start..data_end)
} else {
None
}
})
} else {
None
};
if let Some(range) = write_range {
out.extend_from_slice(&buf[range]);
}
buf.advance(5 + outer_len);
}
assert_eq!(out, b"first messagesecond message");
}
#[tokio::test]
async fn test_grpc_frame_decode_fragmented() {
let payload = b"full message here";
let frame = encode_grpc_frame(payload);
let mut buf = BytesMut::new();
buf.extend_from_slice(&frame[..3]); assert!(buf.len() < 5);
buf.extend_from_slice(&frame[3..]);
let outer_len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize;
let proto = &buf[5..5 + outer_len];
assert_eq!(proto[0], 0x0A);
let (inner_len, varint_len) = read_varint(&proto[1..]).unwrap();
let decoded = &proto[1 + varint_len..1 + varint_len + inner_len as usize];
assert_eq!(decoded, payload);
}
fn decode_all_frames(mut buf: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
while buf.len() >= 5 {
let outer_len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize;
if buf.len() < 5 + outer_len {
break;
}
if let Some(data) = decode_gun_payload(&buf[5..5 + outer_len]) {
out.extend_from_slice(data);
}
buf = &buf[5 + outer_len..];
}
out
}
#[tokio::test]
async fn test_gun_lite_compat_full_relay() {
use std::sync::{Arc, Mutex};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
let (client_io, server_io) = duplex(256 * 1024);
let (mut send_request, conn) = h2::client::handshake(client_io).await.unwrap();
tokio::spawn(async move {
let _ = conn.await;
});
let server_conn: h2::server::Connection<_, Bytes> =
h2::server::handshake(server_io).await.unwrap();
let (result_tx, result_rx) = tokio::sync::oneshot::channel::<Vec<u8>>();
let result_tx = Arc::new(Mutex::new(Some(result_tx)));
tokio::spawn(async move {
let mut conn = server_conn;
while let Some(result) = conn.accept().await {
let (req, mut respond) = result.unwrap();
let tx = result_tx.clone();
tokio::spawn(async move {
let mut body = req.into_body();
let mut raw = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
let _ = body.flow_control().release_capacity(chunk.len());
raw.extend_from_slice(&chunk);
}
let decoded = decode_all_frames(&raw);
let resp = http::Response::builder()
.status(200)
.header("content-type", "application/grpc")
.body(())
.unwrap();
let mut send = respond.send_response(resp, false).unwrap();
send.send_data(encode_grpc_frame(&decoded), true).unwrap();
if let Some(tx) = tx.lock().unwrap().take() {
let _ = tx.send(decoded);
}
});
}
});
std::future::poll_fn(|cx| send_request.poll_ready(cx))
.await
.unwrap();
let req = http::Request::builder()
.method("POST")
.uri("/TestService/Tun")
.header("content-type", "application/grpc")
.header("user-agent", "grpc-go/1.48.0")
.header("te", "trailers")
.body(())
.unwrap();
let (resp_fut, send_stream) = send_request.send_request(req, false).unwrap();
let (inbound_r, mut inbound_w) = duplex(64 * 1024);
let t_send = tokio::spawn(async move { raw_to_grpc(inbound_r, send_stream).await });
let payload = b"hello from the gun-lite relay compatibility test";
inbound_w.write_all(payload).await.unwrap();
drop(inbound_w);
let response = resp_fut.await.unwrap();
assert_eq!(response.status(), 200);
let (mut out_r, out_w) = duplex(64 * 1024);
let t_recv = tokio::spawn(async move { grpc_to_raw(response.into_body(), out_w).await });
let mut client_got = Vec::new();
out_r.read_to_end(&mut client_got).await.unwrap();
t_send.await.unwrap().unwrap();
t_recv.await.unwrap().unwrap();
let server_got = result_rx.await.unwrap();
assert_eq!(
server_got, payload,
"server could not decode client gun-lite frames"
);
assert_eq!(
client_got, payload,
"client could not decode server gun-lite response"
);
}
#[derive(Clone, PartialEq, ::prost::Message)]
struct HunkMsg {
#[prost(bytes = "bytes", tag = "1")]
data: Bytes,
}
#[test]
fn test_prost_compat_encode() {
use prost::Message as _;
let payload = b"gun-lite encodes as standard gRPC protobuf";
let frame = encode_grpc_frame(payload);
let outer_len = u32::from_be_bytes(frame[1..5].try_into().unwrap()) as usize;
let hunk = HunkMsg::decode(&frame[5..5 + outer_len])
.expect("prost must decode gun-lite protobuf payload");
assert_eq!(hunk.data.as_ref(), payload as &[u8]);
}
#[test]
fn test_prost_compat_decode() {
use prost::Message as _;
let payload = b"prost-encoded HunkMsg is gun-lite compatible";
let hunk = HunkMsg {
data: Bytes::copy_from_slice(payload),
};
let mut proto = BytesMut::new();
hunk.encode(&mut proto).unwrap();
let decoded = decode_gun_payload(&proto)
.expect("decode_gun_payload must handle prost-encoded HunkMsg");
assert_eq!(decoded, payload as &[u8]);
}
#[tokio::test]
async fn test_tonic_server_compat() {
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::codec::ProstCodec;
use tonic::server::Grpc;
use tonic::{Request, Response, Status, Streaming};
let (result_tx, result_rx) = oneshot::channel::<Vec<u8>>();
let result_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(result_tx)));
#[derive(Clone)]
struct EchoSvc {
tx: std::sync::Arc<std::sync::Mutex<Option<oneshot::Sender<Vec<u8>>>>>,
}
impl tonic::server::NamedService for EchoSvc {
const NAME: &'static str = "TestService";
}
impl tower::Service<http::Request<tonic::body::BoxBody>> for EchoSvc {
type Response = http::Response<tonic::body::BoxBody>;
type Error = std::convert::Infallible;
type Future = Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<tonic::body::BoxBody>) -> Self::Future {
let tx = self.tx.clone();
Box::pin(async move {
let codec = ProstCodec::<HunkMsg, HunkMsg>::default();
let mut grpc = Grpc::new(codec);
let handler = tower::service_fn(move |req: Request<Streaming<HunkMsg>>| {
let tx = tx.clone();
async move {
let mut stream = req.into_inner();
let mut all: Vec<u8> = Vec::new();
while let Some(h) = stream.message().await? {
all.extend_from_slice(&h.data);
}
if let Some(s) = tx.lock().unwrap().take() {
let _ = s.send(all.clone());
}
let echo = HunkMsg {
data: Bytes::copy_from_slice(&all),
};
let out = tokio_stream::iter(vec![Ok::<_, Status>(echo)]);
Ok::<_, Status>(Response::new(out))
}
});
Ok(grpc.streaming(handler, req).await)
})
}
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let svc = EchoSvc { tx: result_tx };
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(svc)
.serve_with_incoming(TcpListenerStream::new(listener))
.await
.expect("tonic server error");
});
tokio::task::yield_now().await;
let tcp = tokio::net::TcpStream::connect(server_addr).await.unwrap();
tcp.set_nodelay(true).unwrap();
let (mut send_req, h2_conn) = h2::client::handshake(tcp).await.unwrap();
tokio::spawn(async move {
let _ = h2_conn.await;
});
std::future::poll_fn(|cx| send_req.poll_ready(cx))
.await
.unwrap();
let req = http::Request::builder()
.method("POST")
.uri("/TestService/Tun")
.header("content-type", "application/grpc")
.header("user-agent", "grpc-go/1.48.0")
.header("te", "trailers")
.body(())
.unwrap();
let (resp_fut, send_stream) = send_req.send_request(req, false).unwrap();
let (inbound_r, mut inbound_w) = tokio::io::duplex(64 * 1024);
let t_send = tokio::spawn(async move { raw_to_grpc(inbound_r, send_stream).await });
let payload = b"hello from gun-lite h2 client to tonic gRPC server";
inbound_w.write_all(payload).await.unwrap();
drop(inbound_w);
let response = resp_fut.await.unwrap();
assert_eq!(response.status(), 200);
let (mut out_r, out_w) = tokio::io::duplex(64 * 1024);
let t_recv = tokio::spawn(async move { grpc_to_raw(response.into_body(), out_w).await });
let mut client_got = Vec::new();
out_r.read_to_end(&mut client_got).await.unwrap();
t_send.await.unwrap().unwrap();
t_recv.await.unwrap().unwrap();
let server_got = result_rx
.await
.expect("tonic server must report decoded payload");
assert_eq!(
server_got, payload as &[u8],
"tonic decoded our gun-lite frames correctly"
);
assert_eq!(
client_got, payload as &[u8],
"grpc_to_raw decoded tonic's gRPC response correctly"
);
}
}