websocket-sans-io 0.1.0

Low-level WebSocket library
Documentation
use std::net::SocketAddr;

use http_body_util::Empty;
use hyper::body::Bytes;
use rand::Rng;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use websocket_sans_io::{
    FrameInfo, Opcode, WebsocketFrameDecoder, WebsocketFrameEncoder, WebsocketFrameEvent,
};

#[path="../src/tokiort.rs"]
mod tokiort;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    let addr: SocketAddr = "127.0.0.1:1234".parse()?;
    let s = tokio::net::TcpStream::connect(addr).await?;
    let s = tokiort::TokioIo::new(s);
    let b = hyper::client::conn::http1::Builder::new();
    let (mut sr, conn) = b.handshake::<_, Empty<Bytes>>(s).await?;
    tokio::spawn(async move {
        match conn.with_upgrades().await {
            Ok(()) => (),
            Err(e) => {
                eprintln!("Error from connection: {e}");
            }
        }
    });

    let rq = hyper::Request::builder()
        .uri("/")
        .header("Connection", "Upgrade")
        .header("Upgrade", "websocket")
        .header("Sec-WebSocket-Version", "13")
        .header("Sec-WebSocket-Key", "wmnj1sVQ7pHv1bVR/wraDw==")
        .body(Empty::new())?;

    let resp = sr.send_request(rq).await?;

    let upg = hyper::upgrade::on(resp).await?;
    let Ok(parts) = upg.downcast::<tokiort::TokioIo<tokio::net::TcpStream>>() else {
        return Err("Failed to downcast".into());
    };
    let mut s = parts.io.inner();
    let debt = parts.read_buf;

    let mut buf = Vec::<u8>::with_capacity(debt.len().max(4096));
    buf.extend_from_slice(&debt[..]);

    let mut frame_decoder = WebsocketFrameDecoder::new();
    let mut frame_encoder = WebsocketFrameEncoder::new();
    let mut bufptr = 0;

    let mut error = false;

    println!("Connected to a WebSocket");

    loop {
        let bufslice = &mut buf[bufptr..];
        let ret = frame_decoder.add_data(bufslice).unwrap();
        bufptr += ret.consumed_bytes;
        if let Some(ref ev) = ret.event {
            match ev {
                WebsocketFrameEvent::Start{frame_info: mut fi, ..} => {
                    if !fi.is_reasonable() {
                        println!("Unreasonable frame");
                        error = true;
                        break;
                    }
                    if fi.mask.is_some() {
                        println!("Masked frame while expected unmasked one");
                        error = true;
                        break;
                    }
                    println!(
                        "Frame {:?} with payload length {}{}",
                        fi.opcode,
                        fi.payload_length,
                        if fi.fin { "" } else { " (non-final)" }
                    );
                    if fi.opcode == Opcode::Ping {
                        fi.opcode = Opcode::Pong;
                    }
                    if fi.opcode == Opcode::ConnectionClose {
                        break;
                    }

                    fi.mask = Some(rand::thread_rng().gen());
                    let header = frame_encoder.start_frame(&fi);
                    s.write_all(&header[..]).await?;
                }
                WebsocketFrameEvent::PayloadChunk{original_opcode: _} => {
                    let payload_slice = &mut bufslice[0..ret.consumed_bytes];
                    frame_encoder.transform_frame_payload(payload_slice);
                    s.write_all(payload_slice).await?;
                }
                WebsocketFrameEvent::End{..} => (),
            }
        }
        if ret.consumed_bytes == 0 && ret.event.is_none() {
            bufptr = 0;
            buf.resize(buf.capacity(), 0);
            let n = s.read(&mut buf[..]).await?;
            if n == 0 {
                println!("EOF");
                error = true;
                break;
            }
            buf.resize(n, 0);
        }
    }

    let header = frame_encoder.start_frame(&FrameInfo {
        opcode: Opcode::ConnectionClose,
        payload_length: 2,
        mask: Some(rand::thread_rng().gen()),
        fin: true,
        reserved: 0,
    });
    s.write_all(&header[..]).await?;
    let mut last_buf : [u8; 2] = if error { 1002u16 } else { 1000u16 }.to_be_bytes();
    frame_encoder.transform_frame_payload(&mut last_buf[..]);
    s.write_all(&last_buf[..]).await?;

    println!("Finished");

    Ok(())
}