use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, Result};
use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use h2::client::SendRequest;
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::timeout;
use tokio_rustls::TlsConnector;
use crate::buf as buf_pool;
struct PooledConn {
send_request: SendRequest<Bytes>,
}
pub struct GrpcPool {
conns: DashMap<String, Arc<Mutex<Option<PooledConn>>>>,
tls_config: Arc<rustls::ClientConfig>,
}
impl GrpcPool {
pub fn new() -> Result<Self> {
let tls_config = build_tls_config()?;
Ok(Self {
conns: DashMap::new(),
tls_config: Arc::new(tls_config),
})
}
fn pool_key(addr: &str, tls_sni: &str) -> String {
format!("{}:{}", tls_sni, addr)
}
pub async fn get_or_create(&self, addr: &str, tls_sni: &str) -> Result<SendRequest<Bytes>> {
let key = Self::pool_key(addr, tls_sni);
let slot = self
.conns
.entry(key.clone())
.or_insert_with(|| Arc::new(Mutex::new(None)))
.clone();
{
let guard = slot.lock().await;
if let Some(conn) = &*guard {
tracing::debug!("reusing cached H2 connection for {}", key);
return Ok(conn.send_request.clone());
}
}
let send_request = connect_h2(addr, tls_sni, self.tls_config.clone()).await?;
if let Ok(mut guard) = slot.try_lock() {
if guard.is_none() {
*guard = Some(PooledConn {
send_request: send_request.clone(),
});
}
}
Ok(send_request)
}
pub fn evict(&self, addr: &str, tls_sni: &str) {
let key = Self::pool_key(addr, tls_sni);
if let Some(slot) = self.conns.get(&key) {
if let Ok(mut g) = slot.try_lock() {
*g = None;
}
}
}
}
async fn connect_h2(
addr: &str,
tls_sni: &str,
tls_config: Arc<rustls::ClientConfig>,
) -> Result<SendRequest<Bytes>> {
tracing::debug!(
"establishing new H2/TLS connection -> {} (sni={})",
addr,
tls_sni
);
let tcp = timeout(CONNECT_TIMEOUT, TcpStream::connect(addr))
.await
.map_err(|_| anyhow!("connect timeout: {}", addr))??;
tcp.set_nodelay(true)?;
let connector = TlsConnector::from(tls_config);
let domain = ServerName::try_from(tls_sni.to_owned())
.map_err(|_| anyhow!("invalid TLS SNI: {}", tls_sni))?;
let tls = timeout(TLS_HANDSHAKE_TIMEOUT, connector.connect(domain, tcp))
.await
.map_err(|_| anyhow!("TLS handshake timeout: {} (sni={})", addr, tls_sni))??;
let (send_request, connection) = timeout(H2_HANDSHAKE_TIMEOUT, h2::client::handshake(tls))
.await
.map_err(|_| anyhow!("H2 handshake timeout: {} (sni={})", addr, tls_sni))??;
tracing::debug!("H2 connection established -> {} (sni={})", addr, tls_sni);
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::debug!("gRPC H2 connection closed: {}", e);
}
});
Ok(send_request)
}
fn build_tls_config() -> Result<rustls::ClientConfig> {
let mut root_store = rustls::RootCertStore::empty();
let cert_result = rustls_native_certs::load_native_certs();
for cert in cert_result.certs {
let _ = root_store.add(cert);
}
if !cert_result.errors.is_empty() {
tracing::warn!(
"some native certs failed to load: {} error(s)",
cert_result.errors.len()
);
}
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec()];
Ok(config)
}
struct PooledFrameOwner {
buf: Option<BytesMut>,
}
impl AsRef<[u8]> for PooledFrameOwner {
fn as_ref(&self) -> &[u8] {
self.buf
.as_ref()
.expect("pooled frame owner must hold a buffer")
.as_ref()
}
}
impl Drop for PooledFrameOwner {
fn drop(&mut self) {
if let Some(mut buf) = self.buf.take() {
buf.clear();
buf_pool::put(buf);
}
}
}
pub(crate) fn varint_size(mut v: u64) -> usize {
let mut n = 1;
while v >= 0x80 {
v >>= 7;
n += 1;
}
n
}
pub(crate) fn write_varint(buf: &mut BytesMut, mut v: u64) {
loop {
if v < 0x80 {
buf.put_u8(v as u8);
break;
}
buf.put_u8((v as u8 & 0x7F) | 0x80);
v >>= 7;
}
}
pub(crate) fn read_varint(bytes: &[u8]) -> Option<(u64, usize)> {
let mut result = 0u64;
let mut shift = 0u32;
for (i, &b) in bytes.iter().enumerate() {
if shift >= 64 {
return None;
}
result |= ((b & 0x7F) as u64) << shift;
shift += 7;
if b < 0x80 {
return Some((result, i + 1));
}
}
None
}
pub(crate) fn encode_grpc_frame(data: &[u8]) -> Bytes {
let inner_len = data.len() as u64;
let var_size = varint_size(inner_len);
let outer_len = 1 + var_size + data.len();
let mut buf = buf_pool::get(5 + outer_len);
buf.put_u8(0);
buf.put_u32(outer_len as u32);
buf.put_u8(0x0A);
write_varint(&mut buf, inner_len);
buf.put_slice(data);
Bytes::from_owner(PooledFrameOwner { buf: Some(buf) })
}
pub(crate) fn decode_gun_payload(payload: &[u8]) -> Option<&[u8]> {
if payload.is_empty() {
return Some(&[]);
}
if payload[0] != 0x0A {
return None;
}
let (inner_len, varint_len) = read_varint(&payload[1..])?;
let inner_len = inner_len as usize;
let data_start = 1 + varint_len;
if payload.len() < data_start + inner_len {
return None;
}
Some(&payload[data_start..data_start + inner_len])
}
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const H2_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const RAW_TO_GRPC_READ_BUF_SIZE: usize = 16 * 1024;
const GRPC_TO_RAW_INIT_BUF_SIZE: usize = 16 * 1024;
const MAX_GRPC_FRAME_SIZE: usize = 16 * 1024 * 1024;
pub(crate) async fn raw_to_grpc(
mut reader: impl AsyncRead + Unpin,
mut send_stream: h2::SendStream<Bytes>,
) -> Result<()> {
use bytes::BufMut;
let mut read_buf = buf_pool::get(RAW_TO_GRPC_READ_BUF_SIZE);
let result = async {
loop {
let mut limited = (&mut read_buf).limit(RAW_TO_GRPC_READ_BUF_SIZE);
let n = reader.read_buf(&mut limited).await?;
if n == 0 {
break;
}
let frame = encode_grpc_frame(&read_buf[..n]);
send_grpc_data(&mut send_stream, frame, false).await?;
read_buf.clear();
}
let _ = send_grpc_data(&mut send_stream, Bytes::new(), true).await;
Ok(())
}
.await;
buf_pool::put(read_buf);
result
}
pub(crate) async fn send_grpc_data(
send_stream: &mut h2::SendStream<Bytes>,
mut data: Bytes,
end_of_stream: bool,
) -> Result<()> {
if data.is_empty() {
send_stream
.send_data(data, end_of_stream)
.map_err(|e| anyhow!("send grpc data: {}", e))?;
return Ok(());
}
while !data.is_empty() {
send_stream.reserve_capacity(data.len());
let capacity = std::future::poll_fn(|cx| send_stream.poll_capacity(cx))
.await
.ok_or_else(|| anyhow!("gRPC send stream closed"))?
.map_err(|e| anyhow!("poll grpc send capacity: {}", e))?;
if capacity == 0 {
continue;
}
let n = capacity.min(data.len());
let chunk = data.split_to(n);
let is_end = end_of_stream && data.is_empty();
send_stream
.send_data(chunk, is_end)
.map_err(|e| anyhow!("send grpc data: {}", e))?;
}
Ok(())
}
pub(crate) struct GrpcFrameReader {
recv_stream: h2::RecvStream,
buf: BytesMut,
}
impl GrpcFrameReader {
pub(crate) fn new(recv_stream: h2::RecvStream) -> Self {
Self {
recv_stream,
buf: buf_pool::get(GRPC_TO_RAW_INIT_BUF_SIZE),
}
}
pub(crate) async fn next_frame(&mut self) -> Result<Option<Bytes>> {
loop {
if self.buf.len() >= 5 {
let outer_len = u32::from_be_bytes(self.buf[1..5].try_into().unwrap()) as usize;
if outer_len > MAX_GRPC_FRAME_SIZE {
return Err(anyhow!("gRPC frame too large: {} bytes", outer_len));
}
if self.buf.len() >= 5 + outer_len {
return Ok(Some(self.buf.split_to(5 + outer_len).freeze()));
}
}
match self.recv_stream.data().await {
Some(Ok(chunk)) => {
let _ = self
.recv_stream
.flow_control()
.release_capacity(chunk.len());
self.buf.extend_from_slice(&chunk);
}
Some(Err(e)) => return Err(anyhow!("recv grpc data: {}", e)),
None => return Ok(None),
}
}
}
}
impl Drop for GrpcFrameReader {
fn drop(&mut self) {
let mut buf = std::mem::take(&mut self.buf);
buf.clear();
buf_pool::put(buf);
}
}
pub(crate) fn decode_grpc_frame_data(frame: &[u8]) -> Option<&[u8]> {
if frame.len() < 5 {
return None;
}
let outer_len = u32::from_be_bytes(frame[1..5].try_into().ok()?) as usize;
if outer_len > MAX_GRPC_FRAME_SIZE || frame.len() < 5 + outer_len {
return None;
}
decode_gun_payload(&frame[5..5 + outer_len])
}
pub(crate) async fn grpc_to_raw(
recv_stream: h2::RecvStream,
mut writer: impl AsyncWrite + Unpin,
) -> Result<()> {
let mut reader = GrpcFrameReader::new(recv_stream);
while let Some(frame) = reader.next_frame().await? {
if let Some(data) = decode_grpc_frame_data(&frame) {
if !data.is_empty() {
writer.write_all(data).await?;
}
}
}
writer.flush().await?;
Ok(())
}
pub(crate) async fn grpc_frames_to_grpc(
mut reader: GrpcFrameReader,
mut send_stream: h2::SendStream<Bytes>,
) -> Result<()> {
while let Some(frame) = reader.next_frame().await? {
send_grpc_data(&mut send_stream, frame, false).await?;
}
let _ = send_grpc_data(&mut send_stream, Bytes::new(), true).await;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_grpc_frame_roundtrips_through_decoder() {
for data in [&b""[..], b"hello", &[0xAB; 256][..]] {
let frame = encode_grpc_frame(data);
assert_eq!(decode_grpc_frame_data(&frame), Some(data));
}
}
#[test]
fn decode_grpc_frame_data_rejects_malformed_frames() {
assert_eq!(decode_grpc_frame_data(&[0, 0, 0, 0]), None);
let mut truncated = encode_grpc_frame(b"hello").to_vec();
truncated.pop();
assert_eq!(decode_grpc_frame_data(&truncated), None);
let mut wrong_tag = encode_grpc_frame(b"hello").to_vec();
wrong_tag[5] = 0x0B;
assert_eq!(decode_grpc_frame_data(&wrong_tag), None);
let mut too_large = vec![0; 5];
too_large[1..5].copy_from_slice(&((MAX_GRPC_FRAME_SIZE as u32) + 1).to_be_bytes());
assert_eq!(decode_grpc_frame_data(&too_large), None);
}
#[test]
fn pool_key_is_stable_and_endpoint_specific() {
assert_eq!(
GrpcPool::pool_key("example.com:443", "example.com"),
GrpcPool::pool_key("example.com:443", "example.com")
);
assert_ne!(
GrpcPool::pool_key("example.com:443", "example.com"),
GrpcPool::pool_key("example.com:443", "alt.example.com")
);
assert_ne!(
GrpcPool::pool_key("example.com:443", "example.com"),
GrpcPool::pool_key("example.com:8443", "example.com")
);
}
}