#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use std::{
io::{self, Read, Result, Write},
net::{SocketAddr, TcpStream},
};
use tcp_lib::{read, write};
use tracing::{debug, instrument};
#[derive(Debug)]
pub struct Handler {
stream: TcpStream,
}
impl Handler {
#[instrument("tcp/std/new", skip_all)]
pub fn new(host: impl AsRef<str>, port: u16) -> Result<Self> {
let host = host.as_ref();
debug!(?host, port, "connecting TCP stream");
let stream = TcpStream::connect((host, port))?;
debug!("connected");
Ok(Self { stream })
}
#[instrument(skip_all)]
pub fn read(&mut self, mut flow: impl AsMut<read::State>) -> Result<()> {
let state = flow.as_mut();
let bytes_count = self.stream.read(state.get_buffer_mut())?;
state.set_bytes_count(bytes_count);
Ok(())
}
#[instrument(skip_all)]
pub fn write(&mut self, mut flow: impl AsMut<write::State>) -> Result<()> {
let state = flow.as_mut();
let bytes_count = self.stream.write(state.get_buffer())?;
state.set_bytes_count(bytes_count);
Ok(())
}
}
impl From<TcpStream> for Handler {
fn from(stream: TcpStream) -> Self {
Self { stream }
}
}
impl TryFrom<SocketAddr> for Handler {
type Error = io::Error;
fn try_from(addr: SocketAddr) -> io::Result<Self> {
let host = addr.ip();
let port = addr.port();
debug!(?host, port, "connecting TCP stream");
let stream = TcpStream::connect(addr)?;
debug!("connected");
Ok(Self { stream })
}
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::{TcpListener, TcpStream},
thread,
};
use tcp_lib::{read, write};
use crate::Handler;
fn new_tcp_stream_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let accept = thread::spawn(move || listener.accept().unwrap().0);
let client = TcpStream::connect(addr).unwrap();
let server = accept.join().unwrap();
(client, server)
}
#[test]
fn read() {
let (mut client, server) = new_tcp_stream_pair();
let mut handler = Handler::from(server);
let written_bytes = b"data".to_vec();
client.write(&written_bytes).unwrap();
let mut flow = read::Flow::new();
let read_bytes: Vec<u8> = loop {
match flow.next() {
Ok(bytes) => {
break bytes.to_vec();
}
Err(read::Io::Read) => {
handler.read(&mut flow).unwrap();
}
}
};
assert_eq!(written_bytes, read_bytes)
}
#[test]
fn read_chunks() {
let (mut client, server) = new_tcp_stream_pair();
let mut handler = Handler::from(server);
let written_bytes = b"big data ended by dollar$".to_vec();
client.write(&written_bytes).unwrap();
let mut flow = read::Flow::with_capacity(3);
let mut read_bytes = Vec::new();
loop {
let bytes = match flow.next() {
Ok(bytes) => bytes.to_vec(),
Err(read::Io::Read) => {
handler.read(&mut flow).unwrap();
continue;
}
};
println!("bytes: {read_bytes:?}");
read_bytes.extend(bytes);
if let Some(b'$') = read_bytes.last() {
break;
}
}
assert_eq!(written_bytes, read_bytes);
}
#[test]
fn write() {
let (mut client, server) = new_tcp_stream_pair();
let mut handler = Handler::from(server);
let mut flow = write::Flow::new(b"data".to_vec());
let written_bytes: Vec<u8> = loop {
match flow.next() {
Ok(bytes) => {
break bytes.to_vec();
}
Err(write::Io::Write) => {
handler.write(&mut flow).unwrap();
}
}
};
let mut read_bytes = [0; 4];
client.read(&mut read_bytes).unwrap();
assert_eq!(written_bytes, read_bytes)
}
}