pub mod pool;
pub mod retry;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use futures::SinkExt;
use crate::codec::{Codec, DefaultCodec};
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{Metadata, RpcRequest, RpcResponse, RpcStatus};
use crate::rpc::middleware::{Interceptor, MiddlewareStack};
use crate::rpc::stream::{rpc_stream_channel, RpcStreamReceiver};
use crate::transport::connection::{
validate_metadata_key, RPC_HEADER_PREFIX, RPC_STATUS_HEADER, RPC_STATUS_MESSAGE_HEADER,
};
use crate::transport::tls::TlsConfig;
use tokio_quiche::buf_factory::BufFactory;
use tokio_quiche::http3::driver::{InboundFrame, NewClientRequest, OutboundFrame};
use self::pool::{ClientConnectionPool, ConnectionGuard, H3Response};
use self::retry::RetryPolicy;
pub struct Client {
pool: Arc<ClientConnectionPool>,
codec: Arc<dyn Codec>,
middleware: Arc<MiddlewareStack>,
retry_policy: RetryPolicy,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub async fn call(&self, method: &str, body: Vec<u8>) -> Result<RpcResponse, KnafehError> {
self.call_with_metadata(method, body, Metadata::new()).await
}
pub async fn call_with_metadata(
&self,
method: &str,
body: Vec<u8>,
metadata: Metadata,
) -> Result<RpcResponse, KnafehError> {
let encoded_body = self.codec.encode(&body)?;
let mut request = RpcRequest {
method: method.to_string(),
metadata,
body: encoded_body,
};
self.middleware.apply_request(&mut request).await?;
validate_metadata(&request.metadata)?;
let call_timeout = request_timeout(&request.metadata)?;
let deadline = call_timeout.map(|timeout| tokio::time::Instant::now() + timeout);
let mut last_error = None;
let max_attempts = self.retry_policy.max_retries + 1;
for attempt in 0..max_attempts {
if attempt > 0 {
let backoff = self.retry_policy.backoff_for(attempt - 1);
tracing::debug!(
method = method,
attempt = attempt,
backoff_ms = backoff.as_millis() as u64,
"retrying RPC call"
);
if let Some(deadline) = deadline {
let now = tokio::time::Instant::now();
if now >= deadline {
return Err(KnafehError::Timeout);
}
tokio::time::sleep_until(std::cmp::min(now + backoff, deadline)).await;
if tokio::time::Instant::now() >= deadline {
return Err(KnafehError::Timeout);
}
} else {
tokio::time::sleep(backoff).await;
}
}
let attempt_result = match deadline {
Some(deadline) => tokio::time::timeout_at(deadline, self.execute_call(&request))
.await
.map_err(|_| KnafehError::Timeout),
None => Ok(self.execute_call(&request).await),
};
match attempt_result {
Ok(Ok(mut response)) => {
if RetryPolicy::is_retryable_status(response.status.code)
&& attempt + 1 < max_attempts
{
tracing::warn!(
method = method,
attempt = attempt,
status = ?response.status.code,
"retryable status"
);
last_error = Some(KnafehError::Service {
code: response.status.code,
message: response.status.message.clone(),
});
continue;
}
self.middleware.apply_response(&mut response).await?;
if response.status.is_ok() {
response.body = self.codec.decode(&response.body)?;
}
return Ok(response);
}
Ok(Err(e)) | Err(e) => {
if RetryPolicy::is_retryable(&e) && attempt + 1 < max_attempts {
tracing::warn!(
method = method,
attempt = attempt,
error = %e,
"retryable error"
);
last_error = Some(e);
continue;
}
return Err(e);
}
}
}
Err(last_error.unwrap_or(KnafehError::Transport("all retries exhausted".to_string())))
}
pub async fn server_stream(
&self,
method: &str,
body: Vec<u8>,
) -> Result<RpcStreamReceiver, KnafehError> {
let encoded_body = self.codec.encode(&body)?;
let mut request = RpcRequest {
method: method.to_string(),
metadata: Metadata::new(),
body: encoded_body,
};
self.middleware.apply_request(&mut request).await?;
validate_metadata(&request.metadata)?;
let conn_handle = self.pool.acquire().await?;
let guard = ConnectionGuard::new(Arc::clone(&self.pool), conn_handle.id);
let request_id = conn_handle
.inner
.next_request_id
.fetch_add(1, Ordering::Relaxed);
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
conn_handle
.inner
.register_pending_stream(request_id, response_tx)
.await;
let mut pending_guard =
PendingRequestGuard::streaming(Arc::clone(&conn_handle.inner), request_id);
let h3_headers = self.build_h3_request_headers(&request, "server_streaming");
let (body_writer, body_rx_opt) = if !request.body.is_empty() {
let (tx, rx) = tokio::sync::oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
if conn_handle
.inner
.request_sender
.send(NewClientRequest {
request_id,
headers: h3_headers,
body_writer,
})
.is_err()
{
return Err(KnafehError::ConnectionClosed);
}
if let Some(body_rx) = body_rx_opt {
let frame_sender_result = body_rx.await;
if frame_sender_result.is_err() {
return Err(KnafehError::ConnectionClosed);
}
let mut frame_sender = frame_sender_result.unwrap();
let buf = BufFactory::buf_from_slice(&request.body);
if frame_sender
.send(OutboundFrame::body(buf, true))
.await
.is_err()
{
return Err(KnafehError::ConnectionClosed);
}
}
let h3_response = response_rx
.await
.map_err(|_| KnafehError::ConnectionClosed)??;
pending_guard.disarm();
let (headers, mut recv, read_fin) = match h3_response {
H3Response::Streaming {
headers,
recv,
read_fin,
} => (headers, recv, read_fin),
_ => return Err(KnafehError::Transport("expected streaming response".into())),
};
let status_code = parse_status_from_headers(&headers)?;
if status_code != RpcStatusCode::Ok {
let msg = parse_status_message_from_headers(&headers);
return Err(KnafehError::Service {
code: status_code,
message: msg,
});
}
let (stream_tx, stream_rx) = rpc_stream_channel(32);
let codec = Arc::clone(&self.codec);
let (pool, conn_id) = guard.detach();
tokio::spawn(async move {
use bytes::BytesMut;
use crate::transport::quic_wire::MAX_MESSAGE_SIZE as MAX_CHUNK_SIZE;
let mut accum = BytesMut::new();
if !read_fin {
while let Some(frame) = recv.recv().await {
if let InboundFrame::Body(data, fin) = frame {
accum.extend_from_slice(&data);
while accum.len() >= 4 {
let len = u32::from_be_bytes([accum[0], accum[1], accum[2], accum[3]])
as usize;
if len > MAX_CHUNK_SIZE {
let _ = stream_tx
.send_error(KnafehError::InvalidMessage(format!(
"stream chunk size {len} exceeds maximum {MAX_CHUNK_SIZE}"
)))
.await;
pool.release(conn_id);
return;
}
if accum.len() < 4 + len {
break; }
let _ = accum.split_to(4); let chunk = accum.split_to(len).to_vec();
match codec.decode(&chunk) {
Ok(decoded) => {
if stream_tx.send(decoded).await.is_err() {
pool.release(conn_id);
return;
}
}
Err(e) => {
let _ = stream_tx.send_error(e).await;
pool.release(conn_id);
return;
}
}
}
if fin {
if !accum.is_empty() {
let _ = stream_tx
.send_error(KnafehError::InvalidMessage(format!(
"stream ended with {} incomplete bytes",
accum.len()
)))
.await;
}
break;
}
}
}
}
drop(stream_tx);
pool.release(conn_id);
});
Ok(stream_rx)
}
fn build_h3_request_headers(
&self,
request: &RpcRequest,
method_kind: &str,
) -> Vec<quiche::h3::Header> {
let mut headers = vec![
quiche::h3::Header::new(b":method", b"POST"),
quiche::h3::Header::new(b":scheme", b"https"),
quiche::h3::Header::new(b":authority", self.pool.hostname().as_bytes()),
quiche::h3::Header::new(b":path", format!("/{}", request.method).as_bytes()),
quiche::h3::Header::new(b"content-type", self.codec.content_type().as_bytes()),
quiche::h3::Header::new(b"x-rpc-method-kind", method_kind.as_bytes()),
];
for (key, value) in &request.metadata {
let header_key = if key.starts_with("x-rpc-") {
key.clone()
} else {
format!("x-rpc-{key}")
};
headers.push(quiche::h3::Header::new(
header_key.as_bytes(),
value.as_bytes(),
));
}
headers
}
async fn execute_call(&self, request: &RpcRequest) -> Result<RpcResponse, KnafehError> {
let conn_handle = self.pool.acquire().await?;
let _guard = ConnectionGuard::new(Arc::clone(&self.pool), conn_handle.id);
let request_id = conn_handle
.inner
.next_request_id
.fetch_add(1, Ordering::Relaxed);
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
conn_handle
.inner
.register_pending(request_id, response_tx)
.await;
let mut pending_guard =
PendingRequestGuard::unary(Arc::clone(&conn_handle.inner), request_id);
let h3_headers = self.build_h3_request_headers(request, "unary");
let (body_writer, body_rx_opt) = if !request.body.is_empty() {
let (tx, rx) = tokio::sync::oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
if conn_handle
.inner
.request_sender
.send(NewClientRequest {
request_id,
headers: h3_headers,
body_writer,
})
.is_err()
{
return Err(KnafehError::ConnectionClosed);
}
if let Some(body_rx) = body_rx_opt {
let frame_sender_result = body_rx.await;
if frame_sender_result.is_err() {
return Err(KnafehError::ConnectionClosed);
}
let mut frame_sender = frame_sender_result.unwrap();
let buf = BufFactory::buf_from_slice(&request.body);
if frame_sender
.send(OutboundFrame::body(buf, true))
.await
.is_err()
{
return Err(KnafehError::ConnectionClosed);
}
}
let h3_response = response_rx
.await
.map_err(|_| KnafehError::ConnectionClosed)??;
pending_guard.disarm();
match h3_response {
H3Response::Complete { headers, body } => parse_h3_response(headers, body),
_ => Err(KnafehError::Transport("expected unary response".into())),
}
}
pub fn pool(&self) -> &ClientConnectionPool {
&self.pool
}
}
enum PendingRequestKind {
Unary,
Streaming,
}
struct PendingRequestGuard {
inner: Arc<pool::ConnectionInner>,
request_id: u64,
kind: PendingRequestKind,
armed: bool,
}
impl PendingRequestGuard {
fn unary(inner: Arc<pool::ConnectionInner>, request_id: u64) -> Self {
Self {
inner,
request_id,
kind: PendingRequestKind::Unary,
armed: true,
}
}
fn streaming(inner: Arc<pool::ConnectionInner>, request_id: u64) -> Self {
Self {
inner,
request_id,
kind: PendingRequestKind::Streaming,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for PendingRequestGuard {
fn drop(&mut self) {
if !self.armed {
return;
}
let inner = Arc::clone(&self.inner);
let request_id = self.request_id;
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::debug!(
request_id,
"pending request cleanup skipped outside Tokio runtime"
);
return;
};
match self.kind {
PendingRequestKind::Unary => {
handle.spawn(async move {
inner.remove_pending(request_id).await;
});
}
PendingRequestKind::Streaming => {
handle.spawn(async move {
inner.remove_pending_stream(request_id).await;
});
}
}
}
}
fn request_timeout(metadata: &Metadata) -> Result<Option<Duration>, KnafehError> {
let Some(value) = metadata.get("x-rpc-timeout-ms") else {
return Ok(None);
};
let millis = value.parse::<u64>().map_err(|e| {
KnafehError::InvalidMessage(format!("invalid x-rpc-timeout-ms metadata: {e}"))
})?;
Ok(Some(Duration::from_millis(millis)))
}
fn parse_h3_response(
headers: Vec<quiche::h3::Header>,
body: Vec<u8>,
) -> Result<RpcResponse, KnafehError> {
let status_code = parse_status_from_headers(&headers)?;
let status_message = parse_status_message_from_headers(&headers);
let metadata = parse_metadata_from_headers(&headers);
Ok(RpcResponse {
status: RpcStatus {
code: status_code,
message: status_message,
},
metadata,
body,
})
}
fn parse_status_from_headers(headers: &[quiche::h3::Header]) -> Result<RpcStatusCode, KnafehError> {
use quiche::h3::NameValue;
let status_bytes = RPC_STATUS_HEADER.as_bytes();
for h in headers {
if h.name() == status_bytes {
let val = h.value();
let s = std::str::from_utf8(val).map_err(|e| {
KnafehError::InvalidMessage(format!("invalid RPC status header UTF-8: {e}"))
})?;
let code = s.parse::<u8>().map_err(|e| {
KnafehError::InvalidMessage(format!("invalid RPC status header value: {e}"))
})?;
return Ok(RpcStatusCode::from_u8(code));
}
}
Err(KnafehError::InvalidMessage(format!(
"missing required {RPC_STATUS_HEADER} header"
)))
}
fn parse_status_message_from_headers(headers: &[quiche::h3::Header]) -> String {
use quiche::h3::NameValue;
let msg_bytes = RPC_STATUS_MESSAGE_HEADER.as_bytes();
for h in headers {
if h.name() == msg_bytes {
return String::from_utf8_lossy(h.value()).into_owned();
}
}
String::new()
}
fn parse_metadata_from_headers(headers: &[quiche::h3::Header]) -> Metadata {
use quiche::h3::NameValue;
let prefix = RPC_HEADER_PREFIX.as_bytes();
let status_bytes = RPC_STATUS_HEADER.as_bytes();
let msg_bytes = RPC_STATUS_MESSAGE_HEADER.as_bytes();
let mut metadata = Metadata::new();
for h in headers {
let name = h.name();
if name == status_bytes || name == msg_bytes {
continue;
}
if name.starts_with(prefix) {
let key = String::from_utf8_lossy(&name[prefix.len()..]).into_owned();
let value = String::from_utf8_lossy(h.value()).into_owned();
metadata.insert(key, value);
}
}
metadata
}
fn validate_metadata(metadata: &Metadata) -> Result<(), KnafehError> {
for key in metadata.keys() {
validate_metadata_key(key)?;
}
Ok(())
}
pub struct ClientBuilder {
endpoint: Option<String>,
tls_config: Option<TlsConfig>,
codec: Option<Arc<dyn Codec>>,
middleware: MiddlewareStack,
pool_size: usize,
retry_policy: RetryPolicy,
}
impl ClientBuilder {
fn new() -> Self {
Self {
endpoint: None,
tls_config: None,
codec: None,
middleware: MiddlewareStack::new(),
pool_size: 4,
retry_policy: RetryPolicy::none(),
}
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub fn codec(mut self, codec: impl Codec) -> Self {
self.codec = Some(Arc::new(codec));
self
}
pub fn add_interceptor(mut self, interceptor: impl Interceptor) -> Self {
self.middleware.add(Arc::new(interceptor));
self
}
pub fn pool_size(mut self, size: usize) -> Self {
self.pool_size = size;
self
}
pub fn retry(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
pub async fn build(self) -> Result<Client, KnafehError> {
let endpoint = self
.endpoint
.ok_or_else(|| KnafehError::Transport("endpoint is required".to_string()))?;
let tls_config = self.tls_config.unwrap_or_else(TlsConfig::client_insecure);
let codec = self.codec.unwrap_or_else(|| Arc::new(DefaultCodec::new()));
let pool = Arc::new(ClientConnectionPool::new(
endpoint,
self.pool_size,
tls_config,
));
Ok(Client {
pool,
codec,
middleware: Arc::new(self.middleware),
retry_policy: self.retry_policy,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_timeout_only_uses_reserved_timeout_header() {
let mut metadata = Metadata::new();
metadata.insert("timeout-ms".to_string(), "5".to_string());
assert_eq!(request_timeout(&metadata).unwrap(), None);
metadata.insert("x-rpc-timeout-ms".to_string(), "7".to_string());
assert_eq!(
request_timeout(&metadata).unwrap(),
Some(Duration::from_millis(7))
);
}
}