1#![deny(
2 anonymous_parameters,
5 bare_trait_objects,
6 missing_copy_implementations,
8 missing_debug_implementations,
9 missing_docs, single_use_lifetimes, trivial_casts, trivial_numeric_casts,
13 unstable_features,
16 unused_extern_crates,
17 unused_import_braces,
18 unused_qualifications,
19 unused_results,
20 variant_size_differences,
21
22 warnings, clippy::all,
25 clippy::pedantic,
27 clippy::cargo,
29 unreachable_pub,
30)]
31#![allow(
32 clippy::blanket_clippy_restriction_lints, clippy::implicit_return, clippy::module_name_repetitions, clippy::multiple_crate_versions, clippy::missing_errors_doc, clippy::missing_panics_doc, clippy::panic_in_result_fn,
40 clippy::shadow_same, clippy::shadow_reuse, clippy::exhaustive_enums,
43 clippy::exhaustive_structs,
44 clippy::indexing_slicing,
45 clippy::separated_literal_suffix, clippy::single_char_lifetime_names, unknown_lints, linker_messages, unused_attributes, )]
51use open_coroutine_core::co_pool::task::UserTaskFunc;
54use open_coroutine_core::common::constants::SLICE;
55pub use open_coroutine_core::common::ordered_work_steal::DEFAULT_PRECEDENCE;
56pub use open_coroutine_core::config::Config;
57use open_coroutine_core::net::UserFunc;
58pub use open_coroutine_macros::*;
59use std::cmp::Ordering;
60use std::ffi::{c_int, c_longlong, c_uint, c_void};
61use std::io::{Error, ErrorKind};
62use std::marker::PhantomData;
63use std::net::{TcpStream, ToSocketAddrs};
64use std::ops::Deref;
65use std::time::Duration;
66
67extern "C" {
68 fn open_coroutine_init(config: Config) -> c_int;
69
70 fn open_coroutine_stop(secs: c_uint) -> c_int;
71
72 fn maybe_grow_stack(
73 red_zone: usize,
74 stack_size: usize,
75 f: UserFunc,
76 param: usize,
77 ) -> c_longlong;
78}
79
80#[allow(improper_ctypes)]
81extern "C" {
82 fn task_crate(
83 f: UserTaskFunc,
84 param: usize,
85 priority: c_longlong,
86 ) -> open_coroutine_core::net::join::JoinHandle;
87
88 fn task_join(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong;
89
90 fn task_cancel(handle: &open_coroutine_core::net::join::JoinHandle) -> c_longlong;
91
92 fn task_timeout_join(
93 handle: &open_coroutine_core::net::join::JoinHandle,
94 ns_time: u64,
95 ) -> c_longlong;
96}
97
98pub fn init(config: Config) {
100 assert_eq!(
101 0,
102 unsafe { open_coroutine_init(config) },
103 "open-coroutine init failed !"
104 );
105 #[cfg(feature = "ci")]
106 open_coroutine_core::common::ci::init();
107}
108
109pub fn shutdown() {
111 unsafe { _ = open_coroutine_stop(30) };
112}
113
114#[macro_export]
116macro_rules! task {
117 ( $f: expr , $param:expr , $priority: expr $(,)? ) => {
118 $crate::crate_task($f, $param, $priority)
119 };
120 ( $f: expr , $param:expr $(,)? ) => {
121 $crate::crate_task($f, $param, $crate::DEFAULT_PRECEDENCE)
122 };
123}
124
125pub fn crate_task<P: 'static, R: 'static, F: FnOnce(P) -> R>(
127 f: F,
128 param: P,
129 priority: c_longlong,
130) -> JoinHandle<R> {
131 extern "C" fn task_main<P: 'static, R: 'static, F: FnOnce(P) -> R>(input: usize) -> usize {
132 unsafe {
133 let ptr = &mut *((input as *mut c_void).cast::<(F, P)>());
134 let data = std::ptr::read_unaligned(ptr);
135 let result: &'static mut std::io::Result<R> = Box::leak(Box::new(
136 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (data.0)(data.1)))
137 .map_err(|e| {
138 Error::other(
139 e.downcast_ref::<&'static str>()
140 .map_or("task failed without message", |msg| *msg),
141 )
142 }),
143 ));
144 std::ptr::from_mut(result).cast::<c_void>() as usize
145 }
146 }
147 let inner = Box::leak(Box::new((f, param)));
148 unsafe {
149 task_crate(
150 task_main::<P, R, F>,
151 std::ptr::from_mut(inner).cast::<c_void>() as usize,
152 priority,
153 )
154 .into()
155 }
156}
157
158#[allow(missing_docs)]
159#[repr(C)]
160#[derive(Debug)]
161pub struct JoinHandle<R>(open_coroutine_core::net::join::JoinHandle, PhantomData<R>);
162
163#[allow(missing_docs)]
164impl<R> JoinHandle<R> {
165 pub fn timeout_join(&self, dur: Duration) -> std::io::Result<Option<R>> {
166 unsafe {
167 let ptr = task_timeout_join(self, dur.as_nanos().try_into().expect("overflow"));
168 match ptr.cmp(&0) {
169 Ordering::Less => Err(Error::other("timeout join failed")),
170 Ordering::Equal => Ok(None),
171 Ordering::Greater => Ok(Some((*Box::from_raw(ptr as *mut std::io::Result<R>))?)),
172 }
173 }
174 }
175
176 pub fn join(self) -> std::io::Result<Option<R>> {
177 unsafe {
178 let ptr = task_join(&self);
179 match ptr.cmp(&0) {
180 Ordering::Less => Err(Error::other("join failed")),
181 Ordering::Equal => Ok(None),
182 Ordering::Greater => Ok(Some((*Box::from_raw(ptr as *mut std::io::Result<R>))?)),
183 }
184 }
185 }
186
187 pub fn any_timeout_join(dur: Duration, slice: &[Self]) -> std::io::Result<Option<R>> {
188 if slice.is_empty() {
189 return Ok(None);
190 }
191 let timeout_time = open_coroutine_core::common::get_timeout_time(dur);
192 loop {
193 for handle in slice {
194 let left_time = timeout_time.saturating_sub(open_coroutine_core::common::now());
195 if 0 == left_time {
196 return Err(Error::other("timeout join failed"));
197 }
198 if let Ok(x) = handle.timeout_join(Duration::from_nanos(left_time).min(SLICE)) {
199 return Ok(x);
200 }
201 }
202 }
203 }
204
205 pub fn any_join<I: IntoIterator<Item = Self>>(iter: I) -> std::io::Result<Option<R>> {
206 let vec = Vec::from_iter(iter);
207 Self::any_timeout_join(Duration::MAX, &vec).inspect(|_| {
208 for handle in vec {
209 _ = handle.try_cancel();
210 }
211 })
212 }
213
214 pub fn try_cancel(self) -> std::io::Result<()> {
215 let r = unsafe { task_cancel(&self) };
216 match r.cmp(&0) {
217 Ordering::Equal => Ok(()),
218 _ => Err(Error::other("cancel failed")),
219 }
220 }
221}
222
223impl<R> From<open_coroutine_core::net::join::JoinHandle> for JoinHandle<R> {
224 fn from(val: open_coroutine_core::net::join::JoinHandle) -> Self {
225 Self(val, PhantomData)
226 }
227}
228
229impl<R> From<JoinHandle<R>> for open_coroutine_core::net::join::JoinHandle {
230 fn from(val: JoinHandle<R>) -> Self {
231 val.0
232 }
233}
234
235impl<R> Deref for JoinHandle<R> {
236 type Target = open_coroutine_core::net::join::JoinHandle;
237
238 fn deref(&self) -> &Self::Target {
239 &self.0
240 }
241}
242
243#[macro_export]
245macro_rules! any_timeout_join {
246 ($time:expr, $($x:expr),+ $(,)?) => {
247 $crate::JoinHandle::any_timeout_join($time, &vec![$($x),+])
248 }
249}
250
251#[macro_export]
253macro_rules! any_join {
254 ($($x:expr),+ $(,)?) => {
255 $crate::JoinHandle::any_join(vec![$($x),+])
256 }
257}
258
259#[macro_export]
261macro_rules! maybe_grow {
262 ($red_zone:expr, $stack_size:expr, $f:expr $(,)?) => {
263 $crate::maybe_grow($red_zone, $stack_size, $f)
264 };
265 ($stack_size:literal, $f:expr $(,)?) => {
266 $crate::maybe_grow(
267 open_coroutine_core::common::default_red_zone(),
268 $stack_size,
269 $f,
270 )
271 };
272 ($f:expr $(,)?) => {
273 $crate::maybe_grow(
274 open_coroutine_core::common::default_red_zone(),
275 open_coroutine_core::common::constants::DEFAULT_STACK_SIZE,
276 $f,
277 )
278 };
279}
280
281pub fn maybe_grow<R: 'static, F: FnOnce() -> R>(
283 red_zone: usize,
284 stack_size: usize,
285 f: F,
286) -> std::io::Result<R> {
287 extern "C" fn execute_on_stack<R: 'static, F: FnOnce() -> R>(input: usize) -> usize {
288 unsafe {
289 let ptr = &mut *((input as *mut c_void).cast::<F>());
290 let data = std::ptr::read_unaligned(ptr);
291 let result: &'static mut R = Box::leak(Box::new(data()));
292 std::ptr::from_mut(result).cast::<c_void>() as usize
293 }
294 }
295 let inner = Box::leak(Box::new(f));
296 unsafe {
297 let ptr = maybe_grow_stack(
298 red_zone,
299 stack_size,
300 execute_on_stack::<R, F>,
301 std::ptr::from_mut(inner).cast::<c_void>() as usize,
302 );
303 if ptr < 0 {
304 return Err(Error::new(ErrorKind::InvalidInput, "grow stack failed"));
305 }
306 Ok(*Box::from_raw(
307 usize::try_from(ptr).expect("overflow") as *mut R
308 ))
309 }
310}
311
312pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> std::io::Result<TcpStream> {
352 let timeout_time = open_coroutine_core::common::get_timeout_time(timeout);
353 let mut last_err = None;
354 for addr in addr.to_socket_addrs()? {
355 loop {
356 let left_time = timeout_time.saturating_sub(open_coroutine_core::common::now());
357 if 0 == left_time {
358 break;
359 }
360 match TcpStream::connect_timeout(&addr, Duration::from_nanos(left_time).min(SLICE)) {
361 Ok(l) => return Ok(l),
362 Err(e) => last_err = Some(e),
363 }
364 }
365 }
366 Err(last_err.unwrap_or_else(|| {
367 Error::new(
368 ErrorKind::InvalidInput,
369 "could not resolve to any addresses",
370 )
371 }))
372}
373
374#[cfg(test)]
375mod tests {
376 use crate::{init, shutdown};
377 use open_coroutine_core::config::Config;
378
379 #[test]
380 fn test() {
381 init(Config::single());
382 _ = any_join!(task!(|_| 1, ()), task!(|_| 2, ()), task!(|_| 3, ()));
383 task!(
384 |_| {
385 unreachable!("Try cancel!");
386 },
387 (),
388 )
389 .try_cancel()
390 .expect("cancel failed");
391 let join = task!(
392 |_| {
393 println!("Hello, world!");
394 },
395 (),
396 );
397 assert_eq!(Some(()), join.join().expect("join failed"));
398 shutdown();
399 }
400}