use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use std::future::Future;
use tokio::runtime::Runtime;
static GLOBAL_RUNTIME: OnceCell<Runtime> = OnceCell::new();
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
pub worker_threads: usize,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
worker_threads: 2, }
}
}
pub fn init_runtime(config: RuntimeConfig) -> PyResult<()> {
GLOBAL_RUNTIME
.set({
tokio::runtime::Builder::new_multi_thread()
.worker_threads(config.worker_threads)
.enable_io()
.enable_time()
.thread_name("briefcase-python")
.build()
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to initialize async runtime: {}",
e
))
})?
})
.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Global runtime already initialized")
})?;
Ok(())
}
pub fn get_runtime() -> PyResult<&'static Runtime> {
GLOBAL_RUNTIME.get().ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Global runtime not initialized. Call briefcase_ai.init() first",
)
})
}
pub fn block_on_result<F, R, E>(future: F) -> PyResult<R>
where
F: Future<Output = Result<R, E>> + Send,
R: Send,
E: std::fmt::Display + Send,
{
let runtime = get_runtime()?;
runtime
.block_on(future)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}
pub fn is_initialized() -> bool {
GLOBAL_RUNTIME.get().is_some()
}
pub fn shutdown_runtime() -> PyResult<()> {
Ok(())
}
pub trait PythonAsyncExt<T> {
fn block_on_python(self) -> PyResult<T>;
}
impl<F, T, E> PythonAsyncExt<T> for F
where
F: Future<Output = Result<T, E>> + Send,
T: Send,
E: std::fmt::Display + Send,
{
fn block_on_python(self) -> PyResult<T> {
block_on_result(self)
}
}
pub trait PythonAsyncVecExt<T> {
fn block_on_python(self) -> PyResult<Vec<T>>;
}
impl<F, T, E> PythonAsyncVecExt<T> for F
where
F: Future<Output = Vec<Result<T, E>>> + Send,
T: Send,
E: std::fmt::Display + Send,
{
fn block_on_python(self) -> PyResult<Vec<T>> {
let runtime = get_runtime()?;
let results = runtime.block_on(self);
let mut successes = Vec::new();
for result in results {
match result {
Ok(item) => successes.push(item),
Err(_) => {
continue;
}
}
}
Ok(successes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runtime_initialization() {
let config = RuntimeConfig::default();
assert!(init_runtime(config).is_ok());
assert!(is_initialized());
assert!(get_runtime().is_ok());
let result = block_on_result(async { Ok::<i32, &str>(42) });
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_block_on_result() {
let config = RuntimeConfig::default();
init_runtime(config).unwrap();
let result = block_on_result(async { Ok::<i32, &str>(42) });
assert_eq!(result.unwrap(), 42);
let result = block_on_result(async { Err::<i32, &str>("test error") });
assert!(result.is_err());
}
#[test]
fn test_async_trait() {
let config = RuntimeConfig::default();
init_runtime(config).unwrap();
let future = async { Ok::<i32, &str>(42) };
let result = future.block_on_python();
assert_eq!(result.unwrap(), 42);
}
}