use std::io::Result as IoResult;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::sync::CancellationToken;
use crate::conn::{Holding, HttpBuilder};
use crate::fuse::{ArcFuseFactory, ArcFusewire};
use crate::http::HttpConnection;
use crate::service::HyperHandler;
use super::{Accepted, Acceptor, Listener};
pub enum JoinedStream<A, B> {
#[allow(missing_docs)]
A(A),
#[allow(missing_docs)]
B(B),
}
impl<A, B> AsyncRead for JoinedStream<A, B>
where
A: AsyncRead + Send + Unpin + 'static,
B: AsyncRead + Send + Unpin + 'static,
{
#[inline]
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IoResult<()>> {
match &mut self.get_mut() {
JoinedStream::A(a) => Pin::new(a).poll_read(cx, buf),
JoinedStream::B(b) => Pin::new(b).poll_read(cx, buf),
}
}
}
impl<A, B> AsyncWrite for JoinedStream<A, B>
where
A: AsyncWrite + Send + Unpin + 'static,
B: AsyncWrite + Send + Unpin + 'static,
{
#[inline]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
match &mut self.get_mut() {
JoinedStream::A(a) => Pin::new(a).poll_write(cx, buf),
JoinedStream::B(b) => Pin::new(b).poll_write(cx, buf),
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
match &mut self.get_mut() {
JoinedStream::A(a) => Pin::new(a).poll_flush(cx),
JoinedStream::B(b) => Pin::new(b).poll_flush(cx),
}
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
match &mut self.get_mut() {
JoinedStream::A(a) => Pin::new(a).poll_shutdown(cx),
JoinedStream::B(b) => Pin::new(b).poll_shutdown(cx),
}
}
}
#[pin_project]
pub struct JoinedListener<A, B> {
#[pin]
a: A,
#[pin]
b: B,
}
impl<A, B> JoinedListener<A, B> {
#[inline]
pub fn new(a: A, b: B) -> Self {
JoinedListener { a, b }
}
}
impl<A, B> Listener for JoinedListener<A, B>
where
A: Listener + Send + Unpin + 'static,
B: Listener + Send + Unpin + 'static,
A::Acceptor: Acceptor + Send + Unpin + 'static,
B::Acceptor: Acceptor + Send + Unpin + 'static,
{
type Acceptor = JoinedAcceptor<A::Acceptor, B::Acceptor>;
async fn try_bind(self) -> crate::Result<Self::Acceptor> {
let a = self.a.try_bind().await?;
let b = self.b.try_bind().await?;
let holdings = a.holdings().iter().chain(b.holdings().iter()).cloned().collect();
Ok(JoinedAcceptor { a, b, holdings })
}
}
pub struct JoinedAcceptor<A, B> {
a: A,
b: B,
holdings: Vec<Holding>,
}
impl<A, B> JoinedAcceptor<A, B> {
pub fn new(a: A, b: B, holdings: Vec<Holding>) -> Self {
JoinedAcceptor { a, b, holdings }
}
}
impl<A, B> HttpConnection for JoinedStream<A, B>
where
A: HttpConnection + Send,
B: HttpConnection + Send,
{
async fn serve(
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
graceful_stop_token: CancellationToken,
) -> IoResult<()> {
match self {
JoinedStream::A(a) => a.serve(handler, builder, graceful_stop_token).await,
JoinedStream::B(b) => b.serve(handler, builder, graceful_stop_token).await,
}
}
fn fusewire(&self) -> Option<ArcFusewire> {
None
}
}
impl<A, B> Acceptor for JoinedAcceptor<A, B>
where
A: Acceptor + Send + Unpin + 'static,
B: Acceptor + Send + Unpin + 'static,
A::Conn: HttpConnection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
B::Conn: HttpConnection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Conn = JoinedStream<A::Conn, B::Conn>;
#[inline]
fn holdings(&self) -> &[Holding] {
&self.holdings
}
#[inline]
async fn accept(&mut self, fuse_factory: Option<ArcFuseFactory>) -> IoResult<Accepted<Self::Conn>> {
tokio::select! {
accepted = self.a.accept(fuse_factory.clone()) => {
Ok(accepted?.map_conn(JoinedStream::A))
}
accepted = self.b.accept(fuse_factory) => {
Ok(accepted?.map_conn(JoinedStream::B))
}
}
}
}
#[cfg(test)]
mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::*;
use crate::conn::TcpListener;
#[tokio::test]
async fn test_joined_listener() {
let addr1 = std::net::SocketAddr::from(([127, 0, 0, 1], 6978));
let addr2 = std::net::SocketAddr::from(([127, 0, 0, 1], 6979));
let mut acceptor = TcpListener::new(addr1).join(TcpListener::new(addr2)).bind().await;
tokio::spawn(async move {
let mut stream = TcpStream::connect(addr1).await.unwrap();
stream.write_i32(50).await.unwrap();
let mut stream = TcpStream::connect(addr2).await.unwrap();
stream.write_i32(100).await.unwrap();
});
let Accepted { mut conn, .. } = acceptor.accept(None).await.unwrap();
let first = conn.read_i32().await.unwrap();
let Accepted { mut conn, .. } = acceptor.accept(None).await.unwrap();
let second = conn.read_i32().await.unwrap();
assert_eq!(first + second, 150);
}
}