Skip to main content

tasker_pgmq/
client.rs

1//! # Unified PGMQ Client
2//!
3//! This module provides a unified `PostgreSQL` Message Queue (PGMQ) client that combines
4//! all functionality from both the original `PgmqNotifyClient` and the tasker-shared `PgmqClient`.
5//! It includes both notification capabilities and comprehensive PGMQ operations.
6
7use crate::{
8    config::PgmqNotifyConfig,
9    error::{PgmqNotifyError, Result},
10    listener::PgmqNotifyListener,
11    types::{ClientStatus, QueueMetrics},
12};
13use pgmq::{types::Message, PGMQueueExt};
14use regex::Regex;
15use sqlx::{PgPool, Row};
16use std::collections::HashMap;
17use tracing::{debug, error, info, instrument, warn};
18
19/// Unified PGMQ client with comprehensive functionality and notification capabilities
20#[derive(Debug, Clone)]
21pub struct PgmqClient {
22    /// Underlying PGMQ client
23    pgmq: PGMQueueExt,
24    /// Database connection pool for advanced operations and health checks
25    pool: sqlx::PgPool,
26    /// Configuration for notifications and queue naming
27    config: PgmqNotifyConfig,
28}
29
30impl PgmqClient {
31    /// Create new unified PGMQ client using connection string
32    pub async fn new(database_url: &str) -> Result<Self> {
33        Self::new_with_config(database_url, PgmqNotifyConfig::default()).await
34    }
35
36    /// Create new unified PGMQ client with custom configuration
37    pub async fn new_with_config(database_url: &str, config: PgmqNotifyConfig) -> Result<Self> {
38        info!("Connecting to pgmq using unified client");
39
40        let max_connections = 20;
41        let pgmq = PGMQueueExt::new(database_url.to_string(), max_connections).await?;
42        let pool = pgmq.connection.clone();
43
44        info!("Connected to pgmq using unified client");
45        Ok(Self { pgmq, pool, config })
46    }
47
48    /// Create new unified PGMQ client using existing connection pool (BYOP - Bring Your Own Pool)
49    pub async fn new_with_pool(pool: sqlx::PgPool) -> Self {
50        Self::new_with_pool_and_config(pool, PgmqNotifyConfig::default()).await
51    }
52
53    /// Create new unified PGMQ client with existing pool and custom configuration
54    pub async fn new_with_pool_and_config(pool: sqlx::PgPool, config: PgmqNotifyConfig) -> Self {
55        info!("Creating unified pgmq client with shared connection pool");
56
57        let pgmq = PGMQueueExt::new_with_pool(pool.clone()).await;
58
59        info!("Unified pgmq client created with shared pool");
60        Self { pgmq, pool, config }
61    }
62
63    /// Create queue if it doesn't exist
64    #[instrument(skip(self), fields(queue = %queue_name))]
65    pub async fn create_queue(&self, queue_name: &str) -> Result<()> {
66        debug!("๐Ÿ“‹ Creating queue: {}", queue_name);
67
68        let created = self.pgmq.create(queue_name).await?;
69
70        if created {
71            info!("Queue created: {}", queue_name);
72        } else {
73            debug!("Queue already exists: {}", queue_name);
74        }
75        Ok(())
76    }
77
78    /// Send generic JSON message to queue
79    #[instrument(skip(self, message), fields(queue = %queue_name))]
80    pub async fn send_json_message<T>(&self, queue_name: &str, message: &T) -> Result<i64>
81    where
82        T: serde::Serialize,
83    {
84        debug!("๐Ÿ“ค Sending JSON message to queue: {}", queue_name);
85
86        let serialized = serde_json::to_value(message)?;
87
88        // Use wrapper function for atomic message sending + notification
89        let message_id = sqlx::query_scalar!(
90            "SELECT pgmq_send_with_notify($1, $2, $3)",
91            queue_name,
92            &serialized,
93            0i32
94        )
95        .fetch_one(&self.pool)
96        .await?;
97
98        let message_id = message_id.ok_or_else(|| {
99            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
100        })?;
101
102        info!(
103            "JSON message sent to queue: {} with ID: {} (with notification)",
104            queue_name, message_id
105        );
106        Ok(message_id)
107    }
108
109    /// Send message with visibility timeout (delay)
110    #[instrument(skip(self, message), fields(queue = %queue_name, delay_seconds = %delay_seconds))]
111    pub async fn send_message_with_delay<T>(
112        &self,
113        queue_name: &str,
114        message: &T,
115        delay_seconds: u64,
116    ) -> Result<i64>
117    where
118        T: serde::Serialize,
119    {
120        debug!(
121            "๐Ÿ“ค Sending delayed message to queue: {} with delay: {}s",
122            queue_name, delay_seconds
123        );
124
125        let serialized = serde_json::to_value(message)?;
126
127        // Use wrapper function for atomic message sending + notification
128        let message_id = sqlx::query_scalar!(
129            "SELECT pgmq_send_with_notify($1, $2, $3)",
130            queue_name,
131            &serialized,
132            delay_seconds as i32
133        )
134        .fetch_one(&self.pool)
135        .await?;
136
137        let message_id = message_id.ok_or_else(|| {
138            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
139        })?;
140
141        info!(
142            "Delayed message sent to queue: {} with ID: {} (with notification)",
143            queue_name, message_id
144        );
145        Ok(message_id)
146    }
147
148    /// Read messages from queue (non-blocking)
149    ///
150    /// Uses `max_poll_seconds=0` to avoid blocking on empty queues. This is critical
151    /// for the fallback poller which iterates many queues sequentially โ€” blocking 5s
152    /// per empty queue creates unacceptable latency (25 queues ร— 5s = 125s per cycle).
153    #[instrument(skip(self), fields(queue = %queue_name, limit = ?limit))]
154    pub async fn read_messages(
155        &self,
156        queue_name: &str,
157        visibility_timeout: Option<i32>,
158        limit: Option<i32>,
159    ) -> Result<Vec<Message<serde_json::Value>>> {
160        debug!(
161            "๐Ÿ“ฅ Reading messages from queue: {} (limit: {:?})",
162            queue_name, limit
163        );
164
165        let vt = visibility_timeout.unwrap_or(30);
166        let messages = match limit {
167            Some(l) => self
168                .pgmq
169                .read_batch_with_poll::<serde_json::Value>(queue_name, vt, l, None, None)
170                .await?
171                .unwrap_or_default(),
172            None => match self.pgmq.read::<serde_json::Value>(queue_name, vt).await? {
173                Some(msg) => vec![msg],
174                None => vec![],
175            },
176        };
177
178        debug!(
179            "๐Ÿ“จ Read {} messages from queue: {}",
180            messages.len(),
181            queue_name
182        );
183        Ok(messages)
184    }
185
186    /// Read messages from queue with pop (single read and delete)
187    #[instrument(skip(self), fields(queue = %queue_name))]
188    pub async fn pop_message(
189        &self,
190        queue_name: &str,
191    ) -> Result<Option<Message<serde_json::Value>>> {
192        debug!("๐Ÿ“ฅ Popping message from queue: {}", queue_name);
193
194        let message = self.pgmq.pop::<serde_json::Value>(queue_name).await?;
195
196        if message.is_some() {
197            debug!("๐Ÿ“จ Popped message from queue: {}", queue_name);
198        } else {
199            debug!("๐Ÿ“ญ No messages available in queue: {}", queue_name);
200        }
201        Ok(message)
202    }
203
204    /// Read a specific message by ID using custom SQL function (for notification event handling)
205    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
206    pub async fn read_specific_message<T>(
207        &self,
208        queue_name: &str,
209        message_id: i64,
210        visibility_timeout: i32,
211    ) -> Result<Option<Message<T>>>
212    where
213        T: serde::de::DeserializeOwned,
214    {
215        debug!(
216            "๐Ÿ“ฅ Reading specific message {} from queue: {}",
217            message_id, queue_name
218        );
219
220        // Use the custom SQL function pgmq_read_specific_message for efficient specific message reading
221        let query = "SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq_read_specific_message($1, $2, $3)";
222
223        let row = sqlx::query(query)
224            .bind(queue_name)
225            .bind(message_id)
226            .bind(visibility_timeout)
227            .fetch_optional(&self.pool)
228            .await?;
229
230        if let Some(row) = row {
231            let msg_id: i64 = row.get("msg_id");
232            let read_ct: i32 = row.get("read_ct");
233            let enqueued_at: chrono::DateTime<chrono::Utc> = row.get("enqueued_at");
234            let vt: chrono::DateTime<chrono::Utc> = row.get("vt");
235            let message_json: serde_json::Value = row.get("message");
236
237            // Try to deserialize the message to the expected type
238            match serde_json::from_value::<T>(message_json) {
239                Ok(deserialized) => {
240                    let typed_message = Message {
241                        msg_id,
242                        read_ct,
243                        enqueued_at,
244                        vt,
245                        message: deserialized,
246                    };
247                    debug!("Found and deserialized specific message {}", message_id);
248                    Ok(Some(typed_message))
249                }
250                Err(e) => {
251                    warn!("Failed to deserialize message {}: {}", message_id, e);
252                    Err(PgmqNotifyError::Serialization(e))
253                }
254            }
255        } else {
256            debug!(
257                "๐Ÿ“ญ Specific message {} not found in queue: {}",
258                message_id, queue_name
259            );
260            Ok(None)
261        }
262    }
263
264    /// Delete message from queue
265    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
266    pub async fn delete_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
267        debug!(
268            "๐Ÿ—‘Deleting message {} from queue: {}",
269            message_id, queue_name
270        );
271
272        let _ = self.pgmq.delete(queue_name, message_id).await?;
273
274        debug!("Message deleted: {}", message_id);
275        Ok(())
276    }
277
278    /// Archive message (move to archive)
279    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id))]
280    pub async fn archive_message(&self, queue_name: &str, message_id: i64) -> Result<()> {
281        debug!(
282            "Archiving message {} from queue: {}",
283            message_id, queue_name
284        );
285
286        let _ = self.pgmq.archive(queue_name, message_id).await?;
287
288        debug!("Message archived: {}", message_id);
289        Ok(())
290    }
291
292    /// Set visibility timeout for a message
293    ///
294    /// Extends or resets the visibility timeout for a message. This is useful for:
295    /// - Heartbeat during long-running step processing
296    /// - Returning a message to the queue immediately (vt_seconds = 0)
297    ///
298    /// # Arguments
299    ///
300    /// * `queue_name` - The queue containing the message
301    /// * `message_id` - The message ID to update
302    /// * `vt_seconds` - New visibility timeout in seconds from now
303    ///
304    /// # Returns
305    ///
306    /// Returns `Ok(())` on success, or an error if the message doesn't exist.
307    #[instrument(skip(self), fields(queue = %queue_name, message_id = %message_id, vt_seconds = %vt_seconds))]
308    pub async fn set_visibility_timeout(
309        &self,
310        queue_name: &str,
311        message_id: i64,
312        vt_seconds: i32,
313    ) -> Result<()> {
314        debug!(
315            "Setting visibility timeout for message {} in queue {} to {} seconds",
316            message_id, queue_name, vt_seconds
317        );
318
319        // Use pgmq.set_vt SQL function directly for precise control
320        // The function signature is: pgmq.set_vt(queue_name text, msg_id bigint, vt_offset integer)
321        sqlx::query_scalar!(
322            "SELECT msg_id FROM pgmq.set_vt($1::text, $2::bigint, $3::integer)",
323            queue_name,
324            message_id,
325            vt_seconds
326        )
327        .fetch_optional(&self.pool)
328        .await?;
329
330        debug!(
331            "Visibility timeout set for message {} to {} seconds",
332            message_id, vt_seconds
333        );
334        Ok(())
335    }
336
337    /// Purge queue (delete all messages)
338    #[instrument(skip(self), fields(queue = %queue_name))]
339    pub async fn purge_queue(&self, queue_name: &str) -> Result<u64> {
340        warn!("๐Ÿงน Purging queue: {}", queue_name);
341
342        let purged_count = self.pgmq.purge_queue(queue_name).await?;
343
344        warn!(
345            "๐Ÿ—‘Purged {} messages from queue: {}",
346            purged_count, queue_name
347        );
348        Ok(purged_count as u64)
349    }
350
351    /// Drop queue completely
352    #[instrument(skip(self), fields(queue = %queue_name))]
353    pub async fn drop_queue(&self, queue_name: &str) -> Result<()> {
354        warn!("๐Ÿ’ฅ Dropping queue: {}", queue_name);
355
356        self.pgmq.drop_queue(queue_name).await?;
357
358        warn!("๐Ÿ—‘Queue dropped: {}", queue_name);
359        Ok(())
360    }
361
362    /// Get queue metrics/statistics
363    #[instrument(skip(self), fields(queue = %queue_name))]
364    pub async fn queue_metrics(&self, queue_name: &str) -> Result<QueueMetrics> {
365        debug!("Getting metrics for queue: {}", queue_name);
366
367        // Query actual pgmq metrics from the database using pgmq.metrics() function
368        let row = sqlx::query!(
369            "SELECT queue_length, oldest_msg_age_sec FROM pgmq.metrics($1)",
370            queue_name
371        )
372        .fetch_optional(&self.pool)
373        .await?;
374
375        if let Some(row) = row {
376            Ok(QueueMetrics {
377                queue_name: queue_name.to_string(),
378                message_count: row.queue_length.unwrap_or(0),
379                consumer_count: None,
380                oldest_message_age_seconds: row.oldest_msg_age_sec.map(i64::from),
381            })
382        } else {
383            // Queue doesn't exist or has no metrics
384            Ok(QueueMetrics {
385                queue_name: queue_name.to_string(),
386                message_count: 0,
387                consumer_count: None,
388                oldest_message_age_seconds: None,
389            })
390        }
391    }
392
393    /// Get reference to underlying connection pool for advanced operations
394    #[must_use]
395    pub fn pool(&self) -> &sqlx::PgPool {
396        &self.pool
397    }
398
399    /// Get reference to pgmq client for direct access
400    #[must_use]
401    pub fn pgmq(&self) -> &PGMQueueExt {
402        &self.pgmq
403    }
404
405    /// Get the configuration
406    #[must_use]
407    pub fn config(&self) -> &PgmqNotifyConfig {
408        &self.config
409    }
410
411    /// Check if this client has notification capabilities enabled
412    #[must_use]
413    pub fn has_notify_capabilities(&self) -> bool {
414        // Check if the config has triggers enabled which enable notifications
415        self.config.enable_triggers
416    }
417
418    /// Health check - verify database connectivity
419    #[instrument(skip(self))]
420    pub async fn health_check(&self) -> Result<bool> {
421        match sqlx::query!("SELECT 1 as health_check")
422            .fetch_one(&self.pool)
423            .await
424        {
425            Ok(_) => {
426                debug!("Health check passed");
427                Ok(true)
428            }
429            Err(e) => {
430                error!("Health check failed: {}", e);
431                Ok(false)
432            }
433        }
434    }
435
436    /// Get client status information
437    #[instrument(skip(self))]
438    pub async fn get_client_status(&self) -> Result<ClientStatus> {
439        let healthy = self.health_check().await.unwrap_or(false);
440
441        Ok(ClientStatus {
442            client_type: "pgmq-unified".to_string(),
443            connected: healthy,
444            connection_info: HashMap::from([
445                (
446                    "backend".to_string(),
447                    serde_json::Value::String("postgresql".to_string()),
448                ),
449                (
450                    "queue_type".to_string(),
451                    serde_json::Value::String("pgmq".to_string()),
452                ),
453                (
454                    "has_notifications".to_string(),
455                    serde_json::Value::Bool(true),
456                ),
457                (
458                    "pool_size".to_string(),
459                    serde_json::Value::Number(self.pool.size().into()),
460                ),
461            ]),
462            last_activity: Some(chrono::Utc::now()),
463        })
464    }
465
466    /// Extract namespace from queue name using configured pattern
467    #[must_use]
468    pub fn extract_namespace(&self, queue_name: &str) -> Option<String> {
469        let pattern = &self.config.queue_naming_pattern;
470        if let Ok(regex) = Regex::new(pattern) {
471            if let Some(captures) = regex.captures(queue_name) {
472                if let Some(namespace_match) = captures.name("namespace") {
473                    return Some(namespace_match.as_str().to_string());
474                }
475            }
476        }
477
478        // Fallback: assume queue name is "{namespace}_queue"
479        queue_name
480            .ends_with("_queue")
481            .then(|| queue_name.trim_end_matches("_queue").to_string())
482    }
483
484    /// Create a notify listener for this client
485    ///
486    /// # Arguments
487    /// * `buffer_size` - MPSC channel buffer size (TAS-51: bounded channels)
488    ///
489    /// # Note
490    /// TAS-51: Migrated from unbounded to bounded channel to prevent OOM during notification bursts.
491    /// Buffer size should come from configuration based on context:
492    /// - Orchestration: `config.mpsc_channels.orchestration.event_listeners.pgmq_event_buffer_size`
493    /// - Worker: `config.mpsc_channels.worker.event_listeners.pgmq_event_buffer_size`
494    pub async fn create_listener(&self, buffer_size: usize) -> Result<PgmqNotifyListener> {
495        PgmqNotifyListener::new(self.pool.clone(), self.config.clone(), buffer_size).await
496    }
497}
498
499/// Helper methods for common queue operations
500impl PgmqClient {
501    /// Process messages from namespace queue
502    #[instrument(skip(self), fields(namespace = %namespace, batch_size = %batch_size))]
503    pub async fn process_namespace_queue(
504        &self,
505        namespace: &str,
506        visibility_timeout: Option<i32>,
507        batch_size: i32,
508    ) -> Result<Vec<Message<serde_json::Value>>> {
509        let queue_name = format!("worker_{namespace}_queue");
510        self.read_messages(&queue_name, visibility_timeout, Some(batch_size))
511            .await
512    }
513
514    /// Complete message processing (delete from queue)
515    #[instrument(skip(self), fields(namespace = %namespace, message_id = %message_id))]
516    pub async fn complete_message(&self, namespace: &str, message_id: i64) -> Result<()> {
517        let queue_name = format!("worker_{namespace}_queue");
518        self.delete_message(&queue_name, message_id).await
519    }
520
521    /// Initialize standard namespace queues
522    #[instrument(skip(self, namespaces))]
523    pub async fn initialize_namespace_queues(&self, namespaces: &[&str]) -> Result<()> {
524        info!("Initializing {} namespace queues", namespaces.len());
525
526        for namespace in namespaces {
527            let queue_name = format!("worker_{namespace}_queue");
528            self.create_queue(&queue_name).await?;
529        }
530
531        info!("Initialized all namespace queues");
532        Ok(())
533    }
534
535    /// Send message within a transaction (for atomic operations)
536    #[instrument(skip(self, message, tx), fields(queue = %queue_name))]
537    pub async fn send_with_transaction<T>(
538        &self,
539        queue_name: &str,
540        message: &T,
541        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
542    ) -> Result<i64>
543    where
544        T: serde::Serialize,
545    {
546        debug!(
547            "๐Ÿ“ค Sending message within transaction to queue: {}",
548            queue_name
549        );
550
551        let serialized = serde_json::to_value(message)?;
552
553        // Use wrapper function within the transaction for atomic message sending + notification
554        let message_id = sqlx::query_scalar!(
555            "SELECT pgmq_send_with_notify($1, $2, $3)",
556            queue_name,
557            &serialized,
558            0i32
559        )
560        .fetch_one(&mut **tx)
561        .await?;
562
563        let message_id = message_id.ok_or_else(|| {
564            PgmqNotifyError::Generic(anyhow::anyhow!("Wrapper function returned NULL message ID"))
565        })?;
566
567        debug!(
568            "Message sent in transaction with id: {} (with notification)",
569            message_id
570        );
571        Ok(message_id)
572    }
573}
574
575/// Factory for creating `PgmqClient` instances
576#[derive(Debug)]
577pub struct PgmqClientFactory;
578
579impl PgmqClientFactory {
580    /// Create new client from database URL
581    pub async fn create(database_url: &str) -> Result<PgmqClient> {
582        PgmqClient::new(database_url).await
583    }
584
585    /// Create new client with configuration
586    pub async fn create_with_config(
587        database_url: &str,
588        config: PgmqNotifyConfig,
589    ) -> Result<PgmqClient> {
590        PgmqClient::new_with_config(database_url, config).await
591    }
592
593    /// Create new client with existing pool
594    pub async fn create_with_pool(pool: PgPool) -> PgmqClient {
595        PgmqClient::new_with_pool(pool).await
596    }
597
598    /// Create new client with existing pool and configuration
599    pub async fn create_with_pool_and_config(pool: PgPool, config: PgmqNotifyConfig) -> PgmqClient {
600        PgmqClient::new_with_pool_and_config(pool, config).await
601    }
602}
603
604// Re-export for backward compatibility
605pub use PgmqClient as PgmqNotifyClient;
606pub use PgmqClientFactory as PgmqNotifyClientFactory;
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use dotenvy::dotenv;
612
613    #[tokio::test]
614    async fn test_pgmq_client_creation() {
615        dotenv().ok();
616        // This test requires a PostgreSQL database with pgmq extension
617        // Skip in CI or when database is not available
618        if std::env::var("DATABASE_URL").is_err() {
619            println!("Skipping pgmq test - no DATABASE_URL provided");
620            return;
621        }
622
623        let database_url = std::env::var("DATABASE_URL").unwrap();
624        match PgmqClient::new(&database_url).await {
625            Ok(_) => {
626                // Client creation succeeded
627                println!("PgmqClient created successfully");
628            }
629            Err(e) => {
630                // Skip test if it's a URL parsing or connection error
631                // This allows the test to pass in environments without proper database setup
632                println!(" Skipping test due to client creation error: {e:?}");
633                return;
634            }
635        }
636    }
637
638    #[test]
639    fn test_namespace_extraction() {
640        dotenv().ok();
641        let config = PgmqNotifyConfig::new().with_queue_naming_pattern(r"(?P<namespace>\w+)_queue");
642
643        // Test the pattern matching directly without needing a full client
644        let pattern = &config.queue_naming_pattern;
645        let regex = regex::Regex::new(pattern).unwrap();
646
647        // Test valid patterns
648        let captures = regex.captures("orders_queue").unwrap();
649        let namespace = captures.name("namespace").unwrap().as_str();
650        assert_eq!(namespace, "orders");
651
652        let captures = regex.captures("inventory_queue").unwrap();
653        let namespace = captures.name("namespace").unwrap().as_str();
654        assert_eq!(namespace, "inventory");
655
656        // Test invalid pattern
657        assert!(regex.captures("invalid_name").is_none());
658    }
659
660    #[tokio::test]
661    async fn test_shared_pool_pattern() {
662        dotenv().ok();
663        // Skip test if no database URL provided
664        if std::env::var("DATABASE_URL").is_err() {
665            println!("Skipping shared pool test - no DATABASE_URL provided");
666            return;
667        }
668
669        let database_url = std::env::var("DATABASE_URL").unwrap();
670
671        // Create a connection pool
672        let pool = sqlx::postgres::PgPoolOptions::new()
673            .max_connections(5)
674            .connect(&database_url)
675            .await
676            .expect("Failed to create connection pool");
677
678        // Create pgmq client with shared pool
679        let client = PgmqClient::new_with_pool(pool.clone()).await;
680
681        // Verify we can access the pool
682        assert_eq!(client.pool().size(), pool.size());
683
684        println!("Shared pool pattern working correctly");
685    }
686}