use parking_lot::{Mutex, RwLock};
use std::collections::{BTreeMap, HashMap};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::bytes::Bytes;
use crate::cx::{Cx, cap};
use super::client::CompressionEncoding;
use super::codec::{Codec, FramedCodec};
use super::reflection::ReflectionService;
use super::service::{NamedService, ServiceHandler};
use super::status::{GrpcError, Status, TransportErrorKind};
use super::streaming::{Metadata, Request, Response};
fn wall_clock_instant_now() -> Instant {
Instant::now()
}
#[derive(Debug, Clone)]
struct StreamState {
last_activity: Instant,
registered_at: Instant,
}
#[derive(Debug)]
pub struct ConnectionState {
active_streams: Mutex<HashMap<u32, StreamState>>,
}
impl ConnectionState {
pub fn new() -> Self {
Self {
active_streams: Mutex::new(HashMap::new()),
}
}
pub fn add_stream(&self, stream_id: u32, max_concurrent: u32) -> Result<Instant, String> {
let mut active_streams = self.active_streams.lock();
if active_streams.len() >= max_concurrent as usize {
return Err(format!(
"connection exceeds max_concurrent_streams: {} >= {}",
active_streams.len(),
max_concurrent
));
}
let now = wall_clock_instant_now();
active_streams.insert(
stream_id,
StreamState {
last_activity: now,
registered_at: now,
},
);
Ok(now)
}
pub fn update_stream_activity(&self, stream_id: u32) {
let mut active_streams = self.active_streams.lock();
if let Some(stream) = active_streams.get_mut(&stream_id) {
stream.last_activity = wall_clock_instant_now();
}
}
pub fn remove_stream(&self, stream_id: u32) {
let mut active_streams = self.active_streams.lock();
active_streams.remove(&stream_id);
}
pub fn cleanup_idle_streams(&self, idle_timeout: Duration) -> Vec<u32> {
let now = wall_clock_instant_now();
let mut removed = Vec::new();
let mut active_streams = self.active_streams.lock();
active_streams.retain(|&stream_id, stream| {
let idle_duration = now.duration_since(stream.last_activity);
if idle_duration > idle_timeout {
removed.push(stream_id);
false
} else {
true
}
});
removed
}
pub fn active_stream_count(&self) -> usize {
let active_streams = self.active_streams.lock();
active_streams.len()
}
pub fn remove_stream_if_owned(&self, stream_id: u32, registered_at: Instant) {
let mut active_streams = self.active_streams.lock();
if let Some(stream_state) = active_streams.get(&stream_id) {
if stream_state.registered_at == registered_at {
active_streams.remove(&stream_id);
}
}
}
}
#[derive(Debug)]
pub struct ConnectionRegistry {
connections: RwLock<HashMap<String, ConnectionState>>,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self {
connections: RwLock::new(HashMap::new()),
}
}
pub fn add_connection(&self, connection_id: String) {
let mut connections = self.connections.write();
connections.insert(connection_id, ConnectionState::new());
}
pub fn remove_connection(&self, connection_id: &str) {
let mut connections = self.connections.write();
connections.remove(connection_id);
}
pub fn enforce_stream_limits(
&self,
connection_id: &str,
stream_id: u32,
max_concurrent: u32,
idle_timeout: Option<Duration>,
) -> Result<Instant, String> {
let connections = self.connections.read();
let connection = connections
.get(connection_id)
.ok_or_else(|| format!("connection not registered: {}", connection_id))?;
if let Some(timeout) = idle_timeout {
let removed_streams = connection.cleanup_idle_streams(timeout);
if !removed_streams.is_empty() {
eprintln!(
"Cleaned up {} idle streams on connection {}: {:?}",
removed_streams.len(),
connection_id,
removed_streams
);
}
}
connection.add_stream(stream_id, max_concurrent)
}
pub fn update_stream_activity(&self, connection_id: &str, stream_id: u32) {
let connections = self.connections.read();
if let Some(connection) = connections.get(connection_id) {
connection.update_stream_activity(stream_id);
}
}
pub fn remove_stream(&self, connection_id: &str, stream_id: u32) {
let connections = self.connections.read();
if let Some(connection) = connections.get(connection_id) {
connection.remove_stream(stream_id);
}
}
pub fn remove_stream_if_owned(
&self,
connection_id: &str,
stream_id: u32,
registered_at: Instant,
) {
let connections = self.connections.read();
if let Some(connection) = connections.get(connection_id) {
connection.remove_stream_if_owned(stream_id, registered_at);
}
}
pub fn get_stats(&self) -> (usize, usize) {
let connections = self.connections.read();
let connection_count = connections.len();
let total_streams: usize = connections
.values()
.map(|conn| conn.active_stream_count())
.sum();
(connection_count, total_streams)
}
}
struct StreamRegistrationGuard {
registry: Arc<ConnectionRegistry>,
connection_id: String,
stream_id: u32,
registered_at: Instant,
}
impl Drop for StreamRegistrationGuard {
fn drop(&mut self) {
self.registry.remove_stream_if_owned(
&self.connection_id,
self.stream_id,
self.registered_at,
);
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub max_recv_message_size: usize,
pub max_send_message_size: usize,
pub max_request_body_bytes: Option<usize>,
pub initial_connection_window_size: u32,
pub initial_stream_window_size: u32,
pub max_concurrent_streams: u32,
pub keepalive_interval_ms: Option<u64>,
pub keepalive_timeout_ms: Option<u64>,
pub default_timeout: Option<Duration>,
pub max_request_deadline: Option<Duration>,
pub send_compression: Option<CompressionEncoding>,
pub accept_compression: Vec<CompressionEncoding>,
pub max_metadata_size: usize,
pub stream_idle_timeout: Option<Duration>,
}
pub const DEFAULT_MAX_METADATA_SIZE: usize = 8 * 1024;
#[must_use]
pub fn metadata_byte_size(metadata: &super::streaming::Metadata) -> usize {
let mut total = 0usize;
for (key, value) in metadata.iter() {
let value_len = match value {
super::streaming::MetadataValue::Ascii(s) => s.len(),
super::streaming::MetadataValue::Binary(b) => b.len(),
};
total = total.saturating_add(key.len()).saturating_add(value_len);
}
total
}
fn metadata_key_uses_grpc_prefix(key: &str) -> bool {
key.get(..5)
.is_some_and(|prefix| prefix.eq_ignore_ascii_case("grpc-"))
}
fn grpc_request_header_is_allowed(key: &str) -> bool {
key.eq_ignore_ascii_case("grpc-timeout")
|| key.eq_ignore_ascii_case("grpc-encoding")
|| key.eq_ignore_ascii_case("grpc-accept-encoding")
|| key.eq_ignore_ascii_case("grpc-message-type")
}
fn matches_media_type_prefix(value: &str, prefix: &str) -> bool {
value.starts_with(prefix)
&& matches!(value.as_bytes().get(prefix.len()), None | Some(b'+' | b';'))
}
fn grpc_content_type_is_allowed(value: &str) -> bool {
matches_media_type_prefix(value.trim(), "application/grpc")
}
fn grpc_te_header_is_allowed(value: &str) -> bool {
value.trim().eq_ignore_ascii_case("trailers")
}
fn is_valid_header_name_rfc7230(name: &str) -> bool {
if name.is_empty() {
return false;
}
for byte in name.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' => {}
b'0'..=b'9' => {}
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_'
| b'`' | b'|' | b'~' => {}
_ => return false,
}
}
true
}
fn is_valid_header_value_rfc7230(value: &str) -> bool {
let bytes = value.as_bytes();
if value.contains('\r') || value.contains('\n') {
return false;
}
for &byte in bytes {
match byte {
0x21..=0x7E => {}
b' ' | b'\t' => {}
_ => return false,
}
}
true
}
const MAX_HEADER_NAME_LEN: usize = 256; const MAX_HEADER_VALUE_LEN: usize = 8192;
fn validate_inbound_metadata(metadata: &super::streaming::Metadata) -> Result<(), Status> {
for (key, value) in metadata.iter() {
if !is_valid_header_name_rfc7230(key) {
return Err(Status::invalid_argument(format!(
"metadata key '{key}' contains invalid characters (RFC 7230 violation)"
)));
}
if key.len() > MAX_HEADER_NAME_LEN {
return Err(Status::invalid_argument(format!(
"metadata key '{key}' exceeds maximum length ({} > {})",
key.len(),
MAX_HEADER_NAME_LEN
)));
}
match value {
super::streaming::MetadataValue::Ascii(text) => {
if !is_valid_header_value_rfc7230(text) {
return Err(Status::invalid_argument(format!(
"metadata value for '{key}' contains disallowed CRLF or invalid characters (RFC 7230 violation)"
)));
}
if text.len() > MAX_HEADER_VALUE_LEN {
return Err(Status::invalid_argument(format!(
"metadata value for '{key}' exceeds maximum length ({} > {})",
text.len(),
MAX_HEADER_VALUE_LEN
)));
}
}
super::streaming::MetadataValue::Binary(bytes) => {
if bytes.len() > MAX_HEADER_VALUE_LEN {
return Err(Status::invalid_argument(format!(
"binary metadata value for '{key}' exceeds maximum length ({} > {})",
bytes.len(),
MAX_HEADER_VALUE_LEN
)));
}
}
}
if metadata_key_uses_grpc_prefix(key) && !grpc_request_header_is_allowed(key) {
return Err(Status::invalid_argument(format!(
"client metadata key uses reserved grpc-* prefix: {key}"
)));
}
if let super::streaming::MetadataValue::Ascii(text) = value {
if super::streaming::sanitize_metadata_ascii_value(text).as_ref() != text {
return Err(Status::invalid_argument(format!(
"metadata value for {key} contains disallowed control or non-ASCII bytes"
)));
}
}
if key.eq_ignore_ascii_case("content-type") {
match value {
super::streaming::MetadataValue::Ascii(text)
if !grpc_content_type_is_allowed(text) =>
{
return Err(Status::invalid_argument(format!(
"content-type must be application/grpc(+proto|+json), got {text}"
)));
}
super::streaming::MetadataValue::Binary(_) => {
return Err(Status::invalid_argument(
"content-type must be an ASCII gRPC media type",
));
}
super::streaming::MetadataValue::Ascii(_) => {}
}
} else if key.eq_ignore_ascii_case("te") {
match value {
super::streaming::MetadataValue::Ascii(text)
if !grpc_te_header_is_allowed(text) =>
{
return Err(Status::invalid_argument(format!(
"te must be trailers for gRPC over HTTP/2, got {text}"
)));
}
super::streaming::MetadataValue::Binary(_) => {
return Err(Status::invalid_argument(
"te must be an ASCII trailers header",
));
}
super::streaming::MetadataValue::Ascii(_) => {}
}
}
}
Ok(())
}
pub fn enforce_metadata_size_limit(
metadata: &super::streaming::Metadata,
limit: usize,
) -> Result<(), Status> {
validate_inbound_metadata(metadata)?;
if limit == 0 {
return Ok(());
}
let actual = metadata_byte_size(metadata);
if actual > limit {
return Err(Status::resource_exhausted(format!(
"metadata exceeds max_metadata_size: {actual} bytes > {limit} bytes \
(gRPC equivalent of HTTP 431 Request Header Fields Too Large; \
see ServerConfig::max_metadata_size)"
)));
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub struct RequestBodyMeter {
cap: Option<usize>,
accumulated: usize,
}
impl RequestBodyMeter {
#[must_use]
pub fn new(cap: Option<usize>) -> Self {
Self {
cap,
accumulated: 0,
}
}
#[must_use]
pub fn from_config(config: &ServerConfig) -> Self {
Self::new(config.max_request_body_bytes)
}
#[must_use]
pub fn bytes_accumulated(&self) -> usize {
self.accumulated
}
#[must_use]
pub fn cap(&self) -> Option<usize> {
self.cap
}
pub fn record_message_bytes(&mut self, bytes: usize) -> Result<(), Status> {
self.accumulated = self.accumulated.saturating_add(bytes);
if let Some(cap) = self.cap
&& self.accumulated > cap
{
return Err(Status::resource_exhausted(format!(
"request body exceeds max_request_body_bytes: {actual} bytes > {cap} bytes \
(aggregate of all decoded messages on this call; \
see ServerConfig::max_request_body_bytes)",
actual = self.accumulated,
cap = cap,
)));
}
Ok(())
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
max_recv_message_size: 4 * 1024 * 1024, max_send_message_size: 4 * 1024 * 1024, max_request_body_bytes: None,
initial_connection_window_size: 1024 * 1024,
initial_stream_window_size: 1024 * 1024,
max_concurrent_streams: 100,
keepalive_interval_ms: None,
keepalive_timeout_ms: None,
default_timeout: None,
max_request_deadline: None,
send_compression: None,
accept_compression: vec![CompressionEncoding::Identity],
max_metadata_size: DEFAULT_MAX_METADATA_SIZE,
stream_idle_timeout: Some(Duration::from_secs(60)),
}
}
}
#[derive(Default)]
pub struct ServerBuilder {
config: ServerConfig,
services: BTreeMap<String, Arc<dyn ServiceHandler>>,
reflection: Option<ReflectionService>,
interceptors: Vec<Arc<dyn Interceptor>>,
}
impl std::fmt::Debug for ServerBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerBuilder")
.field("config", &self.config)
.field("services", &format!("[{} services]", self.services.len()))
.field("reflection_enabled", &self.reflection.is_some())
.finish()
}
}
impl ServerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: ServerConfig::default(),
services: BTreeMap::new(),
reflection: None,
interceptors: Vec::new(),
}
}
#[must_use]
pub fn interceptor<I>(mut self, interceptor: I) -> Self
where
I: Interceptor + 'static,
{
self.interceptors.push(Arc::new(interceptor));
self
}
#[must_use]
pub fn interceptor_arc(mut self, interceptor: Arc<dyn Interceptor>) -> Self {
self.interceptors.push(interceptor);
self
}
#[must_use]
pub fn max_recv_message_size(mut self, size: usize) -> Self {
self.config.max_recv_message_size = size;
self
}
#[must_use]
pub fn max_metadata_size(mut self, size: usize) -> Self {
self.config.max_metadata_size = size;
self
}
#[must_use]
pub fn stream_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
self.config.stream_idle_timeout = timeout;
self
}
#[must_use]
pub fn max_send_message_size(mut self, size: usize) -> Self {
self.config.max_send_message_size = size;
self
}
#[must_use]
pub fn max_request_body_bytes(mut self, size: usize) -> Self {
self.config.max_request_body_bytes = Some(size);
self
}
#[must_use]
pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.config.initial_connection_window_size = size;
self
}
#[must_use]
pub fn initial_stream_window_size(mut self, size: u32) -> Self {
self.config.initial_stream_window_size = size;
self
}
#[must_use]
pub fn max_concurrent_streams(mut self, max: u32) -> Self {
self.config.max_concurrent_streams = max;
self
}
#[must_use]
pub fn keepalive_interval(mut self, ms: u64) -> Self {
self.config.keepalive_interval_ms = Some(ms);
self
}
#[must_use]
pub fn keepalive_timeout(mut self, ms: u64) -> Self {
self.config.keepalive_timeout_ms = Some(ms);
self
}
#[must_use]
pub fn default_timeout(mut self, timeout: Duration) -> Self {
self.config.default_timeout = Some(timeout);
self
}
#[must_use]
pub fn max_request_deadline(mut self, max: Duration) -> Self {
self.config.max_request_deadline = Some(max);
self
}
#[must_use]
pub fn send_compression(mut self, encoding: CompressionEncoding) -> Self {
self.config.send_compression = Some(encoding);
self
}
#[must_use]
pub fn accept_compression(mut self, encoding: CompressionEncoding) -> Self {
self.config.accept_compression.push(encoding);
self
}
#[must_use]
pub fn accept_compressions(
mut self,
encodings: impl IntoIterator<Item = CompressionEncoding>,
) -> Self {
self.config.accept_compression.clear();
self.config.accept_compression.extend(encodings);
self
}
#[must_use]
pub fn add_service<S>(mut self, service: S) -> Self
where
S: NamedService + ServiceHandler + 'static,
{
let service_name = S::NAME.to_string();
let service: Arc<dyn ServiceHandler> = Arc::new(service);
if let Some(reflection) = self.reflection.as_ref()
&& service_name != ReflectionService::NAME
{
reflection.register_handler(service.as_ref());
}
self.services.insert(service_name, service);
self
}
#[must_use]
pub fn enable_reflection_with_auth<F>(mut self, auth_callback: F) -> Self
where
F: Fn(&Cx, &str) -> Result<(), Status> + Send + Sync + 'static,
{
let reflection = self
.reflection
.take()
.unwrap_or_default()
.with_auth(auth_callback);
for service in self.services.values() {
if service.descriptor().full_name() != ReflectionService::NAME {
reflection.register_handler(service.as_ref());
}
}
self.services.insert(
ReflectionService::NAME.to_string(),
Arc::new(reflection.clone()),
);
self.reflection = Some(reflection);
self
}
#[must_use]
pub fn enable_reflection_anonymous(mut self) -> Self {
let reflection = self.reflection.take().unwrap_or_default().allow_anonymous();
for service in self.services.values() {
if service.descriptor().full_name() != ReflectionService::NAME {
reflection.register_handler(service.as_ref());
}
}
self.services.insert(
ReflectionService::NAME.to_string(),
Arc::new(reflection.clone()),
);
self.reflection = Some(reflection);
self
}
#[deprecated(
since = "0.3.3",
note = "Use enable_reflection_with_auth() or enable_reflection_anonymous() to make auth choice explicit"
)]
#[must_use]
pub fn enable_reflection(mut self) -> Self {
let reflection = self.reflection.take().unwrap_or_default(); for service in self.services.values() {
if service.descriptor().full_name() != ReflectionService::NAME {
reflection.register_handler(service.as_ref());
}
}
self.services.insert(
ReflectionService::NAME.to_string(),
Arc::new(reflection.clone()),
);
self.reflection = Some(reflection);
self
}
#[must_use]
pub fn build(self) -> Server {
Server {
config: self.config,
services: self.services,
interceptors: self.interceptors,
connection_registry: Arc::new(ConnectionRegistry::new()),
}
}
}
pub struct Server {
config: ServerConfig,
services: BTreeMap<String, Arc<dyn ServiceHandler>>,
interceptors: Vec<Arc<dyn Interceptor>>,
connection_registry: Arc<ConnectionRegistry>,
}
impl std::fmt::Debug for Server {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Server")
.field("config", &self.config)
.field("services", &format!("[{} services]", self.services.len()))
.finish()
}
}
impl Server {
#[must_use]
pub fn builder() -> ServerBuilder {
ServerBuilder::new()
}
#[must_use]
pub fn config(&self) -> &ServerConfig {
&self.config
}
#[must_use]
pub fn framed_codec<C: Codec>(&self, inner: C) -> FramedCodec<C> {
FramedCodec::with_message_size_limits(
inner,
self.config.max_send_message_size,
self.config.max_recv_message_size,
)
}
#[must_use]
pub fn services(&self) -> &BTreeMap<String, Arc<dyn ServiceHandler>> {
&self.services
}
#[must_use]
pub fn connection_registry(&self) -> &Arc<ConnectionRegistry> {
&self.connection_registry
}
pub fn register_connection(&self, connection_id: String) {
self.connection_registry.add_connection(connection_id);
}
pub fn unregister_connection(&self, connection_id: &str) {
self.connection_registry.remove_connection(connection_id);
}
fn clear_auth_context_from_request(request: &mut Request<Bytes>) {
let _ = request
.extensions_mut()
.remove_typed::<super::interceptor::AuthContext>();
}
#[must_use]
pub fn interceptors(&self) -> &[Arc<dyn Interceptor>] {
&self.interceptors
}
pub async fn dispatch_unary<H, F>(
&self,
mut request: Request<Bytes>,
handler: H,
) -> Result<Response<Bytes>, Status>
where
H: FnOnce(Request<Bytes>) -> F,
F: Future<Output = Result<Response<Bytes>, Status>>,
{
enforce_metadata_size_limit(request.metadata(), self.config.max_metadata_size)?;
for (index, interceptor) in self.interceptors.iter().enumerate() {
if let Err(mut status) = interceptor.intercept_request(&mut request) {
for cleanup in self.interceptors[..=index].iter().rev() {
if let Err(replacement) =
cleanup.intercept_error_with_request(&request, &mut status)
{
status = replacement;
}
}
Self::clear_auth_context_from_request(&mut request);
return Err(status);
}
}
let call_context = CallContext::from_metadata_at_with_max_deadline(
request.metadata().clone(),
self.config.default_timeout,
self.config.max_request_deadline,
None, wall_clock_instant_now(),
);
let mut request_snapshot = request.snapshot(Bytes::new());
let response_result = if let Some(std_deadline) = call_context.deadline() {
if call_context.is_expired() {
Self::clear_auth_context_from_request(&mut request_snapshot);
return Err(Status::deadline_exceeded(
"Request deadline already expired",
));
}
let now = wall_clock_instant_now();
let remaining_duration = std_deadline.saturating_duration_since(now);
let handler_future = handler(request);
match crate::time::timeout(crate::time::wall_now(), remaining_duration, handler_future)
.await
{
Ok(result) => result,
Err(_timeout) => {
Self::clear_auth_context_from_request(&mut request_snapshot);
return Err(Status::deadline_exceeded("Request deadline exceeded"));
}
}
} else {
handler(request).await
};
let mut response = match response_result {
Ok(response) => response,
Err(mut status) => {
for interceptor in self.interceptors.iter().rev() {
if let Err(replacement) =
interceptor.intercept_error_with_request(&request_snapshot, &mut status)
{
status = replacement;
}
}
Self::clear_auth_context_from_request(&mut request_snapshot);
return Err(status);
}
};
for interceptor in self.interceptors.iter().rev() {
if let Err(mut status) =
interceptor.intercept_response_with_request(&request_snapshot, &mut response)
{
for cleanup in self.interceptors.iter().rev() {
if let Err(replacement) =
cleanup.intercept_error_with_request(&request_snapshot, &mut status)
{
status = replacement;
}
}
Self::clear_auth_context_from_request(&mut request_snapshot);
return Err(status);
}
}
Ok(response)
}
pub async fn dispatch_unary_with_stream_enforcement<H, F>(
&self,
connection_id: String,
stream_id: u32,
request: Request<Bytes>,
handler: H,
) -> Result<Response<Bytes>, Status>
where
H: FnOnce(Request<Bytes>) -> F,
F: Future<Output = Result<Response<Bytes>, Status>>,
{
let registered_at = match self.connection_registry.enforce_stream_limits(
&connection_id,
stream_id,
self.config.max_concurrent_streams,
self.config.stream_idle_timeout,
) {
Ok(timestamp) => timestamp,
Err(limit_error) => {
return Err(Status::resource_exhausted(format!(
"stream limit enforcement failed: {}",
limit_error
)));
}
};
let _stream_guard = StreamRegistrationGuard {
registry: Arc::clone(&self.connection_registry),
connection_id: connection_id.clone(),
stream_id,
registered_at,
};
self.dispatch_unary(request, handler).await
}
pub fn update_stream_activity(&self, connection_id: &str, stream_id: u32) {
self.connection_registry
.update_stream_activity(connection_id, stream_id);
}
pub fn get_connection_stats(&self) -> (usize, usize) {
self.connection_registry.get_stats()
}
#[must_use]
pub fn get_service(&self, name: &str) -> Option<&Arc<dyn ServiceHandler>> {
self.services.get(name)
}
pub fn service_names(&self) -> Vec<&str> {
self.services.keys().map(String::as_str).collect()
}
#[allow(clippy::unused_async)]
pub async fn serve(self, addr: &str) -> Result<(), GrpcError> {
if self.services.is_empty() {
return Err(GrpcError::protocol(
"cannot serve gRPC server without registered services",
));
}
let listener = std::net::TcpListener::bind(addr).map_err(|error| {
GrpcError::transport_kind(
TransportErrorKind::from_io_error_kind(error.kind()),
format!("bind failed: {error}"),
)
})?;
listener.set_nonblocking(true).map_err(|error| {
GrpcError::transport_kind(
TransportErrorKind::from_io_error_kind(error.kind()),
format!("nonblocking setup failed: {error}"),
)
})?;
Ok(())
}
}
#[must_use]
pub fn parse_grpc_timeout(header: &str) -> Option<Duration> {
if header.is_empty() {
return None;
}
if !header.is_ascii() {
return None;
}
let (digits, unit) = header.split_at(header.len() - 1);
if digits.is_empty() || digits.len() > 8 {
return None;
}
let value: u64 = digits.parse().ok()?;
match unit {
"H" => Some(Duration::from_secs(value.checked_mul(3600)?)),
"M" => Some(Duration::from_secs(value.checked_mul(60)?)),
"S" => Some(Duration::from_secs(value)),
"m" => Some(Duration::from_millis(value)),
"u" => Some(Duration::from_micros(value)),
"n" => Some(Duration::from_nanos(value)),
_ => None,
}
}
#[must_use]
pub fn format_grpc_timeout(duration: Duration) -> String {
const MAX_VALUE: u128 = 99_999_999;
let ns = duration.as_nanos();
if ns == 0 {
return "0n".to_string();
}
let secs = u128::from(duration.as_secs());
if duration.subsec_nanos() == 0 {
let hours = secs / 3600;
if hours <= MAX_VALUE && secs % 3600 == 0 {
return format!("{hours}H");
}
let mins = secs / 60;
if mins <= MAX_VALUE && secs % 60 == 0 {
return format!("{mins}M");
}
if secs <= MAX_VALUE {
return format!("{secs}S");
}
}
let ms = duration.as_millis();
if ms <= MAX_VALUE && ns.is_multiple_of(1_000_000) {
return format!("{ms}m");
}
let us = duration.as_micros();
if us <= MAX_VALUE && ns.is_multiple_of(1_000) {
return format!("{us}u");
}
if ns <= MAX_VALUE {
return format!("{ns}n");
}
if us <= MAX_VALUE {
return format!("{us}u");
}
if ms <= MAX_VALUE {
return format!("{ms}m");
}
if secs <= MAX_VALUE {
return format!("{secs}S");
}
let mins = secs / 60;
if mins <= MAX_VALUE {
return format!("{mins}M");
}
let hours = (mins / 60).min(MAX_VALUE);
format!("{hours}H")
}
#[derive(Debug)]
pub struct CallContext {
metadata: Metadata,
deadline: Option<Instant>,
peer_addr: Option<String>,
time_getter: fn() -> Instant,
}
impl CallContext {
#[must_use]
pub fn new() -> Self {
Self {
metadata: Metadata::new(),
deadline: None,
peer_addr: None,
time_getter: wall_clock_instant_now,
}
}
#[must_use]
pub fn from_metadata(
metadata: Metadata,
default_timeout: Option<Duration>,
peer_addr: Option<String>,
) -> Self {
Self::from_metadata_with_time_getter(
metadata,
default_timeout,
peer_addr,
wall_clock_instant_now,
)
}
#[must_use]
pub fn from_metadata_with_time_getter(
metadata: Metadata,
default_timeout: Option<Duration>,
peer_addr: Option<String>,
time_getter: fn() -> Instant,
) -> Self {
Self::from_metadata_at(metadata, default_timeout, peer_addr, time_getter())
.with_time_getter(time_getter)
}
#[must_use]
pub fn from_metadata_at(
metadata: Metadata,
default_timeout: Option<Duration>,
peer_addr: Option<String>,
now: Instant,
) -> Self {
Self::from_metadata_at_with_max_deadline(metadata, default_timeout, None, peer_addr, now)
}
#[must_use]
pub fn from_metadata_at_with_max_deadline(
metadata: Metadata,
default_timeout: Option<Duration>,
max_request_deadline: Option<Duration>,
peer_addr: Option<String>,
now: Instant,
) -> Self {
let timeout = match metadata.get("grpc-timeout") {
Some(super::streaming::MetadataValue::Ascii(s)) => parse_grpc_timeout(s),
Some(super::streaming::MetadataValue::Binary(_)) => None,
None => default_timeout,
};
let timeout = match (timeout, max_request_deadline) {
(Some(peer), Some(cap)) if metadata.get("grpc-timeout").is_some() => {
Some(peer.min(cap))
}
(other, _) => other,
};
let deadline = timeout.and_then(|t| now.checked_add(t));
Self {
metadata,
deadline,
peer_addr,
time_getter: wall_clock_instant_now,
}
}
#[must_use]
pub fn with_deadline(deadline: Instant) -> Self {
Self {
metadata: Metadata::new(),
deadline: Some(deadline),
peer_addr: None,
time_getter: wall_clock_instant_now,
}
}
#[must_use]
pub const fn with_time_getter(mut self, time_getter: fn() -> Instant) -> Self {
self.time_getter = time_getter;
self
}
#[must_use]
pub const fn time_getter(&self) -> fn() -> Instant {
self.time_getter
}
#[must_use]
pub fn metadata(&self) -> &Metadata {
&self.metadata
}
#[must_use]
pub fn deadline(&self) -> Option<Instant> {
self.deadline
}
#[must_use]
pub fn peer_addr(&self) -> Option<&str> {
self.peer_addr.as_deref()
}
#[must_use]
pub fn remaining(&self) -> Option<Duration> {
self.remaining_at((self.time_getter)())
}
#[must_use]
pub fn remaining_at(&self, now: Instant) -> Option<Duration> {
self.deadline.and_then(|d| d.checked_duration_since(now))
}
#[must_use]
pub fn timeout_header_value(&self) -> Option<String> {
self.timeout_header_value_at((self.time_getter)())
}
#[must_use]
pub fn timeout_header_value_at(&self, now: Instant) -> Option<String> {
self.deadline
.map(|deadline| format_grpc_timeout(deadline.saturating_duration_since(now)))
}
pub fn propagate_timeout_to(&self, metadata: &mut Metadata) -> bool {
self.propagate_timeout_to_at(metadata, (self.time_getter)())
}
pub fn propagate_timeout_to_at(&self, metadata: &mut Metadata, now: Instant) -> bool {
let Some(parent_remaining) = self
.deadline
.map(|deadline| deadline.saturating_duration_since(now))
else {
return false;
};
let effective = match metadata.get("grpc-timeout") {
Some(super::streaming::MetadataValue::Ascii(existing)) => parse_grpc_timeout(existing)
.map_or(parent_remaining, |child| child.min(parent_remaining)),
Some(super::streaming::MetadataValue::Binary(_)) | None => parent_remaining,
};
let _ = metadata.insert_or_replace("grpc-timeout", format_grpc_timeout(effective));
true
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.is_expired_at((self.time_getter)())
}
#[must_use]
pub fn is_expired_at(&self, now: Instant) -> bool {
self.deadline.is_some_and(|deadline| now >= deadline)
}
#[must_use]
pub fn with_cx<'a>(&'a self, cx: &'a Cx) -> CallContextWithCx<'a> {
CallContextWithCx { call: self, cx }
}
}
impl Default for CallContext {
fn default() -> Self {
Self::new()
}
}
pub struct CallContextWithCx<'a> {
call: &'a CallContext,
cx: &'a Cx,
}
impl CallContextWithCx<'_> {
#[must_use]
pub fn call(&self) -> &CallContext {
self.call
}
#[must_use]
pub fn metadata(&self) -> &Metadata {
self.call.metadata()
}
#[must_use]
pub fn deadline(&self) -> Option<std::time::Instant> {
self.call.deadline()
}
#[must_use]
pub fn peer_addr(&self) -> Option<&str> {
self.call.peer_addr()
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.call.is_expired()
}
#[must_use]
pub fn remaining(&self) -> Option<Duration> {
self.call.remaining()
}
#[must_use]
pub fn timeout_header_value(&self) -> Option<String> {
self.call.timeout_header_value()
}
pub fn propagate_timeout_to(&self, metadata: &mut Metadata) -> bool {
self.call.propagate_timeout_to(metadata)
}
#[must_use]
pub fn cx(&self) -> &Cx {
self.cx
}
#[must_use]
pub fn cx_narrow<Caps>(&self) -> Cx<Caps>
where
Caps: cap::SubsetOf<cap::All>,
{
self.cx.restrict::<Caps>()
}
#[must_use]
pub fn cx_readonly(&self) -> Cx<cap::None> {
self.cx.restrict::<cap::None>()
}
}
pub trait Interceptor: Send + Sync {
fn intercept_request(&self, request: &mut Request<Bytes>) -> Result<(), Status>;
fn intercept_response(&self, response: &mut Response<Bytes>) -> Result<(), Status>;
fn intercept_response_with_request(
&self,
request: &Request<Bytes>,
response: &mut Response<Bytes>,
) -> Result<(), Status> {
let _ = request;
self.intercept_response(response)
}
fn intercept_error_with_request(
&self,
request: &Request<Bytes>,
status: &mut Status,
) -> Result<(), Status> {
let _ = (request, status);
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopInterceptor;
impl Interceptor for NoopInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
}
#[derive(Debug)]
pub struct AuthInterceptor<F> {
validator: F,
}
impl<F> AuthInterceptor<F>
where
F: Fn(&Metadata) -> Result<(), Status> + Send + Sync,
{
#[must_use]
pub fn new(validator: F) -> Self {
Self { validator }
}
}
impl<F> Interceptor for AuthInterceptor<F>
where
F: Fn(&Metadata) -> Result<(), Status> + Send + Sync,
{
fn intercept_request(&self, request: &mut Request<Bytes>) -> Result<(), Status> {
(self.validator)(request.metadata())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
}
pub type UnaryHandler<Req, Resp> =
Box<dyn Fn(Request<Req>) -> UnaryFuture<Resp> + Send + Sync + 'static>;
pub type UnaryFuture<Resp> =
Pin<Box<dyn Future<Output = Result<Response<Resp>, Status>> + Send + 'static>>;
pub fn ok<T>(message: T) -> Result<Response<T>, Status> {
Ok(Response::new(message))
}
pub fn err<T>(status: Status) -> Result<Response<T>, Status> {
Err(status)
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send,
unused_must_use
)]
use super::*;
use crate::bytes::{BufMut, BytesMut};
use crate::grpc::service::ServiceDescriptor;
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
struct TestService;
impl NamedService for TestService {
const NAME: &'static str = "test.TestService";
}
impl ServiceHandler for TestService {
fn descriptor(&self) -> &ServiceDescriptor {
static DESC: ServiceDescriptor = ServiceDescriptor::new("TestService", "test", &[]);
&DESC
}
fn method_names(&self) -> Vec<&str> {
vec![]
}
}
#[test]
fn test_server_builder() {
init_test("test_server_builder");
let server = Server::builder()
.max_recv_message_size(1024 * 1024)
.max_concurrent_streams(50)
.add_service(TestService)
.build();
let max_recv = server.config().max_recv_message_size;
crate::assert_with_log!(max_recv == 1024 * 1024, "max_recv", 1024 * 1024, max_recv);
let max_streams = server.config().max_concurrent_streams;
crate::assert_with_log!(max_streams == 50, "max_streams", 50, max_streams);
let has_service = server.get_service("test.TestService").is_some();
crate::assert_with_log!(has_service, "service exists", true, has_service);
crate::test_complete!("test_server_builder");
}
#[test]
fn test_server_builder_enable_reflection_anonymous() {
init_test("test_server_builder_enable_reflection_anonymous");
let server = Server::builder()
.add_service(TestService)
.enable_reflection_anonymous()
.build();
let has_reflection = server.get_service(ReflectionService::NAME).is_some();
crate::assert_with_log!(has_reflection, "reflection exists", true, has_reflection);
let names = server.service_names();
let has_test = names.contains(&"test.TestService");
crate::assert_with_log!(has_test, "test service retained", true, has_test);
let has_refl = names.contains(&ReflectionService::NAME);
crate::assert_with_log!(has_refl, "reflection service listed", true, has_refl);
crate::test_complete!("test_server_builder_enable_reflection_anonymous");
}
#[test]
fn test_server_builder_enable_reflection_with_auth() {
init_test("test_server_builder_enable_reflection_with_auth");
let server = Server::builder()
.add_service(TestService)
.enable_reflection_with_auth(|_cx, _method| Ok(()))
.build();
let has_reflection = server.get_service(ReflectionService::NAME).is_some();
crate::assert_with_log!(has_reflection, "reflection exists", true, has_reflection);
let names = server.service_names();
let has_test = names.contains(&"test.TestService");
crate::assert_with_log!(has_test, "test service retained", true, has_test);
let has_refl = names.contains(&ReflectionService::NAME);
crate::assert_with_log!(has_refl, "reflection service listed", true, has_refl);
crate::test_complete!("test_server_builder_enable_reflection_with_auth");
}
#[test]
fn test_server_builder_reflection_tracks_late_registration() {
init_test("test_server_builder_reflection_tracks_late_registration");
let server = Server::builder()
.enable_reflection_anonymous() .add_service(TestService)
.build();
let has_reflection = server.get_service(ReflectionService::NAME).is_some();
crate::assert_with_log!(has_reflection, "reflection exists", true, has_reflection);
let has_service = server.get_service("test.TestService").is_some();
crate::assert_with_log!(has_service, "late service exists", true, has_service);
crate::test_complete!("test_server_builder_reflection_tracks_late_registration");
}
#[test]
#[allow(deprecated)]
fn test_deprecated_enable_reflection_defaults_to_locked() {
init_test("test_deprecated_enable_reflection_defaults_to_locked");
let server = Server::builder()
.add_service(TestService)
.enable_reflection() .build();
let has_reflection = server.get_service(ReflectionService::NAME).is_some();
crate::assert_with_log!(has_reflection, "reflection exists", true, has_reflection);
let locked_reflection = ReflectionService::new(); let result = locked_reflection.list_services();
crate::assert_with_log!(
result.is_err(),
"locked reflection should fail",
true,
result.is_err()
);
if let Err(status) = result {
let is_permission_denied =
status.code() == super::super::status::Code::PermissionDenied;
crate::assert_with_log!(
is_permission_denied,
"should be PermissionDenied",
true,
is_permission_denied
);
let message_contains_with_auth = status.message().contains(".with_auth");
let message_contains_allow_anonymous = status.message().contains(".allow_anonymous");
crate::assert_with_log!(
message_contains_with_auth && message_contains_allow_anonymous,
"error should mention both auth options",
true,
message_contains_with_auth && message_contains_allow_anonymous
);
}
crate::test_complete!("test_deprecated_enable_reflection_defaults_to_locked");
}
#[test]
fn test_reflection_auth_callback_enforcement() {
init_test("test_reflection_auth_callback_enforcement");
let reflection = ReflectionService::new()
.with_auth(|_cx, method| Err(Status::permission_denied(format!("denied: {method}"))));
let _current = Cx::set_current(Some(Cx::for_testing_with_remote(
crate::remote::RemoteCap::new(),
)));
let result = reflection.list_services();
crate::assert_with_log!(
result.is_err(),
"auth callback should deny",
true,
result.is_err()
);
if let Err(status) = result {
let is_permission_denied =
status.code() == super::super::status::Code::PermissionDenied;
crate::assert_with_log!(
is_permission_denied,
"should be PermissionDenied",
true,
is_permission_denied
);
let message_contains_denied = status.message().contains("denied:");
crate::assert_with_log!(
message_contains_denied,
"message should contain 'denied:'",
true,
message_contains_denied
);
}
crate::test_complete!("test_reflection_auth_callback_enforcement");
}
#[test]
fn test_reflection_anonymous_allows_access() {
init_test("test_reflection_anonymous_allows_access");
let reflection = ReflectionService::new().allow_anonymous();
reflection.register_handler(&TestService);
let _current = Cx::set_current(Some(Cx::for_testing_with_remote(
crate::remote::RemoteCap::new(),
)));
let result = reflection.list_services();
crate::assert_with_log!(
result.is_ok(),
"anonymous should allow",
true,
result.is_ok()
);
if let Ok(services) = result {
let has_test_service = services.contains(&"test.TestService".to_string());
crate::assert_with_log!(
has_test_service,
"should list test service",
true,
has_test_service
);
}
crate::test_complete!("test_reflection_anonymous_allows_access");
}
#[test]
fn test_server_service_names() {
init_test("test_server_service_names");
let server = Server::builder().add_service(TestService).build();
let names = server.service_names();
let contains = names.contains(&"test.TestService");
crate::assert_with_log!(contains, "contains service name", true, contains);
crate::test_complete!("test_server_service_names");
}
#[test]
fn test_server_serve_requires_service_registration() {
init_test("test_server_serve_requires_service_registration");
let server = Server::builder().build();
let result = futures_lite::future::block_on(server.serve("127.0.0.1:0"));
let err = result.expect_err("serving without services should fail");
crate::assert_with_log!(
matches!(err, GrpcError::Protocol(_)),
"protocol error for empty service registry",
true,
matches!(err, GrpcError::Protocol(_))
);
crate::test_complete!("test_server_serve_requires_service_registration");
}
#[test]
fn test_server_serve_rejects_invalid_address() {
init_test("test_server_serve_rejects_invalid_address");
let server = Server::builder().add_service(TestService).build();
let result = futures_lite::future::block_on(server.serve("not-an-addr"));
let err = result.expect_err("invalid listen address should fail");
crate::assert_with_log!(
matches!(err, GrpcError::Transport(_, _)),
"transport error for invalid address",
true,
matches!(err, GrpcError::Transport(_, _))
);
crate::test_complete!("test_server_serve_rejects_invalid_address");
}
#[test]
fn test_server_serve_bind_probe() {
init_test("test_server_serve_bind_probe");
let server = Server::builder().add_service(TestService).build();
let result = futures_lite::future::block_on(server.serve("127.0.0.1:0"));
crate::assert_with_log!(result.is_ok(), "bind probe succeeds", true, result.is_ok());
crate::test_complete!("test_server_serve_bind_probe");
}
#[test]
fn test_server_serve_addr_in_use_preserves_non_retryable_kind() {
init_test("test_server_serve_addr_in_use_preserves_non_retryable_kind");
let held_listener = std::net::TcpListener::bind("127.0.0.1:0")
.expect("test should reserve an ephemeral TCP port");
let addr = held_listener
.local_addr()
.expect("reserved listener should expose local addr");
let server = Server::builder().add_service(TestService).build();
let result = futures_lite::future::block_on(server.serve(&addr.to_string()));
let err = result.expect_err("binding an already-held port should fail");
match &err {
GrpcError::Transport(kind, message) => {
crate::assert_with_log!(
*kind == TransportErrorKind::ProtocolViolation,
"addr-in-use transport kind",
TransportErrorKind::ProtocolViolation,
*kind
);
crate::assert_with_log!(
message.contains("bind failed"),
"message contains bind context",
true,
message.contains("bind failed")
);
}
other => panic!("expected typed transport error for AddrInUse, got {other:?}"),
}
let status = err.into_status();
crate::assert_with_log!(
status.code() == crate::grpc::status::Code::Internal,
"addr-in-use status code",
crate::grpc::status::Code::Internal,
status.code()
);
crate::test_complete!("test_server_serve_addr_in_use_preserves_non_retryable_kind");
}
#[test]
fn test_server_serve_accepts_hostname_address() {
init_test("test_server_serve_accepts_hostname_address");
let server = Server::builder().add_service(TestService).build();
let result = futures_lite::future::block_on(server.serve("localhost:0"));
crate::assert_with_log!(
result.is_ok(),
"bind probe accepts hostname form",
true,
result.is_ok()
);
crate::test_complete!("test_server_serve_accepts_hostname_address");
}
#[test]
fn test_call_context() {
init_test("test_call_context");
let ctx = CallContext::new();
let meta_empty = ctx.metadata().is_empty();
crate::assert_with_log!(meta_empty, "metadata empty", true, meta_empty);
let deadline_none = ctx.deadline().is_none();
crate::assert_with_log!(deadline_none, "deadline none", true, deadline_none);
let peer_none = ctx.peer_addr().is_none();
crate::assert_with_log!(peer_none, "peer none", true, peer_none);
let expired = ctx.is_expired();
crate::assert_with_log!(!expired, "not expired", false, expired);
let cx = Cx::for_testing();
let wrapped = ctx.with_cx(&cx);
let _readonly = wrapped.cx_readonly();
let _narrow = wrapped.cx_narrow::<cap::CapSet<true, true, false, false, false>>();
crate::test_complete!("test_call_context");
}
#[test]
fn test_call_context_expiry_boundary_is_inclusive() {
init_test("test_call_context_expiry_boundary_is_inclusive");
let now = std::time::Instant::now();
let ctx = CallContext {
metadata: Metadata::new(),
deadline: Some(now),
peer_addr: None,
time_getter: wall_clock_instant_now,
};
let expired_at_boundary = ctx.is_expired_at(now);
crate::assert_with_log!(
expired_at_boundary,
"expired at deadline boundary",
true,
expired_at_boundary
);
let before_deadline_ctx = CallContext {
metadata: Metadata::new(),
deadline: Some(now + std::time::Duration::from_millis(1)),
peer_addr: None,
time_getter: wall_clock_instant_now,
};
let not_yet_expired = before_deadline_ctx.is_expired_at(now);
crate::assert_with_log!(
!not_yet_expired,
"not expired before deadline",
false,
not_yet_expired
);
crate::test_complete!("test_call_context_expiry_boundary_is_inclusive");
}
#[test]
fn test_call_context_time_getter_controls_deadline_helpers_without_sleep() {
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
static BASE: OnceLock<std::time::Instant> = OnceLock::new();
static NOW_OFFSET_NS: AtomicU64 = AtomicU64::new(0);
fn test_now() -> std::time::Instant {
BASE.get_or_init(std::time::Instant::now)
.checked_add(std::time::Duration::from_nanos(
NOW_OFFSET_NS.load(Ordering::Relaxed),
))
.expect("test instant overflow")
}
init_test("test_call_context_time_getter_controls_deadline_helpers_without_sleep");
NOW_OFFSET_NS.store(0, Ordering::Relaxed);
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "5m");
let ctx = CallContext::from_metadata_with_time_getter(metadata, None, None, test_now);
let initial_remaining = ctx.remaining();
crate::assert_with_log!(
initial_remaining == Some(std::time::Duration::from_millis(5)),
"remaining uses custom time getter at construction time",
Some(std::time::Duration::from_millis(5)),
initial_remaining
);
NOW_OFFSET_NS.store(6_000_000, Ordering::Relaxed);
let expired = ctx.is_expired();
crate::assert_with_log!(
expired,
"is_expired follows custom time getter without sleeping",
true,
expired
);
let remaining_after_expiry = ctx.remaining();
crate::assert_with_log!(
remaining_after_expiry.is_none(),
"remaining returns none after custom-clock expiry",
true,
remaining_after_expiry.is_none()
);
crate::test_complete!(
"test_call_context_time_getter_controls_deadline_helpers_without_sleep"
);
}
#[test]
fn test_call_context_default_timeout_applies_when_header_absent() {
init_test("test_call_context_default_timeout_applies_when_header_absent");
let now = std::time::Instant::now();
let fallback = std::time::Duration::from_secs(3);
let ctx = CallContext::from_metadata_at(Metadata::new(), Some(fallback), None, now);
let deadline = ctx.deadline();
crate::assert_with_log!(
deadline == now.checked_add(fallback),
"default timeout applies when grpc-timeout header is absent",
now.checked_add(fallback),
deadline
);
crate::test_complete!("test_call_context_default_timeout_applies_when_header_absent");
}
#[test]
fn test_call_context_malformed_timeout_does_not_use_default() {
init_test("test_call_context_malformed_timeout_does_not_use_default");
let now = std::time::Instant::now();
let fallback = std::time::Duration::from_secs(3);
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "bogus");
let ctx = CallContext::from_metadata_at(metadata, Some(fallback), None, now);
let deadline = ctx.deadline();
crate::assert_with_log!(
deadline.is_none(),
"malformed grpc-timeout does not use the default timeout",
true,
deadline.is_none()
);
crate::test_complete!("test_call_context_malformed_timeout_does_not_use_default");
}
#[test]
fn test_call_context_malformed_timeout_without_default_yields_no_deadline() {
init_test("test_call_context_malformed_timeout_without_default_yields_no_deadline");
let now = std::time::Instant::now();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "bogus");
let ctx = CallContext::from_metadata_at(metadata, None, None, now);
let deadline = ctx.deadline();
crate::assert_with_log!(
deadline.is_none(),
"malformed grpc-timeout with no default yields no deadline",
true,
deadline.is_none()
);
crate::test_complete!(
"test_call_context_malformed_timeout_without_default_yields_no_deadline"
);
}
#[test]
fn test_parse_grpc_timeout_rejects_more_than_eight_digits() {
init_test("test_parse_grpc_timeout_rejects_more_than_eight_digits");
let parsed = parse_grpc_timeout("100000000S");
crate::assert_with_log!(
parsed.is_none(),
"oversized timeout literal must be rejected per gRPC 8-digit limit",
true,
parsed.is_none()
);
crate::test_complete!("test_parse_grpc_timeout_rejects_more_than_eight_digits");
}
#[test]
fn test_call_context_from_metadata_at_default_time_getter_is_wall_clock() {
init_test("test_call_context_from_metadata_at_default_time_getter_is_wall_clock");
let now = std::time::Instant::now();
let ctx = CallContext::from_metadata_at(Metadata::new(), None, None, now);
let getter = ctx.time_getter();
assert!(
std::ptr::fn_addr_eq(getter, wall_clock_instant_now as fn() -> std::time::Instant),
"from_metadata_at must default time_getter to wall_clock_instant_now"
);
crate::test_complete!(
"test_call_context_from_metadata_at_default_time_getter_is_wall_clock"
);
}
#[test]
fn test_call_context_with_time_getter_chain_overrides_default() {
init_test("test_call_context_with_time_getter_chain_overrides_default");
let recorded = std::time::Instant::now();
fn fixed_time() -> std::time::Instant {
std::time::Instant::now()
}
let ctx = CallContext::from_metadata_at(Metadata::new(), None, None, recorded)
.with_time_getter(fixed_time);
let getter = ctx.time_getter();
assert!(
std::ptr::fn_addr_eq(getter, fixed_time as fn() -> std::time::Instant),
"with_time_getter must replace the default — fixed_time wasn't installed"
);
crate::test_complete!("test_call_context_with_time_getter_chain_overrides_default");
}
#[test]
fn test_call_context_oversized_timeout_header_fails_closed() {
init_test("test_call_context_oversized_timeout_header_fails_closed");
let now = std::time::Instant::now();
let fallback = std::time::Duration::from_secs(3);
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "100000000S");
let ctx = CallContext::from_metadata_at(metadata, Some(fallback), None, now);
let deadline = ctx.deadline();
crate::assert_with_log!(
deadline.is_none(),
"oversized timeout header must not be treated as an unbounded valid deadline",
true,
deadline.is_none()
);
crate::test_complete!("test_call_context_oversized_timeout_header_fails_closed");
}
#[test]
fn test_call_context_timeout_header_value_uses_remaining_budget() {
init_test("test_call_context_timeout_header_value_uses_remaining_budget");
let now = std::time::Instant::now();
let deadline = now + std::time::Duration::from_millis(250);
let ctx = CallContext::with_deadline(deadline);
let header = ctx.timeout_header_value_at(now);
crate::assert_with_log!(
header.as_deref() == Some("250m"),
"timeout header preserves remaining duration",
Some("250m"),
header.as_deref()
);
let expired_header =
ctx.timeout_header_value_at(deadline + std::time::Duration::from_millis(1));
crate::assert_with_log!(
expired_header.as_deref() == Some("0n"),
"expired deadlines propagate as zero timeout",
Some("0n"),
expired_header.as_deref()
);
crate::test_complete!("test_call_context_timeout_header_value_uses_remaining_budget");
}
#[test]
fn test_call_context_propagate_timeout_to_clamps_existing_child_timeout() {
init_test("test_call_context_propagate_timeout_to_clamps_existing_child_timeout");
let now = std::time::Instant::now();
let ctx = CallContext::with_deadline(now + std::time::Duration::from_secs(5));
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "10S");
let wrote = ctx.propagate_timeout_to_at(&mut metadata, now);
crate::assert_with_log!(wrote, "propagation writes timeout header", true, wrote);
crate::assert_with_log!(
matches!(
metadata.get("grpc-timeout"),
Some(crate::grpc::MetadataValue::Ascii(value)) if value == "5S"
),
"existing child timeout is attenuated to parent deadline",
true,
metadata.get("grpc-timeout").is_some()
);
let timeout_count = metadata
.iter()
.filter(|(key, _)| key.eq_ignore_ascii_case("grpc-timeout"))
.count();
crate::assert_with_log!(
timeout_count == 1,
"propagation keeps a single grpc-timeout entry",
1,
timeout_count
);
crate::test_complete!(
"test_call_context_propagate_timeout_to_clamps_existing_child_timeout"
);
}
#[test]
fn test_call_context_propagate_timeout_to_inserts_when_absent() {
init_test("test_call_context_propagate_timeout_to_inserts_when_absent");
let now = std::time::Instant::now();
let ctx = CallContext::with_deadline(now + std::time::Duration::from_millis(750));
let mut metadata = Metadata::new();
let wrote = ctx.propagate_timeout_to_at(&mut metadata, now);
crate::assert_with_log!(wrote, "propagation inserts missing timeout", true, wrote);
crate::assert_with_log!(
matches!(
metadata.get("grpc-timeout"),
Some(crate::grpc::MetadataValue::Ascii(value)) if value == "750m"
),
"propagation inserts parent remaining timeout when absent",
true,
metadata.get("grpc-timeout").is_some()
);
crate::test_complete!("test_call_context_propagate_timeout_to_inserts_when_absent");
}
#[test]
fn test_call_context_propagate_timeout_to_repairs_malformed_child_timeout() {
init_test("test_call_context_propagate_timeout_to_repairs_malformed_child_timeout");
let now = std::time::Instant::now();
let ctx = CallContext::with_deadline(now + std::time::Duration::from_secs(5));
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "bogus");
let wrote = ctx.propagate_timeout_to_at(&mut metadata, now);
crate::assert_with_log!(wrote, "propagation writes repaired timeout", true, wrote);
crate::assert_with_log!(
matches!(
metadata.get("grpc-timeout"),
Some(crate::grpc::MetadataValue::Ascii(value)) if value == "5S"
),
"malformed child timeout replaced with parent deadline",
true,
metadata.get("grpc-timeout").is_some()
);
let timeout_count = metadata
.iter()
.filter(|(key, _)| key.eq_ignore_ascii_case("grpc-timeout"))
.count();
crate::assert_with_log!(
timeout_count == 1,
"repaired child timeout does not leave duplicates",
1,
timeout_count
);
crate::test_complete!(
"test_call_context_propagate_timeout_to_repairs_malformed_child_timeout"
);
}
#[test]
fn test_noop_interceptor() {
init_test("test_noop_interceptor");
let interceptor = NoopInterceptor;
let mut request = Request::new(Bytes::new());
let ok = interceptor.intercept_request(&mut request).is_ok();
crate::assert_with_log!(ok, "request ok", true, ok);
let mut response = Response::new(Bytes::new());
let ok = interceptor.intercept_response(&mut response).is_ok();
crate::assert_with_log!(ok, "response ok", true, ok);
crate::test_complete!("test_noop_interceptor");
}
#[test]
fn test_auth_interceptor() {
init_test("test_auth_interceptor");
let interceptor = AuthInterceptor::new(|metadata| {
if metadata.get("authorization").is_some() {
Ok(())
} else {
Err(Status::unauthenticated("missing authorization"))
}
});
let mut request = Request::new(Bytes::new());
let err = interceptor.intercept_request(&mut request).is_err();
crate::assert_with_log!(err, "missing auth err", true, err);
request
.metadata_mut()
.insert("authorization", "Bearer token");
let ok = interceptor.intercept_request(&mut request).is_ok();
crate::assert_with_log!(ok, "auth ok", true, ok);
crate::test_complete!("test_auth_interceptor");
}
#[test]
fn server_config_default_caps_metadata_at_8_kib() {
init_test("server_config_default_caps_metadata_at_8_kib");
let cfg = ServerConfig::default();
assert_eq!(
cfg.max_metadata_size, DEFAULT_MAX_METADATA_SIZE,
"default max_metadata_size must equal DEFAULT_MAX_METADATA_SIZE (8 KiB)"
);
assert_eq!(cfg.max_metadata_size, 8 * 1024);
crate::test_complete!("server_config_default_caps_metadata_at_8_kib");
}
#[test]
fn enforce_metadata_size_limit_accepts_under_cap() {
init_test("enforce_metadata_size_limit_accepts_under_cap");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("authorization", "Bearer abc");
metadata.insert("x-request-id", "deadbeef");
let total = metadata_byte_size(&metadata);
assert!(total > 0);
enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect("under-cap metadata must pass enforcement");
crate::test_complete!("enforce_metadata_size_limit_accepts_under_cap");
}
#[test]
fn enforce_metadata_size_limit_rejects_over_cap_with_resource_exhausted() {
init_test("enforce_metadata_size_limit_rejects_over_cap_with_resource_exhausted");
let mut metadata = super::super::streaming::Metadata::new();
let chunk = "A".repeat(4 * 1024);
metadata.insert("x-attack-a", chunk.clone());
metadata.insert("x-attack-b", chunk);
match enforce_metadata_size_limit(&metadata, 8 * 1024) {
Err(status) => {
assert_eq!(
status.code(),
super::super::status::Code::ResourceExhausted,
"must reject with RESOURCE_EXHAUSTED, got {:?}",
status.code()
);
let msg = format!("{status}");
assert!(
msg.contains("max_metadata_size") || msg.contains("metadata"),
"error message must mention the limit, got: {msg}"
);
}
Ok(()) => {
panic!("16 KiB metadata must be rejected by 8 KiB cap, but enforcement passed")
}
}
crate::test_complete!(
"enforce_metadata_size_limit_rejects_over_cap_with_resource_exhausted"
);
}
#[test]
fn enforce_metadata_size_limit_zero_disables_cap() {
init_test("enforce_metadata_size_limit_zero_disables_cap");
let mut metadata = super::super::streaming::Metadata::new();
let chunk = "A".repeat(4 * 1024);
for index in 0..256 {
metadata.insert(format!("x-anything-{index}"), chunk.clone());
}
enforce_metadata_size_limit(&metadata, 0)
.expect("limit=0 must disable enforcement (no-cap convention)");
crate::test_complete!("enforce_metadata_size_limit_zero_disables_cap");
}
#[test]
fn enforce_metadata_size_limit_rejects_ascii_control_chars() {
init_test("enforce_metadata_size_limit_rejects_ascii_control_chars");
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"x-request-id".to_string(),
super::super::streaming::MetadataValue::Ascii("line1\r\nline2".to_string()),
)]);
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("CRLF-bearing metadata must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
let msg = format!("{status}");
assert!(
msg.contains("x-request-id") && msg.contains("disallowed"),
"error message must mention the offending header, got: {msg}"
);
crate::test_complete!("enforce_metadata_size_limit_rejects_ascii_control_chars");
}
#[test]
fn enforce_metadata_size_limit_rejects_reserved_grpc_header() {
init_test("enforce_metadata_size_limit_rejects_reserved_grpc_header");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("grpc-status", "0");
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("client grpc-status metadata must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
let msg = format!("{status}");
assert!(
msg.contains("grpc-status") && msg.contains("reserved grpc-* prefix"),
"error message must mention reserved grpc-* prefix, got: {msg}"
);
crate::test_complete!("enforce_metadata_size_limit_rejects_reserved_grpc_header");
}
#[test]
fn enforce_metadata_size_limit_rejects_non_grpc_content_type() {
init_test("enforce_metadata_size_limit_rejects_non_grpc_content_type");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("content-type", "application/json");
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("non-gRPC content-type must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
let msg = format!("{status}");
assert!(
msg.contains("content-type") && msg.contains("application/grpc"),
"error message must mention the required gRPC media type, got: {msg}"
);
crate::test_complete!("enforce_metadata_size_limit_rejects_non_grpc_content_type");
}
#[test]
fn enforce_metadata_size_limit_rejects_non_trailers_te() {
init_test("enforce_metadata_size_limit_rejects_non_trailers_te");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("te", "gzip");
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("non-trailers te must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
let msg = format!("{status}");
assert!(
msg.contains("te") && msg.contains("trailers"),
"error message must mention the trailers requirement, got: {msg}"
);
crate::test_complete!("enforce_metadata_size_limit_rejects_non_trailers_te");
}
#[test]
fn enforce_metadata_size_limit_allows_grpc_request_protocol_headers() {
init_test("enforce_metadata_size_limit_allows_grpc_request_protocol_headers");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("content-type", "application/grpc+proto");
metadata.insert("te", "trailers");
metadata.insert("grpc-timeout", "5S");
metadata.insert("grpc-encoding", "identity");
metadata.insert("grpc-accept-encoding", "identity,gzip");
metadata.insert("grpc-message-type", "test.EchoRequest");
enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect("protocol-owned request grpc-* headers must remain allowed");
crate::test_complete!("enforce_metadata_size_limit_allows_grpc_request_protocol_headers");
}
#[test]
fn server_builder_max_metadata_size_overrides_default() {
init_test("server_builder_max_metadata_size_overrides_default");
let server = ServerBuilder::new().max_metadata_size(16 * 1024).build();
assert_eq!(server.config().max_metadata_size, 16 * 1024);
crate::test_complete!("server_builder_max_metadata_size_overrides_default");
}
#[test]
fn test_rfc7230_header_name_validation_rejects_invalid_characters() {
init_test("test_rfc7230_header_name_validation_rejects_invalid_characters");
assert!(!is_valid_header_name_rfc7230("")); assert!(!is_valid_header_name_rfc7230("invalid space")); assert!(!is_valid_header_name_rfc7230("invalid\r")); assert!(!is_valid_header_name_rfc7230("invalid\n")); assert!(!is_valid_header_name_rfc7230("invalid\t")); assert!(!is_valid_header_name_rfc7230("invalid:header")); assert!(!is_valid_header_name_rfc7230("invalid;header")); assert!(!is_valid_header_name_rfc7230("invalid(header")); assert!(!is_valid_header_name_rfc7230("invalid)header")); assert!(!is_valid_header_name_rfc7230("invalid<header")); assert!(!is_valid_header_name_rfc7230("invalid>header")); assert!(!is_valid_header_name_rfc7230("invalid@header")); assert!(!is_valid_header_name_rfc7230("invalid,header")); assert!(!is_valid_header_name_rfc7230("invalid\\header")); assert!(!is_valid_header_name_rfc7230("invalid\"header")); assert!(!is_valid_header_name_rfc7230("invalid/header")); assert!(!is_valid_header_name_rfc7230("invalid[header")); assert!(!is_valid_header_name_rfc7230("invalid]header")); assert!(!is_valid_header_name_rfc7230("invalid?header")); assert!(!is_valid_header_name_rfc7230("invalid=header")); assert!(!is_valid_header_name_rfc7230("invalid{header")); assert!(!is_valid_header_name_rfc7230("invalid}header"));
assert!(is_valid_header_name_rfc7230("valid-header")); assert!(is_valid_header_name_rfc7230("valid_header")); assert!(is_valid_header_name_rfc7230("validheader123")); assert!(is_valid_header_name_rfc7230("x-custom-header")); assert!(is_valid_header_name_rfc7230("content-type")); assert!(is_valid_header_name_rfc7230("x-trace-id")); assert!(is_valid_header_name_rfc7230("authorization"));
crate::test_complete!("test_rfc7230_header_name_validation_rejects_invalid_characters");
}
#[test]
fn test_rfc7230_header_value_validation_rejects_crlf_injection() {
init_test("test_rfc7230_header_value_validation_rejects_crlf_injection");
assert!(!is_valid_header_value_rfc7230(
"value1\r\ninjected-header: evil"
));
assert!(!is_valid_header_value_rfc7230(
"value1\ninjected-header: evil"
));
assert!(!is_valid_header_value_rfc7230(
"value1\rinjected-header: evil"
));
assert!(!is_valid_header_value_rfc7230("\r\nevil-header: value"));
assert!(!is_valid_header_value_rfc7230(
"normal\r\nContent-Length: 0"
));
assert!(!is_valid_header_value_rfc7230(
"test\r\n\r\nHTTP/1.1 200 OK"
));
assert!(!is_valid_header_value_rfc7230("value\x00control")); assert!(!is_valid_header_value_rfc7230("value\x01control")); assert!(!is_valid_header_value_rfc7230("value\x02control")); assert!(!is_valid_header_value_rfc7230("value\x7Fcontrol"));
assert!(is_valid_header_value_rfc7230("valid header value"));
assert!(is_valid_header_value_rfc7230("Bearer abc123"));
assert!(is_valid_header_value_rfc7230("application/grpc+proto"));
assert!(is_valid_header_value_rfc7230("trailers"));
assert!(is_valid_header_value_rfc7230("5S"));
assert!(is_valid_header_value_rfc7230("identity,gzip"));
assert!(is_valid_header_value_rfc7230(""));
crate::test_complete!("test_rfc7230_header_value_validation_rejects_crlf_injection");
}
#[test]
fn test_enforce_metadata_rejects_rfc7230_header_name_violations() {
init_test("test_enforce_metadata_rejects_rfc7230_header_name_violations");
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"invalid header".to_string(),
super::super::streaming::MetadataValue::Ascii("value".to_string()),
)]);
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("header name with space must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("invalid characters"));
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"invalid\r\nheader".to_string(),
super::super::streaming::MetadataValue::Ascii("value".to_string()),
)]);
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("header name with CRLF must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("RFC 7230 violation"));
crate::test_complete!("test_enforce_metadata_rejects_rfc7230_header_name_violations");
}
#[test]
fn test_enforce_metadata_rejects_header_injection_attacks() {
init_test("test_enforce_metadata_rejects_header_injection_attacks");
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"x-trace-id".to_string(),
super::super::streaming::MetadataValue::Ascii(
"normal\r\ninjected-header: evil".to_string(),
),
)]);
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("CRLF injection attack must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("CRLF"));
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"authorization".to_string(),
super::super::streaming::MetadataValue::Ascii(
"Bearer token\r\n\r\nHTTP/1.1 200 OK".to_string(),
),
)]);
let status = enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect_err("response splitting attack must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
crate::test_complete!("test_enforce_metadata_rejects_header_injection_attacks");
}
#[test]
fn test_enforce_metadata_rejects_oversized_headers() {
init_test("test_enforce_metadata_rejects_oversized_headers");
let long_name = "x-".to_owned() + &"a".repeat(MAX_HEADER_NAME_LEN);
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
long_name,
super::super::streaming::MetadataValue::Ascii("value".to_string()),
)]);
let status = enforce_metadata_size_limit(&metadata, 64 * 1024)
.expect_err("oversized header name must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("exceeds maximum length"));
let long_value = "a".repeat(MAX_HEADER_VALUE_LEN + 1);
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"x-large-value".to_string(),
super::super::streaming::MetadataValue::Ascii(long_value),
)]);
let status = enforce_metadata_size_limit(&metadata, 64 * 1024)
.expect_err("oversized header value must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("exceeds maximum length"));
let long_binary = vec![0u8; MAX_HEADER_VALUE_LEN + 1];
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"x-large-binary".to_string(),
super::super::streaming::MetadataValue::Binary(long_binary.into()),
)]);
let status = enforce_metadata_size_limit(&metadata, 64 * 1024)
.expect_err("oversized binary value must be rejected");
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
crate::test_complete!("test_enforce_metadata_rejects_oversized_headers");
}
#[test]
fn test_enforce_metadata_allows_valid_rfc7230_headers() {
init_test("test_enforce_metadata_allows_valid_rfc7230_headers");
let mut metadata = super::super::streaming::Metadata::new();
metadata.insert("x-trace-id", "abc123def456");
metadata.insert("authorization", "Bearer valid-token");
metadata.insert("content-type", "application/grpc+proto");
metadata.insert("user-agent", "grpc-client/1.0");
metadata.insert("x-custom-header", "valid value with spaces");
enforce_metadata_size_limit(&metadata, 8 * 1024)
.expect("valid RFC 7230 compliant headers must be accepted");
crate::test_complete!("test_enforce_metadata_allows_valid_rfc7230_headers");
}
#[test]
fn test_dispatch_unary_rejects_header_injection_before_handler() {
use futures_lite::future::block_on;
init_test("test_dispatch_unary_rejects_header_injection_before_handler");
let server = Server::builder().max_metadata_size(8 * 1024).build();
let metadata = super::super::streaming::Metadata::from_raw_entries_for_tests(vec![(
"x-trace-id".to_string(),
super::super::streaming::MetadataValue::Ascii(
"valid\r\ninjected-header: malicious".to_string(),
),
)]);
let request = Request::with_metadata(Bytes::new(), metadata);
let mut handler_invoked = false;
let result = block_on(server.dispatch_unary(request, |_req| async {
handler_invoked = true;
Ok(Response::new(Bytes::from_static(b"should-not-reach")))
}));
assert!(result.is_err(), "CRLF injection must be rejected");
assert!(
!handler_invoked,
"handler must NOT be invoked for header injection attempts"
);
let status = result.unwrap_err();
assert_eq!(status.code(), super::super::status::Code::InvalidArgument);
assert!(format!("{status}").contains("CRLF"));
crate::test_complete!("test_dispatch_unary_rejects_header_injection_before_handler");
}
#[test]
fn test_connection_registry_concurrent_operations() {
init_test("test_connection_registry_concurrent_operations");
let registry = Arc::new(ConnectionRegistry::new());
let connection_id = "test-connection".to_string();
registry.add_connection(connection_id.clone());
let registry_clone = Arc::clone(®istry);
let connection_id_clone = connection_id.clone();
let handle = std::thread::spawn(move || {
for i in 0..100 {
let stream_id = i;
let result = registry_clone.enforce_stream_limits(
&connection_id_clone,
stream_id,
200,
None,
);
if result.is_ok() {
registry_clone.update_stream_activity(&connection_id_clone, stream_id);
registry_clone.remove_stream(&connection_id_clone, stream_id);
}
}
});
for i in 100..200 {
let stream_id = i;
let result = registry.enforce_stream_limits(&connection_id, stream_id, 200, None);
if result.is_ok() {
registry.update_stream_activity(&connection_id, stream_id);
registry.remove_stream(&connection_id, stream_id);
}
let (_conn_count, _stream_count) = registry.get_stats();
}
handle.join().expect("thread should complete successfully");
let (conn_count, stream_count) = registry.get_stats();
assert_eq!(conn_count, 1);
assert_eq!(stream_count, 0);
crate::test_complete!("test_connection_registry_concurrent_operations");
}
#[test]
fn test_connection_registry_concurrent_read_write() {
init_test("test_connection_registry_concurrent_read_write");
let registry = Arc::new(ConnectionRegistry::new());
let connection_count = 10;
for i in 0..connection_count {
registry.add_connection(format!("connection-{}", i));
}
let registry_clone = Arc::clone(®istry);
let reader_handle = std::thread::spawn(move || {
for _ in 0..1000 {
let (_conn_count, _stream_count) = registry_clone.get_stats();
std::thread::yield_now();
}
});
for i in 0..connection_count {
let connection_id = format!("connection-{}", i);
for stream_id in 0..5 {
let _ = registry.enforce_stream_limits(&connection_id, stream_id, 50, None);
}
for stream_id in 0..3 {
registry.remove_stream(&connection_id, stream_id);
}
}
reader_handle
.join()
.expect("reader thread should complete successfully");
crate::test_complete!("test_connection_registry_concurrent_read_write");
}
#[test]
fn test_connection_state_thread_safety() {
init_test("test_connection_state_thread_safety");
let connection_state = Arc::new(ConnectionState::new());
let num_threads = 4;
let streams_per_thread = 25;
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let state = Arc::clone(&connection_state);
let handle = std::thread::spawn(move || {
for i in 0..streams_per_thread {
let stream_id = (thread_id * streams_per_thread) + i;
if let Ok(_timestamp) = state.add_stream(stream_id, 200) {
state.update_stream_activity(stream_id);
state.remove_stream(stream_id);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("thread should complete successfully");
}
assert_eq!(connection_state.active_stream_count(), 0);
crate::test_complete!("test_connection_state_thread_safety");
}
#[test]
fn test_connection_registry_no_deadlocks_under_load() {
init_test("test_connection_registry_no_deadlocks_under_load");
let registry = Arc::new(ConnectionRegistry::new());
let num_connections = 5;
let num_threads = 8;
for i in 0..num_connections {
registry.add_connection(format!("conn-{}", i));
}
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let reg = Arc::clone(®istry);
let handle = std::thread::spawn(move || {
for i in 0..50 {
let conn_id = format!("conn-{}", i % num_connections);
let stream_id = (thread_id * 50) + i;
let _ = reg.enforce_stream_limits(&conn_id, stream_id, 100, None);
reg.update_stream_activity(&conn_id, stream_id);
let _ = reg.get_stats();
reg.remove_stream(&conn_id, stream_id);
if i % 10 == 0 {
std::thread::yield_now();
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("no deadlocks should occur");
}
crate::test_complete!("test_connection_registry_no_deadlocks_under_load");
}
#[test]
fn server_config_debug() {
let config = ServerConfig::default();
let dbg = format!("{config:?}");
assert!(dbg.contains("ServerConfig"));
assert!(dbg.contains("max_recv_message_size"));
assert!(dbg.contains("max_concurrent_streams"));
}
#[test]
fn server_config_clone() {
let config = ServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 2048,
..Default::default()
};
let config2 = config;
assert_eq!(config2.max_recv_message_size, 1024);
assert_eq!(config2.max_send_message_size, 2048);
}
#[test]
fn server_config_default_values() {
let config = ServerConfig::default();
assert_eq!(config.max_recv_message_size, 4 * 1024 * 1024);
assert_eq!(config.max_send_message_size, 4 * 1024 * 1024);
assert_eq!(config.initial_connection_window_size, 1024 * 1024);
assert_eq!(config.initial_stream_window_size, 1024 * 1024);
assert_eq!(config.max_concurrent_streams, 100);
assert!(config.keepalive_interval_ms.is_none());
assert!(config.keepalive_timeout_ms.is_none());
}
#[test]
fn server_builder_debug() {
let builder = ServerBuilder::new();
let dbg = format!("{builder:?}");
assert!(dbg.contains("ServerBuilder"));
assert!(dbg.contains("config"));
}
#[test]
fn server_builder_default() {
let builder = ServerBuilder::default();
let dbg = format!("{builder:?}");
assert!(dbg.contains("ServerBuilder"));
}
#[test]
fn server_debug() {
let server = Server::builder().build();
let dbg = format!("{server:?}");
assert!(dbg.contains("Server"));
assert!(dbg.contains("config"));
}
#[test]
fn call_context_debug() {
let ctx = CallContext::new();
let dbg = format!("{ctx:?}");
assert!(dbg.contains("CallContext"));
assert!(dbg.contains("metadata"));
}
#[test]
fn call_context_default() {
let ctx = CallContext::default();
assert!(ctx.deadline().is_none());
assert!(ctx.peer_addr().is_none());
assert!(ctx.metadata().is_empty());
}
#[test]
fn noop_interceptor_debug_clone_copy_default() {
let interceptor = NoopInterceptor;
let dbg = format!("{interceptor:?}");
assert!(dbg.contains("NoopInterceptor"));
let cloned = interceptor;
let _ = format!("{cloned:?}");
let copied = interceptor; let _ = format!("{copied:?}");
let default = NoopInterceptor;
let _ = format!("{default:?}");
}
#[test]
fn ok_utility_returns_ok_response() {
let result: Result<Response<i32>, Status> = ok(42);
assert!(result.is_ok());
assert_eq!(result.unwrap().into_inner(), 42);
}
#[test]
fn err_utility_returns_err_status() {
let result: Result<Response<i32>, Status> = err(Status::not_found("missing"));
assert!(result.is_err());
}
#[test]
fn server_builder_keepalive() {
let server = Server::builder()
.keepalive_interval(5000)
.keepalive_timeout(2000)
.build();
assert_eq!(server.config().keepalive_interval_ms, Some(5000));
assert_eq!(server.config().keepalive_timeout_ms, Some(2000));
}
#[test]
fn server_builder_window_sizes() {
let server = Server::builder()
.initial_connection_window_size(512 * 1024)
.initial_stream_window_size(256 * 1024)
.build();
assert_eq!(server.config().initial_connection_window_size, 512 * 1024);
assert_eq!(server.config().initial_stream_window_size, 256 * 1024);
}
#[test]
fn server_get_service_missing() {
let server = Server::builder().build();
assert!(server.get_service("nonexistent").is_none());
}
mod grpc_timeout_conformance {
use super::*;
#[test]
fn grpc_timeout_1_all_six_units_parse() {
let cases = &[
("1H", Duration::from_secs(3600)),
("2M", Duration::from_secs(120)),
("30S", Duration::from_secs(30)),
("500m", Duration::from_millis(500)),
("250u", Duration::from_micros(250)),
("42n", Duration::from_nanos(42)),
];
for (input, expected) in cases {
let got = parse_grpc_timeout(input);
assert_eq!(
got,
Some(*expected),
"GRPC-TIMEOUT-1: {input:?} must parse to {expected:?}",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-1\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_2_reject_more_than_eight_digits() {
let inputs = &["100000000S", "999999999m", "123456789n", "000000000H"];
for input in inputs {
assert_eq!(
parse_grpc_timeout(input),
None,
"GRPC-TIMEOUT-2: {input:?} must be rejected (>8 digits)",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-2\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_3_reject_malformed() {
let rejected = &[
"", "S", "100", " 10S", "10 S", "10s", "10x", "-1S", "1.5S", "abc", "١٠S", ];
for input in rejected {
assert_eq!(
parse_grpc_timeout(input),
None,
"GRPC-TIMEOUT-3: {input:?} must be rejected",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-3\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_4_format_parse_roundtrip() {
let lossless = &[
Duration::ZERO,
Duration::from_nanos(1),
Duration::from_nanos(42),
Duration::from_micros(250),
Duration::from_millis(500),
Duration::from_secs(30),
Duration::from_secs(120), Duration::from_secs(3600), Duration::from_secs(7200), ];
for d in lossless {
let formatted = format_grpc_timeout(*d);
let parsed = parse_grpc_timeout(&formatted).unwrap_or_else(|| {
panic!("GRPC-TIMEOUT-4: formatter output {formatted:?} not parseable")
});
assert_eq!(
parsed, *d,
"GRPC-TIMEOUT-4: round-trip diverged for {d:?} → {formatted:?} → {parsed:?}",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-4\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_5_formatter_respects_eight_digit_ceiling() {
let samples = &[
Duration::ZERO,
Duration::from_nanos(1),
Duration::from_secs(1),
Duration::from_secs(999_999_999), Duration::MAX, ];
for d in samples {
let formatted = format_grpc_timeout(*d);
if formatted.is_empty() {
panic!(
"format_grpc_timeout returned empty string for duration {:?}",
d
);
}
let (digits, unit) = formatted.split_at(formatted.len() - 1);
assert!(
matches!(unit, "H" | "M" | "S" | "m" | "u" | "n"),
"GRPC-TIMEOUT-5: unit {unit:?} not in spec set for input {d:?}",
);
assert!(
(1..=8).contains(&digits.len()),
"GRPC-TIMEOUT-5: digits {digits:?} length out of [1,8] for input {d:?}",
);
assert!(
digits.bytes().all(|b| b.is_ascii_digit()),
"GRPC-TIMEOUT-5: digits {digits:?} contains non-ASCII-digit for input {d:?}",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-5\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_6_zero_duration_fail_fast_representation() {
let formatted = format_grpc_timeout(Duration::ZERO);
let parsed = parse_grpc_timeout(&formatted).expect("zero parses");
assert_eq!(parsed, Duration::ZERO);
assert_eq!(
formatted, "0n",
"GRPC-TIMEOUT-6: zero must format as canonical \"0n\"",
);
eprintln!("{{\"id\":\"GRPC-TIMEOUT-6\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_7_overflow_rejected_not_wrapped() {
let safe = parse_grpc_timeout("99999999H");
assert!(
safe.is_some(),
"GRPC-TIMEOUT-7: 99_999_999H fits in u64 seconds and must parse",
);
for unit in &["H", "M", "S", "m", "u", "n"] {
let input = format!("99999999{unit}");
let _ = parse_grpc_timeout(&input);
let input = format!("00000000{unit}");
let _ = parse_grpc_timeout(&input);
let input = format!("0{unit}");
let parsed = parse_grpc_timeout(&input);
assert_eq!(
parsed,
Some(Duration::ZERO),
"GRPC-TIMEOUT-7: 0{unit} must parse to ZERO",
);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-7\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
#[test]
fn grpc_timeout_8_no_panic_on_adversarial_input() {
let adversarial: &[&str] = &[
"",
"\0",
"\0\0\0S",
"\n10S",
"10S\n",
"\u{FEFF}10S", "\u{200B}10S", "1\0S",
"\x7f10S",
"10\x00S",
"ääääääääS",
"10😀",
];
for input in adversarial {
let _ = parse_grpc_timeout(input);
}
eprintln!("{{\"id\":\"GRPC-TIMEOUT-8\",\"verdict\":\"PASS\",\"level\":\"Must\"}}",);
}
}
#[derive(Debug)]
struct CountingInterceptor {
name: &'static str,
request_count: std::sync::atomic::AtomicUsize,
response_count: std::sync::atomic::AtomicUsize,
events: Arc<parking_lot::Mutex<Vec<String>>>,
}
impl CountingInterceptor {
fn new(name: &'static str, events: Arc<parking_lot::Mutex<Vec<String>>>) -> Self {
Self {
name,
request_count: std::sync::atomic::AtomicUsize::new(0),
response_count: std::sync::atomic::AtomicUsize::new(0),
events,
}
}
}
impl Interceptor for CountingInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
self.request_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.events.lock().push(format!("req:{}", self.name));
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
self.response_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.events.lock().push(format!("resp:{}", self.name));
Ok(())
}
}
#[derive(Debug)]
struct RejectingInterceptor {
events: Arc<parking_lot::Mutex<Vec<String>>>,
}
impl Interceptor for RejectingInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
self.events.lock().push("req:reject".to_string());
Err(Status::unauthenticated("rejected by RejectingInterceptor"))
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
self.events.lock().push("resp:reject".to_string());
Ok(())
}
}
#[derive(Debug)]
struct AuthContextEchoInterceptor {
seen_principal: Arc<parking_lot::Mutex<Option<String>>>,
}
impl Interceptor for AuthContextEchoInterceptor {
fn intercept_request(&self, request: &mut Request<Bytes>) -> Result<(), Status> {
request.extensions_mut().insert_typed(
crate::grpc::interceptor::AuthContext::with_principal("svc-a")
.with_scopes(["read:rpc"]),
);
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response_with_request(
&self,
request: &Request<Bytes>,
_response: &mut Response<Bytes>,
) -> Result<(), Status> {
let seen = request
.extensions()
.get_typed::<crate::grpc::interceptor::AuthContext>()
.map(|auth| auth.principal.clone());
*self.seen_principal.lock() = seen;
Ok(())
}
}
#[derive(Debug)]
struct AuthContextErrorEchoInterceptor {
seen_principal: Arc<parking_lot::Mutex<Option<String>>>,
}
impl Interceptor for AuthContextErrorEchoInterceptor {
fn intercept_request(&self, request: &mut Request<Bytes>) -> Result<(), Status> {
request.extensions_mut().insert_typed(
crate::grpc::interceptor::AuthContext::with_principal("svc-a")
.with_scopes(["read:rpc"]),
);
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_error_with_request(
&self,
request: &Request<Bytes>,
_status: &mut Status,
) -> Result<(), Status> {
let seen = request
.extensions()
.get_typed::<crate::grpc::interceptor::AuthContext>()
.map(|auth| auth.principal.clone());
*self.seen_principal.lock() = seen;
Ok(())
}
}
#[derive(Debug)]
struct ResponseErrorInterceptor;
impl Interceptor for ResponseErrorInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response_with_request(
&self,
_request: &Request<Bytes>,
_response: &mut Response<Bytes>,
) -> Result<(), Status> {
Err(Status::internal("response interceptor exploded"))
}
}
const EXACT_INTERCEPTOR_MULTISTACK_RATE_LIMIT_CLEANUP_RCH_COMMAND: &str = "rch exec -- env CARGO_TARGET_DIR=${TMPDIR:-/tmp}/rch_target_asupersync_eqpd3i_interceptor cargo test -p asupersync --lib interceptor_multistack_rate_limit_cleanup -- --nocapture";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum MatrixFailureStage {
Request,
Response,
Handler,
}
impl MatrixFailureStage {
fn label(self) -> &'static str {
match self {
Self::Request => "request",
Self::Response => "response",
Self::Handler => "handler",
}
}
}
#[derive(Debug)]
struct MatrixInterceptor {
index: usize,
limiter: Arc<crate::grpc::interceptor::RateLimitInterceptor>,
events: Arc<parking_lot::Mutex<Vec<String>>>,
fail_stage: Option<MatrixFailureStage>,
}
impl MatrixInterceptor {
fn new(
index: usize,
limiter: Arc<crate::grpc::interceptor::RateLimitInterceptor>,
events: Arc<parking_lot::Mutex<Vec<String>>>,
fail_stage: Option<MatrixFailureStage>,
) -> Self {
Self {
index,
limiter,
events,
fail_stage,
}
}
fn record(&self, phase: &str) {
self.events.lock().push(format!(
"{phase}:{}:slots={}",
self.index,
self.limiter.current_count()
));
}
}
impl Interceptor for MatrixInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
self.record("req");
if self.fail_stage == Some(MatrixFailureStage::Request) {
return Err(Status::failed_precondition(format!(
"request interceptor {} rejected",
self.index
)));
}
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response_with_request(
&self,
_request: &Request<Bytes>,
_response: &mut Response<Bytes>,
) -> Result<(), Status> {
self.record("resp");
if self.fail_stage == Some(MatrixFailureStage::Response) {
return Err(Status::internal(format!(
"response interceptor {} exploded",
self.index
)));
}
Ok(())
}
fn intercept_error_with_request(
&self,
_request: &Request<Bytes>,
_status: &mut Status,
) -> Result<(), Status> {
self.record("cleanup");
Ok(())
}
}
fn expected_interceptor_cleanup_events(
stack_depth: usize,
failing_interceptor_index: Option<usize>,
failure_stage: Option<MatrixFailureStage>,
) -> Vec<String> {
let mut expected = Vec::new();
match failure_stage {
None => {
for index in 0..stack_depth {
expected.push(format!("req:{index}:slots=1"));
}
for index in (0..stack_depth).rev() {
expected.push(format!("resp:{index}:slots=1"));
}
}
Some(MatrixFailureStage::Request) => {
let failing_index =
failing_interceptor_index.expect("request failure requires interceptor index");
for index in 0..=failing_index {
expected.push(format!("req:{index}:slots=1"));
}
for index in (0..=failing_index).rev() {
expected.push(format!("cleanup:{index}:slots=1"));
}
}
Some(MatrixFailureStage::Response) => {
let failing_index =
failing_interceptor_index.expect("response failure requires interceptor index");
for index in 0..stack_depth {
expected.push(format!("req:{index}:slots=1"));
}
for index in (failing_index..stack_depth).rev() {
expected.push(format!("resp:{index}:slots=1"));
}
for index in (0..stack_depth).rev() {
expected.push(format!("cleanup:{index}:slots=1"));
}
}
Some(MatrixFailureStage::Handler) => {
for index in 0..stack_depth {
expected.push(format!("req:{index}:slots=1"));
}
for index in (0..stack_depth).rev() {
expected.push(format!("cleanup:{index}:slots=1"));
}
}
}
expected
}
fn assert_interceptor_cleanup_result(
result: Result<Response<Bytes>, Status>,
failure_stage: Option<MatrixFailureStage>,
context: &str,
) -> &'static str {
match failure_stage {
None => {
let response = result.expect(context);
assert_eq!(response.get_ref().as_ref(), b"matrix-ok");
"ok"
}
Some(MatrixFailureStage::Request) => {
let status = result.expect_err(context);
assert_eq!(status.code(), super::super::Code::FailedPrecondition);
"FailedPrecondition"
}
Some(MatrixFailureStage::Response) | Some(MatrixFailureStage::Handler) => {
let status = result.expect_err(context);
assert_eq!(status.code(), super::super::Code::Internal);
"Internal"
}
}
}
fn log_interceptor_cleanup_case(
request_id: &str,
stack_depth: usize,
failing_interceptor_index: Option<usize>,
failure_stage: Option<MatrixFailureStage>,
slot_count_before: u32,
slot_count_after: u32,
release_count: usize,
first_result_kind: &str,
replay_result_kind: &str,
events: &[String],
final_verdict: &str,
) {
println!(
"GRPC_INTERCEPTOR_RATE_LIMIT \
request_id={} \
stack_depth={} \
failing_interceptor_index={} \
failure_stage={} \
slot_count_before={} \
slot_count_after={} \
release_count={} \
response_error_kind={} \
replay_result_kind={} \
cancellation_state=none_unary_dispatch \
event_trace={} \
exact_rch_command=\"{}\" \
artifact_paths=none \
no_slot_leak_verdict={}",
request_id,
stack_depth,
failing_interceptor_index
.map(|index| index.to_string())
.unwrap_or_else(|| "none".to_string()),
failure_stage.map_or("none", MatrixFailureStage::label),
slot_count_before,
slot_count_after,
release_count,
first_result_kind,
replay_result_kind,
events.join(">"),
EXACT_INTERCEPTOR_MULTISTACK_RATE_LIMIT_CLEANUP_RCH_COMMAND,
final_verdict,
);
}
fn block_on<F: Future>(fut: F) -> F::Output {
use std::task::{Context, Waker};
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let mut pinned = Box::pin(fut);
loop {
if let std::task::Poll::Ready(value) = pinned.as_mut().poll(&mut cx) {
return value;
}
}
}
const EXACT_GRPC_UNARY_METADATA_ISOLATION_RCH_COMMAND: &str = "rch exec -- env CARGO_TARGET_DIR=${TMPDIR:-/tmp}/rch_target_asupersync_6lxh8c_metadata cargo test -p asupersync --lib grpc_unary_metadata_isolation -- --nocapture";
#[derive(Clone, Debug, Default)]
struct UnaryMetadataIsolationRecord {
request_fingerprint: Option<String>,
handler_before_fingerprint: Option<String>,
handler_after_fingerprint: Option<String>,
snapshot_fingerprint: Option<String>,
duplicate_key_count: usize,
status_fingerprint: Option<String>,
}
#[derive(Debug)]
struct MetadataIsolationInterceptor {
records: Arc<
parking_lot::Mutex<std::collections::BTreeMap<String, UnaryMetadataIsolationRecord>>,
>,
}
impl MetadataIsolationInterceptor {
fn new(
records: Arc<
parking_lot::Mutex<
std::collections::BTreeMap<String, UnaryMetadataIsolationRecord>,
>,
>,
) -> Self {
Self { records }
}
fn call_id(metadata: &Metadata) -> String {
match metadata.get("x-call-id") {
Some(super::super::streaming::MetadataValue::Ascii(value)) => value.clone(),
Some(super::super::streaming::MetadataValue::Binary(value)) => {
format!("binary-call-id:{}", value.len())
}
None => "missing-call-id".to_string(),
}
}
}
impl Interceptor for MetadataIsolationInterceptor {
fn intercept_request(&self, request: &mut Request<Bytes>) -> Result<(), Status> {
let call_id = Self::call_id(request.metadata());
let mut records = self.records.lock();
let record = records.entry(call_id).or_default();
record.request_fingerprint = Some(sanitized_metadata_fingerprint(request.metadata()));
record.duplicate_key_count = metadata_key_count(request.metadata(), "x-dup");
Ok(())
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
fn intercept_response_with_request(
&self,
request: &Request<Bytes>,
response: &mut Response<Bytes>,
) -> Result<(), Status> {
let call_id = Self::call_id(request.metadata());
let snapshot_fingerprint = sanitized_metadata_fingerprint(request.metadata());
let duplicate_key_count = metadata_key_count(request.metadata(), "x-dup");
let _ = response
.metadata_mut()
.insert("x-call-id-echo", call_id.clone());
let _ = response
.metadata_mut()
.insert("x-request-snapshot", snapshot_fingerprint.clone());
let _ = response
.metadata_mut()
.insert("x-request-dup-count", duplicate_key_count.to_string());
let mut records = self.records.lock();
let record = records.entry(call_id).or_default();
record.snapshot_fingerprint = Some(snapshot_fingerprint);
record.duplicate_key_count = duplicate_key_count;
Ok(())
}
fn intercept_error_with_request(
&self,
request: &Request<Bytes>,
status: &mut Status,
) -> Result<(), Status> {
let call_id = Self::call_id(request.metadata());
let mut records = self.records.lock();
let record = records.entry(call_id).or_default();
record.snapshot_fingerprint = Some(sanitized_metadata_fingerprint(request.metadata()));
record.duplicate_key_count = metadata_key_count(request.metadata(), "x-dup");
record.status_fingerprint = Some(format!("{:?}:{}", status.code(), status.message()));
Ok(())
}
}
fn metadata_value_fingerprint(
key: &str,
value: &super::super::streaming::MetadataValue,
) -> String {
match value {
super::super::streaming::MetadataValue::Ascii(text) => {
let sanitized = super::super::streaming::sanitize_metadata_ascii_value(text);
if matches!(key, "authorization" | "x-trace-id" | "grpc-timeout") {
format!("{key}=redacted:{}", sanitized.len())
} else {
format!("{key}={sanitized}")
}
}
super::super::streaming::MetadataValue::Binary(bytes) => {
format!("{key}=bin:{}", bytes.len())
}
}
}
fn sanitized_metadata_fingerprint(metadata: &Metadata) -> String {
let mut entries = metadata
.iter()
.map(|(key, value)| metadata_value_fingerprint(key, value))
.collect::<Vec<_>>();
entries.sort();
if entries.is_empty() {
"empty".to_string()
} else {
entries.join("|")
}
}
fn metadata_key_count(metadata: &Metadata, key: &str) -> usize {
metadata
.iter()
.filter(|(existing_key, _)| existing_key.eq_ignore_ascii_case(key))
.count()
}
fn metadata_ascii_value(metadata: &Metadata, key: &str) -> Option<String> {
match metadata.get(key) {
Some(super::super::streaming::MetadataValue::Ascii(value)) => Some(value.clone()),
_ => None,
}
}
#[derive(Clone, Debug)]
struct UnaryMetadataCase {
call_id: &'static str,
duplicate_values: &'static [&'static str],
include_binary: bool,
include_auth: bool,
include_trace: bool,
large_value_len: usize,
cancel: bool,
}
#[derive(Debug)]
struct UnaryMetadataOutcome {
call_id: String,
expected_request_fingerprint: String,
expected_duplicate_count: usize,
response_metadata: Option<Metadata>,
status: Option<Status>,
}
fn build_unary_metadata_request(case: &UnaryMetadataCase) -> Request<Bytes> {
let mut metadata = Metadata::new();
let _ = metadata.insert("x-call-id", case.call_id);
let _ = metadata.insert("content-type", "application/grpc+proto");
let _ = metadata.insert("te", "trailers");
if case.include_auth {
let _ = metadata.insert("authorization", format!("Bearer secret-{}", case.call_id));
}
if case.include_trace {
let _ = metadata.insert("x-trace-id", format!("trace-{}-token", case.call_id));
}
if case.include_binary {
let _ = metadata.insert_bin(
"trace-context",
Bytes::from(case.call_id.as_bytes().to_vec()),
);
}
for value in case.duplicate_values {
let _ = metadata.insert("x-dup", (*value).to_string());
}
if case.large_value_len > 0 {
let _ = metadata.insert("x-large", "x".repeat(case.large_value_len));
}
Request::with_metadata(Bytes::from(case.call_id.as_bytes().to_vec()), metadata)
}
fn log_grpc_unary_metadata_case(
scenario_id: &str,
call_id: &str,
sanitized_metadata_fingerprint: &str,
handler_observed_fingerprint: &str,
response_trailer_fingerprint: &str,
cancellation_state: &str,
mismatch_count: usize,
leaked_key_list: &str,
final_isolation_verdict: &str,
) {
println!(
"GRPC_UNARY_METADATA_ISOLATION \
scenario_id={} \
call_id={} \
sanitized_metadata_fingerprint={} \
handler_observed_fingerprint={} \
response_trailer_fingerprint={} \
cancellation_state={} \
mismatch_count={} \
leaked_key_list={} \
exact_rch_command=\"{}\" \
artifact_paths=none \
final_isolation_verdict={}",
scenario_id,
call_id,
sanitized_metadata_fingerprint,
handler_observed_fingerprint,
response_trailer_fingerprint,
cancellation_state,
mismatch_count,
leaked_key_list,
EXACT_GRPC_UNARY_METADATA_ISOLATION_RCH_COMMAND,
final_isolation_verdict,
);
}
fn run_grpc_unary_metadata_isolation_scenario(scenario_id: &str, cases: &[UnaryMetadataCase]) {
let records = Arc::new(parking_lot::Mutex::new(std::collections::BTreeMap::<
String,
UnaryMetadataIsolationRecord,
>::new()));
let server = std::sync::Arc::new(
Server::builder()
.add_service(TestService)
.interceptor(MetadataIsolationInterceptor::new(Arc::clone(&records)))
.build(),
);
let barrier = std::sync::Arc::new(std::sync::Barrier::new(cases.len()));
let outcomes = std::thread::scope(|scope| {
let mut joins = Vec::new();
for case in cases.iter().cloned() {
let server = std::sync::Arc::clone(&server);
let barrier = std::sync::Arc::clone(&barrier);
let records = Arc::clone(&records);
joins.push(scope.spawn(move || {
let request = build_unary_metadata_request(&case);
let expected_request_fingerprint =
sanitized_metadata_fingerprint(request.metadata());
let expected_duplicate_count = metadata_key_count(request.metadata(), "x-dup");
let call_id = case.call_id.to_string();
let cancel = case.cancel;
let result = block_on(server.dispatch_unary(request, {
let barrier = std::sync::Arc::clone(&barrier);
let records = Arc::clone(&records);
let call_id = call_id.clone();
move |mut request| {
let barrier = std::sync::Arc::clone(&barrier);
let records = Arc::clone(&records);
let call_id = call_id.clone();
async move {
let handler_before =
sanitized_metadata_fingerprint(request.metadata());
{
let mut map = records.lock();
let record = map.entry(call_id.clone()).or_default();
record.handler_before_fingerprint = Some(handler_before);
}
barrier.wait();
let _ = request
.metadata_mut()
.insert("x-local-handler-only", format!("mut-{call_id}"));
let _ = request.metadata_mut().insert_or_replace(
"authorization",
format!("Bearer handler-mutated-{call_id}"),
);
let handler_after =
sanitized_metadata_fingerprint(request.metadata());
{
let mut map = records.lock();
let record = map.entry(call_id.clone()).or_default();
record.handler_after_fingerprint = Some(handler_after.clone());
}
if cancel {
Err(Status::cancelled(format!("cancelled-{call_id}")))
} else {
let mut response = Response::new(request.into_inner());
let _ = response
.metadata_mut()
.insert("x-handler-call-id", call_id.clone());
let _ = response
.metadata_mut()
.insert("x-handler-fingerprint", handler_after);
Ok(response)
}
}
}
}));
match result {
Ok(response) => UnaryMetadataOutcome {
call_id,
expected_request_fingerprint,
expected_duplicate_count,
response_metadata: Some(response.metadata().clone()),
status: None,
},
Err(status) => UnaryMetadataOutcome {
call_id,
expected_request_fingerprint,
expected_duplicate_count,
response_metadata: None,
status: Some(status),
},
}
}));
}
joins
.into_iter()
.map(|join| {
join.join()
.expect("metadata isolation worker must complete")
})
.collect::<Vec<_>>()
});
let records = records.lock().clone();
let all_call_ids = outcomes
.iter()
.map(|outcome| outcome.call_id.clone())
.collect::<Vec<_>>();
for outcome in outcomes {
let record = records
.get(&outcome.call_id)
.expect("every call must produce an isolation record");
let mut mismatches = Vec::new();
if record.request_fingerprint.as_deref()
!= Some(outcome.expected_request_fingerprint.as_str())
{
mismatches.push("request_fingerprint");
}
if record.handler_before_fingerprint.as_deref()
!= Some(outcome.expected_request_fingerprint.as_str())
{
mismatches.push("handler_before_fingerprint");
}
if record.snapshot_fingerprint.as_deref()
!= Some(outcome.expected_request_fingerprint.as_str())
{
mismatches.push("snapshot_fingerprint");
}
if record.duplicate_key_count != outcome.expected_duplicate_count {
mismatches.push("duplicate_key_count");
}
let handler_after = record
.handler_after_fingerprint
.as_deref()
.expect("handler_after_fingerprint must be recorded");
assert!(
handler_after.contains("x-local-handler-only=mut-"),
"{}: handler-local mutation must stay visible to the handler copy",
outcome.call_id
);
let response_trailer_fingerprint = if let Some(ref response_metadata) =
outcome.response_metadata
{
let echoed_call_id = metadata_ascii_value(response_metadata, "x-call-id-echo")
.expect("response interceptor must echo request call id");
let handler_call_id = metadata_ascii_value(response_metadata, "x-handler-call-id")
.expect("handler must echo its local call id");
let request_snapshot =
metadata_ascii_value(response_metadata, "x-request-snapshot")
.expect("response interceptor must preserve request snapshot");
let duplicate_key_count =
metadata_ascii_value(response_metadata, "x-request-dup-count")
.expect("response interceptor must surface duplicate count");
let handler_fingerprint =
metadata_ascii_value(response_metadata, "x-handler-fingerprint")
.expect("handler must surface local metadata fingerprint");
if echoed_call_id != outcome.call_id {
mismatches.push("response_call_id_echo");
}
if handler_call_id != outcome.call_id {
mismatches.push("handler_call_id_echo");
}
if request_snapshot != outcome.expected_request_fingerprint {
mismatches.push("response_request_snapshot");
}
if duplicate_key_count != outcome.expected_duplicate_count.to_string() {
mismatches.push("response_duplicate_key_count");
}
if handler_fingerprint != handler_after {
mismatches.push("response_handler_fingerprint");
}
sanitized_metadata_fingerprint(response_metadata)
} else {
let status = outcome
.status
.as_ref()
.expect("cancelled/error case must carry status");
if status.code() != super::super::Code::Cancelled {
mismatches.push("cancelled_status_code");
}
let expected_status = format!("Cancelled:cancelled-{}", outcome.call_id);
if record.status_fingerprint.as_deref() != Some(expected_status.as_str()) {
mismatches.push("status_fingerprint");
}
expected_status
};
let leaked_key_list = all_call_ids
.iter()
.filter(|other| **other != outcome.call_id)
.filter(|other| {
outcome
.expected_request_fingerprint
.contains(other.as_str())
|| record
.snapshot_fingerprint
.as_deref()
.is_some_and(|value| value.contains(other.as_str()))
|| handler_after.contains(other.as_str())
|| response_trailer_fingerprint.contains(other.as_str())
})
.cloned()
.collect::<Vec<_>>();
let mismatch_count = mismatches.len().saturating_add(leaked_key_list.len());
assert_eq!(
mismatch_count, 0,
"{}: metadata isolation mismatches={:?} leaked={:?}",
outcome.call_id, mismatches, leaked_key_list
);
let cancellation_state = if outcome.response_metadata.is_some() {
"completed"
} else {
"cancelled_overlap"
};
let leaked_key_summary = if leaked_key_list.is_empty() {
"none".to_string()
} else {
leaked_key_list.join("|")
};
log_grpc_unary_metadata_case(
scenario_id,
&outcome.call_id,
&outcome.expected_request_fingerprint,
handler_after,
&response_trailer_fingerprint,
cancellation_state,
mismatch_count,
&leaked_key_summary,
"pass",
);
}
}
#[test]
fn mfk14i_dispatch_unary_runs_interceptor_chain_around_handler() {
init_test("mfk14i_dispatch_unary_runs_interceptor_chain_around_handler");
let events = Arc::new(parking_lot::Mutex::new(Vec::new()));
let i_a = CountingInterceptor::new("A", Arc::clone(&events));
let i_b = CountingInterceptor::new("B", Arc::clone(&events));
let server = Server::builder()
.add_service(TestService)
.interceptor(i_a)
.interceptor(i_b)
.build();
let request = Request::with_metadata(Bytes::from_static(b"hello"), Metadata::new());
let result = block_on(server.dispatch_unary(request, |req| async move {
let payload = req.into_inner();
Ok(Response::new(payload))
}));
let response = result.expect("dispatch must succeed");
assert_eq!(response.get_ref().as_ref(), b"hello");
let actual = events.lock().clone();
assert_eq!(
actual,
vec![
"req:A".to_string(),
"req:B".to_string(),
"resp:B".to_string(),
"resp:A".to_string(),
],
"interceptors must fire in registration order on requests \
and REVERSE order on responses; got {actual:?}"
);
}
#[test]
fn mfk14i_dispatch_unary_rejected_request_short_circuits_handler_and_response_chain() {
init_test(
"mfk14i_dispatch_unary_rejected_request_short_circuits_handler_and_response_chain",
);
let events = Arc::new(parking_lot::Mutex::new(Vec::new()));
let i_a = CountingInterceptor::new("A", Arc::clone(&events));
let reject = RejectingInterceptor {
events: Arc::clone(&events),
};
let i_after = CountingInterceptor::new("after", Arc::clone(&events));
let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let handler_called_clone = Arc::clone(&handler_called);
let server = Server::builder()
.add_service(TestService)
.interceptor(i_a)
.interceptor(reject)
.interceptor(i_after)
.build();
let request = Request::with_metadata(Bytes::from_static(b"x"), Metadata::new());
let result = block_on(server.dispatch_unary(request, move |req| {
let flag = Arc::clone(&handler_called_clone);
async move {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(Response::new(req.into_inner()))
}
}));
let err = result.expect_err("rejected request must surface as Err");
assert_eq!(err.code(), super::super::Code::Unauthenticated);
assert!(
!handler_called.load(std::sync::atomic::Ordering::SeqCst),
"handler must NOT be invoked when an earlier interceptor rejects"
);
let actual = events.lock().clone();
assert_eq!(
actual,
vec!["req:A".to_string(), "req:reject".to_string()],
"post-reject interceptors (request and response side) must NOT fire; \
got {actual:?}"
);
}
#[test]
fn mfk14i_dispatch_unary_handler_error_skips_response_chain() {
init_test("mfk14i_dispatch_unary_handler_error_skips_response_chain");
let events = Arc::new(parking_lot::Mutex::new(Vec::new()));
let i_a = CountingInterceptor::new("A", Arc::clone(&events));
let server = Server::builder()
.add_service(TestService)
.interceptor(i_a)
.build();
let request = Request::with_metadata(Bytes::new(), Metadata::new());
let result = block_on(server.dispatch_unary(request, |_req| async move {
Err::<Response<Bytes>, _>(Status::internal("handler exploded"))
}));
assert!(result.is_err());
let actual = events.lock().clone();
assert_eq!(
actual,
vec!["req:A".to_string()],
"response-side chain must NOT fire on handler error; got {actual:?}"
);
}
#[test]
fn dispatch_unary_preserves_auth_context_for_error_interceptors() {
init_test("dispatch_unary_preserves_auth_context_for_error_interceptors");
let seen_principal = Arc::new(parking_lot::Mutex::new(None));
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextErrorEchoInterceptor {
seen_principal: Arc::clone(&seen_principal),
})
.build();
let request = Request::with_metadata(Bytes::new(), Metadata::new());
let result = block_on(server.dispatch_unary(request, |_req| async move {
Err::<Response<Bytes>, _>(Status::permission_denied("denied by handler"))
}));
assert!(result.is_err(), "handler error must surface");
let seen = seen_principal.lock().clone();
assert_eq!(
seen.as_deref(),
Some("svc-a"),
"error-side interceptors must still observe request AuthContext"
);
}
#[derive(Debug)]
struct RequestRejectingInterceptor;
impl Interceptor for RequestRejectingInterceptor {
fn intercept_request(&self, _request: &mut Request<Bytes>) -> Result<(), Status> {
Err(Status::unauthenticated("request interceptor rejection"))
}
fn intercept_response(&self, _response: &mut Response<Bytes>) -> Result<(), Status> {
Ok(())
}
}
#[test]
fn dispatch_unary_clears_auth_context_on_request_interceptor_error() {
init_test("dispatch_unary_clears_auth_context_on_request_interceptor_error");
let seen_principal = Arc::new(parking_lot::Mutex::new(None));
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextErrorEchoInterceptor {
seen_principal: Arc::clone(&seen_principal),
})
.interceptor(RequestRejectingInterceptor)
.build();
let request = Request::with_metadata(Bytes::new(), Metadata::new());
let result = block_on(server.dispatch_unary(request, |_req| async {
panic!("handler should not be called when request interceptor rejects");
}));
assert!(result.is_err(), "request interceptor error must surface");
let status = result.unwrap_err();
assert_eq!(status.code(), crate::grpc::status::Code::Unauthenticated);
let seen = seen_principal.lock().clone();
assert_eq!(
seen.as_deref(),
Some("svc-a"),
"error interceptor should see AuthContext before cleanup"
);
}
#[test]
fn dispatch_unary_clears_auth_context_on_handler_timeout() {
init_test("dispatch_unary_clears_auth_context_on_handler_timeout");
let handler_request = Arc::new(parking_lot::Mutex::new(None::<Request<Bytes>>));
let handler_request_ref = Arc::clone(&handler_request);
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextErrorEchoInterceptor {
seen_principal: Arc::new(parking_lot::Mutex::new(None)),
})
.build();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "1m"); let request = Request::with_metadata(Bytes::new(), metadata);
let result = block_on(server.dispatch_unary(request, move |req| {
let handler_request_ref = Arc::clone(&handler_request_ref);
async move {
*handler_request_ref.lock() = Some(req);
crate::time::sleep(crate::time::wall_now(), Duration::from_millis(100)).await;
Ok::<Response<Bytes>, Status>(Response::new(Bytes::new()))
}
}));
assert!(result.is_err(), "handler should timeout");
let status = result.unwrap_err();
assert_eq!(status.code(), crate::grpc::status::Code::DeadlineExceeded);
}
#[test]
fn dispatch_unary_clears_auth_context_on_handler_error() {
init_test("dispatch_unary_clears_auth_context_on_handler_error");
let seen_principal = Arc::new(parking_lot::Mutex::new(None));
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextErrorEchoInterceptor {
seen_principal: Arc::clone(&seen_principal),
})
.build();
let request = Request::with_metadata(Bytes::new(), Metadata::new());
let result = block_on(server.dispatch_unary(request, |req| async move {
let auth = req
.extensions()
.get_typed::<crate::grpc::interceptor::AuthContext>();
assert!(auth.is_some(), "handler should see AuthContext");
assert_eq!(auth.unwrap().principal, "svc-a");
Err::<Response<Bytes>, _>(Status::internal("handler error"))
}));
assert!(result.is_err(), "handler error must surface");
let seen = seen_principal.lock().clone();
assert_eq!(
seen.as_deref(),
Some("svc-a"),
"error interceptor should see AuthContext before it's cleared"
);
}
#[test]
fn dispatch_unary_clears_auth_context_on_response_interceptor_error() {
init_test("dispatch_unary_clears_auth_context_on_response_interceptor_error");
let seen_principal = Arc::new(parking_lot::Mutex::new(None));
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextErrorEchoInterceptor {
seen_principal: Arc::clone(&seen_principal),
})
.interceptor(ResponseErrorInterceptor)
.build();
let request = Request::with_metadata(Bytes::new(), Metadata::new());
let result = block_on(server.dispatch_unary(request, |req| async move {
let auth = req
.extensions()
.get_typed::<crate::grpc::interceptor::AuthContext>();
assert!(auth.is_some(), "handler should see AuthContext");
Ok::<Response<Bytes>, Status>(Response::new(Bytes::new()))
}));
assert!(result.is_err(), "response interceptor error must surface");
let status = result.unwrap_err();
assert_eq!(status.code(), crate::grpc::status::Code::Internal);
let seen = seen_principal.lock().clone();
assert_eq!(
seen.as_deref(),
Some("svc-a"),
"error interceptor should see AuthContext during cleanup"
);
}
#[test]
fn dispatch_unary_releases_rate_limit_slot_on_handler_error() {
init_test("dispatch_unary_releases_rate_limit_slot_on_handler_error");
let server = Server::builder()
.add_service(TestService)
.interceptor(crate::grpc::interceptor::rate_limiter(1))
.build();
let first = Request::with_metadata(Bytes::new(), Metadata::new());
let first_result = block_on(server.dispatch_unary(first, |_req| async move {
Err::<Response<Bytes>, _>(Status::internal("handler exploded"))
}));
assert!(
matches!(first_result, Err(ref status) if status.code() == super::super::Code::Internal),
"first call must surface the handler error, not resource exhaustion"
);
let second = Request::with_metadata(Bytes::from_static(b"ok"), Metadata::new());
let second_result = block_on(server.dispatch_unary(second, |req| async move {
Ok::<Response<Bytes>, Status>(Response::new(req.into_inner()))
}));
let response = second_result.expect("slot must be released after handler error");
assert_eq!(response.get_ref().as_ref(), b"ok");
}
#[test]
fn dispatch_unary_releases_rate_limit_slot_on_response_hook_error() {
init_test("dispatch_unary_releases_rate_limit_slot_on_response_hook_error");
let server = Server::builder()
.add_service(TestService)
.interceptor(crate::grpc::interceptor::rate_limiter(1))
.interceptor(ResponseErrorInterceptor)
.build();
let first = Request::with_metadata(Bytes::new(), Metadata::new());
let first_result = block_on(server.dispatch_unary(first, |_req| async move {
Ok::<Response<Bytes>, Status>(Response::new(Bytes::from_static(b"ignored")))
}));
assert!(
matches!(first_result, Err(ref status) if status.code() == super::super::Code::Internal),
"first call must surface the response-hook error"
);
let second = Request::with_metadata(Bytes::new(), Metadata::new());
let second_result = block_on(server.dispatch_unary(second, |_req| async move {
Ok::<Response<Bytes>, Status>(Response::new(Bytes::from_static(b"ignored")))
}));
assert!(
matches!(second_result, Err(ref status) if status.code() == super::super::Code::Internal),
"second call must not be blocked by a leaked rate-limit slot"
);
}
#[test]
fn conformance_interceptor_multistack_rate_limit_cleanup_matrix_logs_evidence() {
init_test("conformance_interceptor_multistack_rate_limit_cleanup_matrix_logs_evidence");
let cases = [
("success_depth_1", 1usize, None, None),
("success_depth_2", 2usize, None, None),
("success_depth_5", 5usize, None, None),
(
"request_fail_depth_1_idx_0",
1usize,
Some(0usize),
Some(MatrixFailureStage::Request),
),
(
"request_fail_depth_2_idx_1",
2usize,
Some(1usize),
Some(MatrixFailureStage::Request),
),
(
"request_fail_depth_5_idx_4",
5usize,
Some(4usize),
Some(MatrixFailureStage::Request),
),
(
"response_fail_depth_1_idx_0",
1usize,
Some(0usize),
Some(MatrixFailureStage::Response),
),
(
"response_fail_depth_2_idx_1",
2usize,
Some(1usize),
Some(MatrixFailureStage::Response),
),
(
"response_fail_depth_5_idx_4",
5usize,
Some(4usize),
Some(MatrixFailureStage::Response),
),
(
"handler_error_depth_5",
5usize,
None,
Some(MatrixFailureStage::Handler),
),
];
for (request_id, stack_depth, failing_interceptor_index, failure_stage) in cases {
let limiter = Arc::new(crate::grpc::interceptor::rate_limiter(1));
let events = Arc::new(parking_lot::Mutex::new(Vec::new()));
let mut builder = Server::builder()
.add_service(TestService)
.interceptor_arc(limiter.clone());
for index in 0..stack_depth {
let fail_stage = if failing_interceptor_index == Some(index) {
failure_stage.filter(|stage| *stage != MatrixFailureStage::Handler)
} else {
None
};
builder = builder.interceptor_arc(Arc::new(MatrixInterceptor::new(
index,
Arc::clone(&limiter),
Arc::clone(&events),
fail_stage,
)));
}
let server = builder.build();
let slot_count_before = limiter.current_count();
assert_eq!(
slot_count_before, 0,
"{request_id}: slot count must start at zero"
);
let first_request =
Request::with_metadata(Bytes::from_static(b"matrix"), Metadata::new());
let first_result = block_on(server.dispatch_unary(first_request, |_req| async move {
match failure_stage {
Some(MatrixFailureStage::Handler) => {
Err::<Response<Bytes>, _>(Status::internal("handler exploded"))
}
_ => Ok::<Response<Bytes>, Status>(Response::new(Bytes::from_static(
b"matrix-ok",
))),
}
}));
let first_result_kind = assert_interceptor_cleanup_result(
first_result,
failure_stage,
"first dispatch must match the configured outcome",
);
let first_events = events.lock().clone();
let expected_events = expected_interceptor_cleanup_events(
stack_depth,
failing_interceptor_index,
failure_stage,
);
assert_eq!(
first_events, expected_events,
"{request_id}: interceptor events must prove short-circuit and cleanup order"
);
let slot_count_after = limiter.current_count();
assert_eq!(
slot_count_after, 0,
"{request_id}: failing or succeeding dispatch must release the rate-limit slot"
);
events.lock().clear();
let replay_request =
Request::with_metadata(Bytes::from_static(b"matrix"), Metadata::new());
let replay_result =
block_on(server.dispatch_unary(replay_request, |_req| async move {
match failure_stage {
Some(MatrixFailureStage::Handler) => {
Err::<Response<Bytes>, _>(Status::internal("handler exploded"))
}
_ => Ok::<Response<Bytes>, Status>(Response::new(Bytes::from_static(
b"matrix-ok",
))),
}
}));
let replay_result_kind = assert_interceptor_cleanup_result(
replay_result,
failure_stage,
"replay dispatch must prove no leaked or double-released slot",
);
assert_eq!(
limiter.current_count(),
0,
"{request_id}: replay dispatch must also leave the rate-limit slot count at zero"
);
let release_count =
usize::from(first_events.iter().any(|event| event.ends_with("slots=1")));
assert_eq!(
release_count, 1,
"{request_id}: cleanup matrix must observe exactly one acquired slot per call"
);
log_interceptor_cleanup_case(
request_id,
stack_depth,
failing_interceptor_index,
failure_stage,
slot_count_before,
slot_count_after,
release_count,
first_result_kind,
replay_result_kind,
&first_events,
"pass",
);
}
}
#[test]
fn grpc_unary_metadata_isolation_two_call_cancelled_overlap() {
init_test("grpc_unary_metadata_isolation_two_call_cancelled_overlap");
let cases = [
UnaryMetadataCase {
call_id: "call-alpha",
duplicate_values: &["alpha-0", "alpha-1"],
include_binary: true,
include_auth: true,
include_trace: true,
large_value_len: 0,
cancel: false,
},
UnaryMetadataCase {
call_id: "call-bravo",
duplicate_values: &[],
include_binary: false,
include_auth: false,
include_trace: false,
large_value_len: 0,
cancel: true,
},
];
run_grpc_unary_metadata_isolation_scenario("two_call_cancelled_overlap", &cases);
crate::test_complete!("grpc_unary_metadata_isolation_two_call_cancelled_overlap");
}
#[test]
fn conformance_grpc_unary_metadata_isolation_many_call_matrix_logs_evidence() {
init_test("conformance_grpc_unary_metadata_isolation_many_call_matrix_logs_evidence");
let cases = [
UnaryMetadataCase {
call_id: "call-charlie",
duplicate_values: &["charlie-0", "charlie-1"],
include_binary: true,
include_auth: true,
include_trace: true,
large_value_len: 0,
cancel: false,
},
UnaryMetadataCase {
call_id: "call-delta",
duplicate_values: &[],
include_binary: false,
include_auth: false,
include_trace: true,
large_value_len: 3072,
cancel: false,
},
UnaryMetadataCase {
call_id: "call-echo",
duplicate_values: &["echo-0"],
include_binary: true,
include_auth: false,
include_trace: false,
large_value_len: 0,
cancel: false,
},
UnaryMetadataCase {
call_id: "call-foxtrot",
duplicate_values: &[],
include_binary: false,
include_auth: true,
include_trace: false,
large_value_len: 0,
cancel: true,
},
UnaryMetadataCase {
call_id: "call-golf",
duplicate_values: &[],
include_binary: false,
include_auth: false,
include_trace: false,
large_value_len: 0,
cancel: false,
},
];
run_grpc_unary_metadata_isolation_scenario("many_call_mixed_metadata", &cases);
crate::test_complete!(
"conformance_grpc_unary_metadata_isolation_many_call_matrix_logs_evidence"
);
}
#[test]
fn mfk14i_server_with_no_interceptors_runs_handler_directly() {
init_test("mfk14i_server_with_no_interceptors_runs_handler_directly");
let server = Server::builder().add_service(TestService).build();
assert_eq!(server.interceptors().len(), 0);
let request = Request::with_metadata(Bytes::from_static(b"echo"), Metadata::new());
let result = block_on(server.dispatch_unary(request, |req| async move {
Ok(Response::new(req.into_inner()))
}));
let response = result.expect("dispatch must succeed");
assert_eq!(response.get_ref().as_ref(), b"echo");
}
#[test]
fn dispatch_unary_preserves_auth_context_for_response_interceptors() {
init_test("dispatch_unary_preserves_auth_context_for_response_interceptors");
let seen_principal = Arc::new(parking_lot::Mutex::new(None));
let server = Server::builder()
.add_service(TestService)
.interceptor(AuthContextEchoInterceptor {
seen_principal: Arc::clone(&seen_principal),
})
.build();
let request = Request::with_metadata(Bytes::from_static(b"ping"), Metadata::new());
let result = block_on(server.dispatch_unary(request, |req| async move {
Ok(Response::new(req.into_inner()))
}));
let response = result.expect("dispatch must succeed");
assert_eq!(response.get_ref().as_ref(), b"ping");
assert_eq!(seen_principal.lock().clone(), Some("svc-a".to_string()));
}
#[test]
fn mfk14i_auth_interceptor_actually_blocks_unauthenticated_calls() {
init_test("mfk14i_auth_interceptor_actually_blocks_unauthenticated_calls");
let auth = AuthInterceptor::new(|metadata: &Metadata| -> Result<(), Status> {
if metadata.get("authorization").is_some() {
Ok(())
} else {
Err(Status::unauthenticated("missing authorization"))
}
});
let server = Server::builder()
.add_service(TestService)
.interceptor(auth)
.build();
let unauth_req = Request::with_metadata(Bytes::new(), Metadata::new());
let unauth_result = block_on(server.dispatch_unary(unauth_req, |_req| async move {
Ok(Response::new(Bytes::from_static(b"should not reach")))
}));
assert!(
matches!(
unauth_result,
Err(ref s) if s.code() == super::super::Code::Unauthenticated
),
"missing-auth call must be rejected with Unauthenticated; got {unauth_result:?}"
);
let mut authed_md = Metadata::new();
authed_md.insert("authorization", "Bearer xyz");
let authed_req = Request::with_metadata(Bytes::new(), authed_md);
let authed_result = block_on(server.dispatch_unary(authed_req, |_req| async move {
Ok(Response::new(Bytes::from_static(b"ok")))
}));
let response = authed_result.expect("authed call must succeed");
assert_eq!(response.get_ref().as_ref(), b"ok");
}
#[test]
fn test_dispatch_unary_enforces_max_metadata_size() {
use futures_lite::future::block_on;
init_test("test_dispatch_unary_enforces_max_metadata_size");
let server = Server::builder().max_metadata_size(64).build();
let mut metadata = Metadata::new();
metadata.insert("x-large-trace-id", "a".repeat(128).as_str());
let request = Request::with_metadata(Bytes::new(), metadata);
let handler_invoked = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let handler_invoked_clone = std::sync::Arc::clone(&handler_invoked);
let result = block_on(server.dispatch_unary(request, move |_req| {
handler_invoked_clone.store(true, std::sync::atomic::Ordering::Relaxed);
async move { Ok(Response::new(Bytes::from_static(b"ok"))) }
}));
let err = result.expect_err("oversized metadata must reject");
assert_eq!(
err.code(),
crate::grpc::status::Code::ResourceExhausted,
"rejection must use RESOURCE_EXHAUSTED per gRPC convention"
);
assert!(
!handler_invoked.load(std::sync::atomic::Ordering::Relaxed),
"handler must NOT be invoked when metadata cap is exceeded"
);
crate::test_complete!("test_dispatch_unary_enforces_max_metadata_size");
}
#[test]
fn test_dispatch_unary_within_metadata_cap_succeeds() {
use futures_lite::future::block_on;
init_test("test_dispatch_unary_within_metadata_cap_succeeds");
let server = Server::builder().max_metadata_size(8 * 1024).build();
let mut metadata = Metadata::new();
metadata.insert("x-trace-id", "abc123");
let request = Request::with_metadata(Bytes::new(), metadata);
let result = block_on(server.dispatch_unary(request, |_req| async move {
Ok(Response::new(Bytes::from_static(b"ok")))
}));
let response = result.expect("call within cap must succeed");
assert_eq!(response.get_ref().as_ref(), b"ok");
crate::test_complete!("test_dispatch_unary_within_metadata_cap_succeeds");
}
#[test]
fn test_dispatch_unary_rejects_invalid_metadata_before_handler() {
use futures_lite::future::block_on;
init_test("test_dispatch_unary_rejects_invalid_metadata_before_handler");
let server = Server::builder().max_metadata_size(8 * 1024).build();
let metadata = Metadata::from_raw_entries_for_tests(vec![(
"x-request-id".to_string(),
crate::grpc::MetadataValue::Ascii("line1\r\nline2".to_string()),
)]);
let request = Request::with_metadata(Bytes::new(), metadata);
let handler_invoked = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let handler_invoked_clone = std::sync::Arc::clone(&handler_invoked);
let result = block_on(server.dispatch_unary(request, move |_req| {
handler_invoked_clone.store(true, std::sync::atomic::Ordering::Relaxed);
async move { Ok(Response::new(Bytes::from_static(b"ok"))) }
}));
let err = result.expect_err("invalid metadata must reject");
assert_eq!(err.code(), crate::grpc::status::Code::InvalidArgument);
assert!(
!handler_invoked.load(std::sync::atomic::Ordering::Relaxed),
"handler must NOT be invoked when inbound metadata is malformed"
);
crate::test_complete!("test_dispatch_unary_rejects_invalid_metadata_before_handler");
}
#[test]
fn test_dispatch_unary_rejects_invalid_protocol_headers_before_handler() {
use futures_lite::future::block_on;
init_test("test_dispatch_unary_rejects_invalid_protocol_headers_before_handler");
let server = Server::builder().max_metadata_size(8 * 1024).build();
let metadata = Metadata::from_raw_entries_for_tests(vec![
(
"content-type".to_string(),
crate::grpc::MetadataValue::Ascii("application/json".to_string()),
),
(
"te".to_string(),
crate::grpc::MetadataValue::Ascii("chunked".to_string()),
),
]);
let request = Request::with_metadata(Bytes::new(), metadata);
let handler_invoked = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let handler_invoked_clone = std::sync::Arc::clone(&handler_invoked);
let result = block_on(server.dispatch_unary(request, move |_req| {
handler_invoked_clone.store(true, std::sync::atomic::Ordering::Relaxed);
async move { Ok(Response::new(Bytes::from_static(b"ok"))) }
}));
let err = result.expect_err("invalid protocol headers must reject");
assert_eq!(err.code(), crate::grpc::status::Code::InvalidArgument);
assert!(
!handler_invoked.load(std::sync::atomic::Ordering::Relaxed),
"handler must NOT be invoked when unary protocol headers are malformed"
);
crate::test_complete!(
"test_dispatch_unary_rejects_invalid_protocol_headers_before_handler"
);
}
#[test]
fn test_connection_registry_enforces_stream_limits() {
init_test("test_connection_registry_enforces_stream_limits");
let registry = ConnectionRegistry::new();
let connection_id = "test-conn-1".to_string();
registry.add_connection(connection_id.clone());
for stream_id in 1..=5 {
let result = registry.enforce_stream_limits(&connection_id, stream_id, 5, None);
assert!(
result.is_ok(),
"Should accept stream {} within limit",
stream_id
);
}
let result = registry.enforce_stream_limits(&connection_id, 6, 5, None);
assert!(
result.is_err(),
"Should reject stream that exceeds max_concurrent_streams"
);
assert!(
result
.unwrap_err()
.contains("exceeds max_concurrent_streams")
);
registry.remove_connection(&connection_id);
crate::test_complete!("test_connection_registry_enforces_stream_limits");
}
#[test]
fn test_connection_registry_idle_timeout() {
use std::thread;
init_test("test_connection_registry_idle_timeout");
let registry = ConnectionRegistry::new();
let connection_id = "test-conn-idle".to_string();
registry.add_connection(connection_id.clone());
let result = registry.enforce_stream_limits(&connection_id, 1, 10, None);
assert!(result.is_ok(), "Should accept initial stream");
let (connections, streams) = registry.get_stats();
assert_eq!(connections, 1);
assert_eq!(streams, 1);
thread::sleep(std::time::Duration::from_millis(2));
let short_timeout = std::time::Duration::from_millis(1);
let result = registry.enforce_stream_limits(&connection_id, 2, 10, Some(short_timeout));
assert!(
result.is_ok(),
"Should accept new stream after idle cleanup"
);
let (connections, streams) = registry.get_stats();
assert_eq!(connections, 1);
assert_eq!(streams, 1);
registry.remove_connection(&connection_id);
crate::test_complete!("test_connection_registry_idle_timeout");
}
#[test]
fn test_server_stream_enforcement_integration() {
use futures_lite::future::block_on;
use std::future::Future;
use std::pin::pin;
use std::task::{Context, Poll, Waker};
init_test("test_server_stream_enforcement_integration");
let server = Server::builder()
.max_concurrent_streams(2) .stream_idle_timeout(None)
.build();
let connection_id = "test-integration-conn".to_string();
server.register_connection(connection_id.clone());
{
let request1 = Request::with_metadata(Bytes::from_static(b"test"), Metadata::new());
let dispatch1 = server.dispatch_unary_with_stream_enforcement(
connection_id.clone(),
1,
request1,
|_req| async {
std::future::pending::<()>().await;
Ok::<Response<Bytes>, Status>(Response::new(Bytes::new()))
},
);
let mut dispatch1 = pin!(dispatch1);
let request2 = Request::with_metadata(Bytes::from_static(b"test2"), Metadata::new());
let dispatch2 = server.dispatch_unary_with_stream_enforcement(
connection_id.clone(),
2,
request2,
|_req| async {
std::future::pending::<()>().await;
Ok::<Response<Bytes>, Status>(Response::new(Bytes::new()))
},
);
let mut dispatch2 = pin!(dispatch2);
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
assert!(matches!(dispatch1.as_mut().poll(&mut cx), Poll::Pending));
assert!(matches!(dispatch2.as_mut().poll(&mut cx), Poll::Pending));
assert_eq!(
server.connection_registry.get_stats().1,
2,
"two in-flight streams should consume both stream slots",
);
let request3 = Request::with_metadata(Bytes::from_static(b"test3"), Metadata::new());
let result3 = block_on(server.dispatch_unary_with_stream_enforcement(
connection_id.clone(),
3,
request3,
|req| async move { Ok(Response::new(req.into_inner())) },
));
assert!(result3.is_err(), "Third stream should be rejected");
assert_eq!(
result3.unwrap_err().code(),
crate::grpc::status::Code::ResourceExhausted
);
}
assert_eq!(
server.connection_registry.get_stats().1,
0,
"dropping in-flight dispatch futures should release stream slots",
);
server.unregister_connection(&connection_id);
crate::test_complete!("test_server_stream_enforcement_integration");
}
#[test]
fn test_dispatch_unary_drop_during_handler_releases_stream_registration() {
use std::future::Future;
use std::pin::pin;
use std::task::{Context, Poll, Waker};
init_test("test_dispatch_unary_drop_during_handler_releases_stream_registration");
let server = Server::builder()
.max_concurrent_streams(2)
.stream_idle_timeout(Some(std::time::Duration::from_secs(60)))
.build();
let connection_id = "rst-stream-leak-conn".to_string();
server.register_connection(connection_id.clone());
{
let request =
Request::with_metadata(Bytes::from_static(b"will-be-cancelled"), Metadata::new());
let dispatch = server.dispatch_unary_with_stream_enforcement(
connection_id.clone(),
7,
request,
|_req| async {
let () = std::future::pending().await;
unreachable!("handler must never resolve in this test");
},
);
let mut pinned = pin!(dispatch);
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
assert!(
matches!(pinned.as_mut().poll(&mut cx), Poll::Pending),
"the pending() handler must keep the dispatch parked",
);
assert_eq!(
server.connection_registry.get_stats().1,
1,
"stream must be registered while the dispatch is in flight",
);
}
let (_, total_streams) = server.connection_registry.get_stats();
assert_eq!(
total_streams, 0,
"RST_STREAM mid-handler must release the stream registration; \
pre-fix this leaked until the idle sweep ran",
);
server.unregister_connection(&connection_id);
crate::test_complete!(
"test_dispatch_unary_drop_during_handler_releases_stream_registration"
);
}
#[test]
fn test_connection_hoarding_attack_simulation() {
init_test("test_connection_hoarding_attack_simulation");
let server = Server::builder()
.max_concurrent_streams(3)
.stream_idle_timeout(Some(std::time::Duration::from_secs(60)))
.build();
for conn_num in 1..=5 {
let connection_id = format!("attacker-conn-{}", conn_num);
server.register_connection(connection_id.clone());
for stream_id in 1..=3 {
let result = server.connection_registry.enforce_stream_limits(
&connection_id,
stream_id,
server.config().max_concurrent_streams,
server.config().stream_idle_timeout,
);
assert!(
result.is_ok(),
"Stream {} on connection {} should succeed within limits",
stream_id,
conn_num
);
}
let result = server.connection_registry.enforce_stream_limits(
&connection_id,
4,
server.config().max_concurrent_streams,
server.config().stream_idle_timeout,
);
assert!(
result.is_err(),
"Fourth stream should be rejected due to limit"
);
}
let (active_connections, total_streams) = server.get_connection_stats();
assert_eq!(active_connections, 5, "Should track 5 connections");
assert_eq!(
total_streams, 15,
"Should track maxed-out attacker streams across all connections"
);
for conn_num in 1..=5 {
server.unregister_connection(&format!("attacker-conn-{}", conn_num));
}
crate::test_complete!("test_connection_hoarding_attack_simulation");
}
#[test]
fn grpc_message_size_limit_enforcement_audit() {
init_test("grpc_message_size_limit_enforcement_audit");
let max_message_size = 64; let server = Server::builder()
.max_recv_message_size(max_message_size)
.build();
let oversized_payload = vec![0x42u8; max_message_size + 1];
let mut frame_buf = BytesMut::new();
frame_buf.put_u8(0); frame_buf.put_u32(oversized_payload.len() as u32); frame_buf.extend_from_slice(&oversized_payload[..max_message_size.min(16)]);
let mut codec = server.framed_codec(crate::grpc::IdentityCodec);
let result = codec.decode_message(&mut frame_buf);
let error = result.expect_err("Oversized message must be rejected");
crate::assert_with_log!(
matches!(error, crate::grpc::GrpcError::MessageTooLarge),
"Must reject with MessageTooLarge error",
true,
matches!(error, crate::grpc::GrpcError::MessageTooLarge)
);
let status = error.into_status();
crate::assert_with_log!(
status.code() == crate::grpc::Code::ResourceExhausted,
"Must use RESOURCE_EXHAUSTED status code per gRPC spec",
crate::grpc::Code::ResourceExhausted,
status.code()
);
let message = status.message();
crate::assert_with_log!(
message.contains("message too large"),
"Error message must indicate size violation",
true,
message.contains("message too large")
);
let exact_limit_payload = vec![0x43u8; max_message_size]; let mut exact_frame_buf = BytesMut::new();
exact_frame_buf.put_u8(0); exact_frame_buf.put_u32(exact_limit_payload.len() as u32);
exact_frame_buf.extend_from_slice(&exact_limit_payload);
let mut exact_codec = server.framed_codec(crate::grpc::IdentityCodec);
let exact_result = exact_codec.decode_message(&mut exact_frame_buf);
crate::assert_with_log!(
exact_result.is_ok(),
"Message exactly at size limit must succeed",
true,
exact_result.is_ok()
);
let huge_declared_size = 1024 * 1024 * 1024; let mut dos_frame_buf = BytesMut::new();
dos_frame_buf.put_u8(0); dos_frame_buf.put_u32(huge_declared_size as u32); dos_frame_buf.extend_from_slice(&[0x44u8; 32]);
let mut dos_codec = server.framed_codec(crate::grpc::IdentityCodec);
let dos_result = dos_codec.decode_message(&mut dos_frame_buf);
let dos_error = dos_result.expect_err("Huge declared size must be rejected");
crate::assert_with_log!(
matches!(dos_error, crate::grpc::GrpcError::MessageTooLarge),
"Must reject huge declared size even with partial buffer",
true,
matches!(dos_error, crate::grpc::GrpcError::MessageTooLarge)
);
crate::test_complete!("grpc_message_size_limit_enforcement_audit");
}
mod grpc_deadline_enforcement_audit {
use super::*;
use crate::grpc::Code;
use crate::grpc::MetadataValue;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[test]
fn audit_grpc_timeout_header_parsing_is_sound() {
init_test("audit_grpc_timeout_header_parsing_is_sound");
let now = Instant::now();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "5S");
let context = CallContext::from_metadata_at(metadata, None, None, now);
let deadline = context.deadline().expect("deadline should be parsed");
let expected_deadline = now + Duration::from_secs(5);
let deadline_delta = deadline
.checked_duration_since(expected_deadline)
.or_else(|| expected_deadline.checked_duration_since(deadline))
.expect("deadline delta should be representable");
crate::assert_with_log!(
deadline_delta < Duration::from_millis(1),
"grpc-timeout header correctly parsed to deadline",
true,
deadline_delta < Duration::from_millis(1)
);
assert!(
!context.is_expired_at(now),
"should not be expired immediately"
);
assert!(
context.is_expired_at(deadline + Duration::from_millis(1)),
"should be expired after deadline"
);
crate::test_complete!("audit_grpc_timeout_header_parsing_is_sound");
}
#[test]
fn audit_deadline_enforcement_blocking_limitation() {
use futures_lite::future::block_on;
init_test("audit_deadline_enforcement_blocking_limitation");
let server = Server::builder().build();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "1m"); let request = Request::with_metadata(Bytes::from_static(b"test"), metadata);
let handler_completed = Arc::new(AtomicBool::new(false));
let handler_completed_clone = Arc::clone(&handler_completed);
let start_time = Instant::now();
let result = block_on(server.dispatch_unary(request, move |req| async move {
futures_lite::future::yield_now().await;
std::thread::sleep(Duration::from_millis(5));
handler_completed_clone.store(true, Ordering::Relaxed);
Ok::<Response<Bytes>, Status>(Response::new(req.into_inner()))
}));
assert!(
handler_completed.load(Ordering::Relaxed),
"EXPECTED: Blocking operations complete even past deadline (async limitation)"
);
assert!(
result.is_ok(),
"EXPECTED: Blocking handlers cannot be cancelled by async timeouts"
);
assert!(
start_time.elapsed() > Duration::from_millis(1),
"Handler ran past deadline due to blocking operation"
);
crate::test_complete!("audit_deadline_enforcement_blocking_limitation");
}
#[test]
fn audit_deadline_enforcement_works_for_async_operations() {
use futures_lite::future::block_on;
init_test("audit_deadline_enforcement_works_for_async_operations");
let server = Server::builder().build();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "5m"); let request = Request::with_metadata(Bytes::from_static(b"test"), metadata);
let handler_started = Arc::new(AtomicBool::new(false));
let handler_completed = Arc::new(AtomicBool::new(false));
let handler_started_clone = Arc::clone(&handler_started);
let handler_completed_clone = Arc::clone(&handler_completed);
let result = block_on(server.dispatch_unary(request, move |req| async move {
handler_started_clone.store(true, Ordering::Relaxed);
futures_lite::future::yield_now().await;
std::future::pending::<()>().await;
handler_completed_clone.store(true, Ordering::Relaxed);
Ok::<Response<Bytes>, Status>(Response::new(req.into_inner()))
}));
assert!(
handler_started.load(Ordering::Relaxed),
"Handler should start execution"
);
assert!(
result.is_err(),
"Request should fail with DEADLINE_EXCEEDED for async operations"
);
assert!(
!handler_completed.load(Ordering::Relaxed),
"Timed-out async handler should be dropped before completion"
);
if let Err(ref status) = result {
assert_eq!(
status.code(),
Code::DeadlineExceeded,
"Should return DEADLINE_EXCEEDED status"
);
}
crate::test_complete!("audit_deadline_enforcement_works_for_async_operations");
}
#[test]
fn audit_server_deadline_configuration_is_sound() {
init_test("audit_server_deadline_configuration_is_sound");
let server = Server::builder()
.default_timeout(Duration::from_secs(30))
.max_request_deadline(Duration::from_secs(60))
.build();
let config = server.config();
assert_eq!(
config.default_timeout,
Some(Duration::from_secs(30)),
"default_timeout configuration preserved"
);
assert_eq!(
config.max_request_deadline,
Some(Duration::from_secs(60)),
"max_request_deadline configuration preserved"
);
crate::test_complete!("audit_server_deadline_configuration_is_sound");
}
#[test]
fn audit_max_request_deadline_clamping_is_sound() {
init_test("audit_max_request_deadline_clamping_is_sound");
let now = Instant::now();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "3600S");
let context = CallContext::from_metadata_at_with_max_deadline(
metadata,
None,
Some(Duration::from_secs(60)), None,
now,
);
let deadline = context.deadline().expect("deadline should be set");
let clamped_duration = deadline.duration_since(now);
crate::assert_with_log!(
clamped_duration <= Duration::from_secs(61), "peer timeout correctly clamped to server max_request_deadline",
true,
clamped_duration <= Duration::from_secs(61)
);
crate::assert_with_log!(
clamped_duration >= Duration::from_secs(59), "clamped deadline is approximately the max value",
true,
clamped_duration >= Duration::from_secs(59)
);
crate::test_complete!("audit_max_request_deadline_clamping_is_sound");
}
#[test]
fn deadline_enforcement_contract_is_executable() {
use futures_lite::future::block_on;
init_test("deadline_enforcement_contract_is_executable");
let now = Instant::now();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "3600S");
let context = CallContext::from_metadata_at_with_max_deadline(
metadata,
None,
Some(Duration::from_millis(5)),
None,
now,
);
let deadline = context.deadline().expect("capped deadline should be set");
assert!(
deadline.duration_since(now) <= Duration::from_millis(5),
"server max_request_deadline should cap peer-supplied timeout"
);
let server = Server::builder()
.max_request_deadline(Duration::from_millis(5))
.build();
let mut request_metadata = Metadata::new();
request_metadata.insert("grpc-timeout", "1S");
let request = Request::with_metadata(Bytes::from_static(b"test"), request_metadata);
let handler_completed = Arc::new(AtomicBool::new(false));
let handler_completed_clone = Arc::clone(&handler_completed);
let result = block_on(server.dispatch_unary(request, move |_req| async move {
futures_lite::future::yield_now().await;
let _ =
crate::time::sleep(crate::time::wall_now(), Duration::from_millis(20)).await;
handler_completed_clone.store(true, Ordering::Relaxed);
Ok::<Response<Bytes>, Status>(Response::new(Bytes::from_static(b"late")))
}));
let status = result.expect_err("handler should be cancelled at capped deadline");
assert_eq!(status.code(), Code::DeadlineExceeded);
assert!(
!handler_completed.load(Ordering::Relaxed),
"timed-out async handler future should be dropped before completion"
);
crate::test_complete!("deadline_enforcement_contract_is_executable");
}
#[test]
fn audit_malformed_deadline_handling_is_sound() {
use futures_lite::future::block_on;
init_test("audit_malformed_deadline_handling_is_sound");
let server = Server::builder()
.default_timeout(Duration::from_secs(5))
.build();
let mut metadata = Metadata::new();
metadata.insert("grpc-timeout", "invalid-format");
let request = Request::with_metadata(Bytes::from_static(b"test"), metadata);
let result = block_on(server.dispatch_unary(request, |req| async move {
Ok::<Response<Bytes>, Status>(Response::new(req.into_inner()))
}));
assert!(
result.is_ok(),
"Malformed grpc-timeout should not prevent request processing"
);
crate::test_complete!("audit_malformed_deadline_handling_is_sound");
}
#[test]
fn audit_deadline_propagation_is_sound() {
init_test("audit_deadline_propagation_is_sound");
let now = Instant::now();
let context = CallContext::with_deadline(now + Duration::from_secs(10));
let mut outbound_metadata = Metadata::new();
let propagated = context.propagate_timeout_to_at(&mut outbound_metadata, now);
assert!(
propagated,
"deadline should be propagated to outbound metadata"
);
assert!(
outbound_metadata.get("grpc-timeout").is_some(),
"grpc-timeout header should be added to outbound metadata"
);
let propagated_header = outbound_metadata
.get("grpc-timeout")
.expect("grpc-timeout should be present");
if let MetadataValue::Ascii(header_value) = propagated_header {
let parsed_timeout = parse_grpc_timeout(header_value);
assert!(
parsed_timeout.is_some(),
"propagated timeout should be parseable"
);
let timeout = parsed_timeout.unwrap();
assert!(
timeout >= Duration::from_secs(9) && timeout <= Duration::from_secs(11),
"propagated timeout should be approximately 10 seconds"
);
} else {
panic!("grpc-timeout should be ASCII metadata value");
}
crate::test_complete!("audit_deadline_propagation_is_sound");
}
}
mod grpc_streaming_trailer_emission_audit {
use super::*;
use crate::grpc::{Code, Metadata, Status};
use crate::http::h2::frame::{DataFrame, HeadersFrame};
#[test]
fn audit_grpc_status_final_trailer_requirement() {
init_test("audit_grpc_status_final_trailer_requirement");
let mut response_metadata = Metadata::new();
response_metadata.insert("x-custom-trailer", "application-data");
response_metadata.insert("x-request-id", "req-12345");
response_metadata.insert("grpc-status", "0");
let headers: Vec<_> = response_metadata.iter().collect();
let grpc_status_pos = headers
.iter()
.position(|(key, _)| *key == "grpc-status")
.expect("grpc-status must be present");
let last_pos = headers.len().saturating_sub(1);
crate::assert_with_log!(
grpc_status_pos == last_pos,
"grpc-status must be final trailer per gRPC HTTP/2 spec",
true,
grpc_status_pos == headers.len() - 1
);
eprintln!(
"{{\"audit\":\"GRPC_TRAILER_ORDERING\",\"status\":\"SOUND\",\"requirement\":\"grpc-status final trailer\"}}"
);
crate::test_complete!("audit_grpc_status_final_trailer_requirement");
}
#[test]
fn audit_http2_frame_sequence_for_streaming_completion() {
init_test("audit_http2_frame_sequence_for_streaming_completion");
let data_frame_1 = DataFrame::new(
1, crate::bytes::Bytes::from_static(b"response-1"),
false, );
let data_frame_2 = DataFrame::new(
1, crate::bytes::Bytes::from_static(b"response-2"),
false, );
let trailer_headers =
crate::bytes::Bytes::from_static(b"grpc-status: 0\r\ngrpc-message: success\r\n");
let final_headers_frame = HeadersFrame::new(
1, trailer_headers,
true, true, );
crate::assert_with_log!(
!data_frame_1.end_stream && !data_frame_2.end_stream,
"DATA frames before final headers must not have END_STREAM",
true,
!data_frame_1.end_stream && !data_frame_2.end_stream
);
crate::assert_with_log!(
final_headers_frame.end_stream,
"Final HEADERS frame MUST have END_STREAM per RFC 9113 §8.1",
true,
final_headers_frame.end_stream
);
crate::assert_with_log!(
final_headers_frame.end_headers,
"Final HEADERS frame MUST have END_HEADERS",
true,
final_headers_frame.end_headers
);
eprintln!(
"{{\"audit\":\"HTTP2_FRAME_SEQUENCE\",\"status\":\"SOUND\",\"requirement\":\"proper frame ordering\"}}"
);
crate::test_complete!("audit_http2_frame_sequence_for_streaming_completion");
}
#[test]
fn audit_cancellation_path_trailer_emission() {
init_test("audit_cancellation_path_trailer_emission");
let cancellation_status = Status::cancelled("client requested cancellation");
let status_code = cancellation_status.code() as i32;
crate::assert_with_log!(
status_code == 1, "Cancelled streams must emit grpc-status: 1 per gRPC spec",
1,
status_code
);
let status_message = cancellation_status.message();
crate::assert_with_log!(
!status_message.is_empty(),
"Cancellation status must include descriptive message",
true,
!status_message.is_empty()
);
let mut cancellation_trailers = Metadata::new();
cancellation_trailers.insert("grpc-status", status_code.to_string());
cancellation_trailers.insert("grpc-message", status_message);
let headers: Vec<_> = cancellation_trailers.iter().collect();
let grpc_status_pos = headers
.iter()
.position(|(key, _)| *key == "grpc-status")
.expect("grpc-status must be present in cancellation");
let last_pos = headers.len().saturating_sub(1);
crate::assert_with_log!(
grpc_status_pos == last_pos || headers.len() == 2,
"grpc-status ordering maintained even in cancellation",
true,
grpc_status_pos == headers.len() - 1 || headers.len() == 2
);
eprintln!(
"{{\"audit\":\"CANCELLATION_TRAILERS\",\"status\":\"SOUND\",\"requirement\":\"proper cancellation signaling\"}}"
);
crate::test_complete!("audit_cancellation_path_trailer_emission");
}
#[test]
fn audit_grpc_status_validation_for_server_responses() {
init_test("audit_grpc_status_validation_for_server_responses");
let valid_statuses = vec![
(Code::Ok, 0),
(Code::Cancelled, 1),
(Code::Unknown, 2),
(Code::InvalidArgument, 3),
(Code::DeadlineExceeded, 4),
(Code::NotFound, 5),
(Code::Internal, 13),
(Code::Unavailable, 14),
(Code::Unauthenticated, 16),
];
for (status_code, expected_wire_value) in valid_statuses {
let status = Status::new(status_code, "test message");
let wire_value = status.code() as i32;
crate::assert_with_log!(
wire_value == expected_wire_value,
format!(
"Status {:?} maps to correct wire value {}",
status_code, expected_wire_value
),
expected_wire_value,
wire_value
);
let wire_string = wire_value.to_string();
let reparsed: Result<i32, _> = wire_string.parse();
crate::assert_with_log!(
reparsed.is_ok(),
"grpc-status wire value must be valid integer",
true,
reparsed.is_ok()
);
}
eprintln!(
"{{\"audit\":\"GRPC_STATUS_VALIDATION\",\"status\":\"SOUND\",\"requirement\":\"valid status codes\"}}"
);
crate::test_complete!("audit_grpc_status_validation_for_server_responses");
}
#[test]
fn audit_server_streaming_complete_lifecycle() {
init_test("audit_server_streaming_complete_lifecycle");
let request = super::Request::with_metadata(
crate::bytes::Bytes::from_static(b"stream-request"),
Metadata::new(),
);
crate::assert_with_log!(
request.get_ref().as_ref() == b"stream-request",
"Server streaming lifecycle must start from the request payload",
true,
request.get_ref().as_ref() == b"stream-request"
);
let streaming_responses = vec![
b"message-1".to_vec(),
b"message-2".to_vec(),
b"message-3".to_vec(),
];
let mut completion_metadata = Metadata::new();
completion_metadata.insert("x-response-count", "3");
completion_metadata.insert("x-processing-time", "142ms");
completion_metadata.insert("grpc-message", "stream completed successfully");
completion_metadata.insert("grpc-status", "0");
crate::assert_with_log!(
!streaming_responses.is_empty(),
"Server streaming must include response messages",
true,
!streaming_responses.is_empty()
);
let has_grpc_status = completion_metadata.get("grpc-status").is_some();
crate::assert_with_log!(
has_grpc_status,
"Completion metadata MUST include grpc-status trailer",
true,
has_grpc_status
);
let trailer_headers: Vec<_> = completion_metadata.iter().collect();
if let Some(grpc_status_pos) = trailer_headers
.iter()
.position(|(key, _)| *key == "grpc-status")
{
let custom_trailer_exists = trailer_headers
.iter()
.take(grpc_status_pos)
.any(|(key, _)| key.starts_with("x-"));
crate::assert_with_log!(
custom_trailer_exists,
"custom trailers must be emitted before grpc-status",
true,
custom_trailer_exists
);
crate::assert_with_log!(
grpc_status_pos == trailer_headers.len() - 1,
"grpc-status must be final trailer even with custom trailers present",
true,
grpc_status_pos == trailer_headers.len() - 1
);
}
eprintln!(
"{{\"audit\":\"STREAMING_LIFECYCLE\",\"status\":\"SOUND\",\"requirement\":\"complete response flow\"}}"
);
crate::test_complete!("audit_server_streaming_complete_lifecycle");
}
}
}