use std::net::SocketAddr;
use futures::io::{BufReader, BufWriter};
use hyper::server::conn::http1;
use hyper::{body::Bytes, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use soketto::{
handshake::http::{is_upgrade_request, Server},
BoxedError,
};
use tokio_util::compat::TokioAsyncReadCompatExt;
type FullBody = http_body_util::Full<Bytes>;
#[tokio::main]
async fn main() -> Result<(), BoxedError> {
env_logger::init();
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
let listener = tokio::net::TcpListener::bind(addr).await?;
log::info!(
"Listening on http://{:?} — connect and I'll echo back anything you send!",
listener.local_addr().unwrap()
);
loop {
let stream = match listener.accept().await {
Ok((stream, addr)) => {
log::info!("Accepting new connection: {addr}");
stream
}
Err(e) => {
log::error!("Accepting new connection failed: {e}");
continue;
}
};
tokio::spawn(async {
let io = TokioIo::new(stream);
let conn = http1::Builder::new().serve_connection(io, service_fn(handler));
let conn = conn.with_upgrades();
if let Err(err) = conn.await {
log::error!("HTTP connection failed {err}");
}
});
}
}
async fn handler(req: Request<hyper::body::Incoming>) -> Result<hyper::Response<FullBody>, BoxedError> {
if is_upgrade_request(&req) {
let mut server = Server::new();
#[cfg(feature = "deflate")]
{
let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server);
server.add_extension(Box::new(deflate));
}
match server.receive_request(&req) {
Ok(response) => {
tokio::spawn(async move {
if let Err(e) = websocket_echo_messages(server, req).await {
log::error!("Error upgrading to websocket connection: {}", e);
}
});
Ok(response.map(|()| FullBody::default()))
}
Err(e) => {
log::error!("Could not upgrade connection: {}", e);
Ok(Response::new(FullBody::from("Something went wrong upgrading!")))
}
}
} else {
Ok(Response::new(FullBody::from("Hello HTTP!")))
}
}
async fn websocket_echo_messages(server: Server, req: Request<hyper::body::Incoming>) -> Result<(), BoxedError> {
let stream = hyper::upgrade::on(req).await?;
let io = TokioIo::new(stream);
let stream = BufReader::new(BufWriter::new(io.compat()));
let (mut sender, mut receiver) = server.into_builder(stream).finish();
let mut message = Vec::new();
loop {
message.clear();
match receiver.receive_data(&mut message).await {
Ok(soketto::Data::Binary(n)) => {
assert_eq!(n, message.len());
sender.send_binary_mut(&mut message).await?;
sender.flush().await?
}
Ok(soketto::Data::Text(n)) => {
assert_eq!(n, message.len());
if let Ok(txt) = std::str::from_utf8(&message) {
sender.send_text(txt).await?;
sender.flush().await?
} else {
break;
}
}
Err(soketto::connection::Error::Closed) => break,
Err(e) => {
eprintln!("Websocket connection error: {}", e);
break;
}
}
}
Ok(())
}