use std::sync::{Arc, Weak};
use std::thread;
use std::time::Duration;
use crossbeam_deque as deque;
use fnv::FnvHashMap;
use log::{trace, debug};
use num_cpus;
use parking_lot::{Mutex, RwLock};
use slab::Slab;
use time::Duration as TimeDuration;
use timer::{Guard as TimerGuard, Timer};
use crate::{IoError, IoHandler};
pub type TimerToken = usize;
pub type HandlerId = usize;
pub const TOKENS_PER_HANDLER: usize = 16384;
const MAX_HANDLERS: usize = 8;
pub struct IoContext<Message> where Message: Send + Sync + 'static {
handler: HandlerId,
shared: Arc<Shared<Message>>,
}
impl<Message> IoContext<Message> where Message: Send + Sync + 'static {
pub fn register_timer(&self, token: TimerToken, delay: Duration) -> Result<(), IoError> {
let channel = self.channel();
let msg = WorkTask::TimerTrigger {
handler_id: self.handler,
token,
};
let delay = TimeDuration::from_std(delay)
.map_err(|e| ::std::io::Error::new(::std::io::ErrorKind::Other, e))?;
let guard = self.shared.timer.lock().schedule_repeating(delay, move || {
channel.send_raw(msg.clone());
});
self.shared.timers.lock().insert(token, guard);
Ok(())
}
pub fn register_timer_once(&self, token: TimerToken, delay: Duration) -> Result<(), IoError> {
let channel = self.channel();
let msg = WorkTask::TimerTrigger {
handler_id: self.handler,
token,
};
let delay = TimeDuration::from_std(delay)
.map_err(|e| ::std::io::Error::new(::std::io::ErrorKind::Other, e))?;
let guard = self.shared.timer.lock().schedule_with_delay(delay, move || {
channel.send_raw(msg.clone());
});
self.shared.timers.lock().insert(token, guard);
Ok(())
}
pub fn clear_timer(&self, token: TimerToken) -> Result<(), IoError> {
self.shared.timers.lock().remove(&token);
Ok(())
}
pub fn message(&self, message: Message) -> Result<(), IoError> {
if let Some(ref channel) = *self.shared.channel.lock() {
channel.push(WorkTask::UserMessage(Arc::new(message)));
}
for thread in self.shared.threads.read().iter() {
thread.unpark();
}
Ok(())
}
pub fn channel(&self) -> IoChannel<Message> {
IoChannel { shared: Arc::downgrade(&self.shared) }
}
pub fn unregister_handler(&self) -> Result<(), IoError> {
self.shared.handlers.write().remove(self.handler);
Ok(())
}
}
pub struct IoChannel<Message> where Message: Send + Sync + 'static {
shared: Weak<Shared<Message>>,
}
impl<Message> Clone for IoChannel<Message> where Message: Send + Sync + 'static {
fn clone(&self) -> IoChannel<Message> {
IoChannel {
shared: self.shared.clone(),
}
}
}
impl<Message> IoChannel<Message> where Message: Send + Sync + 'static {
pub fn send(&self, message: Message) -> Result<(), IoError> {
if let Some(shared) = self.shared.upgrade() {
match *shared.channel.lock() {
Some(ref channel) => channel.push(WorkTask::UserMessage(Arc::new(message))),
None => self.send_sync(message)?
};
for thread in shared.threads.read().iter() {
thread.unpark();
}
}
Ok(())
}
pub fn send_sync(&self, message: Message) -> Result<(), IoError> {
if let Some(shared) = self.shared.upgrade() {
for id in 0 .. MAX_HANDLERS {
if let Some(h) = shared.handlers.read().get(id) {
let handler = h.clone();
let ctxt = IoContext { handler: id, shared: shared.clone() };
handler.message(&ctxt, &message);
}
}
}
Ok(())
}
fn send_raw(&self, message: WorkTask<Message>) {
if let Some(shared) = self.shared.upgrade() {
if let Some(ref channel) = *shared.channel.lock() {
channel.push(message);
}
for thread in shared.threads.read().iter() {
thread.unpark();
}
}
}
pub fn disconnected() -> IoChannel<Message> {
IoChannel {
shared: Weak::default(),
}
}
}
pub struct IoService<Message> where Message: Send + Sync + 'static {
thread_joins: Mutex<Vec<thread::JoinHandle<()>>>,
shared: Arc<Shared<Message>>,
}
struct Shared<Message> where Message: Send + Sync + 'static {
handlers: RwLock<Slab<Arc<dyn IoHandler<Message>>>>,
threads: RwLock<Vec<thread::Thread>>,
timer: Mutex<Timer>,
timers: Mutex<FnvHashMap<TimerToken, TimerGuard>>,
channel: Mutex<Option<deque::Worker<WorkTask<Message>>>>,
}
enum WorkTask<Message> where Message: Send + Sized {
Shutdown,
TimerTrigger {
handler_id: HandlerId,
token: TimerToken,
},
UserMessage(Arc<Message>)
}
impl<Message> Clone for WorkTask<Message> where Message: Send + Sized {
fn clone(&self) -> WorkTask<Message> {
match *self {
WorkTask::Shutdown => WorkTask::Shutdown,
WorkTask::TimerTrigger { handler_id, token } => WorkTask::TimerTrigger { handler_id, token },
WorkTask::UserMessage(ref msg) => WorkTask::UserMessage(msg.clone()),
}
}
}
impl<Message> IoService<Message> where Message: Send + Sync + 'static {
pub fn start() -> Result<IoService<Message>, IoError> {
let (tx, rx) = deque::fifo();
let shared = Arc::new(Shared {
handlers: RwLock::new(Slab::with_capacity(MAX_HANDLERS)),
threads: RwLock::new(Vec::new()),
timer: Mutex::new(Timer::new()),
timers: Mutex::new(FnvHashMap::default()),
channel: Mutex::new(Some(tx)),
});
let thread_joins = (0 .. num_cpus::get()).map(|_| {
let rx = rx.clone();
let shared = shared.clone();
thread::spawn(move || {
do_work(&shared, rx)
})
}).collect::<Vec<_>>();
*shared.threads.write() = thread_joins.iter().map(|t| t.thread().clone()).collect();
Ok(IoService {
thread_joins: Mutex::new(thread_joins),
shared,
})
}
pub fn stop(&mut self) {
trace!(target: "shutdown", "[IoService] Closing...");
self.shared.handlers.write().clear();
let channel = self.shared.channel.lock().take();
let mut thread_joins = self.thread_joins.lock();
if let Some(channel) = channel {
for _ in 0 .. thread_joins.len() {
channel.push(WorkTask::Shutdown);
}
}
for thread in thread_joins.drain(..) {
thread.thread().unpark();
thread.join().unwrap_or_else(|e| {
debug!(target: "shutdown", "Error joining IO service worker thread: {:?}", e);
});
}
trace!(target: "shutdown", "[IoService] Closed.");
}
pub fn register_handler(&self, handler: Arc<dyn IoHandler<Message>+Send>) -> Result<(), IoError> {
let id = self.shared.handlers.write().insert(handler.clone());
assert!(id <= MAX_HANDLERS, "Too many handlers registered");
let ctxt = IoContext { handler: id, shared: self.shared.clone() };
handler.initialize(&ctxt);
Ok(())
}
pub fn send_message(&self, message: Message) -> Result<(), IoError> {
if let Some(ref channel) = *self.shared.channel.lock() {
channel.push(WorkTask::UserMessage(Arc::new(message)));
}
for thread in self.shared.threads.read().iter() {
thread.unpark();
}
Ok(())
}
#[inline]
pub fn channel(&self) -> IoChannel<Message> {
IoChannel {
shared: Arc::downgrade(&self.shared)
}
}
}
impl<Message> Drop for IoService<Message> where Message: Send + Sync {
fn drop(&mut self) {
self.stop()
}
}
fn do_work<Message>(shared: &Arc<Shared<Message>>, rx: deque::Stealer<WorkTask<Message>>)
where Message: Send + Sync + 'static
{
loop {
match rx.steal() {
deque::Steal::Retry => continue,
deque::Steal::Empty => thread::park(),
deque::Steal::Data(WorkTask::Shutdown) => break,
deque::Steal::Data(WorkTask::UserMessage(message)) => {
for id in 0 .. MAX_HANDLERS {
if let Some(handler) = shared.handlers.read().get(id) {
let ctxt = IoContext { handler: id, shared: shared.clone() };
handler.message(&ctxt, &message);
}
}
},
deque::Steal::Data(WorkTask::TimerTrigger { handler_id, token }) => {
if let Some(handler) = shared.handlers.read().get(handler_id) {
let ctxt = IoContext { handler: handler_id, shared: shared.clone() };
handler.timeout(&ctxt, token);
}
},
}
}
}