Skip to main content

pg_queue/
request_response.rs

1use crate::errors::{PgQueueError, Result};
2use crate::listen::ListenerService;
3use crate::queue::{QueueName, QueueRepository};
4use serde::{de::DeserializeOwned, Serialize};
5use sqlx::PgPool;
6use std::time::Duration;
7use tokio::time::timeout;
8use uuid::Uuid;
9
10/// Service for request-response pattern using PostgreSQL.
11/// Uses queue for requests and LISTEN/NOTIFY for responses.
12#[derive(Clone)]
13pub struct RequestResponseService {
14    pool: PgPool,
15    queue: QueueRepository,
16}
17
18impl RequestResponseService {
19    pub fn new(pool: PgPool, queue: QueueRepository) -> Self {
20        Self { pool, queue }
21    }
22
23    /// Push a request to a queue and wait for a response.
24    ///
25    /// Sets up a LISTEN channel before pushing to avoid race conditions.
26    /// The processor must call `store_response()` with the matching request_id.
27    pub async fn push_and_wait<Req, Resp>(
28        &self,
29        queue: &QueueName,
30        request: &Req,
31        timeout_duration: Duration,
32    ) -> Result<Resp>
33    where
34        Req: Serialize,
35        Resp: DeserializeOwned,
36    {
37        let request_id = Uuid::new_v4();
38        let channel = format!("response_{}", request_id);
39
40        // Set up listener before pushing to avoid race condition
41        let mut listener = ListenerService::new(&self.pool).await?;
42        listener.listen(&channel).await?;
43
44        // Wrap request with ID for correlation
45        let wrapped = RequestWrapper {
46            request_id,
47            payload: serde_json::to_value(request)?,
48        };
49
50        // Push to queue
51        self.queue.push(queue, &wrapped).await?;
52
53        // Wait for response notification
54        match timeout(timeout_duration, listener.recv()).await {
55            Ok(Ok(_notification)) => {
56                let response = self.fetch_response::<Resp>(&request_id).await?;
57                Ok(response)
58            }
59            Ok(Err(e)) => Err(e),
60            Err(_) => Err(PgQueueError::Timeout),
61        }
62    }
63
64    /// Store a response for a request (called by the processor)
65    pub async fn store_response<T: Serialize>(
66        &self,
67        request_id: &Uuid,
68        response: &T,
69    ) -> Result<()> {
70        let json = serde_json::to_value(response)?;
71
72        sqlx::query("INSERT INTO request_responses (request_id, response) VALUES ($1, $2)")
73            .bind(request_id)
74            .bind(json)
75            .execute(&self.pool)
76            .await?;
77
78        Ok(())
79    }
80
81    /// Fetch a stored response
82    async fn fetch_response<T: DeserializeOwned>(&self, request_id: &Uuid) -> Result<T> {
83        let row: (serde_json::Value,) =
84            sqlx::query_as("SELECT response FROM request_responses WHERE request_id = $1")
85                .bind(request_id)
86                .fetch_one(&self.pool)
87                .await?;
88
89        let parsed: T = serde_json::from_value(row.0)?;
90        Ok(parsed)
91    }
92
93    /// Clean up old responses (housekeeping)
94    pub async fn cleanup_old_responses(&self, older_than: Duration) -> Result<u64> {
95        let cutoff = chrono::Utc::now()
96            - chrono::Duration::from_std(older_than)
97                .map_err(|e| PgQueueError::Listener(e.to_string()))?;
98
99        let result = sqlx::query("DELETE FROM request_responses WHERE created_at < $1")
100            .bind(cutoff)
101            .execute(&self.pool)
102            .await?;
103
104        Ok(result.rows_affected())
105    }
106}
107
108/// Wrapper for requests to include correlation ID
109#[derive(Debug, Serialize, serde::Deserialize)]
110pub struct RequestWrapper {
111    pub request_id: Uuid,
112    pub payload: serde_json::Value,
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_request_wrapper_serialization() {
121        let request_id = Uuid::new_v4();
122        let payload = serde_json::json!({"action": "test"});
123
124        let wrapper = RequestWrapper {
125            request_id,
126            payload: payload.clone(),
127        };
128
129        let serialized = serde_json::to_string(&wrapper).unwrap();
130        let deserialized: RequestWrapper = serde_json::from_str(&serialized).unwrap();
131
132        assert_eq!(deserialized.request_id, request_id);
133        assert_eq!(deserialized.payload, payload);
134    }
135
136    #[test]
137    fn test_response_channel_format() {
138        let request_id = Uuid::new_v4();
139        let channel = format!("response_{}", request_id);
140        assert!(channel.starts_with("response_"));
141    }
142}