use std::future::Future;
use std::sync::OnceLock;
use std::sync::atomic::Ordering;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
pub(crate) fn runtime() -> &'static Runtime {
static RT: OnceLock<Runtime> = OnceLock::new();
RT.get_or_init(|| {
let mut b = tokio::runtime::Builder::new_multi_thread();
b.enable_all();
if let Ok(n) = std::env::var("GOMAXPROCS") {
if let Ok(n) = n.parse::<usize>() {
b.worker_threads(n);
}
}
b.build().expect("goish: failed to build tokio runtime")
})
}
pub struct Goroutine {
handle: Option<JoinHandle<()>>,
}
impl Goroutine {
pub fn spawn<F>(f: F) -> Goroutine
where
F: Future<Output = ()> + Send + 'static,
{
crate::runtime::LIVE_GOROUTINES.fetch_add(1, Ordering::SeqCst);
let handle = runtime().spawn(async move {
struct Guard;
impl Drop for Guard {
fn drop(&mut self) {
crate::runtime::LIVE_GOROUTINES.fetch_sub(1, Ordering::SeqCst);
}
}
let _g = Guard;
f.await;
});
Goroutine { handle: Some(handle) }
}
#[allow(non_snake_case)]
pub fn Wait(mut self) -> crate::errors::error {
match self.handle.take() {
Some(h) => match runtime().block_on(h) {
Ok(()) => crate::errors::nil,
Err(e) => join_err_to_goish(e),
},
None => crate::errors::nil,
}
}
pub async fn wait(mut self) -> crate::errors::error {
match self.handle.take() {
Some(h) => match h.await {
Ok(()) => crate::errors::nil,
Err(e) => join_err_to_goish(e),
},
None => crate::errors::nil,
}
}
}
fn join_err_to_goish(e: tokio::task::JoinError) -> crate::errors::error {
if e.is_panic() {
let payload = e.into_panic();
if crate::runtime::is_goexit_panic(&payload) {
return crate::errors::nil;
}
}
crate::errors::New("goroutine panicked")
}
#[macro_export]
macro_rules! go {
($($tt:tt)*) => {
$crate::goroutine::Goroutine::spawn(async move {
$crate::__macros::rewrite_go_body!($($tt)*);
})
};
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
#[test]
fn go_runs_and_wait_joins() {
let log: Arc<Mutex<Vec<i32>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = log.clone();
let g = crate::go!{
log_clone.lock().unwrap().push(42);
};
let err = g.Wait();
assert!(err == crate::errors::nil);
assert_eq!(*log.lock().unwrap(), vec![42]);
}
#[test]
fn go_with_channel_looks_sync() {
let ch = crate::chan!(i64, 4);
let producer = ch.clone();
let g = crate::go!{
for i in 1i64..=3 {
producer.Send(i);
}
};
let _ = g.Wait();
let mut got: Vec<i64> = Vec::new();
for _ in 0..3 {
let (v, _) = ch.Recv();
got.push(v);
}
got.sort();
assert_eq!(got, vec![1, 2, 3]);
}
#[test]
fn panicking_goroutine_returns_error() {
let g = crate::go!{
panic!("boom");
};
let err = g.Wait();
assert!(err != crate::errors::nil);
}
#[test]
fn ten_thousand_goroutines() {
let ch = crate::chan!(i64, 10_000);
let mut handles = Vec::with_capacity(10_000);
for i in 0..10_000i64 {
let c = ch.clone();
handles.push(crate::go!{ c.Send(i); });
}
let mut sum = 0i64;
for _ in 0..10_000 {
let (v, _) = ch.Recv();
sum += v;
}
for h in handles { let _ = h.Wait(); }
assert_eq!(sum, (9999 * 10_000) / 2);
}
}