#![forbid(unsafe_code)]
use crc::{Crc, Digest, CRC_32_ISO_HDLC};
use fixed_buffer::FixedBuf;
use read_write_ext::ReadWriteExt;
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
use std::println;
use std::time::Duration;
const X25: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
fn handle_hello<W: Write>(mut writer: W) -> Result<(), std::io::Error> {
writer.write_all("HI\n".as_bytes())
}
struct DigestWriter<'a>(&'a mut Digest<'static, u32>);
impl Write for DigestWriter<'_> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.0.update(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
}
fn handle_crc32<RW: Read + Write>(read_writer: &mut RW, len: u64) -> Result<(), std::io::Error> {
let mut digest = X25.digest();
let mut payload = read_writer.take_rw(len);
std::io::copy(&mut payload, &mut DigestWriter(&mut digest))?;
let response = format!("{:x}\n", digest.finalize());
read_writer.write_all(response.as_bytes())
}
#[derive(Debug, PartialEq)]
enum Request {
Hello,
Crc32(u64),
}
impl Request {
pub fn parse(line_bytes: &[u8]) -> Option<Request> {
let line = std::str::from_utf8(line_bytes).ok()?;
let mut parts = line.splitn(2, ' ');
let method = parts.next().unwrap();
let arg = parts.next();
match (method, arg) {
("HELLO", None) => Some(Request::Hello),
("CRC32", Some(arg)) => {
let len: u64 = std::str::FromStr::from_str(arg).ok()?;
if len <= 1024 * 1024 {
Some(Request::Crc32(len))
} else {
None
}
}
_ => None,
}
}
}
fn handle_conn(mut tcp_stream: TcpStream) -> Result<(), std::io::Error> {
println!("SERVER handling connection");
let mut buf: FixedBuf<4096> = FixedBuf::new();
loop {
let Some(line_bytes) = buf.read_frame(&mut tcp_stream, fixed_buffer::deframe_line)? else {
return Ok(());
};
match Request::parse(line_bytes) {
Some(Request::Hello) => handle_hello(&mut tcp_stream)?,
Some(Request::Crc32(len)) => {
let mut read_writer = tcp_stream.chain_after(&mut buf);
handle_crc32(&mut read_writer, len)?;
}
_ => tcp_stream.write_all("ERROR\n".as_bytes())?,
}
}
}
#[test]
fn main() {
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap();
let addr = listener.local_addr().unwrap();
println!("SERVER listening on {addr}");
std::thread::spawn(move || loop {
match listener.accept() {
Ok((tcp_stream, _addr)) => {
std::thread::spawn(move || {
if let Err(e) = handle_conn(tcp_stream) {
if e.kind() != ErrorKind::NotFound {
println!("SERVER error: {e:?}");
}
}
});
}
Err(e) => {
println!("SERVER error accepting connection: {e:?}");
std::thread::sleep(Duration::from_secs(1));
}
}
});
println!("CLIENT connecting");
let mut tcp_stream = TcpStream::connect(addr).unwrap();
println!("CLIENT sending two requests at the same time: CRC('aaaa') and HELLO");
tcp_stream.write_all(b"CRC32 4\naaaaHELLO\n").unwrap();
let mut response = String::new();
tcp_stream.shutdown(Shutdown::Write).unwrap();
tcp_stream.read_to_string(&mut response).unwrap();
for line in response.lines() {
println!("CLIENT got response {line:?}");
}
}