use rustls::{ServerConfig, ServerConnection, StreamOwned};
use std::{
collections::{HashMap, HashSet},
io::{self, Read, Write},
net::{TcpListener, TcpStream},
sync::{Arc, Mutex},
thread::{self, JoinHandle, sleep},
time::Duration,
};
use tracing::{error, info};
use crate::{
DummyRemote, Event, Hook, InitHook, StateMap, StateMapError,
event::EventType,
statemap::{StateMapKey, StateMapValue},
tcp::TcpStateServerResponse,
};
use super::{TcpStateServerRequest, TcpStateServerRequestWrapper};
#[derive(thiserror::Error, Debug)]
pub enum TcpStateServerError {
#[error("IO Error")]
IoError(#[from] io::Error),
#[error("RustLs Error")]
RustlsError(#[from] rustls::Error),
#[error("Decode Error")]
DecodeError(#[from] bincode::error::DecodeError),
#[error("Encode Error")]
EncodeError(#[from] bincode::error::EncodeError),
#[error("Mutex Poisoned Error: {0}")]
MutexPoisonError(String),
#[error("Unknown Error: {0}")]
UnknownError(String),
#[error("StateMapError Error: {0}")]
StateMapError(#[from] StateMapError),
}
fn to_mutex_poison_err<T: ToString>(v: T) -> TcpStateServerError {
TcpStateServerError::MutexPoisonError(v.to_string())
}
fn to_unknown_err<T: ToString>(v: T) -> TcpStateServerError {
TcpStateServerError::UnknownError(v.to_string())
}
pub type HooksVec<K, T, E> = Arc<Vec<Box<dyn Hook<K, T, E>>>>;
pub type InitHooksVec<K, T> = Arc<Vec<Box<dyn InitHook<K, T>>>>;
pub struct TcpStateServer<K, T, E> {
pub serv: TcpListener,
statemaps: HashMap<[u8; 32], Arc<Mutex<StateMap<K, T>>>>,
hooks: HooksVec<K, T, E>,
init_hooks: InitHooksVec<K, T>,
tls_config: Arc<rustls::ServerConfig>,
password: Vec<u8>,
processed_hashes: HashSet<[u8; 32]>, }
impl<K, T, E> TcpStateServer<K, T, E>
where
K: StateMapKey,
T: StateMapValue,
E: EventType,
{
pub fn from_tcp_listner(
listner: TcpListener,
tls_config: Arc<ServerConfig>,
hooks: HooksVec<K, T, E>,
password: Vec<u8>,
) -> Self {
Self {
serv: listner,
statemaps: HashMap::new(),
hooks,
init_hooks: Arc::new(Vec::new()),
tls_config,
password,
processed_hashes: HashSet::new(),
}
}
pub fn set_init_hooks(&mut self, init_hooks: InitHooksVec<K, T>) {
self.init_hooks = init_hooks;
}
fn new_buffer() -> Vec<u8> {
vec![0u8; 100_000]
}
pub fn start_server(self) -> Result<(), TcpStateServerError> {
let sock = self.serv.try_clone()?;
let server = Arc::new(Mutex::new(self));
loop {
let (stream, addr) = match sock.accept() {
Ok(v) => v,
Err(e) => {
error!(
"Error encountered when accepting a connection on socket: {sock:?}, error: {e}"
);
continue;
}
};
info!("Accepted connection from {addr}");
Self::handle_stream(server.clone(), stream);
}
}
pub fn handle_stream(
server: Arc<Mutex<Self>>,
stream: TcpStream,
) -> JoinHandle<Result<TcpStateServerResponse, TcpStateServerError>> {
thread::spawn(move || {
let server_lock = server.lock().map_err(to_mutex_poison_err)?;
let tls_config_arc = server_lock.tls_config.clone();
let server_password_clone = server_lock.password.clone();
drop(server_lock);
let mut stream = StreamOwned::new(ServerConnection::new(tls_config_arc)?, stream);
let mut len_buffer = [0u8; 4];
stream.read_exact(&mut len_buffer)?;
let len = u32::from_be_bytes(len_buffer) as usize;
let mut buffer = Self::new_buffer();
buffer.resize(len, 0);
stream.read_exact(&mut buffer)?;
let req: TcpStateServerRequestWrapper =
bincode::serde::decode_from_slice(&buffer, bincode::config::standard())?.0;
if server_password_clone != req.password {
sleep(Duration::from_micros(rand::random_range(100..100000)));
let resp_buffer = bincode::serde::encode_to_vec(
TcpStateServerResponse::IncorrectPassword,
bincode::config::standard(),
)?;
let _ = stream.write(&(resp_buffer.len() as u32).to_be_bytes())?;
let _ = stream.write(&resp_buffer)?;
drop(stream); return Ok(TcpStateServerResponse::IncorrectPassword);
}
let resp = Self::process_request(server, req.event)?;
let resp_buffer = bincode::serde::encode_to_vec(&resp, bincode::config::standard())?;
let _ = stream.write(&(resp_buffer.len() as u32).to_be_bytes())?;
let _ = stream.write(&resp_buffer)?;
Ok(resp)
})
}
fn process_request(
server: Arc<Mutex<Self>>,
request: TcpStateServerRequest,
) -> Result<TcpStateServerResponse, TcpStateServerError> {
let mut server_lock = server.lock().map_err(to_mutex_poison_err)?;
match request {
TcpStateServerRequest::GetUpdateId { hash } => {
if let Some(v) = server_lock.statemaps.get(&hash) {
Ok(TcpStateServerResponse::UpdateId {
update_id: v.lock().map_err(to_mutex_poison_err)?.get_update_id(),
})
} else {
Ok(TcpStateServerResponse::HashNotFound)
}
}
TcpStateServerRequest::Event { hash, evt_data } => {
if let Some(v) = server_lock.statemaps.get(&hash) {
let event: Event<E> = Event::deserialize(&evt_data).map_err(to_unknown_err)?;
let hooks_arc = server_lock.hooks.clone();
let statemap_arc = v.clone();
thread::spawn(move || {
let statemap_arc = statemap_arc;
info!("Processing event: {:?}", event);
for x in hooks_arc.iter() {
if let Err(e) = x.process_event(&statemap_arc, &event) {
error!("Error processing hook: {e}");
}
}
});
Ok(TcpStateServerResponse::EventInProcess)
} else {
Ok(TcpStateServerResponse::HashNotFound)
}
}
TcpStateServerRequest::GetDiff {
hash,
from_update_id,
upto_update_id,
} => {
if let Some(v) = server_lock.statemaps.get(&hash) {
let diff = v
.lock()
.map_err(to_mutex_poison_err)?
.get_diff(from_update_id, upto_update_id);
Ok(TcpStateServerResponse::Diff {
from_update_id: diff.from_update_id(),
upto_update_id: diff.upto_update_id(),
is_full: diff.is_full_update(),
encoded_diff: bincode::serde::encode_to_vec(
diff.get_diff(),
bincode::config::standard(),
)?,
})
} else {
Ok(TcpStateServerResponse::HashNotFound)
}
}
TcpStateServerRequest::Init { hash, init_data } => {
if server_lock.statemaps.contains_key(&hash) {
Ok(TcpStateServerResponse::InitSuccess)
} else {
let init_hooks = server_lock.init_hooks.clone();
if !server_lock.processed_hashes.insert(hash) {
return Ok(TcpStateServerResponse::InitSuccess);
}
drop(server_lock);
let mut stmap = StateMap::new(Arc::new(DummyRemote));
let raw_map: HashMap<K, T> =
bincode::serde::decode_from_slice(&init_data, bincode::config::standard())
.inspect_err(|_| {
if let Ok(mut server_lock) = server.lock() {
server_lock.processed_hashes.remove(&hash);
}
})?
.0;
for (k, v) in raw_map {
stmap
.push(k, v)
.expect("state map should not be frozen at this point");
}
stmap
.set_master(true)
.expect("state map should not yet be frozen when setting master to true");
stmap.freeze().inspect_err(|_| {
if let Ok(mut server_lock) = server.lock() {
server_lock.processed_hashes.remove(&hash);
}
})?;
if stmap.hash().unwrap() != &hash {
Ok(TcpStateServerResponse::HashMismatch)
} else {
let stmap_arc = Arc::new(Mutex::new(stmap));
for hook in init_hooks.iter() {
if let Err(e) = hook.process_init(stmap_arc.clone()) {
error!("Init hook errored out: {e}");
if let Ok(mut server_lock) = server.lock() {
server_lock.processed_hashes.remove(&hash);
}
return Ok(TcpStateServerResponse::InitFailure {
error_message: e.to_string(),
});
}
}
let mut server_lock = server.lock().map_err(to_mutex_poison_err)?;
server_lock.statemaps.insert(hash, stmap_arc.clone());
server_lock.processed_hashes.remove(&hash);
Ok(TcpStateServerResponse::InitSuccess)
}
}
}
}
}
}