use std::error::Error;
use std::future::Future;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::Instant;
use bytes::Buf;
use bytes::BufMut as _;
use bytes::Bytes;
use futures::stream::StreamExt;
use http::Request as HttpRequest;
use http::Response as HttpResponse;
use http::Uri;
use http::uri::PathAndQuery;
use hyper::client::conn::http2::Builder;
use hyper::client::conn::http2::SendRequest;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_stream::Stream;
use tokio_stream::wrappers::ReceiverStream;
use tonic::Code;
use tonic::Request as TonicRequest;
use tonic::Status as TonicStatus;
use tonic::Streaming;
use tonic::body::Body;
use tonic::client::Grpc;
use tonic::client::GrpcService;
use tonic::codec::Codec;
use tonic::codec::Decoder;
use tonic::codec::EncodeBuf;
use tonic::codec::Encoder;
use tonic::metadata::MetadataMap as TonicMeta;
use tower::ServiceBuilder;
use tower::buffer::Buffer;
use tower::buffer::future::ResponseFuture as BufferResponseFuture;
use tower::limit::ConcurrencyLimitLayer;
use tower::limit::RateLimitLayer;
use tower::util::BoxService;
use tower_service::Service as TowerService;
use crate::StatusCodeError;
use crate::StatusError;
use crate::client::CallOptions;
use crate::client::Invoke;
use crate::client::RecvStream;
use crate::client::ResponseStreamItem;
use crate::client::SendOptions;
use crate::client::SendStream;
use crate::client::name_resolution::TCP_IP_NETWORK_TYPE;
use crate::client::name_resolution::UNIX_NETWORK_TYPE;
use crate::client::transport::SecurityOpts;
use crate::client::transport::Transport;
use crate::client::transport::TransportOptions;
use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY;
use crate::core::RecvMessage;
use crate::core::RequestHeaders;
use crate::core::ResponseHeaders;
use crate::core::SendMessage;
use crate::core::Trailers;
use crate::credentials::client::DynClientConnectionSecurityInfo;
use crate::credentials::dyn_wrapper::DynChannelCredentials;
use crate::rt::BoxedTaskHandle;
use crate::rt::GrpcRuntime;
use crate::rt::TcpOptions;
use crate::rt::UnixSocketOptions;
use crate::rt::hyper_wrapper::HyperCompatExec;
use crate::rt::hyper_wrapper::HyperCompatTimer;
use crate::rt::hyper_wrapper::HyperStream;
use crate::status::Status;
#[cfg(test)]
mod test;
const DEFAULT_BUFFER_SIZE: usize = 1024;
pub(crate) type BoxError = Box<dyn Error + Send + Sync>;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, TonicStatus>> + Send>>;
pub(crate) fn reg() {
GLOBAL_TRANSPORT_REGISTRY.add_transport(
TCP_IP_NETWORK_TYPE,
TransportBuilder {
network_type: NetworkType::Tcp,
},
);
GLOBAL_TRANSPORT_REGISTRY.add_transport(
UNIX_NETWORK_TYPE,
TransportBuilder {
network_type: NetworkType::Unix,
},
);
}
#[derive(Debug, Copy, Clone)]
enum NetworkType {
Tcp,
Unix,
}
struct TransportBuilder {
network_type: NetworkType,
}
struct TonicTransport {
grpc: Grpc<TonicService>,
task_handle: BoxedTaskHandle,
runtime: GrpcRuntime,
}
impl Drop for TonicTransport {
fn drop(&mut self) {
self.task_handle.abort();
}
}
impl Invoke for TonicTransport {
type SendStream = TonicSendStream;
type RecvStream = TonicRecvStream;
async fn invoke(
&self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
let (req_tx, req_rx) = mpsc::channel(1);
let stop_notify = Arc::new(Notify::new());
let stop_notify_clone = stop_notify.clone();
let request_stream =
ReceiverStream::new(req_rx).take_until(stop_notify_clone.notified_owned());
let mut request = TonicRequest::new(Box::pin(request_stream));
let (method, metadata) = headers.into_parts();
*request.metadata_mut() = metadata.into();
let Ok(path) = PathAndQuery::from_maybe_shared(method) else {
return err_streams(StatusError::new(StatusCodeError::Internal, "invalid path"));
};
let mut grpc = self.grpc.clone();
if let Err(e) = grpc.ready().await {
return err_streams(StatusError::new(
StatusCodeError::Unavailable,
format!("Service was not ready: {e}"),
));
}
let (resp_tx, resp_rx) = oneshot::channel();
self.runtime.spawn(Box::pin(async move {
let response = grpc.streaming(request, path, BufCodec {}).await;
let _ = resp_tx.send(response);
}));
(
TonicSendStream { sender: Ok(req_tx) },
TonicRecvStream {
state: StreamState::AwaitingHeaders(resp_rx),
stop_notify: Some(stop_notify),
},
)
}
}
fn trailers_from_tonic_status(status: TonicStatus, md: Option<TonicMeta>) -> ResponseStreamItem {
let status_res = match status.code() {
Code::Ok => Ok(()),
code => Err(StatusError::new(
StatusCodeError::from(code as i32),
status.message(),
)),
};
trailers_from_status(status_res, md)
}
fn trailers_from_status(status: Status, md: Option<TonicMeta>) -> ResponseStreamItem {
let trailers = match md.map(TryInto::try_into) {
Some(Err(e)) => Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
format!("failed to parse metadata: {e}"),
))),
Some(Ok(metadata)) => Trailers::new(status).with_metadata(metadata),
None => Trailers::new(status),
};
ResponseStreamItem::Trailers(trailers)
}
struct TonicSendStream {
sender: Result<mpsc::Sender<Box<dyn Buf + Send + Sync>>, ()>,
}
impl SendStream for TonicSendStream {
async fn send(&mut self, msg: &dyn SendMessage, options: SendOptions) -> Result<(), ()> {
if let Ok(tx) = &self.sender
&& let Ok(buf) = msg.encode()
&& tx.send(buf).await.is_ok()
{
if options.final_msg {
self.sender = Err(());
}
return Ok(());
}
Err(())
}
}
struct TonicRecvStream {
state: StreamState,
stop_notify: Option<Arc<Notify>>,
}
enum StreamState {
Error(StatusError),
AwaitingHeaders(oneshot::Receiver<Result<tonic::Response<Streaming<Bytes>>, TonicStatus>>),
Streaming(Streaming<Bytes>),
Closed,
}
impl RecvStream for TonicRecvStream {
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
let state = std::mem::replace(&mut self.state, StreamState::Closed);
match state {
StreamState::Closed => ResponseStreamItem::StreamClosed,
StreamState::Error(error) => ResponseStreamItem::Trailers(Trailers::new(Err(error))),
StreamState::AwaitingHeaders(rx) => match rx.await {
Ok(Ok(response)) => {
let (metadata, stream, _extensions) = response.into_parts();
match metadata.try_into() {
Ok(md) => {
self.state = StreamState::Streaming(stream);
ResponseStreamItem::Headers(ResponseHeaders::new().with_metadata(md))
}
Err(e) => {
if let Some(notify) = self.stop_notify.take() {
notify.notify_one();
}
trailers_from_status(
Err(StatusError::new(
StatusCodeError::Internal,
format!("error decoding response: {e}"),
)),
None,
)
}
}
}
Err(_) => trailers_from_status(
Err(StatusError::new(StatusCodeError::Unknown, "Task cancelled")),
None,
),
Ok(Err(status)) => trailers_from_tonic_status(status, None),
},
StreamState::Streaming(mut stream) => match stream.message().await {
Ok(Some(mut buf)) => match msg.decode(&mut buf) {
Ok(()) => {
self.state = StreamState::Streaming(stream);
ResponseStreamItem::Message
}
Err(e) => {
if let Some(notify) = self.stop_notify.take() {
notify.notify_one();
}
trailers_from_status(
Err(StatusError::new(
StatusCodeError::Internal,
format!("error decoding response: {e}"),
)),
None,
)
}
},
Err(status) => {
let trailers = stream.trailers().await;
let md = trailers.unwrap_or_default();
trailers_from_tonic_status(status, md)
}
Ok(None) => {
let trailers = stream.trailers().await;
let md = trailers.unwrap_or_default();
trailers_from_status(Ok(()), md)
}
},
}
}
}
impl Drop for TonicRecvStream {
fn drop(&mut self) {
if let Some(notify) = &self.stop_notify {
notify.notify_one();
}
}
}
fn err_streams(status: StatusError) -> (TonicSendStream, TonicRecvStream) {
(
TonicSendStream { sender: Err(()) },
TonicRecvStream {
state: StreamState::Error(status),
stop_notify: None,
},
)
}
impl Transport for TransportBuilder {
type Service = TonicTransport;
async fn connect(
&self,
address: String,
runtime: GrpcRuntime,
security_info: &SecurityOpts,
opts: &TransportOptions,
) -> Result<
(
Self::Service,
DynClientConnectionSecurityInfo,
oneshot::Receiver<Result<(), String>>,
),
String,
> {
let runtime = runtime.clone();
let mut settings = Builder::<HyperCompatExec>::new(HyperCompatExec {
inner: runtime.clone(),
})
.timer(HyperCompatTimer {
inner: runtime.clone(),
})
.initial_stream_window_size(opts.init_stream_window_size)
.initial_connection_window_size(opts.init_connection_window_size)
.keep_alive_interval(opts.http2_keep_alive_interval)
.clone();
if let Some(val) = opts.http2_keep_alive_timeout {
settings.keep_alive_timeout(val);
}
if let Some(val) = opts.http2_keep_alive_while_idle {
settings.keep_alive_while_idle(val);
}
if let Some(val) = opts.http2_adaptive_window {
settings.adaptive_window(val);
}
if let Some(val) = opts.http2_max_header_list_size {
settings.max_header_list_size(val);
}
let transport_fut = match self.network_type {
NetworkType::Tcp => {
let addr: SocketAddr =
SocketAddr::from_str(&address).map_err(|err| err.to_string())?;
runtime.tcp_stream(
addr,
TcpOptions {
enable_nodelay: opts.tcp_nodelay,
keepalive: opts.tcp_keepalive,
},
)
}
NetworkType::Unix => {
runtime.unix_stream(PathBuf::from(&address), UnixSocketOptions::default())
}
};
let transport = if let Some(deadline) = opts.connect_deadline {
let timeout = deadline.saturating_duration_since(Instant::now());
tokio::select! {
_ = runtime.sleep(timeout) => {
return Err("timed out waiting for transport stream to connect".to_string());
}
transport = transport_fut => transport?,
}
} else {
transport_fut.await?
};
let credentials = &security_info.credentials;
let handshake_ouput = credentials
.dyn_connect(
&security_info.authority,
transport,
&security_info.handshake_info,
&runtime,
)
.await?;
let transport = HyperStream::new(handshake_ouput.endpoint);
let (sender, connection) = settings
.handshake(transport)
.await
.map_err(|err| err.to_string())?;
let (tx, rx) = oneshot::channel();
let task_handle = runtime.spawn(Box::pin(async move {
if let Err(err) = connection.await {
let _ = tx.send(Err(err.to_string()));
} else {
let _ = tx.send(Ok(()));
}
}));
let sender = SendRequestWrapper::from(sender);
let service = ServiceBuilder::new()
.option_layer(opts.concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(opts.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.map_err(Into::<BoxError>::into)
.service(sender);
let service = BoxService::new(service);
let (service, worker) = Buffer::pair(service, DEFAULT_BUFFER_SIZE);
runtime.spawn(Box::pin(worker));
let authority = &security_info.authority.host_port_string();
let uri = Uri::from_maybe_shared(format!("http://{}", &authority))
.map_err(|e| format!("failed to create URL with authority {}: {}", authority, e))?;
let grpc = Grpc::with_origin(TonicService { inner: service }, uri);
let service = TonicTransport {
grpc,
task_handle,
runtime,
};
Ok((service, handshake_ouput.security, rx))
}
}
struct SendRequestWrapper {
inner: SendRequest<Body>,
}
impl From<SendRequest<Body>> for SendRequestWrapper {
fn from(inner: SendRequest<Body>) -> Self {
Self { inner }
}
}
impl TowerService<HttpRequest<Body>> for SendRequestWrapper {
type Response = HttpResponse<Body>;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let fut = self.inner.send_request(req);
Box::pin(async move { fut.await.map_err(Into::into).map(|res| res.map(Body::new)) })
}
}
#[derive(Clone)]
struct TonicService {
inner: Buffer<http::Request<Body>, BoxFuture<'static, Result<http::Response<Body>, BoxError>>>,
}
impl GrpcService<Body> for TonicService {
type ResponseBody = Body;
type Error = BoxError;
type Future = ResponseFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
tower::Service::poll_ready(&mut self.inner, cx)
}
fn call(&mut self, request: http::Request<Body>) -> Self::Future {
ResponseFuture {
inner: tower::Service::call(&mut self.inner, request),
}
}
}
pub(crate) struct ResponseFuture {
inner: BufferResponseFuture<BoxFuture<'static, Result<HttpResponse<Body>, BoxError>>>,
}
impl Future for ResponseFuture {
type Output = Result<http::Response<Body>, BoxError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
pub(crate) struct BufCodec {}
impl Codec for BufCodec {
type Encode = Box<dyn Buf + Send + Sync>;
type Decode = Bytes;
type Encoder = BufEncoder;
type Decoder = BytesDecoder;
fn encoder(&mut self) -> Self::Encoder {
BufEncoder {}
}
fn decoder(&mut self) -> Self::Decoder {
BytesDecoder {}
}
}
pub struct BytesEncoder {}
impl Encoder for BytesEncoder {
type Item = Bytes;
type Error = TonicStatus;
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
dst.put_slice(&item);
Ok(())
}
}
pub struct BufEncoder {}
impl Encoder for BufEncoder {
type Item = Box<dyn Buf + Send + Sync>;
type Error = TonicStatus;
fn encode(&mut self, mut item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
dst.put(&mut *item);
Ok(())
}
}
#[derive(Debug)]
pub struct BytesDecoder {}
impl Decoder for BytesDecoder {
type Item = Bytes;
type Error = TonicStatus;
fn decode(
&mut self,
src: &mut tonic::codec::DecodeBuf<'_>,
) -> Result<Option<Self::Item>, Self::Error> {
Ok(Some(src.copy_to_bytes(src.remaining())))
}
}