open_coroutine/
lib.rs

1#![deny(
2    // The following are allowed by default lints according to
3    // https://doc.rust-lang.org/rustc/lints/listing/allowed-by-default.html
4    anonymous_parameters,
5    bare_trait_objects,
6    // elided_lifetimes_in_paths, // allow anonymous lifetime
7    missing_copy_implementations,
8    missing_debug_implementations,
9    missing_docs, // TODO: add documents
10    single_use_lifetimes, // TODO: fix lifetime names only used once
11    trivial_casts, // TODO: remove trivial casts in code
12    trivial_numeric_casts,
13    // unreachable_pub, allow clippy::redundant_pub_crate lint instead
14    // unsafe_code,
15    unstable_features,
16    unused_extern_crates,
17    unused_import_braces,
18    unused_qualifications,
19    unused_results,
20    variant_size_differences,
21
22    warnings, // treat all wanings as errors
23
24    clippy::all,
25    // clippy::restriction,
26    clippy::pedantic,
27    // clippy::nursery, // It's still under development
28    clippy::cargo,
29    unreachable_pub,
30)]
31#![allow(
32    // Some explicitly allowed Clippy lints, must have clear reason to allow
33    clippy::blanket_clippy_restriction_lints, // allow clippy::restriction
34    clippy::implicit_return, // actually omitting the return keyword is idiomatic Rust code
35    clippy::module_name_repetitions, // repeation of module name in a struct name is not big deal
36    clippy::multiple_crate_versions, // multi-version dependency crates is not able to fix
37    clippy::missing_errors_doc, // TODO: add error docs
38    clippy::missing_panics_doc, // TODO: add panic docs
39    clippy::panic_in_result_fn,
40    clippy::shadow_same, // Not too much bad
41    clippy::shadow_reuse, // Not too much bad
42    clippy::exhaustive_enums,
43    clippy::exhaustive_structs,
44    clippy::indexing_slicing,
45    clippy::separated_literal_suffix, // conflicts with clippy::unseparated_literal_suffix
46    clippy::single_char_lifetime_names, // TODO: change lifetime names
47    unknown_lints, // for windows nightly
48    linker_messages, // for windows nightly
49    unused_attributes, // for windows nightly
50)]
51//! see `https://github.com/acl-dev/open-coroutine`
52
53use open_coroutine_core::co_pool::task::UserTaskFunc;
54use open_coroutine_core::common::constants::SLICE;
55pub use open_coroutine_core::common::ordered_work_steal::DEFAULT_PRECEDENCE;
56pub use open_coroutine_core::config::Config;
57use open_coroutine_core::net::UserFunc;
58pub use open_coroutine_macros::*;
59use std::cmp::Ordering;
60use std::ffi::{c_int, c_longlong, c_uint, c_void};
61use std::io::{Error, ErrorKind};
62use std::marker::PhantomData;
63use std::net::{TcpStream, ToSocketAddrs};
64use std::ops::Deref;
65use std::time::Duration;
66
67extern "C" {
68    fn open_coroutine_init(config: Config) -> c_int;
69
70    fn open_coroutine_stop(secs: c_uint) -> c_int;
71
72    fn maybe_grow_stack(
73        red_zone: usize,
74        stack_size: usize,
75        f: UserFunc,
76        param: usize,
77    ) -> c_longlong;
78}
79
80#[allow(improper_ctypes)]
81extern "C" {
82    fn task_crate(
83        f: UserTaskFunc,
84        param: usize,
85        priority: c_longlong,
86    ) -> open_coroutine_core::net::join::JoinHandle;
87
88    fn task_join(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong;
89
90    fn task_cancel(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong;
91
92    fn task_timeout_join(
93        handle: &open_coroutine_core::net::join::JoinHandle,
94        ns_time: u64,
95    ) -> c_longlong;
96}
97
98/// Init the open-coroutine.
99pub fn init(config: Config) {
100    assert_eq!(
101        0,
102        unsafe { open_coroutine_init(config) },
103        "open-coroutine init failed !"
104    );
105    #[cfg(feature = "ci")]
106    open_coroutine_core::common::ci::init();
107}
108
109/// Shutdown the open-coroutine.
110pub fn shutdown() {
111    unsafe { _ = open_coroutine_stop(30) };
112}
113
114/// Create a task.
115#[macro_export]
116macro_rules! task {
117    ( $f: expr , $param:expr , $priority: expr $(,)? ) => {
118        $crate::crate_task($f, $param, $priority)
119    };
120    ( $f: expr , $param:expr $(,)? ) => {
121        $crate::crate_task($f, $param, $crate::DEFAULT_PRECEDENCE)
122    };
123}
124
125/// Create a task.
126pub fn crate_task<P: 'static, R: 'static, F: FnOnce(P) -> R>(
127    f: F,
128    param: P,
129    priority: c_longlong,
130) -> JoinHandle<R> {
131    extern "C" fn task_main<P: 'static, R: 'static, F: FnOnce(P) -> R>(input: usize) -> usize {
132        unsafe {
133            let ptr = &mut *((input as *mut c_void).cast::<(F, P)>());
134            let data = std::ptr::read_unaligned(ptr);
135            let result: &'static mut std::io::Result<R> = Box::leak(Box::new(
136                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (data.0)(data.1)))
137                    .map_err(|e| {
138                        Error::other(
139                            e.downcast_ref::<&'static str>()
140                                .map_or("task failed without message", |msg| *msg),
141                        )
142                    }),
143            ));
144            std::ptr::from_mut(result).cast::<c_void>() as usize
145        }
146    }
147    let inner = Box::leak(Box::new((f, param)));
148    unsafe {
149        task_crate(
150            task_main::<P, R, F>,
151            std::ptr::from_mut(inner).cast::<c_void>() as usize,
152            priority,
153        )
154        .into()
155    }
156}
157
158#[allow(missing_docs)]
159#[repr(C)]
160#[derive(Debug)]
161pub struct JoinHandle<R>(open_coroutine_core::net::join::JoinHandle, PhantomData<R>);
162
163#[allow(missing_docs)]
164impl<R> JoinHandle<R> {
165    pub fn timeout_join(&self, dur: Duration) -> std::io::Result<Option<R>> {
166        unsafe {
167            let ptr = task_timeout_join(self, dur.as_nanos().try_into().expect("overflow"));
168            match ptr.cmp(&0) {
169                Ordering::Less => Err(Error::other("timeout join failed")),
170                Ordering::Equal => Ok(None),
171                Ordering::Greater => Ok(Some((*Box::from_raw(ptr as *mut std::io::Result<R>))?)),
172            }
173        }
174    }
175
176    pub fn join(self) -> std::io::Result<Option<R>> {
177        unsafe {
178            let ptr = task_join(&self);
179            match ptr.cmp(&0) {
180                Ordering::Less => Err(Error::other("join failed")),
181                Ordering::Equal => Ok(None),
182                Ordering::Greater => Ok(Some((*Box::from_raw(ptr as *mut std::io::Result<R>))?)),
183            }
184        }
185    }
186
187    pub fn any_timeout_join(dur: Duration, slice: &[Self]) -> std::io::Result<Option<R>> {
188        if slice.is_empty() {
189            return Ok(None);
190        }
191        let timeout_time = open_coroutine_core::common::get_timeout_time(dur);
192        loop {
193            for handle in slice {
194                let left_time = timeout_time.saturating_sub(open_coroutine_core::common::now());
195                if 0 == left_time {
196                    return Err(Error::other("timeout join failed"));
197                }
198                if let Ok(x) = handle.timeout_join(Duration::from_nanos(left_time).min(SLICE)) {
199                    return Ok(x);
200                }
201            }
202        }
203    }
204
205    pub fn any_join<I: IntoIterator<Item = Self>>(iter: I) -> std::io::Result<Option<R>> {
206        let vec = Vec::from_iter(iter);
207        Self::any_timeout_join(Duration::MAX, &vec).inspect(|_| {
208            for handle in vec {
209                _ = handle.try_cancel();
210            }
211        })
212    }
213
214    pub fn try_cancel(self) -> std::io::Result<()> {
215        let r = unsafe { task_cancel(&self) };
216        match r.cmp(&0) {
217            Ordering::Equal => Ok(()),
218            _ => Err(Error::other("cancel failed")),
219        }
220    }
221}
222
223impl<R> From<open_coroutine_core::net::join::JoinHandle> for JoinHandle<R> {
224    fn from(val: open_coroutine_core::net::join::JoinHandle) -> Self {
225        Self(val, PhantomData)
226    }
227}
228
229impl<R> From<JoinHandle<R>> for open_coroutine_core::net::join::JoinHandle {
230    fn from(val: JoinHandle<R>) -> Self {
231        val.0
232    }
233}
234
235impl<R> Deref for JoinHandle<R> {
236    type Target = open_coroutine_core::net::join::JoinHandle;
237
238    fn deref(&self) -> &Self::Target {
239        &self.0
240    }
241}
242
243/// Waiting for one of the tasks to be completed.
244#[macro_export]
245macro_rules! any_timeout_join {
246    ($time:expr, $($x:expr),+ $(,)?) => {
247        $crate::JoinHandle::any_timeout_join($time, &vec![$($x),+])
248    }
249}
250
251/// Waiting for one of the tasks to be completed.
252#[macro_export]
253macro_rules! any_join {
254    ($($x:expr),+ $(,)?) => {
255        $crate::JoinHandle::any_join(vec![$($x),+])
256    }
257}
258
259/// Grows the call stack if necessary.
260#[macro_export]
261macro_rules! maybe_grow {
262    ($red_zone:expr, $stack_size:expr, $f:expr $(,)?) => {
263        $crate::maybe_grow($red_zone, $stack_size, $f)
264    };
265    ($stack_size:literal, $f:expr $(,)?) => {
266        $crate::maybe_grow(
267            open_coroutine_core::common::default_red_zone(),
268            $stack_size,
269            $f,
270        )
271    };
272    ($f:expr $(,)?) => {
273        $crate::maybe_grow(
274            open_coroutine_core::common::default_red_zone(),
275            open_coroutine_core::common::constants::DEFAULT_STACK_SIZE,
276            $f,
277        )
278    };
279}
280
281/// Grows the call stack if necessary.
282pub fn maybe_grow<R: 'static, F: FnOnce() -> R>(
283    red_zone: usize,
284    stack_size: usize,
285    f: F,
286) -> std::io::Result<R> {
287    extern "C" fn execute_on_stack<R: 'static, F: FnOnce() -> R>(input: usize) -> usize {
288        unsafe {
289            let ptr = &mut *((input as *mut c_void).cast::<F>());
290            let data = std::ptr::read_unaligned(ptr);
291            let result: &'static mut R = Box::leak(Box::new(data()));
292            std::ptr::from_mut(result).cast::<c_void>() as usize
293        }
294    }
295    let inner = Box::leak(Box::new(f));
296    unsafe {
297        let ptr = maybe_grow_stack(
298            red_zone,
299            stack_size,
300            execute_on_stack::<R, F>,
301            std::ptr::from_mut(inner).cast::<c_void>() as usize,
302        );
303        if ptr < 0 {
304            return Err(Error::new(ErrorKind::InvalidInput, "grow stack failed"));
305        }
306        Ok(*Box::from_raw(
307            usize::try_from(ptr).expect("overflow") as *mut R
308        ))
309    }
310}
311
312/// Opens a TCP connection to a remote host.
313///
314/// `addr` is an address of the remote host. Anything which implements
315/// [`ToSocketAddrs`] trait can be supplied for the address; see this trait
316/// documentation for concrete examples.
317///
318/// If `addr` yields multiple addresses, `connect` will be attempted with
319/// each of the addresses until a connection is successful. If none of
320/// the addresses result in a successful connection, the error returned from
321/// the last connection attempt (the last address) is returned.
322///
323/// # Examples
324///
325/// Open a TCP connection to `127.0.0.1:8080`:
326///
327/// ```no_run
328/// if let Ok(stream) = open_coroutine::connect_timeout("127.0.0.1:8080", std::time::Duration::from_secs(3)) {
329///     println!("Connected to the server!");
330/// } else {
331///     println!("Couldn't connect to server...");
332/// }
333/// ```
334///
335/// Open a TCP connection to `127.0.0.1:8080`. If the connection fails, open
336/// a TCP connection to `127.0.0.1:8081`:
337///
338/// ```no_run
339/// use std::net::SocketAddr;
340///
341/// let addrs = [
342///     SocketAddr::from(([127, 0, 0, 1], 8080)),
343///     SocketAddr::from(([127, 0, 0, 1], 8081)),
344/// ];
345/// if let Ok(stream) = open_coroutine::connect_timeout(&addrs[..], std::time::Duration::from_secs(3)) {
346///     println!("Connected to the server!");
347/// } else {
348///     println!("Couldn't connect to server...");
349/// }
350/// ```
351pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> std::io::Result<TcpStream> {
352    let timeout_time = open_coroutine_core::common::get_timeout_time(timeout);
353    let mut last_err = None;
354    for addr in addr.to_socket_addrs()? {
355        loop {
356            let left_time = timeout_time.saturating_sub(open_coroutine_core::common::now());
357            if 0 == left_time {
358                break;
359            }
360            match TcpStream::connect_timeout(&addr, Duration::from_nanos(left_time).min(SLICE)) {
361                Ok(l) => return Ok(l),
362                Err(e) => last_err = Some(e),
363            }
364        }
365    }
366    Err(last_err.unwrap_or_else(|| {
367        Error::new(
368            ErrorKind::InvalidInput,
369            "could not resolve to any addresses",
370        )
371    }))
372}
373
374#[cfg(test)]
375mod tests {
376    use crate::{init, shutdown};
377    use open_coroutine_core::config::Config;
378
379    #[test]
380    fn test() {
381        init(Config::single());
382        _ = any_join!(task!(|_| 1, ()), task!(|_| 2, ()), task!(|_| 3, ()));
383        task!(
384            |_| {
385                unreachable!("Try cancel!");
386            },
387            (),
388        )
389        .try_cancel()
390        .expect("cancel failed");
391        let join = task!(
392            |_| {
393                println!("Hello, world!");
394            },
395            (),
396        );
397        assert_eq!(Some(()), join.join().expect("join failed"));
398        shutdown();
399    }
400}