use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use crate::time::Instant as StdInstant;
use crate::time::sleep;
use anyhow::{Context, Result};
use rivetkit_actor_persist::{generated::v4 as persist_v4, versioned as persist_versioned};
#[cfg(not(feature = "wasm-runtime"))]
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
#[cfg(test)]
use tokio::time::timeout;
use tracing::Instrument;
use crate::actor::context::ActorContext;
use crate::actor::keys::{LAST_PUSHED_ALARM_KEY, PERSIST_DATA_KEY, make_connection_key};
use crate::actor::kv::APPLY_BATCH_CHUNK_SIZE;
use crate::actor::messages::StateDelta;
use crate::actor::persist::{
decode_latest_with_embedded_version, encode_latest_with_embedded_version,
};
use crate::actor::task::LifecycleEvent;
use crate::actor::task_types::StateMutationReason;
use crate::error::ActorRuntime;
#[cfg(feature = "wasm-runtime")]
use crate::runtime::RuntimeSpawner;
use crate::types::SaveStateOpts;
const LAST_PUSHED_ALARM_VERSION: u16 = 1;
pub type PersistedScheduleEvent = persist_v4::ScheduleEvent;
pub type PersistedActor = persist_v4::Actor;
pub(crate) fn encode_persisted_actor(actor: &PersistedActor) -> Result<Vec<u8>> {
encode_latest_with_embedded_version::<persist_versioned::Actor>(
actor.clone(),
rivetkit_actor_persist::CURRENT_VERSION,
"persisted actor",
)
}
pub(crate) fn decode_persisted_actor(payload: &[u8]) -> Result<PersistedActor> {
let actor = decode_latest_with_embedded_version::<persist_versioned::Actor>(
payload,
"persisted actor",
)?;
Ok(actor)
}
pub(crate) fn encode_last_pushed_alarm(alarm_ts: Option<i64>) -> Result<Vec<u8>> {
encode_latest_with_embedded_version::<persist_versioned::LastPushedAlarm>(
alarm_ts,
LAST_PUSHED_ALARM_VERSION,
"last pushed alarm",
)
}
pub(crate) fn decode_last_pushed_alarm(payload: &[u8]) -> Result<Option<i64>> {
decode_latest_with_embedded_version::<persist_versioned::LastPushedAlarm>(
payload,
"last pushed alarm",
)
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct RequestSaveOpts {
pub immediate: bool,
pub max_wait_ms: Option<u32>,
}
pub(super) struct PendingSave {
scheduled_at: StdInstant,
handle: JoinHandle<()>,
}
pub struct OnStateChangeGuard {
ctx: Option<ActorContext>,
}
impl OnStateChangeGuard {
fn new(ctx: ActorContext) -> Self {
ctx.on_state_change_started();
Self { ctx: Some(ctx) }
}
}
impl Drop for OnStateChangeGuard {
fn drop(&mut self) {
if let Some(ctx) = self.ctx.take() {
ctx.on_state_change_finished();
}
}
}
impl ActorContext {
pub fn state(&self) -> Vec<u8> {
self.0.current_state.read().clone()
}
pub(crate) async fn persist_state(&self, opts: SaveStateOpts) -> Result<()> {
if !self.is_dirty() {
return Ok(());
}
let result = if opts.immediate {
self.clear_pending_save();
self.persist_if_dirty().await
} else {
let delay = self.compute_save_delay(None);
if !delay.is_zero() {
sleep(delay).await;
}
self.persist_if_dirty().await
};
result?;
self.record_state_updated();
Ok(())
}
pub fn set_state_initial(&self, state: Vec<u8>) {
self.set_initial_state(state);
}
pub fn request_save(&self, opts: RequestSaveOpts) {
#[cfg(target_arch = "wasm32")]
{
self.request_save_best_effort(opts);
}
#[cfg(not(target_arch = "wasm32"))]
if let Err(error) = self.request_save_with_revision(opts) {
tracing::warn!(?error, "failed to request actor state save");
}
}
#[cfg(target_arch = "wasm32")]
fn request_save_best_effort(&self, opts: RequestSaveOpts) {
let immediate = opts.immediate;
let _save_request_revision =
self.0.save_request_revision.fetch_add(1, Ordering::SeqCst) + 1;
self.notify_request_save_hooks(opts);
let already_requested = self.0.save_requested.swap(true, Ordering::SeqCst);
let immediate_already_requested = if immediate {
self.0.save_requested_immediate.swap(true, Ordering::SeqCst)
} else {
self.0.save_requested_immediate.load(Ordering::SeqCst)
};
if let Some(max_wait_ms) = opts.max_wait_ms {
let deadline = StdInstant::now() + Duration::from_millis(u64::from(max_wait_ms));
let mut requested_deadline = self.0.save_requested_within_deadline.lock();
*requested_deadline = Some(match *requested_deadline {
Some(existing) => existing.min(deadline),
None => deadline,
});
}
let Some(sender) = self.lifecycle_event_sender() else {
return;
};
if opts.max_wait_ms.is_none()
&& already_requested
&& (!immediate || immediate_already_requested)
{
return;
}
let _ = sender.send(LifecycleEvent::SaveRequested { immediate });
}
pub async fn request_save_and_wait(&self, opts: RequestSaveOpts) -> Result<()> {
let save_request_revision = self.request_save_with_revision(opts)?;
self.wait_for_save_request(save_request_revision).await;
Ok(())
}
pub async fn save_state(&self, deltas: Vec<StateDelta>) -> Result<()> {
let save_request_revision = self.save_request_revision();
self.save_state_with_revision(deltas, save_request_revision)
.await
}
pub(crate) fn request_save_with_revision(&self, opts: RequestSaveOpts) -> Result<u64> {
let immediate = opts.immediate;
let save_request_revision = self.0.save_request_revision.fetch_add(1, Ordering::SeqCst) + 1;
self.notify_request_save_hooks(opts);
let already_requested = self.0.save_requested.swap(true, Ordering::SeqCst);
let immediate_already_requested = if immediate {
self.0.save_requested_immediate.swap(true, Ordering::SeqCst)
} else {
self.0.save_requested_immediate.load(Ordering::SeqCst)
};
if let Some(max_wait_ms) = opts.max_wait_ms {
let deadline = StdInstant::now() + Duration::from_millis(u64::from(max_wait_ms));
let mut requested_deadline = self.0.save_requested_within_deadline.lock();
*requested_deadline = Some(match *requested_deadline {
Some(existing) => existing.min(deadline),
None => deadline,
});
}
let Some(sender) = self.lifecycle_event_sender() else {
return Err(ActorRuntime::NotConfigured {
component: "lifecycle events".to_owned(),
}
.build());
};
if opts.max_wait_ms.is_none()
&& already_requested
&& (!immediate || immediate_already_requested)
{
return Ok(save_request_revision);
}
sender
.send(LifecycleEvent::SaveRequested { immediate })
.map(|()| save_request_revision)
.map_err(|_| {
ActorRuntime::NotConfigured {
component: "lifecycle events".to_owned(),
}
.build()
})
}
pub(crate) async fn wait_for_save_request(&self, save_request_revision: u64) {
loop {
if self.0.save_completed_revision.load(Ordering::SeqCst) >= save_request_revision {
return;
}
self.0.save_completion.notified().await;
}
}
pub(crate) fn save_requested(&self) -> bool {
self.0.save_requested.load(Ordering::SeqCst)
}
pub(crate) fn save_requested_immediate(&self) -> bool {
self.0.save_requested_immediate.load(Ordering::SeqCst)
}
pub(crate) fn save_deadline(&self, immediate: bool) -> StdInstant {
self.compute_save_deadline(immediate)
}
pub(crate) fn compute_save_deadline(&self, immediate: bool) -> StdInstant {
if immediate || self.save_requested_immediate() {
return StdInstant::now();
}
let throttled_deadline = StdInstant::now() + self.compute_save_delay(None);
let requested_deadline = *self.0.save_requested_within_deadline.lock();
match requested_deadline {
Some(requested_deadline) => throttled_deadline.min(requested_deadline),
None => throttled_deadline,
}
}
pub(crate) fn save_request_revision(&self) -> u64 {
self.0.save_request_revision.load(Ordering::SeqCst)
}
pub(crate) async fn apply_state_deltas(
&self,
deltas: Vec<StateDelta>,
save_request_revision: u64,
) -> Result<()> {
let delta_count = deltas.len();
let delta_bytes: usize = deltas.iter().map(StateDelta::payload_len).sum();
let current_revision = self.0.state_revision.load(Ordering::SeqCst);
tracing::debug!(
delta_count,
delta_bytes,
state_revision = current_revision,
save_request_revision,
"applying actor state deltas"
);
self.clear_pending_save();
if deltas.is_empty() {
self.mark_save_request_completed(save_request_revision);
self.finish_save_request(save_request_revision);
tracing::debug!(
delta_count,
state_revision = current_revision,
save_request_revision,
"actor state deltas applied without kv write"
);
return Ok(());
}
let (puts, deletes, next_state, revision, _write_guard) = {
let _save_guard = self.0.save_guard.lock().await;
let revision = self.0.state_revision.load(Ordering::SeqCst);
let mut persisted = self.persisted();
let mut next_state = None;
let mut puts = Vec::new();
let mut deletes = Vec::new();
for delta in deltas {
match delta {
StateDelta::ActorState(bytes) => {
next_state = Some(bytes.clone());
persisted.state = bytes;
}
StateDelta::ConnHibernation { conn, bytes } => {
puts.push((make_connection_key(&conn), bytes));
}
StateDelta::ConnHibernationRemoved(conn) => {
deletes.push(make_connection_key(&conn));
}
}
}
if next_state.is_some() {
let encoded =
encode_persisted_actor(&persisted).context("encode persisted actor state")?;
puts.push((PERSIST_DATA_KEY.to_vec(), encoded));
*self.0.persisted.write() = persisted;
}
(puts, deletes, next_state, revision, self.begin_write())
};
let mut put_chunks = puts.chunks(APPLY_BATCH_CHUNK_SIZE);
let mut delete_chunks = deletes.chunks(APPLY_BATCH_CHUNK_SIZE);
loop {
let put_chunk = put_chunks.next().unwrap_or(&[]);
let delete_chunk = delete_chunks.next().unwrap_or(&[]);
if put_chunk.is_empty() && delete_chunk.is_empty() {
break;
}
self.0
.kv
.apply_batch(put_chunk, delete_chunk)
.await
.context("persist actor state deltas to kv")?;
}
if let Some(state) = next_state {
*self.0.current_state.write() = state;
}
*self.0.last_save_at.lock() = Some(StdInstant::now());
if self.0.state_revision.load(Ordering::SeqCst) == revision {
self.0.state_dirty.store(false, Ordering::SeqCst);
}
self.mark_save_request_completed(save_request_revision);
self.finish_save_request(save_request_revision);
tracing::debug!(
delta_count,
delta_bytes,
state_revision = self.0.state_revision.load(Ordering::SeqCst),
save_request_revision,
"actor state deltas applied"
);
Ok(())
}
pub(crate) async fn wait_for_pending_writes(&self) {
loop {
if let Some(handle) = self.take_tracked_persist() {
let _ = handle.await;
continue;
}
let save_guard = self.0.save_guard.lock().await;
if self.has_tracked_persist() {
drop(save_guard);
continue;
}
if self.0.in_flight_state_writes.load(Ordering::SeqCst) == 0 {
return;
}
drop(save_guard);
self.wait_for_in_flight_writes().await;
}
}
pub(crate) async fn wait_for_pending_state_writes(&self) {
self.wait_for_pending_writes().await;
}
pub fn begin_on_state_change(&self) -> OnStateChangeGuard {
OnStateChangeGuard::new(self.clone())
}
pub fn on_state_change_started(&self) {
self.0
.on_state_change_in_flight
.fetch_add(1, Ordering::SeqCst);
self.0.sleep.work.keep_awake.increment();
self.reset_sleep_timer();
}
pub fn on_state_change_finished(&self) {
let previous = self.0.on_state_change_in_flight.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|count| count.checked_sub(1),
);
match previous {
Ok(1) => {
self.0.sleep.work.keep_awake.decrement();
self.0.on_state_change_idle.notify_waiters();
self.reset_sleep_timer();
}
Ok(_) => {
self.0.sleep.work.keep_awake.decrement();
self.reset_sleep_timer();
}
Err(_) => {
tracing::warn!(
actor_id = %self.actor_id(),
"on_state_change finished without a matching start"
);
}
}
}
#[cfg(test)]
#[allow(dead_code)]
pub(crate) async fn wait_for_on_state_change_idle(&self, timeout_duration: Duration) -> bool {
if self.0.on_state_change_in_flight.load(Ordering::SeqCst) == 0 {
return true;
}
timeout(timeout_duration, async {
loop {
let idle = self.0.on_state_change_idle.notified();
tokio::pin!(idle);
idle.as_mut().enable();
if self.0.on_state_change_in_flight.load(Ordering::SeqCst) == 0 {
return;
}
idle.await;
}
})
.await
.is_ok()
}
pub fn persisted(&self) -> PersistedActor {
self.0.persisted.read().clone()
}
pub fn load_persisted(&self, persisted: PersistedActor) {
let state = persisted.state.clone();
*self.0.persisted.write() = persisted;
*self.0.current_state.write() = state;
self.0.state_dirty.store(false, Ordering::SeqCst);
self.finish_save_request(self.save_request_revision());
self.0
.metrics
.inc_state_mutation(StateMutationReason::InternalReplace);
}
pub(crate) fn load_last_pushed_alarm(&self, alarm_ts: Option<i64>) {
*self.0.last_pushed_alarm.write() = alarm_ts;
}
pub(crate) fn last_pushed_alarm(&self) -> Option<i64> {
*self.0.last_pushed_alarm.read()
}
pub(crate) async fn persist_last_pushed_alarm(&self, alarm_ts: Option<i64>) -> Result<()> {
let encoded = encode_last_pushed_alarm(alarm_ts).context("encode last pushed alarm")?;
self.0
.kv
.put(LAST_PUSHED_ALARM_KEY, &encoded)
.await
.context("persist last pushed alarm to kv")?;
self.load_last_pushed_alarm(alarm_ts);
Ok(())
}
pub(crate) fn set_initial_state(&self, state: Vec<u8>) {
*self.0.current_state.write() = state.clone();
self.0.persisted.write().state = state;
self.0.state_dirty.store(true, Ordering::SeqCst);
self.0.state_revision.fetch_add(1, Ordering::SeqCst);
}
pub fn scheduled_events(&self) -> Vec<PersistedScheduleEvent> {
self.0.persisted.read().scheduled_events.clone()
}
pub fn set_scheduled_events(&self, scheduled_events: Vec<PersistedScheduleEvent>) {
self.0.persisted.write().scheduled_events = scheduled_events;
self.0
.metrics
.inc_state_mutation(StateMutationReason::ScheduledEventsUpdate);
self.mark_dirty();
self.schedule_save(None);
}
pub(crate) fn update_scheduled_events<R>(
&self,
update: impl FnOnce(&mut Vec<PersistedScheduleEvent>) -> R,
) -> R {
let result = {
let mut persisted = self.0.persisted.write();
update(&mut persisted.scheduled_events)
};
self.0
.metrics
.inc_state_mutation(StateMutationReason::ScheduledEventsUpdate);
self.mark_dirty();
self.schedule_save(None);
result
}
pub fn set_input(&self, input: Option<Vec<u8>>) {
self.0.persisted.write().input = input;
self.0
.metrics
.inc_state_mutation(StateMutationReason::InputSet);
self.mark_dirty();
self.schedule_save(None);
}
pub fn input(&self) -> Option<Vec<u8>> {
self.0.persisted.read().input.clone()
}
pub fn set_has_initialized(&self, has_initialized: bool) {
{
let mut persisted = self.0.persisted.write();
if persisted.has_initialized == has_initialized {
return;
}
persisted.has_initialized = has_initialized;
}
self.0
.metrics
.inc_state_mutation(StateMutationReason::HasInitialized);
self.mark_dirty();
self.schedule_save(None);
}
pub fn has_initialized(&self) -> bool {
self.0.persisted.read().has_initialized
}
pub fn flush_on_shutdown(&self) {
self.persist_now_tracked("shutdown_flush");
}
pub fn on_request_save(&self, hook: Box<dyn Fn(RequestSaveOpts) + Send + Sync>) {
self.0.request_save_hooks.write().push(Arc::from(hook));
}
fn is_dirty(&self) -> bool {
self.0.state_dirty.load(Ordering::SeqCst)
}
fn mark_dirty(&self) {
self.0.state_dirty.store(true, Ordering::SeqCst);
self.0.state_revision.fetch_add(1, Ordering::SeqCst);
}
fn lifecycle_event_sender(&self) -> Option<mpsc::UnboundedSender<LifecycleEvent>> {
self.0.lifecycle_events.read().clone()
}
fn compute_save_delay(&self, max_wait: Option<Duration>) -> Duration {
let elapsed = self
.0
.last_save_at
.lock()
.map(|instant| instant.elapsed())
.unwrap_or_default();
throttled_save_delay(self.0.state_save_interval, elapsed, max_wait)
}
fn schedule_save(&self, max_wait: Option<Duration>) {
if !self.is_dirty() {
return;
}
let delay = self.compute_save_delay(max_wait);
let scheduled_at = StdInstant::now() + delay;
let mut pending_save = self.0.pending_save.lock();
if let Some(existing) = pending_save.as_ref() {
if existing.scheduled_at <= scheduled_at {
return;
}
existing.handle.abort();
}
let state = self.clone();
let task = async move {
if !delay.is_zero() {
sleep(delay).await;
}
state.take_pending_save();
if let Err(error) = state.persist_if_dirty().await {
tracing::error!(?error, "failed to persist actor state");
}
}
.in_current_span();
#[cfg(not(feature = "wasm-runtime"))]
let handle = {
let Ok(tokio_handle) = Handle::try_current() else {
return;
};
tokio_handle.spawn(task)
};
#[cfg(feature = "wasm-runtime")]
let handle = RuntimeSpawner::spawn(task);
*pending_save = Some(PendingSave {
scheduled_at,
handle,
});
}
fn clear_pending_save(&self) {
if let Some(pending_save) = self.take_pending_save() {
pending_save.handle.abort();
}
}
pub(crate) fn persist_now_tracked(&self, description: &'static str) {
self.clear_pending_save();
let state = self.clone();
let mut tracked_persist = self.0.tracked_persist.lock();
let previous = tracked_persist.take();
let task = async move {
if let Some(previous) = previous {
let _ = previous.await;
}
if let Err(error) = state.persist_state(SaveStateOpts { immediate: true }).await {
tracing::error!(?error, description, "failed to persist actor state");
}
}
.in_current_span();
#[cfg(not(feature = "wasm-runtime"))]
let handle = {
let Ok(tokio_handle) = Handle::try_current() else {
tracing::warn!(
description,
"skipping tracked actor state persistence without runtime"
);
return;
};
tokio_handle.spawn(task)
};
#[cfg(feature = "wasm-runtime")]
let handle = RuntimeSpawner::spawn(task);
*tracked_persist = Some(handle);
}
fn take_pending_save(&self) -> Option<PendingSave> {
self.0.pending_save.lock().take()
}
fn take_tracked_persist(&self) -> Option<JoinHandle<()>> {
self.0.tracked_persist.lock().take()
}
fn has_tracked_persist(&self) -> bool {
self.0.tracked_persist.lock().is_some()
}
#[cfg(test)]
pub(crate) fn tracked_persist_pending(&self) -> bool {
self.has_tracked_persist()
}
async fn persist_if_dirty(&self) -> Result<()> {
if !self.is_dirty() {
return Ok(());
}
let (revision, encoded, _write_guard) = {
let _save_guard = self.0.save_guard.lock().await;
if !self.is_dirty() {
return Ok(());
}
let revision = self.0.state_revision.load(Ordering::SeqCst);
let persisted = self.persisted();
let encoded =
encode_persisted_actor(&persisted).context("encode persisted actor state")?;
(revision, encoded, self.begin_write())
};
self.0
.kv
.put(PERSIST_DATA_KEY, &encoded)
.await
.context("persist actor state to kv")?;
*self.0.last_save_at.lock() = Some(StdInstant::now());
if self.0.state_revision.load(Ordering::SeqCst) == revision {
self.0.state_dirty.store(false, Ordering::SeqCst);
}
Ok(())
}
fn begin_write(&self) -> InFlightWrite {
self.0.in_flight_state_writes.fetch_add(1, Ordering::SeqCst);
InFlightWrite { ctx: self.clone() }
}
async fn wait_for_in_flight_writes(&self) {
loop {
if self.0.in_flight_state_writes.load(Ordering::SeqCst) == 0 {
return;
}
self.0.state_write_completion.notified().await;
}
}
fn finish_save_request(&self, save_request_revision: u64) {
if self.0.save_request_revision.load(Ordering::SeqCst) == save_request_revision {
self.0.save_requested.store(false, Ordering::SeqCst);
self.0
.save_requested_immediate
.store(false, Ordering::SeqCst);
*self.0.save_requested_within_deadline.lock() = None;
}
}
fn mark_save_request_completed(&self, save_request_revision: u64) {
self.0
.save_completed_revision
.fetch_max(save_request_revision, Ordering::SeqCst);
self.0.save_completion.notify_waiters();
}
fn notify_request_save_hooks(&self, opts: RequestSaveOpts) {
let hooks = self.0.request_save_hooks.read().clone();
for hook in hooks {
hook(opts);
}
}
}
struct InFlightWrite {
ctx: ActorContext,
}
impl Drop for InFlightWrite {
fn drop(&mut self) {
if self
.ctx
.0
.in_flight_state_writes
.fetch_sub(1, Ordering::SeqCst)
== 1
{
self.ctx.0.state_write_completion.notify_waiters();
self.ctx.0.state_write_completion.notify_one();
}
}
}
fn throttled_save_delay(
save_interval: Duration,
time_since_last_save: Duration,
max_wait: Option<Duration>,
) -> Duration {
let save_delay = save_interval.saturating_sub(time_since_last_save);
if let Some(max_wait) = max_wait {
save_delay.min(max_wait)
} else {
save_delay
}
}
#[cfg(test)]
#[path = "../../tests/state.rs"]
mod tests;