use crate::prelude::{asserted_short_name, cross_os_fd, ConId, Framer};
use bytes::{Bytes, BytesMut};
use byteserde::utils::hex::to_hex_pretty;
use log::{debug, log_enabled};
use std::fmt::Display;
use std::io::{ErrorKind, Read, Write};
use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::{io::Error, net::TcpStream};
const EOF: usize = 0;
#[derive(Debug)]
pub struct FrameReader<F: Framer, const MAX_MSG_SIZE: usize> {
pub(crate) con_id: ConId,
pub(crate) stream_reader: TcpStream,
buffer: BytesMut,
phantom: std::marker::PhantomData<F>,
}
impl<F: Framer, const MAX_MSG_SIZE: usize> FrameReader<F, MAX_MSG_SIZE> {
pub fn new(con_id: ConId, reader: TcpStream) -> FrameReader<F, MAX_MSG_SIZE> {
Self {
con_id,
stream_reader: reader,
buffer: BytesMut::with_capacity(MAX_MSG_SIZE),
phantom: std::marker::PhantomData,
}
}
#[inline]
pub fn read_frame(&mut self) -> Result<Option<Bytes>, Error> {
loop {
if let Some(bytes) = F::get_frame(&mut self.buffer) {
return Ok(Some(bytes));
} else {
#[allow(clippy::uninit_assumed_init)]
let mut buf: [u8; MAX_MSG_SIZE] = unsafe { MaybeUninit::uninit().assume_init() };
match self.stream_reader.read(&mut buf) {
Ok(EOF) => {
self.shutdown(Shutdown::Write, "read_frame EOF");
if self.buffer.is_empty() {
return Ok(None);
} else {
let msg = format!(
"{} {}::read_frame connection reset by peer, residual buf:\n{}",
self.con_id,
asserted_short_name!("FrameReader", Self),
to_hex_pretty(&self.buffer[..])
);
return Err(Error::new(std::io::ErrorKind::ConnectionReset, msg));
}
}
Ok(len) => {
self.buffer.extend_from_slice(&buf[..len]);
continue; }
Err(e) => {
self.shutdown(Shutdown::Write, "read_frame error");
let msg = format!("{} {}::read_frame caused by: [{}] residual buf:\n{}", self.con_id, asserted_short_name!("FrameReader", Self), e, to_hex_pretty(&self.buffer[..]));
return Err(Error::new(e.kind(), msg));
}
}
}
}
}
#[inline]
fn shutdown(&mut self, how: Shutdown, reason: &str) {
match self.stream_reader.shutdown(how) {
Ok(_) => {
if log_enabled!(log::Level::Debug) {
debug!("{}::shutdown how: {:?}, reason: {}", self, how, reason);
}
}
Err(e) if e.kind() == ErrorKind::NotConnected => {
if log_enabled!(log::Level::Debug) {
debug!("{}::shutdown while disconnected how: {:?}, reason: {}", self, how, reason);
}
}
Err(e) => {
panic!("{}::shutdown how: {:?}, reason: {}, caused by: [{}]", self, how, reason, e);
}
}
}
}
impl<F: Framer, const MAX_MSG_SIZE: usize> Drop for FrameReader<F, MAX_MSG_SIZE> {
fn drop(&mut self) {
self.shutdown(Shutdown::Both, "drop")
}
}
impl<F: Framer, const MAX_MSG_SIZE: usize> Display for FrameReader<F, MAX_MSG_SIZE> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"FrameReader<{}> {{ {}, addr: {}, peer: {}, fd: {} }}",
std::any::type_name::<F>().split("::").last().unwrap_or("Unknown"),
self.con_id,
match self.stream_reader.local_addr() {
Ok(_) => "connected",
Err(_) => "disconnected",
},
match self.stream_reader.peer_addr() {
Ok(_) => "connected",
Err(_) => "disconnected",
},
cross_os_fd!(&self.stream_reader),
)
}
}
#[derive(Debug)]
pub struct FrameWriter {
pub(crate) con_id: ConId,
pub(crate) stream_writer: TcpStream,
}
impl FrameWriter {
pub fn new(con_id: ConId, stream: TcpStream) -> Self {
Self { con_id, stream_writer: stream }
}
#[inline]
pub fn write_frame(&mut self, bytes: &[u8]) -> Result<(), Error> {
match self.stream_writer.write_all(bytes) {
Ok(_) => Ok(()),
Err(e) => {
self.shutdown(Shutdown::Write, "write_frame error");
let msg = format!("{} FrameWriter::write_frame caused by: [{}]", self.con_id, e);
Err(Error::new(e.kind(), msg))
}
}
}
fn shutdown(&mut self, how: Shutdown, reason: &str) {
match self.stream_writer.shutdown(how) {
Ok(_) => {
if log_enabled!(log::Level::Debug) {
debug!("{}::shutdown how: {:?}, reason: {}", self, how, reason);
}
}
Err(e) if e.kind() == ErrorKind::NotConnected => {
if log_enabled!(log::Level::Debug) {
debug!("{}::shutdown while disconnected how: {:?}, reason: {}", self, how, reason);
}
}
Err(e) => {
panic!("{}::shutdown how: {:?}, reason: {}, caused by: [{}]", self, how, reason, e);
}
}
}
}
impl Drop for FrameWriter {
fn drop(&mut self) {
self.shutdown(Shutdown::Both, "drop")
}
}
impl Display for FrameWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"FrameWriter {{ {}, addr: {}, peer: {}, fd: {} }}",
self.con_id,
match self.stream_writer.local_addr() {
Ok(_) => "connected",
Err(_) => "disconnected",
},
match self.stream_writer.peer_addr() {
Ok(_) => "connected",
Err(_) => "disconnected",
},
cross_os_fd!(&self.stream_writer),
)
}
}
type FrameProcessor<F, const MAX_MSG_SIZE: usize> = (FrameReader<F, MAX_MSG_SIZE>, FrameWriter);
pub fn into_split_framer<F: Framer, const MAX_MSG_SIZE: usize>(mut con_id: ConId, stream: TcpStream) -> FrameProcessor<F, MAX_MSG_SIZE> {
con_id.set_local(stream.local_addr().unwrap());
con_id.set_peer(stream.peer_addr().unwrap());
let (reader, writer) = (stream.try_clone().expect("Failed to try_clone TcpStream for FrameReader"), stream);
(FrameReader::<F, MAX_MSG_SIZE>::new(con_id.clone(), reader), FrameWriter::new(con_id, writer))
}
#[cfg(test)]
mod test {
use crate::prelude::*;
use byteserde::utils::hex::to_hex_pretty;
use links_core::{assert_error_kind_on_target_family, fmt_num, prelude::ConId, unittest::setup};
use log::{error, info};
use rand::Rng;
use std::{
net::{TcpListener, TcpStream},
thread::{self, sleep},
time::{Duration, Instant},
};
#[test]
fn test_reader() {
setup::log::configure_level(log::LevelFilter::Info);
const TEST_SEND_FRAME_SIZE: usize = 128;
const WRITE_N_TIMES: usize = 100_000;
pub type MsgFramer = FixedSizeFramer<TEST_SEND_FRAME_SIZE>;
let send_frame = setup::data::random_bytes(TEST_SEND_FRAME_SIZE);
info!("send_frame: \n{}", to_hex_pretty(send_frame));
let addr = setup::net::rand_avail_addr_port();
let svc = thread::Builder::new()
.name("Thread-Svc".to_owned())
.spawn({
move || {
let listener = TcpListener::bind(addr).unwrap();
let (stream, _) = listener.accept().unwrap();
let (mut svc_reader, _svc_writer) = into_split_framer::<MsgFramer, TEST_SEND_FRAME_SIZE>(ConId::svc(Some("unittest"), addr, None), stream);
info!("svc: reader: {}", svc_reader);
let mut frame_recv_count = 0_usize;
loop {
let res = svc_reader.read_frame();
match res {
Ok(frame) => {
if let None = frame {
info!("svc: read_frame is None, client closed connection");
break;
} else {
frame_recv_count += 1;
}
}
Err(e) => {
error!("Svc read_frame error: {}", e.to_string());
break;
}
}
}
frame_recv_count
}
})
.unwrap();
sleep(Duration::from_millis(100)); let (mut clt_reader, mut clt_writer) = into_split_framer::<MsgFramer, TEST_SEND_FRAME_SIZE>(ConId::clt(Some("unittest"), None, addr), TcpStream::connect(addr).unwrap());
info!("clt: {}", clt_writer);
let mut frame_send_count = 0_usize;
let start = Instant::now();
for _ in 0..WRITE_N_TIMES {
clt_writer.write_frame(send_frame).unwrap();
frame_send_count += 1;
}
let elapsed = start.elapsed();
if rand::thread_rng().gen_range(1..=2) % 2 == 0 {
info!("dropping clt_writer");
drop(clt_writer);
let opt = clt_reader.read_frame().unwrap();
info!("clt_reader.read_frame() opt: {:?}", opt);
assert_eq!(opt, None);
} else {
info!("dropping clt_reader");
drop(clt_reader);
let err = clt_writer.write_frame(send_frame).unwrap_err();
info!("clt_writer.write_frame() err: {}", err);
assert_error_kind_on_target_family!(err, std::io::ErrorKind::BrokenPipe);
}
let frame_recv_count = svc.join().unwrap();
info!("frame_send_count: {}, frame_recv_count: {}", fmt_num!(frame_send_count), fmt_num!(frame_recv_count));
info!("per send elapsed: {:?}, total elapsed: {:?} ", elapsed / WRITE_N_TIMES as u32, elapsed);
assert_eq!(frame_send_count, frame_recv_count);
assert_eq!(frame_send_count, WRITE_N_TIMES);
}
}