hyperdriver 0.10.0

The missing middle for Hyper - Servers and Clients with ergonomic APIs
Documentation
#![allow(missing_docs)]

use std::pin::pin;

use futures_util::StreamExt as _;
use hyper::upgrade::Upgraded;
use hyperdriver::{bridge::io::TokioIo, client::conn::transport::duplex::DuplexTransport};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};

const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);

async fn server_for_client_upgrade(
) -> Result<DuplexTransport, Box<dyn std::error::Error + Send + Sync>> {
    let (tx, incoming) = hyperdriver::stream::duplex::pair();

    let acceptor: hyperdriver::server::conn::Acceptor =
        hyperdriver::server::conn::Acceptor::from(incoming);

    tokio::spawn(tokio::time::timeout(
        TIMEOUT,
        serve_one_h1_upgrade(acceptor),
    ));

    Ok(DuplexTransport::new(1024, tx))
}

async fn serve_one_h1_upgrade(
    acceptor: hyperdriver::server::conn::Acceptor,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let mut acceptor = pin!(acceptor);
    let stream = acceptor.next().await.ok_or("no connection")??;

    let service = hyper::service::service_fn(upgrade_svc);

    let conn =
        hyper::server::conn::http1::Builder::new().serve_connection(TokioIo::new(stream), service);

    conn.with_upgrades().await?;

    Ok(())
}

async fn upgrade_svc(
    mut request: http::Request<hyper::body::Incoming>,
) -> Result<http::Response<hyperdriver::Body>, Box<dyn std::error::Error + Send + Sync>> {
    if !request.headers().contains_key(http::header::UPGRADE) {
        return Ok(http::Response::builder()
            .status(http::StatusCode::BAD_REQUEST)
            .body(hyperdriver::Body::empty())?);
    }

    tokio::spawn(tokio::time::timeout(TIMEOUT, async move {
        let upgraded = hyper::upgrade::on(&mut request)
            .await
            .expect("[server] upgrade erorr");
        server_upgraded_io(upgraded)
            .await
            .expect("[server] upgraded protocol error");
    }));

    Ok(http::Response::builder()
        .status(http::StatusCode::SWITCHING_PROTOCOLS)
        .body(hyperdriver::Body::empty())?)
}

async fn server_upgraded_io(
    upgraded: Upgraded,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let mut upgraded = TokioIo::new(upgraded);
    let mut vec = vec![0; 5];
    upgraded.read_exact(&mut vec).await?;
    let req = String::from_utf8_lossy(&vec);
    println!("[server] client sent {req:?}");
    if req != "hello" {
        println!("[server] got unexpected request");
    }
    upgraded.write_all(b"world").await?;
    println!("[server] sent response");
    Ok(())
}

async fn clinet_upgraded_io(mut response: http::Response<hyperdriver::Body>) {
    if response.status() != http::StatusCode::SWITCHING_PROTOCOLS {
        panic!("Server didn't upgrade: {}", response.status());
    }

    let upgraded = hyper::upgrade::on(&mut response)
        .await
        .expect("upgrade error");
    let mut upgraded = TokioIo::new(upgraded);
    upgraded.write_all(b"hello").await.unwrap();
    println!("[client] sent hello");
    let mut vec = vec![0; 5];
    upgraded.read_exact(&mut vec).await.unwrap();
    let res = String::from_utf8_lossy(&vec);
    println!("[client] got {res:?}");
    assert_eq!(res, "world");
}

#[tokio::test]
async fn client_auto() {
    let transport = server_for_client_upgrade().await.unwrap();

    let mut client = hyperdriver::Client::builder()
        .with_auto_http()
        .with_default_pool()
        .with_transport(transport)
        .build();

    let request = http::Request::get("http://example.org")
        .header(http::header::UPGRADE, "test-hyperdriver")
        .body(hyperdriver::Body::empty())
        .unwrap();

    let response = client.request(request).await.unwrap();
    clinet_upgraded_io(response).await;
}

#[tokio::test]
async fn client_http1() {
    let transport = server_for_client_upgrade().await.unwrap();

    let mut client = hyperdriver::Client::builder()
        .with_protocol(hyper::client::conn::http1::Builder::new())
        .with_default_pool()
        .with_transport(transport)
        .build();

    let request = http::Request::get("http://example.org")
        .header(http::header::UPGRADE, "test-hyperdriver")
        .body(hyperdriver::Body::empty())
        .unwrap();

    let response = client.request(request).await.unwrap();

    clinet_upgraded_io(response).await;
}