use std::collections::HashMap;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::mpsc::{self, Sender};
use std::sync::{Arc, Mutex};
use super::packet::{Packet, read_packet, write_packet};
use super::stream::Stream;
pub struct Connection {
writer: Mutex<Box<dyn Write + Send>>,
stream_senders: Mutex<HashMap<u32, Sender<Packet>>>,
next_stream_id: AtomicU32,
server_exited: AtomicBool,
}
impl Connection {
pub fn new(mut reader: Box<dyn Read + Send>, writer: Box<dyn Write + Send>) -> Arc<Self> {
let conn = Arc::new(Self {
writer: Mutex::new(writer),
stream_senders: Mutex::new(HashMap::new()),
next_stream_id: AtomicU32::new(1),
server_exited: AtomicBool::new(false),
});
let conn_for_reader = Arc::clone(&conn);
std::thread::spawn(move || {
loop {
match read_packet(&mut reader) {
Ok(packet) => {
let senders = conn_for_reader.stream_senders.lock().unwrap();
if let Some(sender) = senders.get(&packet.stream) {
let _ = sender.send(packet);
}
}
Err(_) => {
conn_for_reader.server_exited.store(true, Ordering::SeqCst);
conn_for_reader.stream_senders.lock().unwrap().clear();
break;
}
}
}
});
conn
}
pub fn control_stream(self: &Arc<Self>) -> Stream {
self.register_stream(0)
}
pub fn new_stream(self: &Arc<Self>) -> Stream {
let next = self.next_stream_id.fetch_add(1, Ordering::SeqCst);
let stream_id = (next << 1) | 1;
self.register_stream(stream_id)
}
pub fn connect_stream(self: &Arc<Self>, stream_id: u32) -> Stream {
self.register_stream(stream_id)
}
fn register_stream(self: &Arc<Self>, stream_id: u32) -> Stream {
let (tx, rx) = mpsc::channel();
let mut senders = self.stream_senders.lock().unwrap();
senders.insert(stream_id, tx);
if self.server_has_exited() {
senders.remove(&stream_id);
}
drop(senders);
Stream::new(stream_id, Arc::clone(self), rx)
}
pub fn unregister_stream(&self, stream_id: u32) {
self.stream_senders.lock().unwrap().remove(&stream_id);
}
pub fn mark_server_exited(&self) {
self.server_exited.store(true, Ordering::SeqCst);
}
pub fn server_has_exited(&self) -> bool {
self.server_exited.load(Ordering::SeqCst)
}
fn server_crashed_error() -> std::io::Error {
std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
super::SERVER_CRASHED_MESSAGE,
)
}
pub fn send_packet(&self, packet: &Packet) -> std::io::Result<()> {
let mut writer = self.writer.lock().unwrap();
match write_packet(&mut **writer, packet) {
Ok(()) => Ok(()),
Err(_) if self.server_has_exited() => Err(Self::server_crashed_error()), Err(e) => Err(e), }
}
}
#[cfg(test)]
#[path = "../../tests/embedded/protocol/connection_tests.rs"]
mod tests;