use std::io;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use futures::future;
use futures::stream;
use futures::stream::StreamExt as _;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tonic::transport::Channel;
use tonic::transport::Endpoint;
use tonic::transport::Uri;
use tonic::transport::server::Connected;
use tonic::transport::server::Router;
use tower::service_fn;
pub struct DuplexChannel<R, W> {
read: R,
write: W,
}
impl<R, W> DuplexChannel<R, W> {
pub fn new(read: R, write: W) -> Self {
Self { read, write }
}
}
impl<R, W> AsyncRead for DuplexChannel<R, W>
where
R: AsyncRead + Unpin,
W: Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.read).poll_read(cx, buf)
}
}
impl<R, W> AsyncWrite for DuplexChannel<R, W>
where
R: Unpin,
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.write).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.write).poll_shutdown(cx)
}
}
impl<R, W> Connected for DuplexChannel<R, W> {
type ConnectInfo = ();
fn connect_info(&self) -> Self::ConnectInfo {}
}
pub async fn connect_client<T>(io: T, name: &str) -> anyhow::Result<Channel>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = hyper_util::rt::TokioIo::new(io);
let mut io = Some(io);
let channel = Endpoint::try_from(format!("http://{name}.invalid"))?
.connect_with_connector(service_fn(move |_: Uri| {
let io = io
.take()
.ok_or_else(|| io::Error::other("test-runner gRPC channel cannot reconnect"));
future::ready(io)
}))
.await?;
Ok(channel)
}
pub async fn serve_connection<I, F>(
io: I,
router: Router,
shutdown: F,
) -> Result<(), tonic::transport::Error>
where
I: AsyncRead + AsyncWrite + Connected + Send + Unpin + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
let incoming = stream::once(future::ready(io))
.chain(stream::once(future::pending()))
.map(Ok::<_, io::Error>);
router
.serve_with_incoming_shutdown(incoming, shutdown)
.await
}