1use std::ops::{Deref, DerefMut};
4
5#[cfg(feature = "mysql")]
6use sea_query::MysqlQueryBuilder;
7#[cfg(feature = "postgres")]
8use sea_query::PostgresQueryBuilder;
9#[cfg(feature = "sqlite")]
10use sea_query::SqliteQueryBuilder;
11use sea_query::{Cond, Expr, ExprTrait, Iden, IntoColumnRef, OnConflict, Query, SelectStatement};
12use sea_query_sqlx::SqlxBinder;
13use sqlx::{Database, Pool};
14use ulid::Ulid;
15
16use evento_core::{
17 cursor::{self, Args, Cursor, Edge, PageInfo, ReadResult, Value},
18 Executor, ReadAggregator, WriteError,
19};
20
21#[derive(Iden, Clone)]
38pub enum Event {
39 Table,
41 Id,
43 Name,
45 AggregatorType,
47 AggregatorId,
49 Version,
51 Data,
53 Metadata,
55 RoutingKey,
57 Timestamp,
59 TimestampSubsec,
61}
62
63#[derive(Iden)]
69pub enum Snapshot {
70 Table,
72 Id,
74 Type,
76 Cursor,
78 Revision,
80 Data,
82 CreatedAt,
84 UpdatedAt,
86}
87
88#[derive(Iden)]
101pub enum Subscriber {
102 Table,
104 Key,
106 WorkerId,
108 Cursor,
110 Lag,
112 Enabled,
114 CreatedAt,
116 UpdatedAt,
118}
119
120#[cfg(feature = "mysql")]
124pub type MySql = Sql<sqlx::MySql>;
125
126#[cfg(feature = "mysql")]
130pub type RwMySql = evento_core::Rw<MySql, MySql>;
131
132#[cfg(feature = "postgres")]
136pub type Postgres = Sql<sqlx::Postgres>;
137
138#[cfg(feature = "postgres")]
142pub type RwPostgres = evento_core::Rw<Postgres, Postgres>;
143
144#[cfg(feature = "sqlite")]
148pub type Sqlite = Sql<sqlx::Sqlite>;
149
150#[cfg(feature = "sqlite")]
154pub type RwSqlite = evento_core::Rw<Sqlite, Sqlite>;
155
156pub struct Sql<DB: Database>(Pool<DB>);
194
195impl<DB: Database> Sql<DB> {
196 fn build_sqlx<S: SqlxBinder>(statement: S) -> (String, sea_query_sqlx::SqlxValues) {
197 match DB::NAME {
198 #[cfg(feature = "sqlite")]
199 "SQLite" => statement.build_sqlx(SqliteQueryBuilder),
200 #[cfg(feature = "mysql")]
201 "MySQL" => statement.build_sqlx(MysqlQueryBuilder),
202 #[cfg(feature = "postgres")]
203 "PostgreSQL" => statement.build_sqlx(PostgresQueryBuilder),
204 name => panic!("'{name}' not supported, consider using SQLite, PostgreSQL or MySQL"),
205 }
206 }
207}
208
209#[async_trait::async_trait]
210impl<DB> Executor for Sql<DB>
211where
212 DB: Database,
213 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
214 sea_query_sqlx::SqlxValues: for<'q> sqlx::IntoArguments<'q, DB>,
215 String: for<'r> sqlx::Decode<'r, DB> + sqlx::Type<DB>,
216 bool: for<'r> sqlx::Decode<'r, DB> + sqlx::Type<DB>,
217 Vec<u8>: for<'r> sqlx::Decode<'r, DB> + sqlx::Type<DB>,
218 usize: sqlx::ColumnIndex<DB::Row>,
219 SqlEvent: for<'r> sqlx::FromRow<'r, DB::Row>,
220{
221 async fn read(
222 &self,
223 aggregators: Option<Vec<ReadAggregator>>,
224 routing_key: Option<evento_core::RoutingKey>,
225 args: Args,
226 ) -> anyhow::Result<ReadResult<evento_core::Event>> {
227 let statement = Query::select()
228 .columns([
229 Event::Id,
230 Event::Name,
231 Event::AggregatorType,
232 Event::AggregatorId,
233 Event::Version,
234 Event::Data,
235 Event::Metadata,
236 Event::RoutingKey,
237 Event::Timestamp,
238 Event::TimestampSubsec,
239 ])
240 .from(Event::Table)
241 .conditions(
242 aggregators.is_some(),
243 |q| {
244 let Some(aggregators) = aggregators else {
245 return;
246 };
247
248 let mut cond = Cond::any();
249
250 for aggregator in aggregators {
251 let mut aggregator_cond = Cond::all()
252 .add(Expr::col(Event::AggregatorType).eq(aggregator.aggregator_type));
253
254 if let Some(id) = aggregator.aggregator_id {
255 aggregator_cond =
256 aggregator_cond.add(Expr::col(Event::AggregatorId).eq(id));
257 }
258
259 if let Some(name) = aggregator.name {
260 aggregator_cond = aggregator_cond.add(Expr::col(Event::Name).eq(name));
261 }
262
263 cond = cond.add(aggregator_cond);
264 }
265
266 q.and_where(cond.into());
267 },
268 |_| {},
269 )
270 .conditions(
271 matches!(routing_key, Some(evento_core::RoutingKey::Value(_))),
272 |q| {
273 if let Some(evento_core::RoutingKey::Value(Some(ref routing_key))) = routing_key
274 {
275 q.and_where(Expr::col(Event::RoutingKey).eq(routing_key));
276 }
277
278 if let Some(evento_core::RoutingKey::Value(None)) = routing_key {
279 q.and_where(Expr::col(Event::RoutingKey).is_null());
280 }
281 },
282 |_q| {},
283 )
284 .to_owned();
285
286 Ok(Reader::new(statement)
287 .args(args)
288 .execute::<_, SqlEvent, _>(&self.0)
289 .await?
290 .map(|e| e.0))
291 }
292
293 async fn get_subscriber_cursor(&self, key: String) -> anyhow::Result<Option<Value>> {
294 let statement = Query::select()
295 .columns([Subscriber::Cursor])
296 .from(Subscriber::Table)
297 .and_where(Expr::col(Subscriber::Key).eq(Expr::value(key)))
298 .limit(1)
299 .to_owned();
300
301 let (sql, values) = Self::build_sqlx(statement);
302
303 let Some((cursor,)) = sqlx::query_as_with::<DB, (Option<String>,), _>(&sql, values)
304 .fetch_optional(&self.0)
305 .await?
306 else {
307 return Ok(None);
308 };
309
310 Ok(cursor.map(|c| c.into()))
311 }
312
313 async fn is_subscriber_running(&self, key: String, worker_id: Ulid) -> anyhow::Result<bool> {
314 let statement = Query::select()
315 .columns([Subscriber::WorkerId, Subscriber::Enabled])
316 .from(Subscriber::Table)
317 .and_where(Expr::col(Subscriber::Key).eq(Expr::value(key)))
318 .limit(1)
319 .to_owned();
320
321 let (sql, values) = Self::build_sqlx(statement);
322
323 let (id, enabled) = sqlx::query_as_with::<DB, (String, bool), _>(&sql, values)
324 .fetch_one(&self.0)
325 .await?;
326
327 Ok(worker_id.to_string() == id && enabled)
328 }
329
330 async fn upsert_subscriber(&self, key: String, worker_id: Ulid) -> anyhow::Result<()> {
331 let statement = Query::insert()
332 .into_table(Subscriber::Table)
333 .columns([Subscriber::Key, Subscriber::WorkerId, Subscriber::Lag])
334 .values_panic([key.into(), worker_id.to_string().into(), 0.into()])
335 .on_conflict(
336 OnConflict::column(Subscriber::Key)
337 .update_columns([Subscriber::WorkerId])
338 .value(Subscriber::UpdatedAt, Expr::current_timestamp())
339 .to_owned(),
340 )
341 .to_owned();
342
343 let (sql, values) = Self::build_sqlx(statement);
344
345 sqlx::query_with::<DB, _>(&sql, values)
346 .execute(&self.0)
347 .await?;
348
349 Ok(())
350 }
351
352 async fn write(&self, events: Vec<evento_core::Event>) -> Result<(), WriteError> {
353 let mut statement = Query::insert()
354 .into_table(Event::Table)
355 .columns([
356 Event::Id,
357 Event::Name,
358 Event::Data,
359 Event::Metadata,
360 Event::AggregatorType,
361 Event::AggregatorId,
362 Event::Version,
363 Event::RoutingKey,
364 Event::Timestamp,
365 Event::TimestampSubsec,
366 ])
367 .to_owned();
368
369 for event in events {
370 let metadata = bitcode::encode(&event.metadata);
371 statement.values_panic([
372 event.id.to_string().into(),
373 event.name.into(),
374 event.data.into(),
375 metadata.into(),
376 event.aggregator_type.into(),
377 event.aggregator_id.into(),
378 event.version.into(),
379 event.routing_key.into(),
380 event.timestamp.into(),
381 event.timestamp_subsec.into(),
382 ]);
383 }
384
385 let (sql, values) = Self::build_sqlx(statement);
386
387 sqlx::query_with::<DB, _>(&sql, values)
388 .execute(&self.0)
389 .await
390 .map_err(|err| {
391 let err_str = err.to_string();
392 if err_str.contains("(code: 2067)") {
393 return WriteError::InvalidOriginalVersion;
394 }
395 if err_str.contains("1062 (23000): Duplicate entry") {
396 return WriteError::InvalidOriginalVersion;
397 }
398 if err_str.contains("duplicate key value violates unique constraint") {
399 return WriteError::InvalidOriginalVersion;
400 }
401 WriteError::Unknown(err.into())
402 })?;
403
404 Ok(())
405 }
406
407 async fn acknowledge(&self, key: String, cursor: Value, lag: u64) -> anyhow::Result<()> {
408 let statement = Query::update()
409 .table(Subscriber::Table)
410 .values([
411 (Subscriber::Cursor, cursor.0.into()),
412 (Subscriber::Lag, lag.into()),
413 (Subscriber::UpdatedAt, Expr::current_timestamp()),
414 ])
415 .and_where(Expr::col(Subscriber::Key).eq(key))
416 .to_owned();
417
418 let (sql, values) = Self::build_sqlx(statement);
419
420 sqlx::query_with::<DB, _>(&sql, values)
421 .execute(&self.0)
422 .await?;
423
424 Ok(())
425 }
426
427 async fn get_snapshot(
428 &self,
429 aggregator_type: String,
430 aggregator_revision: String,
431 id: String,
432 ) -> anyhow::Result<Option<(Vec<u8>, Value)>> {
433 let statement = Query::select()
434 .columns([Snapshot::Data, Snapshot::Cursor])
435 .from(Snapshot::Table)
436 .and_where(Expr::col(Snapshot::Type).eq(Expr::value(aggregator_type)))
437 .and_where(Expr::col(Snapshot::Id).eq(Expr::value(id)))
438 .and_where(Expr::col(Snapshot::Revision).eq(Expr::value(aggregator_revision)))
439 .limit(1)
440 .to_owned();
441
442 let (sql, values) = Self::build_sqlx(statement);
443
444 Ok(
445 sqlx::query_as_with::<DB, (Vec<u8>, String), _>(&sql, values)
446 .fetch_optional(&self.0)
447 .await
448 .map(|res| res.map(|(data, cursor)| (data, cursor.into())))?,
449 )
450 }
451
452 async fn save_snapshot(
453 &self,
454 aggregator_type: String,
455 aggregator_revision: String,
456 id: String,
457 data: Vec<u8>,
458 cursor: Value,
459 ) -> anyhow::Result<()> {
460 let statement = Query::insert()
461 .into_table(Snapshot::Table)
462 .columns([
463 Snapshot::Type,
464 Snapshot::Id,
465 Snapshot::Cursor,
466 Snapshot::Revision,
467 Snapshot::Data,
468 ])
469 .values_panic([
470 aggregator_type.into(),
471 id.to_string().into(),
472 cursor.to_string().into(),
473 aggregator_revision.into(),
474 data.into(),
475 ])
476 .on_conflict(
477 OnConflict::columns([Snapshot::Type, Snapshot::Id])
478 .update_columns([Snapshot::Data, Snapshot::Cursor, Snapshot::Revision])
479 .value(Snapshot::UpdatedAt, Expr::current_timestamp())
480 .to_owned(),
481 )
482 .to_owned();
483
484 let (sql, values) = Self::build_sqlx(statement);
485
486 sqlx::query_with::<DB, _>(&sql, values)
487 .execute(&self.0)
488 .await?;
489
490 Ok(())
491 }
492}
493
494impl<D: Database> Clone for Sql<D> {
495 fn clone(&self) -> Self {
496 Self(self.0.clone())
497 }
498}
499
500impl<D: Database> From<Pool<D>> for Sql<D> {
501 fn from(value: Pool<D>) -> Self {
502 Self(value)
503 }
504}
505
506pub struct Reader {
547 statement: SelectStatement,
548 args: Args,
549 order: cursor::Order,
550}
551
552impl Reader {
553 pub fn new(statement: SelectStatement) -> Self {
555 Self {
556 statement,
557 args: Args::default(),
558 order: cursor::Order::Asc,
559 }
560 }
561
562 pub fn order(&mut self, order: cursor::Order) -> &mut Self {
564 self.order = order;
565
566 self
567 }
568
569 pub fn desc(&mut self) -> &mut Self {
571 self.order(cursor::Order::Desc)
572 }
573
574 pub fn args(&mut self, args: Args) -> &mut Self {
576 self.args = args;
577
578 self
579 }
580
581 pub fn backward(&mut self, last: u16, before: Option<Value>) -> &mut Self {
588 self.args(Args {
589 last: Some(last),
590 before,
591 ..Default::default()
592 })
593 }
594
595 pub fn forward(&mut self, first: u16, after: Option<Value>) -> &mut Self {
602 self.args(Args {
603 first: Some(first),
604 after,
605 ..Default::default()
606 })
607 }
608
609 pub async fn execute<'e, 'c: 'e, DB, O, E>(
622 &mut self,
623 executor: E,
624 ) -> anyhow::Result<ReadResult<O>>
625 where
626 DB: Database,
627 E: 'e + sqlx::Executor<'c, Database = DB>,
628 O: for<'r> sqlx::FromRow<'r, DB::Row>,
629 O: Cursor,
630 O: Send + Unpin,
631 O: Bind<Cursor = O>,
632 <<O as Bind>::I as IntoIterator>::IntoIter: DoubleEndedIterator,
633 <<O as Bind>::V as IntoIterator>::IntoIter: DoubleEndedIterator,
634 sea_query_sqlx::SqlxValues: for<'q> sqlx::IntoArguments<'q, DB>,
635 {
636 let limit = self.build_reader::<O, O>()?;
637
638 let (sql, values) = match DB::NAME {
639 #[cfg(feature = "sqlite")]
640 "SQLite" => self.statement.build_sqlx(SqliteQueryBuilder),
641 #[cfg(feature = "mysql")]
642 "MySQL" => self.build_sqlx(MysqlQueryBuilder),
643 #[cfg(feature = "postgres")]
644 "PostgreSQL" => self.build_sqlx(PostgresQueryBuilder),
645 name => panic!("'{name}' not supported, consider using SQLite, PostgreSQL or MySQL"),
646 };
647
648 let mut rows = sqlx::query_as_with::<DB, O, _>(&sql, values)
649 .fetch_all(executor)
650 .await?;
651
652 let has_more = rows.len() > limit as usize;
653 if has_more {
654 rows.pop();
655 }
656
657 let mut edges = vec![];
658 for node in rows.into_iter() {
659 edges.push(Edge {
660 cursor: node.serialize_cursor()?,
661 node,
662 });
663 }
664
665 if self.args.is_backward() {
666 edges = edges.into_iter().rev().collect();
667 }
668
669 let page_info = if self.args.is_backward() {
670 let start_cursor = edges.first().map(|e| e.cursor.clone());
671
672 PageInfo {
673 has_previous_page: has_more,
674 has_next_page: false,
675 start_cursor,
676 end_cursor: None,
677 }
678 } else {
679 let end_cursor = edges.last().map(|e| e.cursor.clone());
680 PageInfo {
681 has_previous_page: false,
682 has_next_page: has_more,
683 start_cursor: None,
684 end_cursor,
685 }
686 };
687
688 Ok(ReadResult { edges, page_info })
689 }
690
691 fn build_reader<O: Cursor, B: Bind<Cursor = O>>(&mut self) -> Result<u16, cursor::CursorError>
692 where
693 B::T: Clone,
694 <<B as Bind>::I as IntoIterator>::IntoIter: DoubleEndedIterator,
695 <<B as Bind>::V as IntoIterator>::IntoIter: DoubleEndedIterator,
696 {
697 let (limit, cursor) = self.args.get_info();
698
699 if let Some(cursor) = cursor.as_ref() {
700 self.build_reader_where::<O, B>(cursor)?;
701 }
702
703 self.build_reader_order::<B>();
704 self.limit((limit + 1).into());
705
706 Ok(limit)
707 }
708
709 fn build_reader_where<O, B>(&mut self, cursor: &Value) -> Result<(), cursor::CursorError>
710 where
711 O: Cursor,
712 B: Bind<Cursor = O>,
713 B::T: Clone,
714 <<B as Bind>::I as IntoIterator>::IntoIter: DoubleEndedIterator,
715 <<B as Bind>::V as IntoIterator>::IntoIter: DoubleEndedIterator,
716 {
717 let is_order_desc = self.is_order_desc();
718 let cursor = O::deserialize_cursor(cursor)?;
719 let colums = B::columns().into_iter().rev();
720 let values = B::values(cursor).into_iter().rev();
721
722 let mut expr = None::<Expr>;
723 for (col, value) in colums.zip(values) {
724 let current_expr = if is_order_desc {
725 Expr::col(col.clone()).lt(value.clone())
726 } else {
727 Expr::col(col.clone()).gt(value.clone())
728 };
729
730 let Some(ref prev_expr) = expr else {
731 expr = Some(current_expr.clone());
732 continue;
733 };
734
735 expr = Some(current_expr.or(Expr::col(col).eq(value).and(prev_expr.clone())));
736 }
737
738 self.and_where(expr.unwrap());
739
740 Ok(())
741 }
742
743 fn build_reader_order<O: Bind>(&mut self) {
744 let order = if self.is_order_desc() {
745 sea_query::Order::Desc
746 } else {
747 sea_query::Order::Asc
748 };
749
750 let colums = O::columns();
751 for col in colums {
752 self.order_by(col, order.clone());
753 }
754 }
755
756 fn is_order_desc(&self) -> bool {
757 matches!(
758 (&self.order, self.args.is_backward()),
759 (cursor::Order::Asc, true) | (cursor::Order::Desc, false)
760 )
761 }
762}
763
764impl Deref for Reader {
765 type Target = SelectStatement;
766
767 fn deref(&self) -> &Self::Target {
768 &self.statement
769 }
770}
771
772impl DerefMut for Reader {
773 fn deref_mut(&mut self) -> &mut Self::Target {
774 &mut self.statement
775 }
776}
777
778pub trait Bind {
796 type T: IntoColumnRef + Clone;
798 type I: IntoIterator<Item = Self::T>;
800 type V: IntoIterator<Item = Expr>;
802 type Cursor: Cursor;
804
805 fn columns() -> Self::I;
807 fn values(cursor: <<Self as Bind>::Cursor as Cursor>::T) -> Self::V;
809}
810
811impl evento_core::cursor::Cursor for SqlEvent {
812 type T = evento_core::EventCursor;
813
814 fn serialize(&self) -> Self::T {
815 evento_core::EventCursor {
816 i: self.0.id.to_string(),
817 v: self.0.version,
818 t: self.0.timestamp,
819 s: self.0.timestamp_subsec,
820 }
821 }
822}
823
824impl Bind for SqlEvent {
825 type T = Event;
826 type I = [Self::T; 4];
827 type V = [Expr; 4];
828 type Cursor = Self;
829
830 fn columns() -> Self::I {
831 [
832 Event::Timestamp,
833 Event::TimestampSubsec,
834 Event::Version,
835 Event::Id,
836 ]
837 }
838
839 fn values(cursor: <<Self as Bind>::Cursor as Cursor>::T) -> Self::V {
840 [
841 cursor.t.into(),
842 cursor.s.into(),
843 cursor.v.into(),
844 cursor.i.into(),
845 ]
846 }
847}
848
849#[cfg(feature = "sqlite")]
850impl From<Sqlite> for evento_core::Evento {
851 fn from(value: Sqlite) -> Self {
852 evento_core::Evento::new(value)
853 }
854}
855
856#[cfg(feature = "sqlite")]
857impl From<&Sqlite> for evento_core::Evento {
858 fn from(value: &Sqlite) -> Self {
859 evento_core::Evento::new(value.clone())
860 }
861}
862
863#[cfg(feature = "mysql")]
864impl From<MySql> for evento_core::Evento {
865 fn from(value: MySql) -> Self {
866 evento_core::Evento::new(value)
867 }
868}
869
870#[cfg(feature = "mysql")]
871impl From<&MySql> for evento_core::Evento {
872 fn from(value: &MySql) -> Self {
873 evento_core::Evento::new(value.clone())
874 }
875}
876
877#[cfg(feature = "postgres")]
878impl From<Postgres> for evento_core::Evento {
879 fn from(value: Postgres) -> Self {
880 evento_core::Evento::new(value)
881 }
882}
883
884#[cfg(feature = "postgres")]
885impl From<&Postgres> for evento_core::Evento {
886 fn from(value: &Postgres) -> Self {
887 evento_core::Evento::new(value.clone())
888 }
889}
890
891#[derive(Debug, Clone, PartialEq, Default)]
892pub struct SqlEvent(pub evento_core::Event);
893
894impl<R: sqlx::Row> sqlx::FromRow<'_, R> for SqlEvent
895where
896 i32: sqlx::Type<R::Database> + for<'r> sqlx::Decode<'r, R::Database>,
897 Vec<u8>: sqlx::Type<R::Database> + for<'r> sqlx::Decode<'r, R::Database>,
898 String: sqlx::Type<R::Database> + for<'r> sqlx::Decode<'r, R::Database>,
899 i64: sqlx::Type<R::Database> + for<'r> sqlx::Decode<'r, R::Database>,
900 for<'r> &'r str: sqlx::Type<R::Database> + sqlx::Decode<'r, R::Database>,
901 for<'r> &'r str: sqlx::ColumnIndex<R>,
902{
903 fn from_row(row: &R) -> Result<Self, sqlx::Error> {
904 let timestamp: i64 = sqlx::Row::try_get(row, "timestamp")?;
905 let timestamp_subsec: i64 = sqlx::Row::try_get(row, "timestamp_subsec")?;
906 let version: i32 = sqlx::Row::try_get(row, "version")?;
907 let metadata: Vec<u8> = sqlx::Row::try_get(row, "metadata")?;
908 let metadata: evento_core::metadata::Metadata =
909 bitcode::decode(&metadata).map_err(|e| sqlx::Error::Decode(e.into()))?;
910
911 Ok(SqlEvent(evento_core::Event {
912 id: Ulid::from_string(sqlx::Row::try_get(row, "id")?)
913 .map_err(|err| sqlx::Error::InvalidArgument(err.to_string()))?,
914 aggregator_id: sqlx::Row::try_get(row, "aggregator_id")?,
915 aggregator_type: sqlx::Row::try_get(row, "aggregator_type")?,
916 version: version as u16,
917 name: sqlx::Row::try_get(row, "name")?,
918 routing_key: sqlx::Row::try_get(row, "routing_key")?,
919 data: sqlx::Row::try_get(row, "data")?,
920 timestamp: timestamp as u64,
921 timestamp_subsec: timestamp_subsec as u32,
922 metadata,
923 }))
924 }
925}