use std::{error::Error, fmt, io::ErrorKind};
use tokio::{
io::AsyncReadExt,
net::TcpStream,
runtime::Runtime,
sync::{mpsc, oneshot},
task::JoinHandle,
};
use crate::common::{write_data, MessageQueue};
#[derive(Debug)]
pub struct NotConnectedError;
impl fmt::Display for NotConnectedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "attempting to call not started client")
}
}
impl Error for NotConnectedError {}
#[derive(Clone)]
pub enum Event {
Disconnect,
Data(Vec<u8>),
}
pub struct Client {
handle: Option<ClientHandle>,
rt: Runtime,
}
impl Client {
pub fn new() -> Self {
Self {
handle: None,
rt: Runtime::new().unwrap(),
}
}
pub fn start(&mut self, addr: &str) {
let handle = self.rt.block_on(async { ClientHandle::new(addr) });
self.handle = Some(handle);
}
pub fn stop(&mut self) {
self.handle = None;
}
pub fn send(&self, data: Vec<u8>) -> Result<(), NotConnectedError> {
if self.connected() {
self.rt
.block_on(async { self.handle.as_ref().unwrap().send(data) })?;
Ok(())
} else {
Err(NotConnectedError)
}
}
pub fn received(&mut self) -> Result<Vec<Event>, NotConnectedError> {
if self.connected() {
self.rt
.block_on(async { self.handle.as_mut().unwrap().received() })
} else {
Err(NotConnectedError)
}
}
pub fn connected(&self) -> bool {
match &self.handle {
Some(h) => self.rt.block_on(async { h.connected() }),
None => false,
}
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
enum ClientMessage {
Write(Vec<u8>),
Stop,
}
struct ClientHandle {
queue: MessageQueue<Event>,
tx: mpsc::UnboundedSender<ClientMessage>,
handle: JoinHandle<()>,
}
impl ClientHandle {
fn new(addr: &str) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
let queue = MessageQueue::new();
let mut worker = ClientWorker {
queue: queue.clone(),
rx,
};
let a = addr.to_owned();
let handle = tokio::spawn(async move { worker.run(a).await });
Self { queue, tx, handle }
}
fn received(&mut self) -> Result<Vec<Event>, NotConnectedError> {
if self.connected() {
Ok(self.queue.flush())
} else {
Err(NotConnectedError)
}
}
fn send(&self, data: Vec<u8>) -> Result<(), NotConnectedError> {
if self.connected() {
let _ = self.tx.send(ClientMessage::Write(data));
Ok(())
} else {
Err(NotConnectedError)
}
}
fn connected(&self) -> bool {
!self.handle.is_finished()
}
}
impl Drop for ClientHandle {
fn drop(&mut self) {
let _ = self.tx.send(ClientMessage::Stop);
}
}
struct ClientWorker {
queue: MessageQueue<Event>,
rx: mpsc::UnboundedReceiver<ClientMessage>,
}
impl ClientWorker {
async fn run(&mut self, addr: String) {
let conn = TcpStream::connect(addr).await.unwrap();
let (mut read_half, mut write_half) = conn.into_split();
println!("Connected to server");
let mut q = self.queue.clone();
let (stop_tx, mut stop_rx) = oneshot::channel();
tokio::spawn(async move {
loop {
let mut len_buf = [0u8; 4];
match read_half.read_exact(len_buf.as_mut_slice()).await {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
Err(e) => {
eprintln!("Error while reading: {}", e);
break;
}
}
let len = u32::from_le_bytes(len_buf);
let mut buf = vec![0u8; len as usize];
let n = match read_half.read_exact(&mut buf).await {
Ok(n) => n,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
Err(e) => {
eprintln!("Error while reading: {}", e);
break;
}
};
println!("Received {} bytes from server", n);
q.push(Event::Data(buf));
}
let _ = stop_tx.send(());
});
loop {
tokio::select! {
_ = &mut stop_rx => {
self.queue.push(Event::Disconnect);
println!("Disconnected from server");
return;
},
Some(msg) = self.rx.recv() => {
match msg {
ClientMessage::Write(mut data) => {
match write_data(&mut write_half, &mut data).await {
Ok(_) => println!("Wrote {} bytes to server", data.len()),
Err(e) => println!("Error while writing: {}", e),
}
},
ClientMessage::Stop => return
}
}
};
}
}
}