mysql_es/
view_repository.rs1use std::marker::PhantomData;
2
3use cqrs_es::persist::{PersistenceError, ViewContext, ViewRepository};
4use cqrs_es::{Aggregate, View};
5use sqlx::mysql::MySqlRow;
6use sqlx::{MySql, Pool, Row};
7
8use crate::error::MysqlAggregateError;
9
10pub struct MysqlViewRepository<V, A> {
12 insert_sql: String,
13 update_sql: String,
14 select_sql: String,
15 pool: Pool<MySql>,
16 _phantom: PhantomData<(V, A)>,
17}
18
19impl<V, A> MysqlViewRepository<V, A>
20where
21 V: View<A>,
22 A: Aggregate,
23{
24 pub fn new(view_name: &str, pool: Pool<MySql>) -> Self {
39 let insert_sql =
40 format!("INSERT INTO {view_name} (payload, version, view_id) VALUES ( ?, ?, ? )");
41 let update_sql = format!("UPDATE {view_name} SET payload= ? , version= ? WHERE view_id= ?");
42 let select_sql = format!("SELECT version,payload FROM {view_name} WHERE view_id= ?");
43 Self {
44 insert_sql,
45 update_sql,
46 select_sql,
47 pool,
48 _phantom: PhantomData,
49 }
50 }
51}
52
53impl<V, A> ViewRepository<V, A> for MysqlViewRepository<V, A>
54where
55 V: View<A>,
56 A: Aggregate,
57{
58 async fn load(&self, view_id: &str) -> Result<Option<V>, PersistenceError> {
59 let row: Option<MySqlRow> = sqlx::query(&self.select_sql)
60 .bind(view_id)
61 .fetch_optional(&self.pool)
62 .await
63 .map_err(MysqlAggregateError::from)?;
64 match row {
65 None => Ok(None),
66 Some(row) => {
67 let view = serde_json::from_value(row.get("payload"))?;
68 Ok(Some(view))
69 }
70 }
71 }
72
73 async fn load_with_context(
74 &self,
75 view_id: &str,
76 ) -> Result<Option<(V, ViewContext)>, PersistenceError> {
77 let row: Option<MySqlRow> = sqlx::query(&self.select_sql)
78 .bind(view_id)
79 .fetch_optional(&self.pool)
80 .await
81 .map_err(MysqlAggregateError::from)?;
82 match row {
83 None => Ok(None),
84 Some(row) => {
85 let version = row.get("version");
86 let view = serde_json::from_value(row.get("payload"))?;
87 let view_context = ViewContext::new(view_id.to_string(), version);
88 Ok(Some((view, view_context)))
89 }
90 }
91 }
92
93 async fn update_view(&self, view: V, context: ViewContext) -> Result<(), PersistenceError> {
94 let sql = match context.version {
95 0 => &self.insert_sql,
96 _ => &self.update_sql,
97 };
98 let version = context.version + 1;
99 let payload = serde_json::to_value(&view).map_err(MysqlAggregateError::from)?;
100 sqlx::query(sql.as_str())
101 .bind(payload)
102 .bind(version)
103 .bind(context.view_instance_id)
104 .execute(&self.pool)
105 .await
106 .map_err(MysqlAggregateError::from)?;
107 Ok(())
108 }
109}
110
111#[cfg(test)]
112mod test {
113 use crate::testing::tests::{
114 Created, TestAggregate, TestEvent, TestView, TEST_CONNECTION_STRING,
115 };
116 use crate::{default_mysql_pool, MysqlViewRepository};
117 use cqrs_es::persist::{ViewContext, ViewRepository};
118
119 #[tokio::test]
120 async fn test_valid_view_repository() {
121 let pool = default_mysql_pool(TEST_CONNECTION_STRING).await;
122 let repo = MysqlViewRepository::<TestView, TestAggregate>::new("test_view", pool.clone());
123 let test_view_id = uuid::Uuid::new_v4().to_string();
124
125 let view = TestView {
126 events: vec![TestEvent::Created(Created {
127 id: "just a test event for this view".to_string(),
128 })],
129 };
130 repo.update_view(view.clone(), ViewContext::new(test_view_id.to_string(), 0))
131 .await
132 .unwrap();
133 let (found, context) = repo
134 .load_with_context(&test_view_id)
135 .await
136 .unwrap()
137 .unwrap();
138 assert_eq!(found, view);
139
140 let updated_view = TestView {
141 events: vec![TestEvent::Created(Created {
142 id: "a totally different view".to_string(),
143 })],
144 };
145 repo.update_view(updated_view.clone(), context)
146 .await
147 .unwrap();
148 let found_option = repo.load(&test_view_id).await.unwrap();
149 let found = found_option.unwrap();
150
151 assert_eq!(found, updated_view);
152 }
153}