1use std::sync::{Arc, Mutex};
2use std::time;
3use std::any::Any;
4use {Context, InnerContext, ContextError};
5use futures::{Future, Poll, Async};
6use futures::task::{self, Task};
7
8#[derive(Clone)]
9pub struct WithCancel {
10 parent: Context,
11 canceled: Arc<Mutex<bool>>,
12 handle: Arc<Mutex<Option<Task>>>,
13}
14
15impl InnerContext for WithCancel {
16 fn deadline(&self) -> Option<time::Instant> {
17 None
18 }
19
20 fn value(&self) -> Option<&Any> {
21 None
22 }
23
24 fn parent(&self) -> Option<Context> {
25 self.parent.0.borrow().parent()
26 }
27}
28
29impl Future for WithCancel {
30 type Item = ();
31 type Error = ContextError;
32
33 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
34 if *self.canceled.lock().unwrap() {
35 Err(ContextError::Canceled)
36 } else {
37 self.parent.0.borrow_mut()
38 .poll()
39 .map(|r| {
40 if r == Async::NotReady {
41 let mut handle = self.handle.lock().unwrap();
44 let must_update = match *handle {
45 Some(ref task) if task.is_current() => false,
46 _ => true,
47 };
48 if must_update {
49 *handle = Some(task::park())
50 }
51 }
52 r
53 })
54 }
55 }
56}
57
58pub fn with_cancel(parent: Context) -> (Context, Box<Fn() + Send>) {
78 let canceled = Arc::new(Mutex::new(false));
79 let handle = Arc::new(Mutex::new(None));
80 let canceled_clone = canceled.clone();
81 let handle_clone = handle.clone();
82
83 let ctx = WithCancel {
84 parent: parent,
85 canceled: canceled,
86 handle: handle,
87 };
88 let cancel = Box::new(move || {
89 let mut canceled = canceled_clone.lock().unwrap();
90 *canceled = true;
91
92 if let Some(ref task) = *handle_clone.lock().unwrap() {
93 task.unpark();
94 }
95 });
96 (Context::new(ctx), cancel)
97}
98
99#[cfg(test)]
100mod test {
101 use std::time::Duration;
102 use std::thread;
103 use tokio_timer::Timer;
104 use with_cancel::with_cancel;
105 use {background, ContextError};
106 use futures::Future;
107
108 #[test]
109 fn cancel_test() {
110 let (ctx, cancel) = with_cancel(background());
111 cancel();
112
113 assert_eq!(ctx.wait().unwrap_err(), ContextError::Canceled);
114 }
115
116 #[test]
117 fn cancel_parent_test() {
118 let (parent, cancel) = with_cancel(background());
119 let (ctx, _) = with_cancel(parent);
120 cancel();
121
122 assert_eq!(ctx.wait().unwrap_err(), ContextError::Canceled);
123 }
124
125 #[test]
126 fn example_test() {
127 let timer = Timer::default();
128
129 let long_running_process = timer.sleep(Duration::from_secs(2));
130 let (ctx, cancel) = with_cancel(background());
131
132 let first = long_running_process
133 .map_err(|_| ContextError::DeadlineExceeded)
134 .select(ctx);
135
136 thread::spawn(move || {
137 thread::sleep(Duration::from_millis(100));
138 cancel();
139 });
140
141 let result = first.wait();
142 assert!(result.is_err());
143 match result {
144 Err((err, _)) => assert_eq!(err, ContextError::Canceled),
145 _ => assert!(false),
146 }
147 }
148
149 #[test]
150 fn clone_test() {
151 let (ctx, cancel) = with_cancel(background());
152 let clone = ctx.clone();
153 cancel();
154
155 assert_eq!(ctx.wait().unwrap_err(), ContextError::Canceled);
156 assert_eq!(clone.wait().unwrap_err(), ContextError::Canceled);
157 }
158}