base_coroutine/event_loop/
mod.rs

1pub mod event;
2
3pub mod interest;
4
5mod selector;
6
7use crate::event_loop::event::Events;
8use crate::event_loop::interest::Interest;
9use crate::event_loop::selector::Selector;
10use crate::{Coroutine, Scheduler, UserFunc};
11use once_cell::sync::Lazy;
12use rayon::prelude::*;
13use std::collections::{HashMap, HashSet};
14use std::os::raw::c_void;
15use std::sync::atomic::{AtomicUsize, Ordering};
16use std::time::Duration;
17
18#[repr(C)]
19pub struct JoinHandle(pub &'static c_void);
20
21impl JoinHandle {
22    pub fn timeout_join(&self, dur: Duration) -> std::io::Result<usize> {
23        if self.0 as *const c_void as usize == 0 {
24            return Ok(0);
25        }
26        let timeout_time = timer_utils::get_timeout_time(dur);
27        let result = unsafe {
28            &*(self.0 as *const _ as *const Coroutine<&'static mut c_void, &'static mut c_void>)
29        };
30        while result.get_result().is_none() {
31            if timeout_time <= timer_utils::now() {
32                //timeout
33                return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout"));
34            }
35            EventLoop::round_robin_timeout_schedule(timeout_time)?;
36            if result.get_result().is_some() {
37                break;
38            }
39            let left_time = timeout_time.saturating_sub(timer_utils::now());
40            //等待事件到来
41            if let Err(e) = EventLoop::next().wait(Some(Duration::from_nanos(left_time))) {
42                match e.kind() {
43                    //maybe invoke by Monitor::signal(), just ignore this
44                    std::io::ErrorKind::Interrupted => continue,
45                    _ => return Err(e),
46                }
47            }
48        }
49        Ok(result.get_result().unwrap() as *mut c_void as usize)
50    }
51
52    pub fn join(self) -> std::io::Result<usize> {
53        if self.0 as *const c_void as usize == 0 {
54            return Ok(0);
55        }
56        let result = unsafe {
57            &*(self.0 as *const _ as *const Coroutine<&'static mut c_void, &'static mut c_void>)
58        };
59        while result.get_result().is_none() {
60            EventLoop::round_robin_schedule()?;
61            if result.get_result().is_some() {
62                break;
63            }
64            //等待事件到来
65            if let Err(e) = EventLoop::next().wait(Some(Duration::from_secs(1))) {
66                match e.kind() {
67                    //maybe invoke by Monitor::signal(), just ignore this
68                    std::io::ErrorKind::Interrupted => continue,
69                    _ => return Err(e),
70                }
71            }
72        }
73        Ok(result.get_result().unwrap() as *mut c_void as usize)
74    }
75}
76
77static mut READABLE_RECORDS: Lazy<HashSet<libc::c_int>> = Lazy::new(HashSet::new);
78
79static mut READABLE_TOKEN_RECORDS: Lazy<HashMap<libc::c_int, usize>> = Lazy::new(HashMap::new);
80
81static mut WRITABLE_RECORDS: Lazy<HashSet<libc::c_int>> = Lazy::new(HashSet::new);
82
83static mut WRITABLE_TOKEN_RECORDS: Lazy<HashMap<libc::c_int, usize>> = Lazy::new(HashMap::new);
84
85static mut INDEX: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
86
87static mut EVENT_LOOPS: Lazy<Box<[EventLoop]>> = Lazy::new(|| {
88    (0..num_cpus::get())
89        .map(|_| EventLoop::new().expect("init event loop failed!"))
90        .collect()
91});
92
93pub struct EventLoop<'a> {
94    selector: Selector,
95    scheduler: &'a mut Scheduler,
96}
97
98unsafe impl Send for EventLoop<'_> {}
99
100impl<'a> EventLoop<'a> {
101    fn new() -> std::io::Result<EventLoop<'a>> {
102        let scheduler = Box::leak(Box::new(Scheduler::new()));
103        Ok(EventLoop {
104            selector: Selector::new()?,
105            scheduler,
106        })
107    }
108
109    pub fn next() -> &'static mut EventLoop<'static> {
110        unsafe {
111            let index = INDEX.fetch_add(1, Ordering::SeqCst);
112            if index == usize::MAX {
113                INDEX.store(1, Ordering::SeqCst);
114            }
115            EVENT_LOOPS.get_mut(index % num_cpus::get()).unwrap()
116        }
117    }
118
119    fn next_scheduler() -> &'static mut Scheduler {
120        EventLoop::next().scheduler
121    }
122
123    pub fn submit(
124        f: UserFunc<&'static mut c_void, (), &'static mut c_void>,
125        param: &'static mut c_void,
126        size: usize,
127    ) -> std::io::Result<JoinHandle> {
128        EventLoop::next_scheduler()
129            .submit(f, param, size)
130            .map(|co| JoinHandle(unsafe { std::mem::transmute(co) }))
131    }
132
133    pub fn round_robin_schedule() -> std::io::Result<()> {
134        EventLoop::round_robin_timeout_schedule(u64::MAX)
135    }
136
137    pub fn round_robin_timed_schedule(timeout_time: u64) -> std::io::Result<()> {
138        loop {
139            if timeout_time <= timer_utils::now() {
140                return Ok(());
141            }
142            EventLoop::round_robin_timeout_schedule(timeout_time)?;
143        }
144    }
145
146    pub fn round_robin_timeout_schedule(timeout_time: u64) -> std::io::Result<()> {
147        let results: Vec<std::io::Result<()>> = (0..num_cpus::get())
148            .into_par_iter()
149            .map(|_| EventLoop::next_scheduler().try_timeout_schedule(timeout_time))
150            .collect();
151        for result in results {
152            result?;
153        }
154        Ok(())
155    }
156
157    pub fn round_robin_del_event(fd: libc::c_int) {
158        (0..num_cpus::get()).into_par_iter().for_each(|_| {
159            let _ = EventLoop::next().del_event(fd);
160        });
161    }
162
163    fn del_event(&mut self, fd: libc::c_int) -> std::io::Result<()> {
164        self.selector.deregister(fd)?;
165        unsafe {
166            READABLE_RECORDS.remove(&fd);
167            READABLE_TOKEN_RECORDS.remove(&fd);
168            WRITABLE_RECORDS.remove(&fd);
169            WRITABLE_TOKEN_RECORDS.remove(&fd);
170        }
171        Ok(())
172    }
173
174    pub fn round_robin_del_read_event(fd: libc::c_int) {
175        (0..num_cpus::get()).into_par_iter().for_each(|_| {
176            let _ = EventLoop::next().del_read_event(fd);
177        });
178    }
179
180    fn del_read_event(&mut self, fd: libc::c_int) -> std::io::Result<()> {
181        unsafe {
182            if READABLE_RECORDS.contains(&fd) {
183                if WRITABLE_RECORDS.contains(&fd) {
184                    //写事件不能删
185                    self.selector.reregister(
186                        fd,
187                        WRITABLE_TOKEN_RECORDS.remove(&fd).unwrap_or(0),
188                        Interest::WRITABLE,
189                    )?;
190                    READABLE_RECORDS.remove(&fd);
191                } else {
192                    self.del_event(fd)?;
193                }
194            }
195        }
196        Ok(())
197    }
198
199    pub fn round_robin_del_write_event(fd: libc::c_int) {
200        (0..num_cpus::get()).into_par_iter().for_each(|_| {
201            let _ = EventLoop::next().del_write_event(fd);
202        });
203    }
204
205    fn del_write_event(&mut self, fd: libc::c_int) -> std::io::Result<()> {
206        unsafe {
207            if WRITABLE_RECORDS.contains(&fd) {
208                if READABLE_RECORDS.contains(&fd) {
209                    //读事件不能删
210                    self.selector.reregister(
211                        fd,
212                        READABLE_TOKEN_RECORDS.remove(&fd).unwrap_or(0),
213                        Interest::READABLE,
214                    )?;
215                    WRITABLE_RECORDS.remove(&fd);
216                } else {
217                    self.del_event(fd)?;
218                }
219            }
220        }
221        Ok(())
222    }
223
224    fn build_token() -> usize {
225        if let Some(co) = Coroutine::<&'static mut c_void, &'static mut c_void>::current() {
226            co.get_id()
227        } else {
228            0
229        }
230    }
231
232    pub fn add_read_event(&mut self, fd: libc::c_int) -> std::io::Result<()> {
233        let token = <EventLoop<'a>>::build_token();
234        self.selector.register(fd, token, Interest::READABLE)?;
235        unsafe {
236            READABLE_RECORDS.insert(fd);
237            READABLE_TOKEN_RECORDS.insert(fd, token);
238        }
239        Ok(())
240    }
241
242    pub fn add_write_event(&mut self, fd: libc::c_int) -> std::io::Result<()> {
243        let token = <EventLoop<'a>>::build_token();
244        self.selector.register(fd, token, Interest::WRITABLE)?;
245        unsafe {
246            WRITABLE_RECORDS.insert(fd);
247            WRITABLE_TOKEN_RECORDS.insert(fd, token);
248        }
249        Ok(())
250    }
251
252    fn wait(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
253        //fixme 这里应该只调1次scheduler.syscall,实际由于外层的loop,可能会调用多次
254        self.scheduler.syscall();
255        let mut events = Events::with_capacity(1024);
256        self.selector.select(&mut events, timeout)?;
257        for event in events.iter() {
258            let fd = event.fd();
259            let token = event.token();
260            unsafe {
261                let _ = self.scheduler.resume(token);
262                if event.is_readable() {
263                    READABLE_TOKEN_RECORDS.remove(&fd);
264                }
265                if event.is_writable() {
266                    WRITABLE_TOKEN_RECORDS.remove(&fd);
267                }
268            }
269        }
270        Ok(())
271    }
272
273    pub fn wait_read_event(
274        &mut self,
275        fd: libc::c_int,
276        timeout: Option<Duration>,
277    ) -> std::io::Result<()> {
278        self.add_read_event(fd)?;
279        self.wait(timeout)
280    }
281
282    pub fn wait_write_event(
283        &mut self,
284        fd: libc::c_int,
285        timeout: Option<Duration>,
286    ) -> std::io::Result<()> {
287        self.add_write_event(fd)?;
288        self.wait(timeout)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use crate::{EventLoop, Yielder};
295    use std::os::raw::c_void;
296
297    fn val(val: usize) -> &'static mut c_void {
298        unsafe { std::mem::transmute(val) }
299    }
300
301    extern "C" fn f1(
302        _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
303        input: &'static mut c_void,
304    ) -> &'static mut c_void {
305        println!("[coroutine1] launched");
306        input
307    }
308
309    extern "C" fn f2(
310        _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
311        input: &'static mut c_void,
312    ) -> &'static mut c_void {
313        println!("[coroutine2] launched");
314        input
315    }
316
317    #[test]
318    fn join_test() {
319        let handle1 = EventLoop::submit(f1, val(1), 4096).expect("submit failed !");
320        let handle2 = EventLoop::submit(f2, val(2), 4096).expect("submit failed !");
321        assert_eq!(handle1.join().unwrap(), 1);
322        assert_eq!(handle2.join().unwrap(), 2);
323    }
324
325    extern "C" fn f3(
326        _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
327        input: &'static mut c_void,
328    ) -> &'static mut c_void {
329        println!("[coroutine3] launched");
330        input
331    }
332
333    #[test]
334    fn timed_join_test() {
335        let handle = EventLoop::submit(f3, val(3), 4096).expect("submit failed !");
336        let error = handle
337            .timeout_join(std::time::Duration::from_nanos(0))
338            .unwrap_err();
339        assert_eq!(error.kind(), std::io::ErrorKind::TimedOut);
340        assert_eq!(
341            handle
342                .timeout_join(std::time::Duration::from_secs(1))
343                .unwrap(),
344            3
345        );
346    }
347}