use std::sync::{Arc, RwLock};
use std::sync::atomic::{AtomicI64, Ordering};
use log::*;
use chashmap::CHashMap;
use chrono::Utc;
use tokio::io::WriteHalf;
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio::prelude::task::Task;
use crate::errors::*;
use crate::init::CONFIG;
use crate::init::SERVER;
use crate::message::response::*;
use crate::message::stomp_message::{StompMessage, Header};
use crate::parser::ParserState;
use crate::session::mq::{Mq, SessionMessage};
use crate::session::reader::{Reader, ReadKiller};
use crate::session::subscription::Subscription;
use crate::session::writer::Writer;
use crate::web_socket::ws_ports::{ws_is_trusted_port, ws_is_web_port};
use std::ops::Deref;
use crate::downstream::DownstreamConnector;
const READ_TIMEOUT_MARGIN: i64 = 15000;
pub const FLAG_WEB_SOCKETS: u64 = 1;
pub const FLAG_WEB: u64 = 2;
pub const FLAG_ADMIN: u64 = 4;
pub const FLAG_DOWNSTREAM: u64 = 8;
pub struct StompSession {
id: usize,
pub user: Option<String>,
flags: u64,
mq: Arc<RwLock<Mq>>,
write_half: Option<Arc<RwLock<WriteHalf<TcpStream>>>>,
read_killer: Option<Arc<RwLock<ReadKiller>>>,
pub(crate) downstream_connector: Option<Arc<RwLock<DownstreamConnector>>>,
timeout_task: Option<Task>,
pub heart_beat_read: u32,
pub heart_beat_write: u32,
last_read: AtomicI64,
last_write: AtomicI64,
shutdown: bool,
subscriptions: Vec<Arc<RwLock<Subscription>>>,
pending_acks: CHashMap<usize, usize>,
}
impl StompSession {
pub fn new() -> StompSession {
let mut session = StompSession {
id: SERVER.new_session(),
user: None,
flags: 0,
mq: Arc::new(RwLock::new(Mq::new())),
write_half: None,
read_killer: None,
downstream_connector: None,
timeout_task: None,
heart_beat_read: 0,
heart_beat_write: 0,
last_read: AtomicI64::new(Utc::now().timestamp_millis()),
last_write: AtomicI64::new(Utc::now().timestamp_millis()),
shutdown: false,
subscriptions: vec!(),
pending_acks: CHashMap::new(),
};
session.set_heart_beat_defaults();
session
}
pub fn id(&self) -> usize {
self.id
}
pub fn user(&self) -> &str {
if let Some(user) = &self.user {
return user.as_str();
}
""
}
pub fn get_flag(&self, flag_mask: u64) -> bool {
self.flags & flag_mask == flag_mask
}
pub fn set_flag(&mut self, flag_mask: u64) {
self.flags |= flag_mask
}
pub(crate) fn split(&mut self, sock: TcpStream, session: Arc<RwLock<StompSession>>) -> (Reader, Writer) {
{
if ! self.get_flag(FLAG_DOWNSTREAM) {
let port = sock.local_addr().unwrap().port();
if ws_is_trusted_port(port) {
self.set_flag(FLAG_ADMIN);
}
if ws_is_web_port(port) {
self.set_flag(FLAG_WEB);
}
}
}
let (read_half, write_half) = sock.split();
let write_half_lock = Arc::new(RwLock::new(write_half));
self.set_write_half(write_half_lock.clone());
(
Reader::new(session.clone(), self.id(), read_half),
Writer::new(session.clone(), self.id(), write_half_lock)
)
}
fn set_write_half(&mut self, write_half: Arc<RwLock<WriteHalf<TcpStream>>>) {
self.write_half = Some(write_half);
}
pub(crate) fn set_read_killer(&mut self, read_killer: Arc<RwLock<ReadKiller>>) {
self.read_killer = Some(read_killer);
}
pub(crate) fn poll_mq(&self) -> Result<Async<()>, ()>{
match self.mq.read().unwrap().poll() {
Ok(Async::Ready(())) => Ok(Async::Ready(())),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => Err(()),
}
}
pub(crate) fn set_mq_task(&self, task: Task) {
self.mq.write().unwrap().set_task(task);
}
pub(crate) fn pop(&self) -> Option<SessionMessage> {
self.mq.write().unwrap().next()
}
pub fn len(&self) -> usize {
self.mq.read().unwrap().len()
}
pub fn send_client_error(&self, err: ClientError) -> bool {
return if let Ok(mut mq) = self.mq.try_write() {
mq.push(get_response_error_client(err), vec!());
true
} else {
false
}
}
pub fn send_client_error_fatal(&mut self, err: ClientError) -> bool {
let sent;
if let Ok(mut mq) = self.mq.try_write() {
mq.push(get_response_error_client(err), vec!());
sent = true
} else {
sent = false;
}
self.shutdown();
sent
}
pub fn send_message(&self, message: Arc<StompMessage>) -> bool {
return if let Ok(mut mq) = self.mq.try_write() {
mq.push(message, vec!());
true
} else {
false
}
}
pub fn send_message_w_hdrs(&self, message: Arc<StompMessage>, headers: Vec<Header>) -> bool {
return if let Ok(mut mq) = self.mq.try_write() {
mq.push(message, headers);
true
} else {
false
}
}
pub fn count_subscription(&self) -> usize {
self.subscriptions.len()
}
pub(crate) fn add_subscription(&mut self, sub: Arc<RwLock<Subscription>>) {
self.subscriptions.push(sub);
debug!("subscription count {}", self.count_subscription());
}
pub(crate) fn remove_subscription(&mut self, id: u64) {
self.subscriptions.retain(|sub| {
sub.read().unwrap().subscription_id() != id
});
}
pub fn unsubscribe(&mut self, id: u64) {
for sub in &self.subscriptions {
let sub = sub.read().unwrap();
if sub.subscription_id() == id {
sub.unsubscribe();
break;
}
}
self.remove_subscription(id);
}
pub fn unsubscribe_all(&mut self) {
for sub in &self.subscriptions {
{
let sub = sub.read().unwrap();
sub.unsubscribe();
}
}
self.subscriptions.clear();
}
pub(crate) fn pending_ack(&self, msg_id: usize, destination_id: usize) {
self.pending_acks.insert(msg_id, destination_id);
}
fn get_destination_id(&self, msg_id: usize) -> Option<usize> {
if let Some(entry) = self.pending_acks.get(&msg_id) {
return Some(*entry.deref());
}
return None;
}
pub fn ack(&self, ack_nack: bool, msg_id: &String) -> bool {
if let Ok(msg_id) = msg_id.parse::<usize>() {
if let Some(destination_id) = self.get_destination_id(msg_id) {
if let Some(destination) = SERVER.find_destination_by_id(&destination_id) {
let destination = destination.read().unwrap();
if ! destination.auto_ack() {
debug!("acking msg @ dest={} msg_id={}", destination.name(), msg_id);
if ack_nack {
destination.ack(msg_id);
return true;
} else {
destination.nack(msg_id);
return true;
}
}
}
}
}
return false;
}
pub fn set_heart_beat_defaults(&mut self) {
self.heart_beat_read = CONFIG.heart_beat_read;
self.heart_beat_write = CONFIG.heart_beat_write_min + ( (CONFIG.heart_beat_write_max - CONFIG.heart_beat_write_min) / 2);
}
pub fn set_login(&mut self, name: &str) -> bool {
self.user = Some(String::from(name));
info!("login usr={}", self.user());
true
}
pub fn is_authenticated(&self) -> bool {
match self.user {
Some(_) => true,
_ => false,
}
}
pub(crate) fn ws_upgrade(&mut self) {
self.set_flag(FLAG_WEB_SOCKETS);
self.set_flag(FLAG_WEB);
}
pub fn shutdown_pending(&self) -> bool {
self.shutdown
}
pub fn shutdown(&mut self) {
info!("session shutdown usr={} id={}", self.user(), self.id);
self.shutdown = true;
self.unsubscribe_all();
match self.mq.write() {
Ok(mut mq) => {
match mq.drain() {
Ok(_) => {
debug!("empty mq at drain id={}", self.id);
if let Some(write_half) = &self.write_half {
if let Ok(mut write_half) = write_half.try_write() {
write_half.shutdown().ok();
}
}
self.write_half = None;
},
Err(remaining) => {
debug!("session will shutdown when messages are drained: {} id={}", remaining, self.id);
}
}
},
Err(_) => warn!("mq.drain() lock failed"),
}
if let Some(read_killer) = &self.read_killer {
read_killer.write().unwrap().kill();
self.read_killer = None;
} else {
debug!("no read_killer?? '{}'", self.user());
}
if let Some(timeout_task) = &self.timeout_task {
timeout_task.notify();
}
debug!("session shutdown complete id={}", self.id);
}
pub fn kill(&mut self) {
info!("session kill usr={}", self.user());
self.shutdown = true;
self.unsubscribe_all();
match self.mq.write() {
Ok(mut mq) => {
mq.close();
if let Some(write_half) = &self.write_half {
if let Ok(mut write_half) = write_half.try_write() {
write_half.shutdown().ok();
}
}
self.write_half = None;
},
Err(_) => warn!("mq.close() lock failed"),
}
if let Some(read_killer) = &self.read_killer {
read_killer.write().unwrap().kill();
self.read_killer = None;
} else {
debug!("no read_killer on kill");
}
if let Some(timeout_task) = &self.timeout_task {
timeout_task.notify();
}
}
pub(crate) fn write_terminated(&mut self) {
if let Some(write_half) = &self.write_half {
if let Ok(mut write_half) = write_half.try_write() {
write_half.shutdown().ok();
}
}
self.write_half = None;
}
pub(crate) fn read_terminated(&mut self) {
if let Some(read_killer) = &self.read_killer {
read_killer.write().unwrap().kill();
self.read_killer = None;
if let Some(shutdown_listener) = &self.downstream_connector {
shutdown_listener.read().unwrap().session_shutdown(shutdown_listener.clone(), self.id, self.flags);
}
self.downstream_connector = None;
}
}
pub(crate) fn read_error(&mut self, ps: ParserState) {
if let Ok(mut mq) = self.mq.try_write() {
if ps == ParserState::BodyFlup {
mq.push(get_response_error_client(ClientError::BodyFlup), vec!());
} else if ps == ParserState::HdrFlup {
mq.push(get_response_error_client(ClientError::HdrFlup), vec!());
} else {
mq.push(get_response_error_client(ClientError::Syntax), vec!());
}
}
self.shutdown();
}
pub(crate) fn write_error(&mut self) {
self.kill();
}
pub(crate) fn timeout(&mut self) -> Result<(), tokio::timer::Error> {
debug!("timeout polled");
if let None = self.timeout_task {
self.timeout_task = Some(futures::task::current());
}
if self.shutdown {
debug!("timeout ending");
self.timeout_task = None;
return Err(tokio::timer::Error::shutdown());
}
let last_read = self.last_read.load(Ordering::Relaxed);
if last_read + (self.heart_beat_read as i64) + READ_TIMEOUT_MARGIN < Utc::now().timestamp_millis() {
warn!("read timeout at {}", self.heart_beat_read);
self.shutdown();
return Ok(());
}
if self.should_heart_beat() {
if let Ok(mq) = self.mq.try_read() {
mq.notify();
}
}
return Ok(());
}
pub(crate) fn should_heart_beat(&self) -> bool {
let last_write = self.last_write.load(Ordering::Relaxed);
return last_write + (self.heart_beat_write as i64) < Utc::now().timestamp_millis() ;
}
pub fn next_timeout(&self) -> i64 {
let next_write = self.last_write.load(Ordering::Relaxed) + (self.heart_beat_write as i64);
let next_read = self.last_read.load(Ordering::Relaxed) + (self.heart_beat_read as i64) + READ_TIMEOUT_MARGIN;
return match next_read < next_write {
true => next_read - Utc::now().timestamp_millis(),
false => next_write - Utc::now().timestamp_millis(),
}
}
pub(crate) fn read_something(&self) {
self.last_read.store(Utc::now().timestamp_millis(), Ordering::Relaxed);
}
pub(crate) fn wrote_something(&self) {
self.last_write.store(Utc::now().timestamp_millis(), Ordering::Relaxed);
}
}
impl Drop for StompSession {
fn drop(&mut self) {
debug!("dropped session for '{}' id={}", self.user(), self.id);
SERVER.drop_session();
}
}
pub trait ShutdownListener {
fn session_shutdown(&self, id: usize, flags: u64);
}
#[cfg(test)]
mod tests {
use super::*;
use std::{thread, time};
#[test]
fn test_new() {
let mut session = StompSession::new();
assert_eq!(120000, session.heart_beat_read);
assert_eq!(120000, session.heart_beat_write);
session.set_flag(FLAG_ADMIN);
assert_eq!(true, session.get_flag(FLAG_ADMIN));
assert_eq!(false, session.get_flag(FLAG_WEB));
assert_eq!(false, session.get_flag(FLAG_WEB_SOCKETS));
assert_eq!(false, session.should_heart_beat());
session.heart_beat_write = 10;
thread::sleep(time::Duration::from_millis(11));
assert_eq!(true, session.should_heart_beat());
session.wrote_something();
assert_eq!(false, session.should_heart_beat());
thread::sleep(time::Duration::from_millis(11));
assert_eq!(true, session.should_heart_beat());
session.wrote_something();
assert_eq!(false, session.should_heart_beat());
println!("next_timeout = {}" , session.next_timeout());
thread::sleep(time::Duration::from_millis(5));
println!("next_timeout = {}" , session.next_timeout());
}
}