use std::collections::{HashMap, HashSet};
use std::cell::Cell;
use std::result;
use queen_io::{
epoll::{Epoll, Token, Ready, EpollOpt}
};
use nson::{
Message, msg,
message_id::MessageId
};
use slab::Slab;
use rand::{SeedableRng, seq::SliceRandom, rngs::SmallRng};
use crate::Wire;
use crate::dict::*;
use crate::error::{ErrorCode, Result, SendError, RecvError};
use super::Hook;
pub struct Slot {
pub chans: HashMap<String, HashSet<usize>>,
pub client_ids: HashMap<MessageId, usize>,
pub clients: Slab<Client>,
pub send_messages: Cell<usize>,
pub recv_messages: Cell<usize>,
rand: SmallRng
}
pub struct Client {
pub token: usize,
pub id: MessageId,
pub label: Message,
pub auth: bool,
pub root: bool,
pub chans: HashMap<String, HashSet<String>>,
pub send_messages: Cell<usize>,
pub recv_messages: Cell<usize>,
pub wire: Wire<Message>
}
impl Slot {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
chans: HashMap::new(),
client_ids: HashMap::new(),
clients: Slab::new(),
send_messages: Cell::new(0),
recv_messages: Cell::new(0),
rand: SmallRng::from_entropy()
}
}
pub(crate) fn add_client(
&mut self,
epoll: &Epoll,
hook: &impl Hook,
wire: Wire<Message>
) -> Result<()> {
let entry = self.clients.vacant_entry();
let client = Client::new(entry.key(), wire);
let success = hook.accept(&client);
if success && matches!(client.wire.send(msg!{OK: 0i32}), Ok(_)) {
epoll.add(&client.wire, Token(entry.key()), Ready::readable(), EpollOpt::level())?;
entry.insert(client);
} else {
client.wire.close();
}
Ok(())
}
pub(crate) fn del_client(
&mut self,
epoll: &Epoll,
hook: &impl Hook,
token: usize
) -> Result<()> {
if self.clients.contains(token) {
let client = self.clients.remove(token);
epoll.delete(&client.wire)?;
for chan in client.chans.keys() {
if let Some(ids) = self.chans.get_mut(chan) {
ids.remove(&token);
if ids.is_empty() {
self.chans.remove(chan);
}
}
}
self.client_ids.remove(&client.id);
hook.remove(&client);
let event_message = msg!{
CHAN: CLIENT_BREAK,
CLIENT_ID: client.id,
LABEL: client.label.clone(),
ATTR: client.wire.attr().clone()
};
self.relay_root_message(hook, token, CLIENT_BREAK, event_message);
}
Ok(())
}
pub(crate) fn recv_message(
&mut self,
epoll: &Epoll,
hook: &impl Hook,
socket_id: &MessageId,
token: usize,
mut message: Message
) -> Result<()> {
self.recv_messages.set(self.recv_messages.get() + 1);
let success = hook.recv(&self.clients[token], &mut message);
if !success {
ErrorCode::RefuseReceiveMessage.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
let chan = match message.get_str(CHAN) {
Ok(chan) => chan,
Err(_) => {
ErrorCode::CannotGetChanField.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
};
if chan.starts_with('_') {
match chan {
AUTH => self.auth(hook, socket_id, token, message),
ATTACH => self.attach(hook, token, message),
DETACH => self.detach(hook, token, message),
PING => self.ping(hook, token, message),
MINE => self.mine(hook, token, message),
QUERY => self.query(hook, token, message),
CUSTOM => self.custom(hook, token, message),
CLIENT_KILL => self.kill(epoll, hook, token, message)?,
_ => {
ErrorCode::UnsupportedChan.insert(&mut message);
self.send_message(hook, token, message);
}
}
} else {
self.relay_message(hook, token, chan.to_string(), message);
}
Ok(())
}
pub(crate) fn send_message(
&self,
hook: &impl Hook,
token: usize,
mut message: Message
) {
self.send_messages.set(self.send_messages.get() + 1);
if let Some(client) = self.clients.get(token) {
let success = hook.send(client, &mut message);
if success {
let _ = client.send(message);
}
}
}
pub(crate) fn relay_root_message(
&self,
hook: &impl Hook,
token: usize,
chan: &str, message: Message
) {
if let Some(tokens) = self.chans.get(chan) {
for other_token in tokens {
if token == *other_token {
continue;
}
if let Some(client) = self.clients.get(*other_token) {
let mut message = message.clone();
let success = hook.send(&client, &mut message);
if success {
self.send_message(hook, client.token, message);
}
}
}
}
}
#[allow(clippy::cognitive_complexity)]
pub(crate) fn relay_message(
&mut self,
hook: &impl Hook,
token: usize,
chan: String,
mut message: Message
) {
if !self.clients[token].auth {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
let success = hook.emit(&self.clients[token], &mut message);
if !success {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
let reply_message = if let Some(ack) = message.get(ACK) {
let mut reply_message = msg!{
CHAN: &chan,
ACK: ack.clone()
};
if let Ok(message_id) = message.get_message_id(ID) {
reply_message.insert(ID, message_id);
}
ErrorCode::OK.insert(&mut reply_message);
message.remove(ACK);
Some(reply_message)
} else {
None
};
let mut to_ids = vec![];
if let Some(to) = message.get(TO).cloned() {
if let Some(to_id) = to.as_message_id() {
if !self.client_ids.contains_key(to_id) {
ErrorCode::TargetClientIdNotExist.insert(&mut message);
message.insert(CLIENT_ID, to_id);
self.send_message(hook, token, message);
return
}
to_ids.push(to_id.clone());
} else if let Some(to_array) = to.as_array() {
let mut not_exist_ids = vec![];
for to in to_array {
if let Some(to_id) = to.as_message_id() {
if self.client_ids.contains_key(to_id) {
to_ids.push(to_id.clone());
} else {
not_exist_ids.push(to_id.clone());
}
} else {
ErrorCode::InvalidToFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
if !not_exist_ids.is_empty() {
ErrorCode::TargetClientIdNotExist.insert(&mut message);
message.insert(CLIENT_ID, not_exist_ids);
self.send_message(hook, token, message);
return
}
} else {
ErrorCode::InvalidToFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
let mut labels = HashSet::new();
if let Some(label) = message.get(LABEL) {
if let Some(label) = label.as_str() {
labels.insert(label.to_string());
} else if let Some(label_array) = label.as_array() {
for v in label_array {
if let Some(v) = v.as_str() {
labels.insert(v.to_string());
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
message.insert(FROM, &self.clients[token].id.clone());
macro_rules! send {
($self: ident, $hook: ident, $client: ident, $message: ident) => {
let success = $hook.push(&$client, &mut $message);
if success {
$self.send_message($hook, $client.token, $message.clone());
let event_message = msg!{
CHAN: CLIENT_RECV,
VALUE: $message.clone(),
TO: $client.id.clone()
};
let id = $client.token;
$self.relay_root_message($hook, id, CLIENT_RECV, event_message);
}
};
}
let mut no_consumers = true;
if !to_ids.is_empty() {
no_consumers = false;
if message.get_bool(SHARE).ok().unwrap_or(false) {
if to_ids.len() == 1 {
if let Some(client_id) = self.client_ids.get(&to_ids[0]) {
if let Some(client) = self.clients.get(*client_id) {
send!(self, hook, client, message);
}
}
} else if let Some(to) = to_ids.choose(&mut self.rand) {
if let Some(client_id) = self.client_ids.get(to) {
if let Some(client) = self.clients.get(*client_id) {
send!(self, hook, client, message);
}
}
}
} else {
for to in &to_ids {
if let Some(client_id) = self.client_ids.get(to) {
if let Some(client) = self.clients.get(*client_id) {
send!(self, hook, client, message);
}
}
}
}
} else if message.get_bool(SHARE).ok().unwrap_or(false) {
let mut array: Vec<usize> = Vec::new();
if let Some(ids) = self.chans.get(&chan) {
for client_id in ids.iter() {
if let Some(client) = self.clients.get(*client_id) {
if !labels.is_empty() {
let client_labels = client.chans.get(&chan).expect("It shouldn't be executed here!");
if (client_labels & &labels).is_empty() {
continue;
}
}
array.push(*client_id);
}
}
}
if !array.is_empty() {
no_consumers = false;
if array.len() == 1 {
if let Some(client) = self.clients.get(array[0]) {
send!(self, hook, client, message);
}
} else if let Some(id) = array.choose(&mut self.rand) {
if let Some(client) = self.clients.get(*id) {
send!(self, hook, client, message);
}
}
}
} else if let Some(ids) = self.chans.get(&chan) {
for client_id in ids.iter() {
if let Some(client) = self.clients.get(*client_id) {
if !labels.is_empty() {
let client_labels = client.chans.get(&chan).expect("It shouldn't be executed here!");
if !client_labels.iter().any(|l| labels.contains(l)) {
continue
}
}
no_consumers = false;
send!(self, hook, client, message);
}
}
}
if no_consumers {
message.remove(FROM);
ErrorCode::NoConsumers.insert(&mut message);
self.send_message(hook, token, message);
return
}
let event_message = msg!{
CHAN: CLIENT_SEND,
VALUE: message
};
self.relay_root_message(hook, token, CLIENT_SEND, event_message);
if let Some(reply_message) = reply_message {
self.send_message(hook, token, reply_message);
}
}
pub(crate) fn auth(
&mut self,
hook: &impl Hook,
socket_id: &MessageId,
token: usize,
mut message: Message
) {
let mut client = &mut self.clients[token];
struct TempSession {
pub auth: bool,
pub root: bool,
pub id: MessageId,
pub label: Message
}
let temp_session = TempSession {
auth: client.auth,
root: client.root,
id: client.id.clone(),
label: client.label.clone()
};
if let Some(s) = message.get(ROOT) {
if let Some(s) = s.as_bool() {
client.root = s;
} else {
ErrorCode::InvalidRootFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
if let Some(label) = message.get(LABEL) {
if let Some(label) = label.as_message() {
client.label = label.clone();
} else {
client.root = temp_session.root;
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
if let Some(client_id) = message.get(CLIENT_ID) {
if let Some(client_id) = client_id.as_message_id() {
if let Some(other_token) = self.client_ids.get(client_id) {
if *other_token != token {
client.root = temp_session.root;
client.label = temp_session.label;
ErrorCode::DuplicateClientId.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
self.client_ids.remove(&client.id);
self.client_ids.insert(client_id.clone(), token);
client.id = client_id.clone();
} else {
client.root = temp_session.root;
client.label = temp_session.label;
ErrorCode::InvalidClientIdFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
} else {
message.insert(CLIENT_ID, client.id.clone());
self.client_ids.insert(client.id.clone(), token);
}
let success = hook.auth(&client, &mut message);
if !success {
self.client_ids.remove(&client.id);
if temp_session.auth {
self.client_ids.insert(temp_session.id.clone(), token);
}
client.auth = temp_session.auth;
client.root = temp_session.root;
client.id = temp_session.id;
client.label = temp_session.label;
ErrorCode::AuthenticationFailed.insert(&mut message);
self.send_message(hook, token, message);
return
}
client.auth = true;
if !client.label.is_empty() {
message.insert(LABEL, client.label.clone());
}
message.insert(SOCKET_ID, socket_id.clone());
ErrorCode::OK.insert(&mut message);
let event_message = msg!{
CHAN: CLIENT_READY,
ROOT: client.root,
CLIENT_ID: client.id.clone(),
LABEL: client.label.clone(),
ATTR: client.wire.attr().clone()
};
self.send_message(hook, token, message);
self.relay_root_message(hook, token, CLIENT_READY, event_message);
}
pub(crate) fn attach(
&mut self,
hook: &impl Hook,
token: usize,
mut message: Message
) {
if !self.clients[token].auth {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
if let Ok(chan) = message.get_str(VALUE).map(ToOwned::to_owned) {
match chan.as_str() {
CLIENT_READY | CLIENT_BREAK | CLIENT_ATTACH | CLIENT_DETACH | CLIENT_SEND | CLIENT_RECV => {
if !self.clients[token].root {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
_ => ()
}
let mut labels = HashSet::new();
if let Some(label) = message.get(LABEL) {
if let Some(label) = label.as_str() {
labels.insert(label.to_string());
} else if let Some(label_array) = label.as_array() {
for v in label_array {
if let Some(v) = v.as_str() {
labels.insert(v.to_string());
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
let success = hook.attach(&self.clients[token], &mut message, &chan, &labels);
if !success {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
let mut event_message = msg!{
CHAN: CLIENT_ATTACH,
VALUE: &chan,
CLIENT_ID: self.clients[token].id.clone()
};
if let Some(label) = message.get(LABEL) {
event_message.insert(LABEL, label.clone());
}
let ids = self.chans.entry(chan.to_owned()).or_insert_with(HashSet::new);
ids.insert(token);
{
let client = &mut self.clients[token];
let set = client.chans.entry(chan).or_insert_with(HashSet::new);
set.extend(labels);
}
self.relay_root_message(hook, token, CLIENT_ATTACH, event_message);
ErrorCode::OK.insert(&mut message);
} else {
ErrorCode::CannotGetValueField.insert(&mut message);
}
self.send_message(hook, token, message);
}
pub(crate) fn detach(
&mut self,
hook: &impl Hook,
token: usize,
mut message: Message
) {
if !self.clients[token].auth {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
if let Ok(chan) = message.get_str(VALUE).map(ToOwned::to_owned) {
let mut labels = HashSet::new();
if let Some(label) = message.get(LABEL) {
if let Some(label) = label.as_str() {
labels.insert(label.to_string());
} else if let Some(label_array) = label.as_array() {
for v in label_array {
if let Some(v) = v.as_str() {
labels.insert(v.to_string());
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
} else {
ErrorCode::InvalidLabelFieldType.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
let success = hook.detach(&self.clients[token], &mut message, &chan, &labels);
if !success {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
let mut event_message = msg!{
CHAN: CLIENT_DETACH,
VALUE: &chan,
CLIENT_ID: self.clients[token].id.clone()
};
if let Some(label) = message.get(LABEL) {
event_message.insert(LABEL, label.clone());
}
{
let client = &mut self.clients[token];
if labels.is_empty() {
client.chans.remove(&chan);
if let Some(ids) = self.chans.get_mut(&chan) {
ids.remove(&token);
if ids.is_empty() {
self.chans.remove(&chan);
}
}
} else if let Some(set) = client.chans.get_mut(&chan) {
*set = set.iter().filter(|label| !labels.contains(*label)).map(|s| s.to_string()).collect();
}
}
self.relay_root_message(hook, token, CLIENT_DETACH, event_message);
ErrorCode::OK.insert(&mut message);
} else {
ErrorCode::CannotGetValueField.insert(&mut message);
}
self.send_message(hook, token, message);
}
pub(crate) fn ping(&mut self, hook: &impl Hook, token: usize, mut message: Message) {
hook.ping(&self.clients[token], &mut message);
ErrorCode::OK.insert(&mut message);
self.send_message(hook, token, message);
}
pub(crate) fn mine(&self, hook: &impl Hook, token: usize, mut message: Message) {
if let Some(client) = self.clients.get(token) {
let mut chans = Message::new();
for (chan, labels) in &client.chans {
let labels: Vec<&String> = labels.iter().collect();
chans.insert(chan, labels);
}
let client = msg!{
AUTH: client.auth,
ROOT: client.root,
CHANS: chans,
CLIENT_ID: client.id.clone(),
LABEL: client.label.clone(),
ATTR: client.wire.attr().clone(),
SEND_MESSAGES: client.send_messages.get() as u64,
RECV_MESSAGES: client.recv_messages.get() as u64
};
message.insert(VALUE, client);
}
ErrorCode::OK.insert(&mut message);
self.send_message(hook, token, message);
}
pub(crate) fn query(&self, hook: &impl Hook, token: usize, mut message: Message) {
{
let client = &self.clients[token];
if !client.auth || !client.root {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
hook.query(&self, token, &mut message);
self.send_message(hook, token, message);
}
pub(crate) fn custom(&self, hook: &impl Hook, token: usize, mut message: Message) {
{
let client = &self.clients[token];
if !client.auth {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return
}
}
hook.custom(&self, token, &mut message);
self.send_message(hook, token, message);
}
pub(crate) fn kill(
&mut self,
epoll: &Epoll,
hook: &impl Hook,
token: usize,
mut message: Message
) -> Result<()> {
{
let client = &self.clients[token];
if !client.auth || !client.root {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
}
let success = hook.kill(&self.clients[token], &mut message);
if !success {
ErrorCode::Unauthorized.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
let remove_token;
if let Some(client_id) = message.get(CLIENT_ID) {
if let Some(client_id) = client_id.as_message_id() {
if let Some(other_token) = self.client_ids.get(client_id).cloned() {
remove_token = Some(other_token);
} else {
ErrorCode::TargetClientIdNotExist.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
} else {
ErrorCode::InvalidClientIdFieldType.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
} else {
ErrorCode::CannotGetClientIdField.insert(&mut message);
self.send_message(hook, token, message);
return Ok(())
}
ErrorCode::OK.insert(&mut message);
self.send_message(hook, token, message);
if let Some(remove_token) = remove_token {
self.del_client(epoll, hook, remove_token)?;
}
Ok(())
}
}
impl Client {
pub fn new(token: usize, wire: Wire<Message>) -> Self {
Self {
token,
id: MessageId::new(),
label: Message::new(),
auth: false,
root: false,
chans: HashMap::new(),
send_messages: Cell::new(0),
recv_messages: Cell::new(0),
wire
}
}
pub fn send(&self, message: Message) -> result::Result<(), SendError<Message>> {
self.wire.send(message).map(|m| {
self.send_messages.set(self.send_messages.get() + 1);
m
})
}
pub fn recv(&self) -> result::Result<Message, RecvError> {
self.wire.recv().map(|m| {
self.recv_messages.set(self.recv_messages.get() + 1);
m
})
}
}