Skip to main content

strev_postgres/
subscriber.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use bytes::Bytes;
6use serde_json::Value;
7use sqlx::{PgPool, Row};
8use strev::{CloseError, Message, MessageStream, Metadata, SubscribeError, Topic};
9use tokio::sync::mpsc::Sender;
10
11use crate::schema::ensure_schema;
12
13pub struct PostgresSubscriberConfig {
14    pub pool: PgPool,
15    pub consumer_group: String,
16    pub poll_interval: Duration,
17    pub batch_size: i64,
18    pub buffer_size: usize,
19}
20
21impl PostgresSubscriberConfig {
22    pub fn new(pool: PgPool, consumer_group: impl Into<String>) -> Self {
23        Self {
24            pool,
25            consumer_group: consumer_group.into(),
26            poll_interval: Duration::from_millis(200),
27            batch_size: 100,
28            buffer_size: 64,
29        }
30    }
31}
32
33pub struct PostgresSubscriber {
34    config: Arc<PostgresSubscriberConfig>,
35}
36
37impl PostgresSubscriber {
38    pub fn new(config: PostgresSubscriberConfig) -> Self {
39        Self {
40            config: Arc::new(config),
41        }
42    }
43}
44
45#[async_trait]
46impl strev::Subscriber for PostgresSubscriber {
47    async fn subscribe(&self, topic: &Topic) -> Result<MessageStream, SubscribeError> {
48        let config = self.config.clone();
49        let topic = topic.as_str().to_string();
50
51        ensure_schema(&config.pool)
52            .await
53            .map_err(|e| SubscribeError::Backend(Box::new(e)))?;
54
55        sqlx::query(
56            "INSERT INTO strev_offsets (consumer_group, topic, last_id) VALUES ($1, $2, 0) ON CONFLICT DO NOTHING",
57        )
58        .bind(&config.consumer_group)
59        .bind(&topic)
60        .execute(&config.pool)
61        .await
62        .map_err(|e| SubscribeError::Backend(Box::new(e)))?;
63
64        let (sender, stream) = MessageStream::channel(config.buffer_size);
65
66        tokio::spawn(async move {
67            loop {
68                if sender.is_closed() {
69                    break;
70                }
71
72                match poll_once(&config, &topic, &sender).await {
73                    Ok(count) if count > 0 => continue,
74                    Ok(_) => tokio::time::sleep(config.poll_interval).await,
75                    Err(_) => tokio::time::sleep(config.poll_interval).await,
76                }
77            }
78        });
79
80        Ok(stream)
81    }
82
83    async fn close(&mut self) -> Result<(), CloseError> {
84        Ok(())
85    }
86}
87
88async fn poll_once(
89    config: &PostgresSubscriberConfig,
90    topic: &str,
91    sender: &Sender<Message>,
92) -> Result<usize, sqlx::Error> {
93    let mut tx = config.pool.begin().await?;
94
95    let locked = sqlx::query(
96        "SELECT last_id FROM strev_offsets WHERE consumer_group = $1 AND topic = $2 FOR UPDATE SKIP LOCKED",
97    )
98    .bind(&config.consumer_group)
99    .bind(topic)
100    .fetch_optional(&mut *tx)
101    .await?;
102
103    let last_id: i64 = match locked {
104        Some(row) => row.try_get("last_id")?,
105        None => {
106            tx.rollback().await?;
107            return Ok(0);
108        }
109    };
110
111    let rows = sqlx::query(
112        "SELECT id, payload, metadata FROM strev_messages WHERE topic = $1 AND id > $2 ORDER BY id ASC LIMIT $3",
113    )
114    .bind(topic)
115    .bind(last_id)
116    .bind(config.batch_size)
117    .fetch_all(&mut *tx)
118    .await?;
119
120    if rows.is_empty() {
121        tx.rollback().await?;
122        return Ok(0);
123    }
124
125    let mut max_id = last_id;
126    for row in &rows {
127        let id: i64 = row.try_get("id")?;
128        let payload: Vec<u8> = row.try_get("payload")?;
129        let metadata_json: Value = row.try_get("metadata")?;
130
131        let mut metadata = Metadata::new();
132        if let Value::Object(map) = metadata_json {
133            for (key, value) in map {
134                if let Value::String(text) = value {
135                    metadata.set(key, text);
136                }
137            }
138        }
139
140        let message = Message::with_metadata(Bytes::from(payload), metadata);
141        if sender.send(message).await.is_err() {
142            tx.rollback().await?;
143            return Ok(0);
144        }
145
146        max_id = id;
147    }
148
149    sqlx::query("UPDATE strev_offsets SET last_id = $1 WHERE consumer_group = $2 AND topic = $3")
150        .bind(max_id)
151        .bind(&config.consumer_group)
152        .bind(topic)
153        .execute(&mut *tx)
154        .await?;
155
156    tx.commit().await?;
157    Ok(rows.len())
158}