use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tonic::body::Body;
use tonic::client::GrpcService;
use tonic::codegen::http::{Request, Response};
use tonic_web::{GrpcWebCall, GrpcWebClientService};
use tower::{Layer, ServiceBuilder};
pub fn wrap_channel_with_grpc_web<C>(channel: C) -> GrpcWebWrapperLayerService<C>
where
C: GrpcService<Body>,
{
ServiceBuilder::new()
.layer(GrpcWebWrapperLayer)
.service(channel)
}
#[derive(Copy, Clone)]
pub struct GrpcWebWrapperLayer;
pub type GrpcWebWrapperLayerService<S> = GrpcWebClientService<GrpcWebWrapperService<S>>;
impl<S> Layer<S> for GrpcWebWrapperLayer
where
S: GrpcService<Body>,
{
type Service = GrpcWebWrapperLayerService<S>;
fn layer(&self, service: S) -> Self::Service {
GrpcWebClientService::new(GrpcWebWrapperService { service })
}
}
#[derive(Clone)]
pub struct GrpcWebWrapperService<S> {
service: S,
}
impl<S> tower::Service<Request<GrpcWebCall<Body>>> for GrpcWebWrapperService<S>
where
S: GrpcService<Body> + Clone + Send + 'static,
<S as GrpcService<Body>>::Future: Send,
<S as GrpcService<Body>>::ResponseBody: Send,
{
type Response = Response<<S as GrpcService<Body>>::ResponseBody>;
type Error = <S as GrpcService<Body>>::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request<GrpcWebCall<Body>>) -> Self::Future {
let mut service = self.service.clone();
std::mem::swap(&mut self.service, &mut service);
super::common::pin_future_with_otel_context_if_available(async move {
let headers = req.headers().clone();
let method = req.method().clone();
let uri = req.uri().clone();
let version = req.version();
let body = Body::new(req.into_body());
let mut builder = Request::builder();
for (key, value) in headers {
if let Some(key) = key {
builder = builder.header(key, value);
}
}
let request = builder
.method(method)
.uri(uri)
.version(version)
.body(body)
.expect("valid request");
service.call(request).await
})
}
}