1use std::sync::Arc;
25
26use chrono::Utc;
27use dashmap::DashMap;
28use sea_orm::{sea_query::OnConflict, ActiveValue, DatabaseConnection, EntityTrait};
29
30use crate::entity::{ActiveModel, Column, Entity};
31use crate::error::ProjectionError;
32use crate::key::ProjectionKey;
33use crate::projection::Projection;
34
35pub struct ProjectionRuntime<P: Projection> {
38 pub(crate) db: DatabaseConnection,
39 pub(crate) broadcaster: Arc<ferro_broadcast::Broadcaster>,
40 pub(crate) projection: P,
41 pub(crate) locks: DashMap<String, Arc<tokio::sync::Mutex<()>>>,
42}
43
44impl<P: Projection> ProjectionRuntime<P> {
45 pub fn new(
49 db: DatabaseConnection,
50 broadcaster: Arc<ferro_broadcast::Broadcaster>,
51 projection: P,
52 ) -> Self {
53 Self {
54 db,
55 broadcaster,
56 projection,
57 locks: DashMap::new(),
58 }
59 }
60
61 pub async fn read(&self, key: &ProjectionKey) -> Result<Option<P::State>, ProjectionError> {
65 let row = Entity::find_by_id((P::NAME.to_string(), key.0.clone()))
66 .one(&self.db)
67 .await?;
68 match row {
69 None => Ok(None),
70 Some(model) => {
71 let state: P::State = serde_json::from_value(model.state)?;
72 Ok(Some(state))
73 }
74 }
75 }
76
77 pub async fn read_required(&self, key: &ProjectionKey) -> Result<P::State, ProjectionError> {
81 self.read(key)
82 .await?
83 .ok_or_else(|| ProjectionError::StateNotFound {
84 name: P::NAME,
85 key: key.0.clone(),
86 })
87 }
88
89 pub async fn apply_event(&self, event: &P::Event) -> Result<(), ProjectionError> {
93 let key = self.projection.key(event);
95
96 let lock_arc: Arc<tokio::sync::Mutex<()>> = {
101 self.locks
102 .entry(key.0.clone())
103 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
104 .clone()
105 }; let _guard = lock_arc.lock().await;
107
108 let existing = Entity::find_by_id((P::NAME.to_string(), key.0.clone()))
110 .one(&self.db)
111 .await?;
112
113 let (mut state, new_version) = match existing {
114 Some(model) => {
115 let s: P::State = serde_json::from_value(model.state)?;
116 (s, model.version + 1)
117 }
118 None => (P::State::default(), 1_i64),
119 };
120
121 let delta = self.projection.apply(&mut state, event);
123
124 let state_json = serde_json::to_value(&state)?;
126 let now = Utc::now().naive_utc();
127 let am = ActiveModel {
128 projection_name: ActiveValue::Set(P::NAME.to_string()),
129 key: ActiveValue::Set(key.0.clone()),
130 state: ActiveValue::Set(state_json),
131 version: ActiveValue::Set(new_version),
132 updated_at: ActiveValue::Set(now),
133 };
134
135 Entity::insert(am)
136 .on_conflict(
137 OnConflict::columns([Column::ProjectionName, Column::Key])
138 .update_columns([Column::State, Column::Version, Column::UpdatedAt])
139 .to_owned(),
140 )
141 .exec(&self.db)
142 .await?;
143
144 let channel_name = format!("projection.{}.{}", P::NAME, key.as_str());
146 let event_name = self.projection.broadcast_event_name();
147 let send_result = ferro_broadcast::Broadcast::new(self.broadcaster.clone())
148 .channel(channel_name.clone())
149 .event(event_name)
150 .data(delta)
151 .send()
152 .await;
153
154 if let Err(e) = send_result {
155 tracing::warn!(
156 error = %e,
157 channel = %channel_name,
158 "projection broadcast failed; snapshot persisted"
159 );
160 return Err(ProjectionError::from(e));
161 }
162
163 Ok(())
165 }
166
167 pub fn register(self: Arc<Self>) {
176 let listener = crate::listener::ProjectionListener {
177 runtime: self.clone(),
178 };
179 ferro_events::global_dispatcher().listen::<P::Event, _>(listener);
180 }
181
182 pub async fn rebuild<I>(
203 &self,
204 key: &ProjectionKey,
205 events: I,
206 ) -> Result<P::State, ProjectionError>
207 where
208 I: IntoIterator<Item = P::Event>,
209 {
210 let lock_arc: Arc<tokio::sync::Mutex<()>> = {
212 self.locks
213 .entry(key.0.clone())
214 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
215 .clone()
216 }; let _guard = lock_arc.lock().await;
218
219 Entity::delete_by_id((P::NAME.to_string(), key.0.clone()))
221 .exec(&self.db)
222 .await?;
223
224 let mut state = P::State::default();
226 let mut count: i64 = 0;
227 for event in events {
228 let _delta = self.projection.apply(&mut state, &event);
229 count += 1;
230 }
231
232 if count == 0 {
234 return Ok(state);
235 }
236
237 let state_json = serde_json::to_value(&state)?;
239 let now = Utc::now().naive_utc();
240 let am = ActiveModel {
241 projection_name: ActiveValue::Set(P::NAME.to_string()),
242 key: ActiveValue::Set(key.0.clone()),
243 state: ActiveValue::Set(state_json),
244 version: ActiveValue::Set(count),
245 updated_at: ActiveValue::Set(now),
246 };
247 Entity::insert(am).exec(&self.db).await?;
248
249 let channel_name = format!("projection.{}.{}", P::NAME, key.as_str());
251 let send_result = ferro_broadcast::Broadcast::new(self.broadcaster.clone())
252 .channel(channel_name.clone())
253 .event("rebuild")
254 .data(state.clone())
255 .send()
256 .await;
257
258 if let Err(e) = send_result {
259 tracing::warn!(
260 error = %e,
261 channel = %channel_name,
262 "projection rebuild broadcast failed; snapshot persisted"
263 );
264 return Err(ProjectionError::from(e));
265 }
266
267 Ok(state)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use sea_orm::Database;
275 use sea_orm_migration::MigratorTrait;
276 use serde::{Deserialize, Serialize};
277
278 #[derive(Clone, Serialize, Deserialize)]
280 struct CounterEvent {
281 delta: i32,
282 }
283
284 impl ferro_events::Event for CounterEvent {
285 fn name(&self) -> &'static str {
286 "CounterEvent"
287 }
288 }
289
290 #[derive(Clone, Default, Serialize, Deserialize, PartialEq, Debug)]
291 struct CounterState {
292 total: i64,
293 }
294
295 #[derive(Clone, Serialize, Debug, PartialEq)]
296 struct CounterDelta {
297 new_total: i64,
298 }
299
300 struct CounterProjection;
301
302 impl Projection for CounterProjection {
303 type Event = CounterEvent;
304 type State = CounterState;
305 type Delta = CounterDelta;
306 const NAME: &'static str = "test.counter";
307
308 fn key(&self, _e: &Self::Event) -> ProjectionKey {
309 ProjectionKey::new("default-key")
310 }
311
312 fn apply(&self, state: &mut Self::State, event: &Self::Event) -> Self::Delta {
313 state.total += event.delta as i64;
314 CounterDelta {
315 new_total: state.total,
316 }
317 }
318 }
319
320 struct KeyedCounterProjection;
322
323 impl Projection for KeyedCounterProjection {
324 type Event = CounterEvent;
325 type State = CounterState;
326 type Delta = CounterDelta;
327 const NAME: &'static str = "test.keyed_counter";
328
329 fn key(&self, event: &Self::Event) -> ProjectionKey {
330 ProjectionKey::new(format!("k-{}", event.delta))
331 }
332
333 fn apply(&self, state: &mut Self::State, event: &Self::Event) -> Self::Delta {
334 state.total += event.delta as i64;
335 CounterDelta {
336 new_total: state.total,
337 }
338 }
339 }
340
341 struct TestMigrator;
342
343 #[async_trait::async_trait]
344 impl MigratorTrait for TestMigrator {
345 fn migrations() -> Vec<Box<dyn sea_orm_migration::MigrationTrait>> {
346 vec![Box::new(crate::migration::Migration)]
347 }
348 }
349
350 async fn fresh_runtime<P: Projection>(projection: P) -> ProjectionRuntime<P> {
351 let conn = Database::connect("sqlite::memory:").await.expect("connect");
352 TestMigrator::up(&conn, None).await.expect("migrate");
353 let broadcaster = Arc::new(ferro_broadcast::Broadcaster::new());
354 ProjectionRuntime::new(conn, broadcaster, projection)
355 }
356
357 #[tokio::test]
359 async fn new_returns_owned_runtime_arc_is_send_sync() {
360 let rt = fresh_runtime(CounterProjection).await;
361 let arc: Arc<ProjectionRuntime<CounterProjection>> = Arc::new(rt);
362 fn assert_send_sync<T: Send + Sync>(_: &T) {}
363 assert_send_sync(&arc);
364 }
365
366 #[tokio::test]
368 async fn apply_event_initial_writes_version_1() {
369 let rt = fresh_runtime(CounterProjection).await;
370 rt.apply_event(&CounterEvent { delta: 5 })
371 .await
372 .expect("apply");
373
374 let key = ProjectionKey::new("default-key");
375 let state = rt.read(&key).await.expect("read").expect("state");
376 assert_eq!(state.total, 5);
377
378 let row = Entity::find_by_id((
380 CounterProjection::NAME.to_string(),
381 "default-key".to_string(),
382 ))
383 .one(&rt.db)
384 .await
385 .expect("query")
386 .expect("row");
387 assert_eq!(row.version, 1);
388 }
389
390 #[tokio::test]
392 async fn apply_event_second_call_folds_and_bumps_version() {
393 let rt = fresh_runtime(CounterProjection).await;
394 rt.apply_event(&CounterEvent { delta: 5 })
395 .await
396 .expect("first apply");
397 rt.apply_event(&CounterEvent { delta: 3 })
398 .await
399 .expect("second apply");
400
401 let key = ProjectionKey::new("default-key");
402 let state = rt.read(&key).await.expect("read").expect("state");
403 assert_eq!(state.total, 8);
404
405 let row = Entity::find_by_id((
406 CounterProjection::NAME.to_string(),
407 "default-key".to_string(),
408 ))
409 .one(&rt.db)
410 .await
411 .expect("query")
412 .expect("row");
413 assert_eq!(row.version, 2);
414 }
415
416 #[tokio::test]
418 async fn apply_event_new_key_initializes_from_default() {
419 let rt = fresh_runtime(KeyedCounterProjection).await;
420 rt.apply_event(&CounterEvent { delta: 7 })
421 .await
422 .expect("apply key 7");
423 rt.apply_event(&CounterEvent { delta: 9 })
424 .await
425 .expect("apply key 9");
426
427 let s7 = rt
428 .read(&ProjectionKey::new("k-7"))
429 .await
430 .expect("read 7")
431 .expect("state 7");
432 let s9 = rt
433 .read(&ProjectionKey::new("k-9"))
434 .await
435 .expect("read 9")
436 .expect("state 9");
437 assert_eq!(s7.total, 7);
438 assert_eq!(s9.total, 9);
439 }
440
441 #[tokio::test]
443 async fn read_returns_none_for_absent_key() {
444 let rt = fresh_runtime(CounterProjection).await;
445 let key = ProjectionKey::new("absent");
446 let r = rt.read(&key).await.expect("read");
447 assert!(r.is_none());
448 }
449
450 #[tokio::test]
451 async fn read_returns_some_after_apply() {
452 let rt = fresh_runtime(CounterProjection).await;
453 rt.apply_event(&CounterEvent { delta: 1 })
454 .await
455 .expect("apply");
456 let r = rt
457 .read(&ProjectionKey::new("default-key"))
458 .await
459 .expect("read");
460 assert!(r.is_some());
461 }
462
463 #[tokio::test]
465 async fn read_required_returns_state_not_found_for_absent() {
466 let rt = fresh_runtime(CounterProjection).await;
467 let key = ProjectionKey::new("absent");
468 let err = rt.read_required(&key).await.expect_err("should err");
469 match err {
470 ProjectionError::StateNotFound { name, key: k } => {
471 assert_eq!(name, CounterProjection::NAME);
472 assert_eq!(k, "absent");
473 }
474 other => panic!("expected StateNotFound, got {other:?}"),
475 }
476 }
477
478 #[tokio::test]
480 async fn version_increments_per_apply_same_key() {
481 let rt = fresh_runtime(CounterProjection).await;
482 for _ in 0..5 {
483 rt.apply_event(&CounterEvent { delta: 1 })
484 .await
485 .expect("apply");
486 }
487 let row = Entity::find_by_id((
488 CounterProjection::NAME.to_string(),
489 "default-key".to_string(),
490 ))
491 .one(&rt.db)
492 .await
493 .expect("query")
494 .expect("row");
495 assert_eq!(row.version, 5);
496 }
497
498 #[tokio::test]
500 async fn updated_at_advances_per_apply() {
501 let rt = fresh_runtime(CounterProjection).await;
502 rt.apply_event(&CounterEvent { delta: 1 })
503 .await
504 .expect("first");
505 let row1 = Entity::find_by_id((
506 CounterProjection::NAME.to_string(),
507 "default-key".to_string(),
508 ))
509 .one(&rt.db)
510 .await
511 .expect("query")
512 .expect("row");
513 let t1 = row1.updated_at;
514
515 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
519
520 rt.apply_event(&CounterEvent { delta: 1 })
521 .await
522 .expect("second");
523 let row2 = Entity::find_by_id((
524 CounterProjection::NAME.to_string(),
525 "default-key".to_string(),
526 ))
527 .one(&rt.db)
528 .await
529 .expect("query")
530 .expect("row");
531 assert!(row2.updated_at > t1, "updated_at must advance");
532 }
533
534 #[tokio::test]
538 async fn cross_key_apply_does_not_share_lock() {
539 let rt = fresh_runtime(KeyedCounterProjection).await;
540 rt.apply_event(&CounterEvent { delta: 1 })
541 .await
542 .expect("k-1");
543 rt.apply_event(&CounterEvent { delta: 2 })
544 .await
545 .expect("k-2");
546 rt.apply_event(&CounterEvent { delta: 3 })
547 .await
548 .expect("k-3");
549
550 assert_eq!(rt.locks.len(), 3);
552 }
553
554 #[tokio::test]
556 async fn rebuild_three_events_equals_three_sequential_applies() {
557 let rt_a = fresh_runtime(CounterProjection).await;
559 for d in [3, 5, 7] {
560 rt_a.apply_event(&CounterEvent { delta: d })
561 .await
562 .expect("apply");
563 }
564 let state_a = rt_a
565 .read(&ProjectionKey::new("default-key"))
566 .await
567 .expect("read a")
568 .expect("state a");
569
570 let rt_b = fresh_runtime(CounterProjection).await;
572 let events: Vec<CounterEvent> = vec![
573 CounterEvent { delta: 3 },
574 CounterEvent { delta: 5 },
575 CounterEvent { delta: 7 },
576 ];
577 let state_b = rt_b
578 .rebuild(&ProjectionKey::new("default-key"), events)
579 .await
580 .expect("rebuild");
581
582 assert_eq!(state_a, state_b);
583 assert_eq!(state_a.total, 15);
584 }
585
586 #[tokio::test]
588 async fn rebuild_empty_deletes_row_and_returns_default() {
589 let rt = fresh_runtime(CounterProjection).await;
590 rt.apply_event(&CounterEvent { delta: 7 })
592 .await
593 .expect("seed");
594 let pre = rt
595 .read(&ProjectionKey::new("default-key"))
596 .await
597 .expect("read pre")
598 .expect("state pre");
599 assert_eq!(pre.total, 7);
600
601 let after = rt
603 .rebuild(
604 &ProjectionKey::new("default-key"),
605 Vec::<CounterEvent>::new(),
606 )
607 .await
608 .expect("rebuild empty");
609 assert_eq!(after.total, 0);
610
611 let post = rt
613 .read(&ProjectionKey::new("default-key"))
614 .await
615 .expect("read post");
616 assert!(post.is_none(), "row should be wiped");
617 }
618}