reactive_pg/
watcher.rs

1/// Adapted from https://github.com/nothingisdead/pg-live-query/blob/master/watcher.js
2use async_trait::async_trait;
3use derive_new::new;
4use futures::StreamExt;
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use std::{fmt::Debug, sync::Arc};
7use tokio::{sync::RwLock, task::JoinHandle};
8
9use tokio_postgres::{
10    tls::NoTlsStream, types::Json, AsyncMessage, Client, Connection, NoTls, Socket,
11};
12
13#[derive(Default)]
14pub struct TmpTable {
15    pub name: String,
16    pub fields: Vec<String>,
17}
18
19#[derive(new)]
20pub struct Watcher<T> {
21    ctx: Arc<RwLock<WatcherCtx<T>>>,
22}
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct JsonEvent {
26    pub id: i32,
27    pub op: i32,
28    pub data: serde_json::Value,
29}
30
31impl<T> Watcher<T>
32where
33    T: Sync + Send + 'static + DeserializeOwned,
34{
35    pub async fn start(
36        &mut self,
37        mut connection: Connection<Socket, NoTlsStream>,
38    ) -> JoinHandle<()> {
39        let (tx, mut rx) = tokio::sync::mpsc::channel(100);
40
41        let handle = tokio::spawn(async move {
42            let mut stream = futures::stream::poll_fn(move |cx| connection.poll_message(cx));
43
44            while let Some(message) = stream.next().await {
45                let message = message
46                    .map_err(|e| {
47                        eprintln!("failed to get message from db: {}", e);
48                        e
49                    })
50                    .unwrap();
51
52                if let AsyncMessage::Notification(not) = &message {
53                    let updated_table_nb: u32 = serde_json::from_str(not.payload()).unwrap();
54
55                    tx.send(updated_table_nb)
56                        .await
57                        .map_err(|e| {
58                            eprintln!("failed to send message on channel: {:?}", e);
59                            e
60                        })
61                        .unwrap();
62                }
63            }
64        });
65
66        let ctx = Arc::clone(&self.ctx);
67
68        tokio::spawn(async move {
69            while rx.recv().await.is_some() {
70                ctx.write().await.handle_event().await;
71            }
72        });
73
74        self.ctx.write().await.start().await;
75
76        handle
77    }
78}
79
80#[derive(new)]
81pub struct WatcherCtx<T> {
82    cb: Box<dyn Fn(Vec<Event<T>>) + Sync + Send + 'static>,
83    query: String,
84    pub client: Client,
85
86    #[new(default)]
87    triggers: Vec<String>,
88
89    #[new(default)]
90    result_table: TmpTable,
91
92    #[new(default)]
93    source_tables: Vec<String>,
94
95    #[new(value = "true")]
96    first_run: bool,
97
98    phantom: std::marker::PhantomData<T>,
99}
100
101impl<T> WatcherCtx<T>
102where
103    T: Sync + Send + 'static + DeserializeOwned,
104{
105    pub async fn start(&mut self) {
106        self.client
107            .query("LISTEN __live_update;", &[])
108            .await
109            .unwrap();
110
111        self.init().await
112    }
113
114    pub async fn init(&mut self) {
115        self.collect_source_tables();
116        self.create_triggers().await;
117
118        // If no initial record, we can't infer the table schema. Skipping until the first event
119        if !self.setup_query_result_table().await {
120            return;
121        }
122
123        self.update_result_table().await;
124    }
125
126    pub async fn setup_query_result_table(&mut self) -> bool {
127        let tmp_table_name = "query_result";
128
129        let columns = self.get_query_result_columns().await;
130
131        if columns.is_empty() {
132            // Delay the setup until there is at least one record
133            return false;
134        }
135
136        let fields = columns.iter().map(|(name, _)| name.clone()).collect();
137
138        let columns_def = columns
139            .iter()
140            .map(|(name, t)| format!("{} {} NOT NULL", name, t))
141            .collect::<Vec<_>>()
142            .join(",\n");
143
144        let query = format!(
145            r#"
146            CREATE TEMP TABLE {} (
147                {}
148            )
149        "#,
150            tmp_table_name, columns_def,
151        );
152
153        self.client.execute(&query, &[]).await.unwrap();
154
155        self.result_table = TmpTable {
156            name: tmp_table_name.to_string(),
157            fields,
158        };
159
160        true
161    }
162
163    pub async fn get_query_result_columns(&self) -> Vec<(String, String)> {
164        let query = format!("SELECT * FROM ({}) q LIMIT 1", self.query);
165
166        let columns = if let Some(first) = self.client.query(&query, &[]).await.unwrap().get(0) {
167            first
168                .columns()
169                .iter()
170                .map(|c| (c.name().to_string(), c.type_().name().to_string()))
171                .collect()
172        } else {
173            vec![]
174        };
175
176        columns
177    }
178
179    pub fn collect_source_tables(&mut self) {
180        use sqlparser::dialect::PostgreSqlDialect;
181        let sql_ast =
182            sqlparser::parser::Parser::parse_sql(&PostgreSqlDialect {}, &self.query).unwrap();
183        use sqlparser::ast::{SetExpr, Statement, TableFactor};
184        let mut names = Vec::new();
185
186        for stmt in sql_ast {
187            match stmt {
188                Statement::Query(query) => match &*query.body {
189                    SetExpr::Select(select) => {
190                        for table in &select.from {
191                            match &table.relation {
192                                TableFactor::Table {
193                                    name,
194                                    alias: _,
195                                    args: _,
196                                    with_hints: _,
197                                } => names.push(name.to_string()),
198                                _ => {}
199                            }
200                        }
201                    }
202                    _ => {}
203                },
204                _ => {}
205            }
206        }
207
208        self.source_tables = names;
209    }
210
211    pub async fn create_triggers(&mut self) {
212        for (i, table_name) in self.source_tables.iter().enumerate() {
213            if !self.triggers.contains(&table_name) {
214                let trigger_name = &format!(r#""__live_update{}""#, table_name);
215                let l_key = i.to_string();
216
217                let drop_sql = format!(
218                    r#"
219                    DROP TRIGGER IF EXISTS
220                        {}
221                    ON
222                        {}
223                "#,
224                    trigger_name, table_name
225                );
226
227                self.client.execute(&drop_sql, &[]).await.unwrap();
228
229                let func_sql = format!(
230                    r#"
231                    CREATE OR REPLACE FUNCTION pg_temp.{}()
232                    RETURNS TRIGGER AS $$
233                        BEGIN
234                            EXECUTE pg_notify('__live_update', '{}');
235                        RETURN NULL;
236                        END;
237                    $$ LANGUAGE plpgsql
238                "#,
239                    trigger_name, l_key
240                );
241
242                self.client.execute(&func_sql, &[]).await.unwrap();
243
244                let create_sql = format!(
245                    "
246                    CREATE TRIGGER
247                        {}
248                    AFTER INSERT OR UPDATE OR DELETE OR TRUNCATE ON
249                        {}
250                    EXECUTE PROCEDURE pg_temp.{}()
251                ",
252                    trigger_name, table_name, trigger_name
253                );
254
255                self.client.execute(&create_sql, &[]).await.unwrap();
256
257                self.triggers.push(table_name.clone());
258            }
259        }
260    }
261
262    pub async fn update_result_table(&mut self) {
263        let i_table = "query_result";
264
265        let q_obj = self
266            .result_table
267            .fields
268            .iter()
269            .map(|name| format!("'{name}', i.{name}"))
270            .collect::<Vec<_>>()
271            .join(",");
272
273        let u_obj = self
274            .result_table
275            .fields
276            .iter()
277            .map(|name| format!("'{name}', u.{name}"))
278            .collect::<Vec<_>>()
279            .join(",");
280        let cols = self.result_table.fields.join(", ");
281        let set_cols = self
282            .result_table
283            .fields
284            .iter()
285            .map(|name| format!("{} = q.{}", name, name))
286            .collect::<Vec<_>>()
287            .join(", ");
288        let update_sql = format!(
289            "WITH
290                q AS (
291                    SELECT
292                        *,
293                        ROW_NUMBER() OVER() AS lol
294                    FROM
295                        ({}) t
296                ),
297                i AS (
298                    INSERT INTO {i_table} (
299                        {cols}
300                    )
301                    SELECT
302                        {cols}
303                    FROM
304                        q
305                    WHERE q.id NOT IN (
306                        SELECT id FROM {i_table}
307                    )
308                    RETURNING
309                        {i_table}.*
310                ),
311				d AS (
312					DELETE FROM
313						{i_table}
314					WHERE
315						NOT EXISTS(
316							SELECT
317								1
318							FROM
319								q
320							WHERE
321								q.id = {i_table}.id
322						)
323					RETURNING
324						{i_table}.id
325				),
326                u AS (
327					UPDATE {i_table} SET
328                        {set_cols}
329					FROM
330						q
331					WHERE
332						{i_table}.id = q.id
333					RETURNING
334						{i_table}.*
335				)
336            SELECT
337                jsonb_build_object(
338                    'id', i.id,
339                    'op', 1,
340                    'data', jsonb_build_object(
341                        {q_obj}
342                    )
343                ) AS c
344            FROM
345                i JOIN
346                q ON
347                    i.id = q.id
348            UNION ALL
349			SELECT
350				jsonb_build_object(
351					'id', u.id,
352					'op', 2,
353					'data', jsonb_build_object({u_obj})
354				) AS c
355			FROM
356				u JOIN
357				q ON
358					u.id = q.id
359			UNION ALL
360			SELECT
361				jsonb_build_object(
362					'id', d.id,
363					'op', 3,
364                    'data', jsonb_build_object('id', d.id)
365				) AS c
366			FROM
367				d
368        ",
369            self.query
370        );
371
372        let res: Vec<_> = self
373            .client
374            .query(&update_sql, &[])
375            .await
376            .unwrap_or_else(|err| {
377                panic!("error {}", err);
378            });
379
380        // Don't send the first result, as it will contains the whole query.
381        if self.first_run {
382            self.first_run = false;
383
384            return;
385        }
386
387        let res = res
388            .into_iter()
389            .map(|row| row.get("c"))
390            .collect::<Vec<Json<serde_json::Value>>>();
391
392        let json_value = res.iter().map(|json| json.0.clone()).collect::<Vec<_>>();
393        let json_events = json_value
394            .into_iter()
395            .map(|json| serde_json::from_value(json).unwrap())
396            .collect::<Vec<JsonEvent>>();
397
398        let events = json_events
399            .into_iter()
400            .map(|event| match event.op {
401                1 => Event::Insert(serde_json::from_value(event.data).unwrap()),
402                2 => Event::Update(serde_json::from_value(event.data).unwrap()),
403                3 => Event::Delete(event.id),
404                _ => unimplemented!(),
405            })
406            .collect::<Vec<Event<T>>>();
407
408        (self.cb)(events);
409    }
410
411    pub async fn handle_event(&mut self) {
412        // The table was previously empty, sending it all after setting up the triggers.
413        if self.result_table.fields.is_empty() {
414            self.first_run = false;
415
416            if !self.setup_query_result_table().await {
417                return;
418            }
419        }
420
421        self.update_result_table().await;
422    }
423}
424
425pub async fn watch<T>(
426    query: &str,
427    handler: Box<dyn Fn(Vec<Event<T>>) + Sync + Send + 'static>,
428) -> JoinHandle<()>
429where
430    T: Debug + Send + Sync + 'static + DeserializeOwned,
431{
432    let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
433
434    let (client, connection) = tokio_postgres::connect(&db_url, NoTls).await.unwrap();
435
436    let mut watcher = Watcher::new(Arc::new(RwLock::new(WatcherCtx::new(
437        handler,
438        query.to_string(),
439        client,
440    ))));
441
442    watcher.start(connection).await
443}
444
445#[derive(Serialize, Deserialize, Debug, Clone)]
446pub enum Event<T> {
447    Insert(T),
448    Update(T),
449    Delete(i32),
450}
451
452#[async_trait]
453pub trait WatchableSql<F> {
454    async fn watch<T>(&self, handler: F) -> JoinHandle<()>
455    where
456        F: Fn(Vec<Event<T>>) + Sync + Send + 'static,
457        T: Debug + Send + Sync + 'static + DeserializeOwned;
458}
459
460#[async_trait]
461impl<F> WatchableSql<F> for str {
462    async fn watch<T>(&self, handler: F) -> JoinHandle<()>
463    where
464        F: Fn(Vec<Event<T>>) + Sync + Send + 'static,
465        T: Debug + Send + Sync + 'static + DeserializeOwned,
466    {
467        watch(self, Box::new(handler)).await
468    }
469}