use crate::statsig_global::StatsigGlobal;
use crate::StatsigErr;
use crate::{log_d, log_e};
use futures::future::join_all;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::{Builder, Handle, Runtime};
use tokio::sync::Notify;
use tokio::task::JoinHandle;
const TAG: &str = stringify!(StatsigRuntime);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TaskId {
tag: String,
tokio_id: tokio::task::Id,
}
pub struct StatsigRuntime {
spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
shutdown_notify: Arc<Notify>,
is_shutdown: Arc<AtomicBool>,
}
impl StatsigRuntime {
#[must_use]
pub fn get_runtime() -> Arc<StatsigRuntime> {
create_runtime_if_required();
Arc::new(StatsigRuntime {
spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
shutdown_notify: Arc::new(Notify::new()),
is_shutdown: Arc::new(AtomicBool::new(false)),
})
}
pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
if let Ok(handle) = Handle::try_current() {
return Ok(handle);
}
let global = StatsigGlobal::get();
let mut rt = global
.tokio_runtime
.try_lock_for(Duration::from_secs(5))
.ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
if rt.is_none() {
*rt = Some(Arc::new(create_new_runtime()));
}
if let Some(rt) = rt.as_ref() {
return Ok(rt.handle().clone());
}
Err(StatsigErr::ThreadFailure(
"No tokio runtime found".to_string(),
))
}
pub fn get_num_active_tasks(&self) -> usize {
match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(lock) => lock.len(),
None => {
log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
0
}
}
}
pub fn shutdown(&self) {
self.shutdown_notify.notify_waiters();
match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(mut lock) => {
for (_, task) in lock.drain() {
task.abort();
}
}
None => {
log_e!(TAG, "Failed to lock spawned tasks for shutdown");
}
}
}
pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
where
F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let tag_string = tag.to_string();
let shutdown_notify = self.shutdown_notify.clone();
let spawned_tasks = self.spawned_tasks.clone();
let is_shutdown = self.is_shutdown.clone();
log_d!(TAG, "Spawning task {}", tag);
let handle = self.get_handle()?.spawn(async move {
if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let task_id = tokio::task::id();
log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
task(shutdown_notify).await;
remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
});
Ok(self.insert_join_handle(tag, handle))
}
pub async fn await_tasks_with_tag(&self, tag: &str) {
let mut handles = Vec::new();
match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(mut lock) => {
let keys: Vec<TaskId> = lock.keys().cloned().collect();
for key in &keys {
if key.tag == tag {
let removed = if let Some(handle) = lock.remove(key) {
handle
} else {
log_e!(TAG, "No running task found for tag {}", tag);
continue;
};
handles.push(removed);
}
}
}
None => {
log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
return;
}
};
join_all(handles).await;
}
pub async fn await_join_handle(
&self,
tag: &str,
handle_id: &tokio::task::Id,
) -> Result<(), StatsigErr> {
let task_id = TaskId {
tag: tag.to_string(),
tokio_id: *handle_id,
};
let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(mut lock) => match lock.remove(&task_id) {
Some(handle) => handle,
None => {
return Err(StatsigErr::ThreadFailure(
"No running task found".to_string(),
));
}
},
None => {
log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
return Err(StatsigErr::ThreadFailure(
"Failed to lock spawned tasks".to_string(),
));
}
};
handle
.await
.map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
Ok(())
}
pub fn get_running_task_ids(&self) -> Vec<(String, String)> {
let tasks = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(lock) => lock,
None => {
log_e!(TAG, "Failed to lock spawned tasks for get_running_task_ids");
return Vec::new();
}
};
tasks
.keys()
.map(|key| (key.tag.clone(), key.tokio_id.to_string()))
.collect()
}
fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
let handle_id = handle.id();
let task_id = TaskId {
tag: tag.to_string(),
tokio_id: handle_id,
};
match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(mut lock) => {
lock.insert(task_id, handle);
}
None => {
log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
}
}
handle_id
}
}
pub fn create_new_runtime() -> Runtime {
#[cfg(not(target_family = "wasm"))]
return Builder::new_multi_thread()
.worker_threads(5)
.thread_name("statsig")
.enable_all()
.build()
.expect("Failed to create a tokio Runtime");
#[cfg(target_family = "wasm")]
return Builder::new_current_thread()
.thread_name("statsig")
.enable_all()
.build()
.expect("Failed to create a tokio Runtime (single-threaded for wasm");
}
fn remove_join_handle_with_id(
spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
tag: String,
handle_id: &tokio::task::Id,
) {
let task_id = TaskId {
tag,
tokio_id: *handle_id,
};
match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
Some(mut lock) => {
lock.remove(&task_id);
}
None => {
log_e!(
TAG,
"Failed to lock spawned tasks for remove_join_handle_with_id"
);
}
}
}
fn create_runtime_if_required() {
if Handle::try_current().is_ok() {
log_d!(TAG, "External tokio runtime found");
return;
}
let global = StatsigGlobal::get();
let mut lock = global
.tokio_runtime
.try_lock_for(Duration::from_secs(5))
.expect("Failed to lock owned tokio runtime");
match lock.as_ref() {
Some(_) => {
log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
}
None => {
log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
let rt = Arc::new(create_new_runtime());
lock.replace(rt);
}
};
}
impl Drop for StatsigRuntime {
fn drop(&mut self) {
self.shutdown();
}
}