rustrails-support 0.1.1

Core utilities (ActiveSupport equivalent)
Documentation
use std::{cell::RefCell, future::Future};

use tokio::runtime::{Builder, Handle, Runtime};

const RUNTIME_NOT_INITIALIZED: &str = "rustrails_support::runtime::init_runtime() must be called on this thread before using runtime helpers";

thread_local! {
    static RT_HANDLE: RefCell<Option<Handle>> = const { RefCell::new(None) };
}

fn with_handle<R>(f: impl FnOnce(&Handle) -> R) -> R {
    RT_HANDLE.with(|cell| {
        let borrow = cell.borrow();
        let handle = borrow
            .as_ref()
            .unwrap_or_else(|| panic!("{RUNTIME_NOT_INITIALIZED}"));
        f(handle)
    })
}

/// Initializes the thread-local Tokio runtime handle for the current thread.
///
/// The returned runtime must be kept alive by the caller for as long as the
/// thread-local helpers are used on this thread.
pub fn init_runtime() -> Runtime {
    let runtime = match Builder::new_multi_thread().enable_all().build() {
        Ok(runtime) => runtime,
        Err(error) => panic!("failed to build Tokio runtime: {error}"),
    };
    let handle = runtime.handle().clone();
    RT_HANDLE.with(|cell| {
        *cell.borrow_mut() = Some(handle);
    });
    runtime
}

/// Runs a future to completion on the thread-local Tokio runtime.
///
/// Panics when the current thread has not been initialized with
/// [`init_runtime`]. When called from within the same Tokio runtime, this
/// temporarily yields the worker thread with `block_in_place` before re-entering
/// the async context.
pub fn block_on<F: Future>(future: F) -> F::Output {
    with_handle(|handle| match Handle::try_current() {
        Ok(current) if current.id() == handle.id() => {
            tokio::task::block_in_place(|| handle.block_on(future))
        }
        _ => handle.block_on(future),
    })
}

/// Spawns a task onto the thread-local Tokio runtime.
///
/// Panics when the current thread has not been initialized with [`init_runtime`].
pub fn spawn<F>(future: F) -> tokio::task::JoinHandle<F::Output>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    with_handle(|handle| handle.spawn(future))
}

/// Returns `true` when the current thread has an initialized Tokio runtime handle.
pub fn is_initialized() -> bool {
    RT_HANDLE.with(|cell| cell.borrow().is_some())
}

#[cfg(test)]
mod tests {
    use std::{
        any::Any,
        sync::mpsc,
        thread,
        time::{Duration, Instant},
    };

    use super::{block_on, init_runtime, is_initialized, spawn};

    fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
    where
        R: Send + 'static,
    {
        match thread::spawn(test).join() {
            Ok(result) => result,
            Err(payload) => std::panic::resume_unwind(payload),
        }
    }

    fn panic_message(payload: Box<dyn Any + Send>) -> String {
        if let Some(message) = payload.downcast_ref::<String>() {
            message.clone()
        } else if let Some(message) = payload.downcast_ref::<&str>() {
            (*message).to_owned()
        } else {
            "non-string panic payload".to_owned()
        }
    }

    #[test]
    fn init_runtime_sets_initialized_to_true() {
        run_isolated(|| {
            assert!(!is_initialized());
            let _runtime = init_runtime();
            assert!(is_initialized());
        });
    }

    #[test]
    fn is_initialized_is_false_before_init() {
        run_isolated(|| {
            assert!(!is_initialized());
        });
    }

    #[test]
    fn block_on_executes_simple_future() {
        run_isolated(|| {
            let _runtime = init_runtime();
            assert_eq!(block_on(async { 42 }), 42);
        });
    }

    #[test]
    fn block_on_propagates_result_errors() {
        run_isolated(|| {
            let _runtime = init_runtime();
            let result = block_on(async { Result::<(), &'static str>::Err("boom") });
            assert_eq!(result, Err("boom"));
        });
    }

    #[test]
    fn block_on_panics_with_clear_message_before_init() {
        let message = run_isolated(|| {
            let panic = std::panic::catch_unwind(|| {
                let _: i32 = block_on(async { 42 });
            })
            .expect_err("block_on should panic before init_runtime");
            panic_message(panic)
        });

        assert!(message.contains("init_runtime() must be called on this thread"));
    }

    #[test]
    fn spawn_runs_task_to_completion() {
        run_isolated(|| {
            let _runtime = init_runtime();
            let join = spawn(async { 7_i32 * 6 });
            let value = block_on(async { join.await.expect("task should complete") });
            assert_eq!(value, 42);
        });
    }

    #[test]
    fn multiple_sequential_block_on_calls_work() {
        run_isolated(|| {
            let _runtime = init_runtime();
            assert_eq!(block_on(async { 1 }), 1);
            assert_eq!(block_on(async { 2 }), 2);
            assert_eq!(block_on(async { 3 }), 3);
        });
    }

    #[test]
    fn block_on_supports_sleeping_futures() {
        run_isolated(|| {
            let _runtime = init_runtime();
            let start = Instant::now();
            block_on(async {
                tokio::time::sleep(Duration::from_millis(10)).await;
            });
            assert!(start.elapsed() >= Duration::from_millis(10));
        });
    }

    #[test]
    fn block_on_reenters_the_same_runtime_inside_async_context() {
        run_isolated(|| {
            let runtime = init_runtime();
            let value = runtime.block_on(async { block_on(async { 21 * 2 }) });
            assert_eq!(value, 42);
        });
    }

    #[test]
    fn spawn_can_signal_back_to_sync_code() {
        run_isolated(|| {
            let _runtime = init_runtime();
            let (sender, receiver) = mpsc::channel();
            let join = spawn(async move {
                sender.send("done").expect("channel send should succeed");
            });
            block_on(async { join.await.expect("task should complete") });
            assert_eq!(
                receiver.recv().expect("channel receive should succeed"),
                "done"
            );
        });
    }
}