use crate::rpc::RPC;
use std::io::Write;
use std::mem::MaybeUninit;
use std::net::TcpStream;
use std::time::Duration;
pub const ADDR: &str = "127.0.0.1:1337";
pub const BACKOFF_DURATION: Duration = Duration::from_millis(10);
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum ReadStatus {
Completed(Vec<u8>),
Progress,
WouldBlock,
}
pub fn send_socket_rpc(msg: RPC, stream: &mut TcpStream) -> std::io::Result<()> {
let msgpack_bytes = rmp_serde::to_vec(&msg).unwrap();
send_socket_frame(&msgpack_bytes, stream)?;
Ok(())
}
pub fn send_socket_frame(msg: &[u8], stream: &mut TcpStream) -> std::io::Result<()> {
let len: u32 = msg.len().try_into().unwrap();
write_all_robust(stream, &len.to_be_bytes())?;
write_all_robust(stream, msg)?;
Ok(())
}
fn write_all_robust(stream: &mut TcpStream, mut buf: &[u8]) -> std::io::Result<()> {
while !buf.is_empty() {
match stream.write(buf) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(e) => return Err(e),
}
}
Ok(())
}
pub struct InFlightMessage {
bytes: Vec<u8>,
buf: [MaybeUninit<u8>; 1024],
}
impl Default for InFlightMessage {
fn default() -> Self {
Self::new()
}
}
impl InFlightMessage {
pub fn new() -> InFlightMessage {
InFlightMessage {
bytes: vec![],
buf: [MaybeUninit::uninit(); 1024],
}
}
pub fn add_bytes(&mut self, bytes: &[u8]) {
self.bytes.extend_from_slice(bytes);
}
pub fn read_stream(&mut self, stream: &mut TcpStream) -> std::io::Result<ReadStatus> {
use std::io::Read;
if let Some(msg) = self.pop_msg() {
return Ok(ReadStatus::Completed(msg));
}
let read_data = unsafe {
let read_slice =
std::slice::from_raw_parts_mut(self.buf.as_mut_ptr() as *mut u8, self.buf.len());
stream.set_nonblocking(true)?;
let read_size_answer = stream.read(read_slice);
match read_size_answer {
Ok(length) => {
std::slice::from_raw_parts_mut(self.buf.as_mut_ptr() as *mut u8, length)
}
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => return Ok(ReadStatus::WouldBlock),
_ => {
return Err(e);
}
},
}
};
self.add_bytes(read_data);
if let Some(msg) = self.pop_msg() {
Ok(ReadStatus::Completed(msg))
} else if read_data.is_empty() {
Ok(ReadStatus::WouldBlock)
} else {
Ok(ReadStatus::Progress)
}
}
fn pop_msg(&mut self) -> Option<Vec<u8>> {
if self.bytes.len() < 4 {
None
} else {
let len_bytes: &[u8; 4] = self.bytes[0..4].try_into().unwrap();
let len = u32::from_be_bytes(*len_bytes);
let self_bytes_len: u32 = self.bytes.len().try_into().unwrap();
if self_bytes_len - 4 < len {
None
} else {
self.bytes.drain(0..4);
let len_usize: usize = len.try_into().unwrap();
let msg = self.bytes.drain(0..len_usize);
Some(msg.collect())
}
}
}
pub fn expected_length(&self) -> Option<u32> {
if self.bytes.len() < 4 {
None
} else {
let len_bytes: &[u8; 4] = self.bytes[0..4].try_into().unwrap();
let len = u32::from_be_bytes(*len_bytes);
Some(len)
}
}
pub fn current_length(&self) -> usize {
if self.bytes.len() < 4 {
0
} else {
self.bytes.len() - 4
}
}
}