use std::sync::Arc;
use chrono::Utc;
use dashmap::DashMap;
use sea_orm::{sea_query::OnConflict, ActiveValue, DatabaseConnection, EntityTrait};
use crate::entity::{ActiveModel, Column, Entity};
use crate::error::ProjectionError;
use crate::key::ProjectionKey;
use crate::projection::Projection;
pub struct ProjectionRuntime<P: Projection> {
pub(crate) db: DatabaseConnection,
pub(crate) broadcaster: Arc<ferro_broadcast::Broadcaster>,
pub(crate) projection: P,
pub(crate) locks: DashMap<String, Arc<tokio::sync::Mutex<()>>>,
}
impl<P: Projection> ProjectionRuntime<P> {
pub fn new(
db: DatabaseConnection,
broadcaster: Arc<ferro_broadcast::Broadcaster>,
projection: P,
) -> Self {
Self {
db,
broadcaster,
projection,
locks: DashMap::new(),
}
}
pub async fn read(&self, key: &ProjectionKey) -> Result<Option<P::State>, ProjectionError> {
let row = Entity::find_by_id((P::NAME.to_string(), key.0.clone()))
.one(&self.db)
.await?;
match row {
None => Ok(None),
Some(model) => {
let state: P::State = serde_json::from_value(model.state)?;
Ok(Some(state))
}
}
}
pub async fn read_required(&self, key: &ProjectionKey) -> Result<P::State, ProjectionError> {
self.read(key)
.await?
.ok_or_else(|| ProjectionError::StateNotFound {
name: P::NAME,
key: key.0.clone(),
})
}
pub async fn apply_event(&self, event: &P::Event) -> Result<(), ProjectionError> {
let key = self.projection.key(event);
let lock_arc: Arc<tokio::sync::Mutex<()>> = {
self.locks
.entry(key.0.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}; let _guard = lock_arc.lock().await;
let existing = Entity::find_by_id((P::NAME.to_string(), key.0.clone()))
.one(&self.db)
.await?;
let (mut state, new_version) = match existing {
Some(model) => {
let s: P::State = serde_json::from_value(model.state)?;
(s, model.version + 1)
}
None => (P::State::default(), 1_i64),
};
let delta = self.projection.apply(&mut state, event);
let state_json = serde_json::to_value(&state)?;
let now = Utc::now().naive_utc();
let am = ActiveModel {
projection_name: ActiveValue::Set(P::NAME.to_string()),
key: ActiveValue::Set(key.0.clone()),
state: ActiveValue::Set(state_json),
version: ActiveValue::Set(new_version),
updated_at: ActiveValue::Set(now),
};
Entity::insert(am)
.on_conflict(
OnConflict::columns([Column::ProjectionName, Column::Key])
.update_columns([Column::State, Column::Version, Column::UpdatedAt])
.to_owned(),
)
.exec(&self.db)
.await?;
let channel_name = format!("projection.{}.{}", P::NAME, key.as_str());
let event_name = self.projection.broadcast_event_name();
let send_result = ferro_broadcast::Broadcast::new(self.broadcaster.clone())
.channel(channel_name.clone())
.event(event_name)
.data(delta)
.send()
.await;
if let Err(e) = send_result {
tracing::warn!(
error = %e,
channel = %channel_name,
"projection broadcast failed; snapshot persisted"
);
return Err(ProjectionError::from(e));
}
Ok(())
}
pub fn register(self: Arc<Self>) {
let listener = crate::listener::ProjectionListener {
runtime: self.clone(),
};
ferro_events::global_dispatcher().listen::<P::Event, _>(listener);
}
pub async fn rebuild<I>(
&self,
key: &ProjectionKey,
events: I,
) -> Result<P::State, ProjectionError>
where
I: IntoIterator<Item = P::Event>,
{
let lock_arc: Arc<tokio::sync::Mutex<()>> = {
self.locks
.entry(key.0.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}; let _guard = lock_arc.lock().await;
Entity::delete_by_id((P::NAME.to_string(), key.0.clone()))
.exec(&self.db)
.await?;
let mut state = P::State::default();
let mut count: i64 = 0;
for event in events {
let _delta = self.projection.apply(&mut state, &event);
count += 1;
}
if count == 0 {
return Ok(state);
}
let state_json = serde_json::to_value(&state)?;
let now = Utc::now().naive_utc();
let am = ActiveModel {
projection_name: ActiveValue::Set(P::NAME.to_string()),
key: ActiveValue::Set(key.0.clone()),
state: ActiveValue::Set(state_json),
version: ActiveValue::Set(count),
updated_at: ActiveValue::Set(now),
};
Entity::insert(am).exec(&self.db).await?;
let channel_name = format!("projection.{}.{}", P::NAME, key.as_str());
let send_result = ferro_broadcast::Broadcast::new(self.broadcaster.clone())
.channel(channel_name.clone())
.event("rebuild")
.data(state.clone())
.send()
.await;
if let Err(e) = send_result {
tracing::warn!(
error = %e,
channel = %channel_name,
"projection rebuild broadcast failed; snapshot persisted"
);
return Err(ProjectionError::from(e));
}
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sea_orm::Database;
use sea_orm_migration::MigratorTrait;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
struct CounterEvent {
delta: i32,
}
impl ferro_events::Event for CounterEvent {
fn name(&self) -> &'static str {
"CounterEvent"
}
}
#[derive(Clone, Default, Serialize, Deserialize, PartialEq, Debug)]
struct CounterState {
total: i64,
}
#[derive(Clone, Serialize, Debug, PartialEq)]
struct CounterDelta {
new_total: i64,
}
struct CounterProjection;
impl Projection for CounterProjection {
type Event = CounterEvent;
type State = CounterState;
type Delta = CounterDelta;
const NAME: &'static str = "test.counter";
fn key(&self, _e: &Self::Event) -> ProjectionKey {
ProjectionKey::new("default-key")
}
fn apply(&self, state: &mut Self::State, event: &Self::Event) -> Self::Delta {
state.total += event.delta as i64;
CounterDelta {
new_total: state.total,
}
}
}
struct KeyedCounterProjection;
impl Projection for KeyedCounterProjection {
type Event = CounterEvent;
type State = CounterState;
type Delta = CounterDelta;
const NAME: &'static str = "test.keyed_counter";
fn key(&self, event: &Self::Event) -> ProjectionKey {
ProjectionKey::new(format!("k-{}", event.delta))
}
fn apply(&self, state: &mut Self::State, event: &Self::Event) -> Self::Delta {
state.total += event.delta as i64;
CounterDelta {
new_total: state.total,
}
}
}
struct TestMigrator;
#[async_trait::async_trait]
impl MigratorTrait for TestMigrator {
fn migrations() -> Vec<Box<dyn sea_orm_migration::MigrationTrait>> {
vec![Box::new(crate::migration::Migration)]
}
}
async fn fresh_runtime<P: Projection>(projection: P) -> ProjectionRuntime<P> {
let conn = Database::connect("sqlite::memory:").await.expect("connect");
TestMigrator::up(&conn, None).await.expect("migrate");
let broadcaster = Arc::new(ferro_broadcast::Broadcaster::new());
ProjectionRuntime::new(conn, broadcaster, projection)
}
#[tokio::test]
async fn new_returns_owned_runtime_arc_is_send_sync() {
let rt = fresh_runtime(CounterProjection).await;
let arc: Arc<ProjectionRuntime<CounterProjection>> = Arc::new(rt);
fn assert_send_sync<T: Send + Sync>(_: &T) {}
assert_send_sync(&arc);
}
#[tokio::test]
async fn apply_event_initial_writes_version_1() {
let rt = fresh_runtime(CounterProjection).await;
rt.apply_event(&CounterEvent { delta: 5 })
.await
.expect("apply");
let key = ProjectionKey::new("default-key");
let state = rt.read(&key).await.expect("read").expect("state");
assert_eq!(state.total, 5);
let row = Entity::find_by_id((
CounterProjection::NAME.to_string(),
"default-key".to_string(),
))
.one(&rt.db)
.await
.expect("query")
.expect("row");
assert_eq!(row.version, 1);
}
#[tokio::test]
async fn apply_event_second_call_folds_and_bumps_version() {
let rt = fresh_runtime(CounterProjection).await;
rt.apply_event(&CounterEvent { delta: 5 })
.await
.expect("first apply");
rt.apply_event(&CounterEvent { delta: 3 })
.await
.expect("second apply");
let key = ProjectionKey::new("default-key");
let state = rt.read(&key).await.expect("read").expect("state");
assert_eq!(state.total, 8);
let row = Entity::find_by_id((
CounterProjection::NAME.to_string(),
"default-key".to_string(),
))
.one(&rt.db)
.await
.expect("query")
.expect("row");
assert_eq!(row.version, 2);
}
#[tokio::test]
async fn apply_event_new_key_initializes_from_default() {
let rt = fresh_runtime(KeyedCounterProjection).await;
rt.apply_event(&CounterEvent { delta: 7 })
.await
.expect("apply key 7");
rt.apply_event(&CounterEvent { delta: 9 })
.await
.expect("apply key 9");
let s7 = rt
.read(&ProjectionKey::new("k-7"))
.await
.expect("read 7")
.expect("state 7");
let s9 = rt
.read(&ProjectionKey::new("k-9"))
.await
.expect("read 9")
.expect("state 9");
assert_eq!(s7.total, 7);
assert_eq!(s9.total, 9);
}
#[tokio::test]
async fn read_returns_none_for_absent_key() {
let rt = fresh_runtime(CounterProjection).await;
let key = ProjectionKey::new("absent");
let r = rt.read(&key).await.expect("read");
assert!(r.is_none());
}
#[tokio::test]
async fn read_returns_some_after_apply() {
let rt = fresh_runtime(CounterProjection).await;
rt.apply_event(&CounterEvent { delta: 1 })
.await
.expect("apply");
let r = rt
.read(&ProjectionKey::new("default-key"))
.await
.expect("read");
assert!(r.is_some());
}
#[tokio::test]
async fn read_required_returns_state_not_found_for_absent() {
let rt = fresh_runtime(CounterProjection).await;
let key = ProjectionKey::new("absent");
let err = rt.read_required(&key).await.expect_err("should err");
match err {
ProjectionError::StateNotFound { name, key: k } => {
assert_eq!(name, CounterProjection::NAME);
assert_eq!(k, "absent");
}
other => panic!("expected StateNotFound, got {other:?}"),
}
}
#[tokio::test]
async fn version_increments_per_apply_same_key() {
let rt = fresh_runtime(CounterProjection).await;
for _ in 0..5 {
rt.apply_event(&CounterEvent { delta: 1 })
.await
.expect("apply");
}
let row = Entity::find_by_id((
CounterProjection::NAME.to_string(),
"default-key".to_string(),
))
.one(&rt.db)
.await
.expect("query")
.expect("row");
assert_eq!(row.version, 5);
}
#[tokio::test]
async fn updated_at_advances_per_apply() {
let rt = fresh_runtime(CounterProjection).await;
rt.apply_event(&CounterEvent { delta: 1 })
.await
.expect("first");
let row1 = Entity::find_by_id((
CounterProjection::NAME.to_string(),
"default-key".to_string(),
))
.one(&rt.db)
.await
.expect("query")
.expect("row");
let t1 = row1.updated_at;
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
rt.apply_event(&CounterEvent { delta: 1 })
.await
.expect("second");
let row2 = Entity::find_by_id((
CounterProjection::NAME.to_string(),
"default-key".to_string(),
))
.one(&rt.db)
.await
.expect("query")
.expect("row");
assert!(row2.updated_at > t1, "updated_at must advance");
}
#[tokio::test]
async fn cross_key_apply_does_not_share_lock() {
let rt = fresh_runtime(KeyedCounterProjection).await;
rt.apply_event(&CounterEvent { delta: 1 })
.await
.expect("k-1");
rt.apply_event(&CounterEvent { delta: 2 })
.await
.expect("k-2");
rt.apply_event(&CounterEvent { delta: 3 })
.await
.expect("k-3");
assert_eq!(rt.locks.len(), 3);
}
#[tokio::test]
async fn rebuild_three_events_equals_three_sequential_applies() {
let rt_a = fresh_runtime(CounterProjection).await;
for d in [3, 5, 7] {
rt_a.apply_event(&CounterEvent { delta: d })
.await
.expect("apply");
}
let state_a = rt_a
.read(&ProjectionKey::new("default-key"))
.await
.expect("read a")
.expect("state a");
let rt_b = fresh_runtime(CounterProjection).await;
let events: Vec<CounterEvent> = vec![
CounterEvent { delta: 3 },
CounterEvent { delta: 5 },
CounterEvent { delta: 7 },
];
let state_b = rt_b
.rebuild(&ProjectionKey::new("default-key"), events)
.await
.expect("rebuild");
assert_eq!(state_a, state_b);
assert_eq!(state_a.total, 15);
}
#[tokio::test]
async fn rebuild_empty_deletes_row_and_returns_default() {
let rt = fresh_runtime(CounterProjection).await;
rt.apply_event(&CounterEvent { delta: 7 })
.await
.expect("seed");
let pre = rt
.read(&ProjectionKey::new("default-key"))
.await
.expect("read pre")
.expect("state pre");
assert_eq!(pre.total, 7);
let after = rt
.rebuild(
&ProjectionKey::new("default-key"),
Vec::<CounterEvent>::new(),
)
.await
.expect("rebuild empty");
assert_eq!(after.total, 0);
let post = rt
.read(&ProjectionKey::new("default-key"))
.await
.expect("read post");
assert!(post.is_none(), "row should be wiped");
}
}