resp-async 0.0.7

Asynchronous Redis protocol parser
Documentation
use std::net::SocketAddr;
use std::time::Duration;

use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;

use resp_async::{Cmd, Extensions, PushHandle, Router, Server, ServerBuilder, Value, ValueDecoder};

#[derive(Debug)]
struct Marker(u8);

async fn start_server_with<State, F>(
    app: Router<State>,
    build: F,
) -> (SocketAddr, oneshot::Sender<()>)
where
    State: Send + Sync + 'static,
    F: FnOnce(ServerBuilder) -> ServerBuilder + Send + 'static,
{
    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = listener.local_addr().unwrap();
    let (shutdown_tx, shutdown_rx) = oneshot::channel();

    tokio::spawn(async move {
        let _ = build(Server::bind(addr.to_string()))
            .with_graceful_shutdown(async move {
                let _ = shutdown_rx.await;
            })
            .serve_with_listener(listener, app)
            .await;
    });

    (addr, shutdown_tx)
}

async fn start_server<State>(app: Router<State>) -> (SocketAddr, oneshot::Sender<()>)
where
    State: Send + Sync + 'static,
{
    start_server_with(app, |builder| builder).await
}

fn command(name: &str, args: &[&str]) -> Value {
    let mut items = Vec::with_capacity(1 + args.len());
    items.push(Value::Bulk(Bytes::from(name.as_bytes().to_vec())));
    for arg in args {
        items.push(Value::Bulk(Bytes::from(arg.as_bytes().to_vec())));
    }
    Value::Array(items)
}

async fn read_value(
    stream: &mut TcpStream,
    rd: &mut BytesMut,
    decoder: &mut ValueDecoder,
) -> Value {
    loop {
        if let Some(value) = decoder.try_decode(rd).unwrap() {
            return value;
        }
        let n = stream.read_buf(rd).await.unwrap();
        if n == 0 {
            panic!("unexpected EOF");
        }
    }
}

#[tokio::test]
async fn responses_are_ordered_for_pipelining() {
    async fn ping() -> Value {
        Value::Simple(Bytes::from_static(b"PONG"))
    }

    let app: Router<()> = Router::new().route("PING", ping);
    let (addr, shutdown) = start_server(app).await;

    let mut stream = TcpStream::connect(addr).await.unwrap();
    let mut rd = BytesMut::with_capacity(1024);
    let mut decoder = ValueDecoder::default();
    let mut buf = BytesMut::new();
    command("PING", &[]).encode(&mut buf);
    command("PING", &[]).encode(&mut buf);
    stream.write_all(&buf).await.unwrap();

    let first = read_value(&mut stream, &mut rd, &mut decoder).await;
    let second = read_value(&mut stream, &mut rd, &mut decoder).await;

    assert_eq!(first, Value::Simple(Bytes::from_static(b"PONG")));
    assert_eq!(second, Value::Simple(Bytes::from_static(b"PONG")));

    let _ = shutdown.send(());
}

#[tokio::test]
async fn push_messages_follow_responses() {
    async fn subscribe(_cmd: Cmd, push: PushHandle) -> Value {
        let _ = push
            .send(Value::Array(vec![
                Value::Bulk(Bytes::from_static(b"message")),
                Value::Bulk(Bytes::from_static(b"chan")),
                Value::Bulk(Bytes::from_static(b"payload")),
            ]))
            .await;
        Value::Simple(Bytes::from_static(b"OK"))
    }

    let app: Router<()> = Router::new().route("SUB", subscribe);
    let (addr, shutdown) = start_server(app).await;

    let mut stream = TcpStream::connect(addr).await.unwrap();
    let mut rd = BytesMut::with_capacity(1024);
    let mut decoder = ValueDecoder::default();
    let mut buf = BytesMut::new();
    command("SUB", &[]).encode(&mut buf);
    stream.write_all(&buf).await.unwrap();

    let first = read_value(&mut stream, &mut rd, &mut decoder).await;
    let second = read_value(&mut stream, &mut rd, &mut decoder).await;

    assert_eq!(first, Value::Simple(Bytes::from_static(b"OK")));
    assert!(matches!(second, Value::Array(_)));

    let _ = shutdown.send(());
}

#[tokio::test]
async fn connection_extensions_are_visible() {
    async fn check(ext: Extensions) -> Value {
        match ext.get::<Marker>() {
            Some(marker) => Value::Integer(i64::from(marker.0)),
            None => Value::Error(Bytes::from_static(b"ERR missing extension")),
        }
    }

    let app: Router<()> = Router::new().route("EXT", check);
    let (addr, shutdown) = start_server_with(app, |builder| {
        builder.with_connection_extensions(|_info| {
            let mut ext = Extensions::default();
            ext.insert(Marker(7));
            ext
        })
    })
    .await;

    let mut stream = TcpStream::connect(addr).await.unwrap();
    let mut rd = BytesMut::with_capacity(1024);
    let mut decoder = ValueDecoder::default();
    let mut buf = BytesMut::new();
    command("EXT", &[]).encode(&mut buf);
    stream.write_all(&buf).await.unwrap();

    let response = read_value(&mut stream, &mut rd, &mut decoder).await;
    assert_eq!(response, Value::Integer(7));

    let _ = shutdown.send(());
}

#[tokio::test]
async fn quit_closes_connection() {
    async fn quit() -> Value {
        Value::Simple(Bytes::from_static(b"OK"))
    }

    let app: Router<()> = Router::new().route("QUIT", quit);
    let (addr, shutdown) = start_server(app).await;

    let mut stream = TcpStream::connect(addr).await.unwrap();
    let mut rd = BytesMut::with_capacity(1024);
    let mut decoder = ValueDecoder::default();
    let mut buf = BytesMut::new();
    command("QUIT", &[]).encode(&mut buf);
    stream.write_all(&buf).await.unwrap();

    let response = read_value(&mut stream, &mut rd, &mut decoder).await;
    assert_eq!(response, Value::Simple(Bytes::from_static(b"OK")));

    let mut byte = [0u8; 1];
    let read = tokio::time::timeout(Duration::from_secs(1), stream.read(&mut byte))
        .await
        .expect("read timeout")
        .unwrap();
    assert_eq!(read, 0);

    let _ = shutdown.send(());
}

#[tokio::test]
async fn graceful_shutdown_stops_server() {
    let app: Router<()> = Router::new();
    let (addr, shutdown) = start_server(app).await;

    let _ = TcpStream::connect(addr).await.unwrap();
    let _ = shutdown.send(());
}