use crate::Comms;
use base64ct::{Base64, Encoding};
use log::*;
use std::{
collections::HashMap,
fmt::Write as _,
io::{BufReader, prelude::*},
net::TcpStream,
sync::{Arc, Mutex, mpsc},
};
pub const CMD_CLDR: &str = "#"; pub const CMD_PING: &str = "^"; pub const CMD_SEND: &str = "$"; pub const CMD_BCST: &str = "*";
pub fn handle_protocol(
id: usize,
mut stream: TcpStream,
leader: usize,
members: Arc<Mutex<HashMap<String, usize>>>,
tx_toleader: Vec<mpsc::Sender<Comms>>,
tx_broadcast: Vec<mpsc::Sender<Comms>>,
) {
let mut reader = BufReader::new(&stream);
let mut data = String::new();
reader.read_line(&mut data).unwrap();
debug!("[T{id}]: request: {data:?}");
if data.starts_with(CMD_CLDR) {
let mut ack = String::new();
if leader > 0 {
write!(&mut ack, "+1\n").unwrap();
} else {
write!(&mut ack, "+0\n").unwrap();
}
let _ = stream.write_all(ack.as_bytes());
return;
}
if data.starts_with(CMD_PING) {
if data.len() == 2 {
let mut ack = String::new();
write!(&mut ack, "+1\n").unwrap();
let _ = stream.write_all(ack.as_bytes());
return;
}
{
if let Ok(mut v) = members.lock() {
let name = &data[1..&data.len() - 1];
v.insert(name.to_string(), 0);
}
}
let mut all = String::new();
let mut ack = String::new();
{
if let Ok(v) = members.lock() {
for (k, _) in &*v {
write!(&mut all, "{},", k).unwrap();
}
}
}
all.pop(); write!(&mut ack, "+{}\n", all).unwrap();
let _ = stream.write_all(ack.as_bytes());
return;
}
if data.starts_with(CMD_SEND) {
if tx_toleader.len() == 0 {
let _ = stream.write_all("-send disabled\n".as_bytes());
return;
}
if leader == 0 {
let _ = stream.write_all("-not leader\n".as_bytes());
return;
}
let decoded = match Base64::decode_vec(&data[1..&data.len() - 1]) {
Ok(v) => v,
Err(e) => {
let mut err = String::new();
write!(&mut err, "-{e}\n").unwrap();
let _ = stream.write_all(err.as_bytes());
return;
}
};
let (tx, rx): (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>) = mpsc::channel();
if let Err(e) = tx_toleader[0].send(Comms::ToLeader { msg: decoded, tx }) {
let mut err = String::new();
write!(&mut err, "-{e}\n").unwrap();
let _ = stream.write_all(err.as_bytes());
return;
}
let mut rep = rx.recv().unwrap();
let mut ack = vec![b'+'];
ack.append(&mut rep);
ack.push(b'\n');
let _ = stream.write_all(&ack);
return;
}
if data.starts_with(CMD_BCST) {
if tx_broadcast.len() == 0 {
let _ = stream.write_all("-send disabled\n".as_bytes());
return;
}
let decoded = match Base64::decode_vec(&data[1..&data.len() - 1]) {
Ok(v) => v,
Err(e) => {
let mut err = String::new();
write!(&mut err, "-{e}\n").unwrap();
let _ = stream.write_all(err.as_bytes());
return;
}
};
let (tx, rx): (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>) = mpsc::channel();
if let Err(e) = tx_broadcast[0].send(Comms::Broadcast { msg: decoded, tx }) {
let mut err = String::new();
write!(&mut err, "-{e}\n").unwrap();
let _ = stream.write_all(err.as_bytes());
return;
}
let mut rep = rx.recv().unwrap();
let mut ack = vec![b'+'];
ack.append(&mut rep);
ack.push(b'\n');
let _ = stream.write_all(&ack);
return;
}
let _ = stream.write_all(b"-unknown\n");
}