postgres_es/
view_repository.rs1use std::marker::PhantomData;
2
3use async_trait::async_trait;
4use cqrs_es::persist::{PersistenceError, ViewContext, ViewRepository};
5use cqrs_es::{Aggregate, View};
6use sqlx::postgres::PgRow;
7use sqlx::{Pool, Postgres, Row};
8
9use crate::error::PostgresAggregateError;
10
11pub struct PostgresViewRepository<V, A> {
13 insert_sql: String,
14 update_sql: String,
15 select_sql: String,
16 pool: Pool<Postgres>,
17 _phantom: PhantomData<(V, A)>,
18}
19
20impl<V, A> PostgresViewRepository<V, A>
21where
22 V: View<A>,
23 A: Aggregate,
24{
25 pub fn new(view_name: &str, pool: Pool<Postgres>) -> Self {
40 let insert_sql = format!(
41 "INSERT INTO {} (payload, version, view_id) VALUES ( $1, $2, $3 )",
42 view_name
43 );
44 let update_sql = format!(
45 "UPDATE {} SET payload= $1 , version= $2 WHERE view_id= $3",
46 view_name
47 );
48 let select_sql = format!(
49 "SELECT version,payload FROM {} WHERE view_id= $1",
50 view_name
51 );
52 Self {
53 insert_sql,
54 update_sql,
55 select_sql,
56 pool,
57 _phantom: Default::default(),
58 }
59 }
60}
61
62#[async_trait]
63impl<V, A> ViewRepository<V, A> for PostgresViewRepository<V, A>
64where
65 V: View<A>,
66 A: Aggregate,
67{
68 async fn load(&self, view_id: &str) -> Result<Option<V>, PersistenceError> {
69 let row: Option<PgRow> = sqlx::query(&self.select_sql)
70 .bind(view_id)
71 .fetch_optional(&self.pool)
72 .await
73 .map_err(PostgresAggregateError::from)?;
74 match row {
75 None => Ok(None),
76 Some(row) => {
77 let view = serde_json::from_value(row.get("payload"))?;
78 Ok(Some(view))
79 }
80 }
81 }
82
83 async fn load_with_context(
84 &self,
85 view_id: &str,
86 ) -> Result<Option<(V, ViewContext)>, PersistenceError> {
87 let row: Option<PgRow> = sqlx::query(&self.select_sql)
88 .bind(view_id)
89 .fetch_optional(&self.pool)
90 .await
91 .map_err(PostgresAggregateError::from)?;
92 match row {
93 None => Ok(None),
94 Some(row) => {
95 let version = row.get("version");
96 let view = serde_json::from_value(row.get("payload"))?;
97 let view_context = ViewContext::new(view_id.to_string(), version);
98 Ok(Some((view, view_context)))
99 }
100 }
101 }
102
103 async fn update_view(&self, view: V, context: ViewContext) -> Result<(), PersistenceError> {
104 let sql = match context.version {
105 0 => &self.insert_sql,
106 _ => &self.update_sql,
107 };
108 let version = context.version + 1;
109 let payload = serde_json::to_value(&view).map_err(PostgresAggregateError::from)?;
110 sqlx::query(sql.as_str())
111 .bind(payload)
112 .bind(version)
113 .bind(context.view_instance_id)
114 .execute(&self.pool)
115 .await
116 .map_err(PostgresAggregateError::from)?;
117 Ok(())
118 }
119}
120
121#[cfg(test)]
122mod test {
123 use crate::testing::tests::{
124 Created, TestAggregate, TestEvent, TestView, TEST_CONNECTION_STRING,
125 };
126 use crate::{default_postgress_pool, PostgresViewRepository};
127 use cqrs_es::persist::{ViewContext, ViewRepository};
128
129 #[tokio::test]
130 async fn test_valid_view_repository() {
131 let pool = default_postgress_pool(TEST_CONNECTION_STRING).await;
132 let repo =
133 PostgresViewRepository::<TestView, TestAggregate>::new("test_view", pool.clone());
134 let test_view_id = uuid::Uuid::new_v4().to_string();
135
136 let view = TestView {
137 events: vec![TestEvent::Created(Created {
138 id: "just a test event for this view".to_string(),
139 })],
140 };
141 repo.update_view(view.clone(), ViewContext::new(test_view_id.to_string(), 0))
142 .await
143 .unwrap();
144 let (found, context) = repo
145 .load_with_context(&test_view_id)
146 .await
147 .unwrap()
148 .unwrap();
149 assert_eq!(found, view);
150 let found = repo.load(&test_view_id).await.unwrap().unwrap();
151 assert_eq!(found, view);
152
153 let updated_view = TestView {
154 events: vec![TestEvent::Created(Created {
155 id: "a totally different view".to_string(),
156 })],
157 };
158 repo.update_view(updated_view.clone(), context)
159 .await
160 .unwrap();
161 let found_option = repo.load(&test_view_id).await.unwrap();
162 let found = found_option.unwrap();
163
164 assert_eq!(found, updated_view);
165 }
166}