1use async_trait::async_trait;
3use derive_new::new;
4use futures::StreamExt;
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use std::{fmt::Debug, sync::Arc};
7use tokio::{sync::RwLock, task::JoinHandle};
8
9use tokio_postgres::{
10 tls::NoTlsStream, types::Json, AsyncMessage, Client, Connection, NoTls, Socket,
11};
12
13#[derive(Default)]
14pub struct TmpTable {
15 pub name: String,
16 pub fields: Vec<String>,
17}
18
19#[derive(new)]
20pub struct Watcher<T> {
21 ctx: Arc<RwLock<WatcherCtx<T>>>,
22}
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct JsonEvent {
26 pub id: i32,
27 pub op: i32,
28 pub data: serde_json::Value,
29}
30
31impl<T> Watcher<T>
32where
33 T: Sync + Send + 'static + DeserializeOwned,
34{
35 pub async fn start(
36 &mut self,
37 mut connection: Connection<Socket, NoTlsStream>,
38 ) -> JoinHandle<()> {
39 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
40
41 let handle = tokio::spawn(async move {
42 let mut stream = futures::stream::poll_fn(move |cx| connection.poll_message(cx));
43
44 while let Some(message) = stream.next().await {
45 let message = message
46 .map_err(|e| {
47 eprintln!("failed to get message from db: {}", e);
48 e
49 })
50 .unwrap();
51
52 if let AsyncMessage::Notification(not) = &message {
53 let updated_table_nb: u32 = serde_json::from_str(not.payload()).unwrap();
54
55 tx.send(updated_table_nb)
56 .await
57 .map_err(|e| {
58 eprintln!("failed to send message on channel: {:?}", e);
59 e
60 })
61 .unwrap();
62 }
63 }
64 });
65
66 let ctx = Arc::clone(&self.ctx);
67
68 tokio::spawn(async move {
69 while rx.recv().await.is_some() {
70 ctx.write().await.handle_event().await;
71 }
72 });
73
74 self.ctx.write().await.start().await;
75
76 handle
77 }
78}
79
80#[derive(new)]
81pub struct WatcherCtx<T> {
82 cb: Box<dyn Fn(Vec<Event<T>>) + Sync + Send + 'static>,
83 query: String,
84 pub client: Client,
85
86 #[new(default)]
87 triggers: Vec<String>,
88
89 #[new(default)]
90 result_table: TmpTable,
91
92 #[new(default)]
93 source_tables: Vec<String>,
94
95 #[new(value = "true")]
96 first_run: bool,
97
98 phantom: std::marker::PhantomData<T>,
99}
100
101impl<T> WatcherCtx<T>
102where
103 T: Sync + Send + 'static + DeserializeOwned,
104{
105 pub async fn start(&mut self) {
106 self.client
107 .query("LISTEN __live_update;", &[])
108 .await
109 .unwrap();
110
111 self.init().await
112 }
113
114 pub async fn init(&mut self) {
115 self.collect_source_tables();
116 self.create_triggers().await;
117
118 if !self.setup_query_result_table().await {
120 return;
121 }
122
123 self.update_result_table().await;
124 }
125
126 pub async fn setup_query_result_table(&mut self) -> bool {
127 let tmp_table_name = "query_result";
128
129 let columns = self.get_query_result_columns().await;
130
131 if columns.is_empty() {
132 return false;
134 }
135
136 let fields = columns.iter().map(|(name, _)| name.clone()).collect();
137
138 let columns_def = columns
139 .iter()
140 .map(|(name, t)| format!("{} {} NOT NULL", name, t))
141 .collect::<Vec<_>>()
142 .join(",\n");
143
144 let query = format!(
145 r#"
146 CREATE TEMP TABLE {} (
147 {}
148 )
149 "#,
150 tmp_table_name, columns_def,
151 );
152
153 self.client.execute(&query, &[]).await.unwrap();
154
155 self.result_table = TmpTable {
156 name: tmp_table_name.to_string(),
157 fields,
158 };
159
160 true
161 }
162
163 pub async fn get_query_result_columns(&self) -> Vec<(String, String)> {
164 let query = format!("SELECT * FROM ({}) q LIMIT 1", self.query);
165
166 let columns = if let Some(first) = self.client.query(&query, &[]).await.unwrap().get(0) {
167 first
168 .columns()
169 .iter()
170 .map(|c| (c.name().to_string(), c.type_().name().to_string()))
171 .collect()
172 } else {
173 vec![]
174 };
175
176 columns
177 }
178
179 pub fn collect_source_tables(&mut self) {
180 use sqlparser::dialect::PostgreSqlDialect;
181 let sql_ast =
182 sqlparser::parser::Parser::parse_sql(&PostgreSqlDialect {}, &self.query).unwrap();
183 use sqlparser::ast::{SetExpr, Statement, TableFactor};
184 let mut names = Vec::new();
185
186 for stmt in sql_ast {
187 match stmt {
188 Statement::Query(query) => match &*query.body {
189 SetExpr::Select(select) => {
190 for table in &select.from {
191 match &table.relation {
192 TableFactor::Table {
193 name,
194 alias: _,
195 args: _,
196 with_hints: _,
197 } => names.push(name.to_string()),
198 _ => {}
199 }
200 }
201 }
202 _ => {}
203 },
204 _ => {}
205 }
206 }
207
208 self.source_tables = names;
209 }
210
211 pub async fn create_triggers(&mut self) {
212 for (i, table_name) in self.source_tables.iter().enumerate() {
213 if !self.triggers.contains(&table_name) {
214 let trigger_name = &format!(r#""__live_update{}""#, table_name);
215 let l_key = i.to_string();
216
217 let drop_sql = format!(
218 r#"
219 DROP TRIGGER IF EXISTS
220 {}
221 ON
222 {}
223 "#,
224 trigger_name, table_name
225 );
226
227 self.client.execute(&drop_sql, &[]).await.unwrap();
228
229 let func_sql = format!(
230 r#"
231 CREATE OR REPLACE FUNCTION pg_temp.{}()
232 RETURNS TRIGGER AS $$
233 BEGIN
234 EXECUTE pg_notify('__live_update', '{}');
235 RETURN NULL;
236 END;
237 $$ LANGUAGE plpgsql
238 "#,
239 trigger_name, l_key
240 );
241
242 self.client.execute(&func_sql, &[]).await.unwrap();
243
244 let create_sql = format!(
245 "
246 CREATE TRIGGER
247 {}
248 AFTER INSERT OR UPDATE OR DELETE OR TRUNCATE ON
249 {}
250 EXECUTE PROCEDURE pg_temp.{}()
251 ",
252 trigger_name, table_name, trigger_name
253 );
254
255 self.client.execute(&create_sql, &[]).await.unwrap();
256
257 self.triggers.push(table_name.clone());
258 }
259 }
260 }
261
262 pub async fn update_result_table(&mut self) {
263 let i_table = "query_result";
264
265 let q_obj = self
266 .result_table
267 .fields
268 .iter()
269 .map(|name| format!("'{name}', i.{name}"))
270 .collect::<Vec<_>>()
271 .join(",");
272
273 let u_obj = self
274 .result_table
275 .fields
276 .iter()
277 .map(|name| format!("'{name}', u.{name}"))
278 .collect::<Vec<_>>()
279 .join(",");
280 let cols = self.result_table.fields.join(", ");
281 let set_cols = self
282 .result_table
283 .fields
284 .iter()
285 .map(|name| format!("{} = q.{}", name, name))
286 .collect::<Vec<_>>()
287 .join(", ");
288 let update_sql = format!(
289 "WITH
290 q AS (
291 SELECT
292 *,
293 ROW_NUMBER() OVER() AS lol
294 FROM
295 ({}) t
296 ),
297 i AS (
298 INSERT INTO {i_table} (
299 {cols}
300 )
301 SELECT
302 {cols}
303 FROM
304 q
305 WHERE q.id NOT IN (
306 SELECT id FROM {i_table}
307 )
308 RETURNING
309 {i_table}.*
310 ),
311 d AS (
312 DELETE FROM
313 {i_table}
314 WHERE
315 NOT EXISTS(
316 SELECT
317 1
318 FROM
319 q
320 WHERE
321 q.id = {i_table}.id
322 )
323 RETURNING
324 {i_table}.id
325 ),
326 u AS (
327 UPDATE {i_table} SET
328 {set_cols}
329 FROM
330 q
331 WHERE
332 {i_table}.id = q.id
333 RETURNING
334 {i_table}.*
335 )
336 SELECT
337 jsonb_build_object(
338 'id', i.id,
339 'op', 1,
340 'data', jsonb_build_object(
341 {q_obj}
342 )
343 ) AS c
344 FROM
345 i JOIN
346 q ON
347 i.id = q.id
348 UNION ALL
349 SELECT
350 jsonb_build_object(
351 'id', u.id,
352 'op', 2,
353 'data', jsonb_build_object({u_obj})
354 ) AS c
355 FROM
356 u JOIN
357 q ON
358 u.id = q.id
359 UNION ALL
360 SELECT
361 jsonb_build_object(
362 'id', d.id,
363 'op', 3,
364 'data', jsonb_build_object('id', d.id)
365 ) AS c
366 FROM
367 d
368 ",
369 self.query
370 );
371
372 let res: Vec<_> = self
373 .client
374 .query(&update_sql, &[])
375 .await
376 .unwrap_or_else(|err| {
377 panic!("error {}", err);
378 });
379
380 if self.first_run {
382 self.first_run = false;
383
384 return;
385 }
386
387 let res = res
388 .into_iter()
389 .map(|row| row.get("c"))
390 .collect::<Vec<Json<serde_json::Value>>>();
391
392 let json_value = res.iter().map(|json| json.0.clone()).collect::<Vec<_>>();
393 let json_events = json_value
394 .into_iter()
395 .map(|json| serde_json::from_value(json).unwrap())
396 .collect::<Vec<JsonEvent>>();
397
398 let events = json_events
399 .into_iter()
400 .map(|event| match event.op {
401 1 => Event::Insert(serde_json::from_value(event.data).unwrap()),
402 2 => Event::Update(serde_json::from_value(event.data).unwrap()),
403 3 => Event::Delete(event.id),
404 _ => unimplemented!(),
405 })
406 .collect::<Vec<Event<T>>>();
407
408 (self.cb)(events);
409 }
410
411 pub async fn handle_event(&mut self) {
412 if self.result_table.fields.is_empty() {
414 self.first_run = false;
415
416 if !self.setup_query_result_table().await {
417 return;
418 }
419 }
420
421 self.update_result_table().await;
422 }
423}
424
425pub async fn watch<T>(
426 query: &str,
427 handler: Box<dyn Fn(Vec<Event<T>>) + Sync + Send + 'static>,
428) -> JoinHandle<()>
429where
430 T: Debug + Send + Sync + 'static + DeserializeOwned,
431{
432 let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
433
434 let (client, connection) = tokio_postgres::connect(&db_url, NoTls).await.unwrap();
435
436 let mut watcher = Watcher::new(Arc::new(RwLock::new(WatcherCtx::new(
437 handler,
438 query.to_string(),
439 client,
440 ))));
441
442 watcher.start(connection).await
443}
444
445#[derive(Serialize, Deserialize, Debug, Clone)]
446pub enum Event<T> {
447 Insert(T),
448 Update(T),
449 Delete(i32),
450}
451
452#[async_trait]
453pub trait WatchableSql<F> {
454 async fn watch<T>(&self, handler: F) -> JoinHandle<()>
455 where
456 F: Fn(Vec<Event<T>>) + Sync + Send + 'static,
457 T: Debug + Send + Sync + 'static + DeserializeOwned;
458}
459
460#[async_trait]
461impl<F> WatchableSql<F> for str {
462 async fn watch<T>(&self, handler: F) -> JoinHandle<()>
463 where
464 F: Fn(Vec<Event<T>>) + Sync + Send + 'static,
465 T: Debug + Send + Sync + 'static + DeserializeOwned,
466 {
467 watch(self, Box::new(handler)).await
468 }
469}