1use std::sync::{Arc, Weak};
18use std::thread;
19use std::time::Duration;
20
21use crossbeam_deque as deque;
22use fnv::FnvHashMap;
23use log::{trace, debug};
24use num_cpus;
25use parking_lot::{Mutex, RwLock};
26use slab::Slab;
27use time::Duration as TimeDuration;
28use timer::{Guard as TimerGuard, Timer};
29
30use crate::{IoError, IoHandler};
31
32pub type TimerToken = usize;
34pub type HandlerId = usize;
36
37pub const TOKENS_PER_HANDLER: usize = 16384;
39const MAX_HANDLERS: usize = 8;
40
41pub struct IoContext<Message> where Message: Send + Sync + 'static {
43 handler: HandlerId,
44 shared: Arc<Shared<Message>>,
45}
46
47impl<Message> IoContext<Message> where Message: Send + Sync + 'static {
48 pub fn register_timer(&self, token: TimerToken, delay: Duration) -> Result<(), IoError> {
50 let channel = self.channel();
51
52 let msg = WorkTask::TimerTrigger {
53 handler_id: self.handler,
54 token,
55 };
56
57 let delay = TimeDuration::from_std(delay)
58 .map_err(|e| ::std::io::Error::new(::std::io::ErrorKind::Other, e))?;
59 let guard = self.shared.timer.lock().schedule_repeating(delay, move || {
60 channel.send_raw(msg.clone());
61 });
62
63 self.shared.timers.lock().insert(token, guard);
64
65 Ok(())
66 }
67
68 pub fn register_timer_once(&self, token: TimerToken, delay: Duration) -> Result<(), IoError> {
70 let channel = self.channel();
71
72 let msg = WorkTask::TimerTrigger {
73 handler_id: self.handler,
74 token,
75 };
76
77 let delay = TimeDuration::from_std(delay)
78 .map_err(|e| ::std::io::Error::new(::std::io::ErrorKind::Other, e))?;
79 let guard = self.shared.timer.lock().schedule_with_delay(delay, move || {
80 channel.send_raw(msg.clone());
81 });
82
83 self.shared.timers.lock().insert(token, guard);
84
85 Ok(())
86 }
87
88 pub fn clear_timer(&self, token: TimerToken) -> Result<(), IoError> {
90 self.shared.timers.lock().remove(&token);
91 Ok(())
92 }
93
94 pub fn message(&self, message: Message) -> Result<(), IoError> {
96 if let Some(ref channel) = *self.shared.channel.lock() {
97 channel.push(WorkTask::UserMessage(Arc::new(message)));
98 }
99 for thread in self.shared.threads.read().iter() {
100 thread.unpark();
101 }
102
103 Ok(())
104 }
105
106 pub fn channel(&self) -> IoChannel<Message> {
108 IoChannel { shared: Arc::downgrade(&self.shared) }
109 }
110
111 pub fn unregister_handler(&self) -> Result<(), IoError> {
113 self.shared.handlers.write().remove(self.handler);
114 Ok(())
115 }
116}
117
118pub struct IoChannel<Message> where Message: Send + Sync + 'static {
121 shared: Weak<Shared<Message>>,
122}
123
124impl<Message> Clone for IoChannel<Message> where Message: Send + Sync + 'static {
125 fn clone(&self) -> IoChannel<Message> {
126 IoChannel {
127 shared: self.shared.clone(),
128 }
129 }
130}
131
132impl<Message> IoChannel<Message> where Message: Send + Sync + 'static {
133 pub fn send(&self, message: Message) -> Result<(), IoError> {
135 if let Some(shared) = self.shared.upgrade() {
136 match *shared.channel.lock() {
137 Some(ref channel) => channel.push(WorkTask::UserMessage(Arc::new(message))),
138 None => self.send_sync(message)?
139 };
140
141 for thread in shared.threads.read().iter() {
142 thread.unpark();
143 }
144 }
145
146 Ok(())
147 }
148
149 pub fn send_sync(&self, message: Message) -> Result<(), IoError> {
151 if let Some(shared) = self.shared.upgrade() {
152 for id in 0 .. MAX_HANDLERS {
153 if let Some(h) = shared.handlers.read().get(id) {
154 let handler = h.clone();
155 let ctxt = IoContext { handler: id, shared: shared.clone() };
156 handler.message(&ctxt, &message);
157 }
158 }
159 }
160
161 Ok(())
162 }
163
164 fn send_raw(&self, message: WorkTask<Message>) {
166 if let Some(shared) = self.shared.upgrade() {
167 if let Some(ref channel) = *shared.channel.lock() {
168 channel.push(message);
169 }
170
171 for thread in shared.threads.read().iter() {
172 thread.unpark();
173 }
174 }
175 }
176
177 pub fn disconnected() -> IoChannel<Message> {
179 IoChannel {
180 shared: Weak::default(),
181 }
182 }
183}
184
185pub struct IoService<Message> where Message: Send + Sync + 'static {
188 thread_joins: Mutex<Vec<thread::JoinHandle<()>>>,
189 shared: Arc<Shared<Message>>,
190}
191
192struct Shared<Message> where Message: Send + Sync + 'static {
194 handlers: RwLock<Slab<Arc<dyn IoHandler<Message>>>>,
196 threads: RwLock<Vec<thread::Thread>>,
198 timer: Mutex<Timer>,
200 timers: Mutex<FnvHashMap<TimerToken, TimerGuard>>,
203 channel: Mutex<Option<deque::Worker<WorkTask<Message>>>>,
205}
206
207enum WorkTask<Message> where Message: Send + Sized {
209 Shutdown,
210 TimerTrigger {
211 handler_id: HandlerId,
212 token: TimerToken,
213 },
214 UserMessage(Arc<Message>)
215}
216
217impl<Message> Clone for WorkTask<Message> where Message: Send + Sized {
218 fn clone(&self) -> WorkTask<Message> {
219 match *self {
220 WorkTask::Shutdown => WorkTask::Shutdown,
221 WorkTask::TimerTrigger { handler_id, token } => WorkTask::TimerTrigger { handler_id, token },
222 WorkTask::UserMessage(ref msg) => WorkTask::UserMessage(msg.clone()),
223 }
224 }
225}
226
227impl<Message> IoService<Message> where Message: Send + Sync + 'static {
228 pub fn start() -> Result<IoService<Message>, IoError> {
230 let (tx, rx) = deque::fifo();
231
232 let shared = Arc::new(Shared {
233 handlers: RwLock::new(Slab::with_capacity(MAX_HANDLERS)),
234 threads: RwLock::new(Vec::new()),
235 timer: Mutex::new(Timer::new()),
236 timers: Mutex::new(FnvHashMap::default()),
237 channel: Mutex::new(Some(tx)),
238 });
239
240 let thread_joins = (0 .. num_cpus::get()).map(|_| {
241 let rx = rx.clone();
242 let shared = shared.clone();
243 thread::spawn(move || {
244 do_work(&shared, rx)
245 })
246 }).collect::<Vec<_>>();
247
248 *shared.threads.write() = thread_joins.iter().map(|t| t.thread().clone()).collect();
249
250 Ok(IoService {
251 thread_joins: Mutex::new(thread_joins),
252 shared,
253 })
254 }
255
256 pub fn stop(&mut self) {
258 trace!(target: "shutdown", "[IoService] Closing...");
259 self.shared.handlers.write().clear();
262 let channel = self.shared.channel.lock().take();
263 let mut thread_joins = self.thread_joins.lock();
264 if let Some(channel) = channel {
265 for _ in 0 .. thread_joins.len() {
266 channel.push(WorkTask::Shutdown);
267 }
268 }
269 for thread in thread_joins.drain(..) {
270 thread.thread().unpark();
271 thread.join().unwrap_or_else(|e| {
272 debug!(target: "shutdown", "Error joining IO service worker thread: {:?}", e);
273 });
274 }
275 trace!(target: "shutdown", "[IoService] Closed.");
276 }
277
278 pub fn register_handler(&self, handler: Arc<dyn IoHandler<Message>+Send>) -> Result<(), IoError> {
280 let id = self.shared.handlers.write().insert(handler.clone());
281 assert!(id <= MAX_HANDLERS, "Too many handlers registered");
282 let ctxt = IoContext { handler: id, shared: self.shared.clone() };
283 handler.initialize(&ctxt);
284 Ok(())
285 }
286
287 pub fn send_message(&self, message: Message) -> Result<(), IoError> {
289 if let Some(ref channel) = *self.shared.channel.lock() {
290 channel.push(WorkTask::UserMessage(Arc::new(message)));
291 }
292 for thread in self.shared.threads.read().iter() {
293 thread.unpark();
294 }
295 Ok(())
296 }
297
298 #[inline]
300 pub fn channel(&self) -> IoChannel<Message> {
301 IoChannel {
302 shared: Arc::downgrade(&self.shared)
303 }
304 }
305}
306
307impl<Message> Drop for IoService<Message> where Message: Send + Sync {
308 fn drop(&mut self) {
309 self.stop()
310 }
311}
312
313fn do_work<Message>(shared: &Arc<Shared<Message>>, rx: deque::Stealer<WorkTask<Message>>)
314 where Message: Send + Sync + 'static
315{
316 loop {
317 match rx.steal() {
318 deque::Steal::Retry => continue,
319 deque::Steal::Empty => thread::park(),
320 deque::Steal::Data(WorkTask::Shutdown) => break,
321 deque::Steal::Data(WorkTask::UserMessage(message)) => {
322 for id in 0 .. MAX_HANDLERS {
323 if let Some(handler) = shared.handlers.read().get(id) {
324 let ctxt = IoContext { handler: id, shared: shared.clone() };
325 handler.message(&ctxt, &message);
326 }
327 }
328 },
329 deque::Steal::Data(WorkTask::TimerTrigger { handler_id, token }) => {
330 if let Some(handler) = shared.handlers.read().get(handler_id) {
331 let ctxt = IoContext { handler: handler_id, shared: shared.clone() };
332 handler.timeout(&ctxt, token);
333 }
334 },
335 }
336 }
337}