use std::sync::Arc;
use std::sync::atomic::AtomicBool;
#[cfg(feature = "wasm-runtime")]
use futures::channel::oneshot as futures_oneshot;
#[cfg(feature = "wasm-runtime")]
use futures::future::AbortHandle;
use parking_lot::Mutex;
use rivet_envoy_client::async_counter::AsyncCounter;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use crate::actor::task_types::UserTaskKind;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ActorWorkKind {
Action,
KeepAwake,
InternalKeepAwake,
WaitUntil,
RegisteredTask,
WebSocketCallback,
DisconnectCallback,
}
#[derive(Debug, Clone, Copy)]
pub struct ActorWorkPolicy {
pub blocks_idle_sleep: bool,
pub drains_shutdown_grace: bool,
pub aborts_at_shutdown_deadline: bool,
pub user_task_kind: Option<UserTaskKind>,
}
impl ActorWorkKind {
pub fn policy(self) -> ActorWorkPolicy {
match self {
ActorWorkKind::Action => ActorWorkPolicy {
blocks_idle_sleep: true,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: Some(UserTaskKind::Action),
},
ActorWorkKind::KeepAwake => ActorWorkPolicy {
blocks_idle_sleep: true,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: None,
},
ActorWorkKind::InternalKeepAwake => ActorWorkPolicy {
blocks_idle_sleep: true,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: false,
user_task_kind: None,
},
ActorWorkKind::WaitUntil => ActorWorkPolicy {
blocks_idle_sleep: false,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: Some(UserTaskKind::WaitUntil),
},
ActorWorkKind::RegisteredTask => ActorWorkPolicy {
blocks_idle_sleep: false,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: Some(UserTaskKind::RegisteredTask),
},
ActorWorkKind::WebSocketCallback => ActorWorkPolicy {
blocks_idle_sleep: true,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: Some(UserTaskKind::WebSocketCallback),
},
ActorWorkKind::DisconnectCallback => ActorWorkPolicy {
blocks_idle_sleep: true,
drains_shutdown_grace: true,
aborts_at_shutdown_deadline: true,
user_task_kind: Some(UserTaskKind::DisconnectCallback),
},
}
}
pub(crate) fn label(self) -> &'static str {
match self {
ActorWorkKind::Action => "action",
ActorWorkKind::KeepAwake => "keep_awake",
ActorWorkKind::InternalKeepAwake => "internal_keep_awake",
ActorWorkKind::WaitUntil => "wait_until",
ActorWorkKind::RegisteredTask => "registered_task",
ActorWorkKind::WebSocketCallback => "websocket_callback",
ActorWorkKind::DisconnectCallback => "disconnect_callback",
}
}
}
pub(crate) struct WorkRegistry {
pub(crate) keep_awake: Arc<AsyncCounter>,
pub(crate) internal_keep_awake: Arc<AsyncCounter>,
pub(crate) websocket_callback: Arc<AsyncCounter>,
pub(crate) disconnect_callback: Arc<AsyncCounter>,
pub(crate) shutdown_counter: Arc<AsyncCounter>,
pub(crate) core_dispatched_hooks: Arc<AsyncCounter>,
pub(crate) shutdown_tasks: Mutex<JoinSet<()>>,
pub(crate) unabortable_shutdown_tasks: Mutex<JoinSet<()>>,
#[cfg(feature = "wasm-runtime")]
pub(crate) local_shutdown_tasks: Mutex<Vec<LocalShutdownTask>>,
pub(crate) idle_notify: Arc<Notify>,
pub(crate) activity_notify: Arc<Notify>,
pub(crate) teardown_started: AtomicBool,
pub(crate) shutdown_deadline_reached: AtomicBool,
}
#[cfg(feature = "wasm-runtime")]
pub(crate) struct LocalShutdownTask {
pub(crate) abort_handle: AbortHandle,
pub(crate) complete_rx: futures_oneshot::Receiver<()>,
pub(crate) aborts_at_shutdown_deadline: bool,
}
impl WorkRegistry {
pub(crate) fn new() -> Self {
let idle_notify = Arc::new(Notify::new());
let keep_awake = Arc::new(AsyncCounter::new());
keep_awake.register_zero_notify(&idle_notify);
let internal_keep_awake = Arc::new(AsyncCounter::new());
internal_keep_awake.register_zero_notify(&idle_notify);
let websocket_callback = Arc::new(AsyncCounter::new());
websocket_callback.register_zero_notify(&idle_notify);
let disconnect_callback = Arc::new(AsyncCounter::new());
disconnect_callback.register_zero_notify(&idle_notify);
Self {
keep_awake,
internal_keep_awake,
websocket_callback,
disconnect_callback,
shutdown_counter: Arc::new(AsyncCounter::new()),
core_dispatched_hooks: Arc::new(AsyncCounter::new()),
shutdown_tasks: Mutex::new(JoinSet::new()),
unabortable_shutdown_tasks: Mutex::new(JoinSet::new()),
#[cfg(feature = "wasm-runtime")]
local_shutdown_tasks: Mutex::new(Vec::new()),
idle_notify,
activity_notify: Arc::new(Notify::new()),
teardown_started: AtomicBool::new(false),
shutdown_deadline_reached: AtomicBool::new(false),
}
}
pub(crate) fn keep_awake_guard(&self) -> RegionGuard {
RegionGuard::new(self.keep_awake.clone())
}
pub(crate) fn internal_keep_awake_guard(&self) -> RegionGuard {
RegionGuard::new(self.internal_keep_awake.clone())
}
pub(crate) fn websocket_callback_guard(&self) -> RegionGuard {
RegionGuard::new(self.websocket_callback.clone())
}
pub(crate) fn disconnect_callback_guard(&self) -> RegionGuard {
RegionGuard::new(self.disconnect_callback.clone())
}
}
impl Default for WorkRegistry {
fn default() -> Self {
Self::new()
}
}
pub(crate) struct RegionGuard {
counter: Arc<AsyncCounter>,
log_kind: Option<&'static str>,
log_actor_id: Option<String>,
}
impl RegionGuard {
fn new(counter: Arc<AsyncCounter>) -> Self {
counter.increment();
Self {
counter,
log_kind: None,
log_actor_id: None,
}
}
pub(crate) fn from_incremented(counter: Arc<AsyncCounter>) -> Self {
Self {
counter,
log_kind: None,
log_actor_id: None,
}
}
pub(crate) fn with_log_fields(mut self, kind: &'static str, actor_id: Option<String>) -> Self {
let count = self.counter.load();
match actor_id.as_deref() {
Some(actor_id) => tracing::debug!(actor_id, kind, count, "sleep keep-awake engaged"),
None => tracing::debug!(kind, count, "sleep keep-awake engaged"),
}
self.log_kind = Some(kind);
self.log_actor_id = actor_id;
self
}
}
impl Drop for RegionGuard {
fn drop(&mut self) {
self.counter.decrement();
let Some(kind) = self.log_kind else {
return;
};
let count = self.counter.load();
match self.log_actor_id.as_deref() {
Some(actor_id) => tracing::debug!(actor_id, kind, count, "sleep keep-awake disengaged"),
None => tracing::debug!(kind, count, "sleep keep-awake disengaged"),
}
}
}
pub(crate) type CountGuard = RegionGuard;
#[cfg(test)]
#[path = "../../tests/work_registry.rs"]
mod tests;