use std::io::{Read, Write};
use std::net::TcpStream;
use crate::error::TraciError;
use crate::storage::Storage;
const LENGTH_LEN: usize = 4;
pub struct TraciSocket {
stream: TcpStream,
}
impl TraciSocket {
pub fn connect(host: &str, port: u16) -> Result<Self, TraciError> {
let addr = format!("{host}:{port}");
let stream = TcpStream::connect(&addr)
.map_err(TraciError::Connection)?;
stream.set_nodelay(true)
.map_err(TraciError::Connection)?;
Ok(Self { stream })
}
pub fn send_exact(&mut self, storage: &Storage) -> Result<(), TraciError> {
let payload = storage.as_bytes();
let total_len = (LENGTH_LEN + payload.len()) as u32;
let header = total_len.to_be_bytes();
self.stream.write_all(&header).map_err(TraciError::Connection)?;
self.stream.write_all(payload).map_err(TraciError::Connection)?;
Ok(())
}
pub fn receive_exact(&mut self) -> Result<Storage, TraciError> {
let mut header = [0u8; LENGTH_LEN];
self.stream.read_exact(&mut header).map_err(TraciError::Connection)?;
let total_len = u32::from_be_bytes(header) as usize;
if total_len < LENGTH_LEN {
return Err(TraciError::Protocol(format!(
"Received message length {total_len} is smaller than header size {LENGTH_LEN}"
)));
}
let payload_len = total_len - LENGTH_LEN;
let mut payload = vec![0u8; payload_len];
self.stream.read_exact(&mut payload).map_err(TraciError::Connection)?;
Ok(Storage::from_bytes(payload))
}
pub fn close(&mut self) -> Result<(), TraciError> {
self.stream.shutdown(std::net::Shutdown::Both).map_err(TraciError::Connection)
}
}