use bytes::Buf;
use bytes::Bytes;
use http::Request;
use http::Response;
use http_body_util::BodyExt;
use http_body_util::Full;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Context;
use std::task::Poll;
use sui_http::middleware::callback::CallbackLayer;
use sui_http::middleware::callback::MakeCallbackHandler;
use sui_http::middleware::callback::RequestHandler;
use sui_http::middleware::callback::ResponseHandler;
use tower::Service;
use tower::ServiceBuilder;
use tower::ServiceExt;
#[derive(Clone, Default)]
struct GrpcEcho;
impl Service<Request<tonic::body::Body>> for GrpcEcho {
type Response = Response<tonic::body::Body>;
type Error = tonic::Status;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
Box::pin(async move {
let (_parts, body) = req.into_parts();
let collected = body
.collect()
.await
.map_err(|e| tonic::Status::internal(format!("body error: {e}")))?;
let bytes = collected.to_bytes();
Ok(Response::new(tonic::body::Body::new(Full::new(bytes))))
})
}
}
#[derive(Debug, Default)]
struct Events {
request_bytes: usize,
request_end_seen: bool,
response_bytes: usize,
response_end_seen: bool,
response_seen: bool,
service_errors: Vec<String>,
}
#[derive(Clone, Default)]
struct Recorder(Arc<Mutex<Events>>);
struct ReqH(Arc<Mutex<Events>>);
struct RespH(Arc<Mutex<Events>>);
impl RequestHandler for ReqH {
fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
self.0.lock().unwrap().request_bytes += chunk.remaining();
}
fn on_end_of_stream(&mut self, _trailers: Option<&http::HeaderMap>) {
self.0.lock().unwrap().request_end_seen = true;
}
}
impl ResponseHandler for RespH {
fn on_response(&mut self, _parts: &http::response::Parts) {
self.0.lock().unwrap().response_seen = true;
}
fn on_service_error<E: std::fmt::Display + 'static>(&mut self, error: &E) {
self.0
.lock()
.unwrap()
.service_errors
.push(error.to_string());
}
fn on_body_chunk<B: Buf>(&mut self, chunk: &B) {
self.0.lock().unwrap().response_bytes += chunk.remaining();
}
fn on_end_of_stream(&mut self, _trailers: Option<&http::HeaderMap>) {
self.0.lock().unwrap().response_end_seen = true;
}
}
impl MakeCallbackHandler for Recorder {
type RequestHandler = ReqH;
type ResponseHandler = RespH;
fn make_handler(
&self,
_request: &http::request::Parts,
) -> (Self::RequestHandler, Self::ResponseHandler) {
(ReqH(self.0.clone()), RespH(self.0.clone()))
}
}
#[tokio::test]
async fn callback_layer_bridges_into_tonic_service() {
let recorder = Recorder::default();
let events = recorder.0.clone();
let mut stack = ServiceBuilder::new()
.layer(CallbackLayer::new(recorder))
.map_request(|req: Request<_>| req.map(tonic::body::Body::new))
.map_response(|resp: Response<tonic::body::Body>| {
resp
})
.service(GrpcEcho);
let request = Request::new(Full::new(Bytes::from_static(b"hello tonic")));
let response = stack.ready().await.unwrap().call(request).await.unwrap();
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body_bytes, Bytes::from_static(b"hello tonic"));
let events = events.lock().unwrap();
assert_eq!(events.request_bytes, b"hello tonic".len());
assert!(events.request_end_seen);
assert!(events.response_seen);
assert_eq!(events.response_bytes, b"hello tonic".len());
assert!(events.response_end_seen);
assert!(events.service_errors.is_empty());
}
#[tokio::test]
async fn callback_layer_observes_tonic_service_error() {
#[derive(Clone, Default)]
struct FailingGrpc;
impl Service<Request<tonic::body::Body>> for FailingGrpc {
type Response = Response<tonic::body::Body>;
type Error = tonic::Status;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<tonic::body::Body>) -> Self::Future {
Box::pin(async { Err(tonic::Status::unavailable("nope")) })
}
}
let recorder = Recorder::default();
let events = recorder.0.clone();
let mut stack = ServiceBuilder::new()
.layer(CallbackLayer::new(recorder))
.map_request(|req: Request<_>| req.map(tonic::body::Body::new))
.service(FailingGrpc);
let request = Request::new(Full::new(Bytes::from_static(b"ping")));
let result = stack.ready().await.unwrap().call(request).await;
let status = match result {
Ok(_) => panic!("expected tonic::Status error"),
Err(status) => status,
};
let rendered = status.to_string();
assert!(
rendered.contains("nope") || rendered.to_lowercase().contains("unavailable"),
"unexpected status display: {rendered}"
);
let events = events.lock().unwrap();
assert!(!events.service_errors.is_empty());
assert!(
events
.service_errors
.iter()
.any(|s| s.contains("nope") || s.to_lowercase().contains("unavailable")),
"unexpected service_errors: {:?}",
events.service_errors
);
assert!(!events.response_seen);
assert_eq!(events.response_bytes, 0);
}
#[derive(Clone, Default)]
struct MockChannel;
impl Service<Request<tonic::body::Body>> for MockChannel {
type Response = Response<tonic::body::Body>;
type Error = tonic::transport::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
Box::pin(async move {
let (_parts, body) = req.into_parts();
let collected = body
.collect()
.await
.expect("MockChannel body collect cannot fail in this test");
let bytes = collected.to_bytes();
Ok(Response::new(tonic::body::Body::new(Full::new(bytes))))
})
}
}
#[tokio::test]
async fn callback_layer_wraps_tonic_client_channel() {
let recorder = Recorder::default();
let events = recorder.0.clone();
let mut client = ServiceBuilder::new()
.layer(CallbackLayer::new(recorder))
.map_request(|req: Request<_>| req.map(tonic::body::Body::new))
.service(MockChannel);
let request = Request::new(tonic::body::Body::new(Full::new(Bytes::from_static(
b"outbound rpc",
))));
let response = client.ready().await.unwrap().call(request).await.unwrap();
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body_bytes, Bytes::from_static(b"outbound rpc"));
let events = events.lock().unwrap();
assert_eq!(events.request_bytes, b"outbound rpc".len());
assert!(events.request_end_seen);
assert!(events.response_seen);
assert_eq!(events.response_bytes, b"outbound rpc".len());
assert!(events.response_end_seen);
assert!(events.service_errors.is_empty());
}
#[tokio::test]
async fn callback_layer_observes_tonic_client_channel_error() {
#[derive(Clone, Default)]
struct UnreachableChannel;
impl Service<Request<tonic::body::Body>> for UnreachableChannel {
type Response = Response<tonic::body::Body>;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<tonic::body::Body>) -> Self::Future {
Box::pin(async { Err("connection refused".into()) })
}
}
let recorder = Recorder::default();
let events = recorder.0.clone();
let mut client = ServiceBuilder::new()
.layer(CallbackLayer::new(recorder))
.map_request(|req: Request<_>| req.map(tonic::body::Body::new))
.service(UnreachableChannel);
let request = Request::new(tonic::body::Body::new(Full::new(Bytes::from_static(
b"outbound rpc",
))));
let result = client.ready().await.unwrap().call(request).await;
let err = match result {
Ok(_) => panic!("expected channel error"),
Err(err) => err,
};
assert!(err.to_string().contains("connection refused"));
let events = events.lock().unwrap();
assert_eq!(events.service_errors.len(), 1);
assert!(events.service_errors[0].contains("connection refused"));
assert!(!events.response_seen);
assert_eq!(events.response_bytes, 0);
}