pg_queue/
request_response.rs1use crate::errors::{PgQueueError, Result};
2use crate::listen::ListenerService;
3use crate::queue::{QueueName, QueueRepository};
4use serde::{de::DeserializeOwned, Serialize};
5use sqlx::PgPool;
6use std::time::Duration;
7use tokio::time::timeout;
8use uuid::Uuid;
9
10#[derive(Clone)]
13pub struct RequestResponseService {
14 pool: PgPool,
15 queue: QueueRepository,
16}
17
18impl RequestResponseService {
19 pub fn new(pool: PgPool, queue: QueueRepository) -> Self {
20 Self { pool, queue }
21 }
22
23 pub async fn push_and_wait<Req, Resp>(
28 &self,
29 queue: &QueueName,
30 request: &Req,
31 timeout_duration: Duration,
32 ) -> Result<Resp>
33 where
34 Req: Serialize,
35 Resp: DeserializeOwned,
36 {
37 let request_id = Uuid::new_v4();
38 let channel = format!("response_{}", request_id);
39
40 let mut listener = ListenerService::new(&self.pool).await?;
42 listener.listen(&channel).await?;
43
44 let wrapped = RequestWrapper {
46 request_id,
47 payload: serde_json::to_value(request)?,
48 };
49
50 self.queue.push(queue, &wrapped).await?;
52
53 match timeout(timeout_duration, listener.recv()).await {
55 Ok(Ok(_notification)) => {
56 let response = self.fetch_response::<Resp>(&request_id).await?;
57 Ok(response)
58 }
59 Ok(Err(e)) => Err(e),
60 Err(_) => Err(PgQueueError::Timeout),
61 }
62 }
63
64 pub async fn store_response<T: Serialize>(
66 &self,
67 request_id: &Uuid,
68 response: &T,
69 ) -> Result<()> {
70 let json = serde_json::to_value(response)?;
71
72 sqlx::query("INSERT INTO request_responses (request_id, response) VALUES ($1, $2)")
73 .bind(request_id)
74 .bind(json)
75 .execute(&self.pool)
76 .await?;
77
78 Ok(())
79 }
80
81 async fn fetch_response<T: DeserializeOwned>(&self, request_id: &Uuid) -> Result<T> {
83 let row: (serde_json::Value,) =
84 sqlx::query_as("SELECT response FROM request_responses WHERE request_id = $1")
85 .bind(request_id)
86 .fetch_one(&self.pool)
87 .await?;
88
89 let parsed: T = serde_json::from_value(row.0)?;
90 Ok(parsed)
91 }
92
93 pub async fn cleanup_old_responses(&self, older_than: Duration) -> Result<u64> {
95 let cutoff = chrono::Utc::now()
96 - chrono::Duration::from_std(older_than)
97 .map_err(|e| PgQueueError::Listener(e.to_string()))?;
98
99 let result = sqlx::query("DELETE FROM request_responses WHERE created_at < $1")
100 .bind(cutoff)
101 .execute(&self.pool)
102 .await?;
103
104 Ok(result.rows_affected())
105 }
106}
107
108#[derive(Debug, Serialize, serde::Deserialize)]
110pub struct RequestWrapper {
111 pub request_id: Uuid,
112 pub payload: serde_json::Value,
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_request_wrapper_serialization() {
121 let request_id = Uuid::new_v4();
122 let payload = serde_json::json!({"action": "test"});
123
124 let wrapper = RequestWrapper {
125 request_id,
126 payload: payload.clone(),
127 };
128
129 let serialized = serde_json::to_string(&wrapper).unwrap();
130 let deserialized: RequestWrapper = serde_json::from_str(&serialized).unwrap();
131
132 assert_eq!(deserialized.request_id, request_id);
133 assert_eq!(deserialized.payload, payload);
134 }
135
136 #[test]
137 fn test_response_channel_format() {
138 let request_id = Uuid::new_v4();
139 let channel = format!("response_{}", request_id);
140 assert!(channel.starts_with("response_"));
141 }
142}