warpdrive_proxy/cache/invalidation.rs
1//! Cache invalidation via PostgreSQL LISTEN/NOTIFY
2//!
3//! This module provides distributed cache invalidation using PostgreSQL's
4//! LISTEN/NOTIFY feature. When one WarpDrive instance invalidates a cache entry,
5//! all other instances are notified via PostgreSQL and can clear their local caches.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Instance A Instance B
11//! ┌────────┐ ┌────────┐
12//! │ DELETE │ │ │
13//! │ key │ │ │
14//! └───┬────┘ └───▲────┘
15//! │ │
16//! │ ┌──────────────┐ │
17//! └───►│ PostgreSQL │─────────┘
18//! │ NOTIFY │
19//! │ "key=foo" │
20//! └──────────────┘
21//! ```
22//!
23//! # Example
24//!
25//! ```no_run
26//! use warpdrive::cache::coordinator::CacheCoordinator;
27//! use warpdrive::cache::invalidation::InvalidationListener;
28//! use warpdrive::config::Config;
29//!
30//! # async fn example() -> anyhow::Result<()> {
31//! let config = Config::from_env()?;
32//! let cache = CacheCoordinator::from_config(&config).await?;
33//!
34//! // Start listener in background (if PostgreSQL configured)
35//! if let Some(db_url) = &config.database_url {
36//! InvalidationListener::spawn(
37//! db_url.clone(),
38//! config.pg_channel_cache_invalidation.clone(),
39//! cache.clone(),
40//! );
41//! }
42//! # Ok(())
43//! # }
44//! ```
45
46use futures::StreamExt;
47use serde::{Deserialize, Serialize};
48use std::sync::Arc;
49use tokio::task::JoinHandle;
50use tracing::{error, info, warn};
51
52use crate::cache::Cache;
53use crate::cache::coordinator::CacheCoordinator;
54use crate::postgres::PgListener;
55
56/// Cache invalidation message format
57///
58/// Sent via PostgreSQL NOTIFY when a cache entry is invalidated.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct InvalidationMessage {
61 /// Cache key to invalidate
62 pub key: String,
63
64 /// Source instance ID (for debugging/metrics)
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub source: Option<String>,
67
68 /// Reason for invalidation (for debugging)
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub reason: Option<String>,
71}
72
73impl InvalidationMessage {
74 /// Create a new invalidation message
75 pub fn new(key: String) -> Self {
76 InvalidationMessage {
77 key,
78 source: None,
79 reason: None,
80 }
81 }
82
83 /// Create invalidation message with metadata
84 pub fn with_metadata(key: String, source: String, reason: String) -> Self {
85 InvalidationMessage {
86 key,
87 source: Some(source),
88 reason: Some(reason),
89 }
90 }
91}
92
93/// Cache invalidation listener
94///
95/// Subscribes to PostgreSQL NOTIFY events and invalidates local cache entries.
96pub struct InvalidationListener {
97 _handle: JoinHandle<()>,
98}
99
100impl InvalidationListener {
101 /// Spawn a cache invalidation listener in the background
102 ///
103 /// Creates a PostgreSQL listener and processes invalidation messages.
104 /// The listener runs in a background task and will continue until the
105 /// application exits or an unrecoverable error occurs.
106 ///
107 /// # Arguments
108 ///
109 /// * `database_url` - PostgreSQL connection string
110 /// * `channel` - Channel name to listen on (e.g., "warpdrive:cache:invalidate")
111 /// * `cache` - Cache coordinator to invalidate entries from
112 ///
113 /// # Panics
114 ///
115 /// Does not panic. Errors are logged but the listener will attempt to reconnect.
116 ///
117 /// # Example
118 ///
119 /// ```no_run
120 /// # use warpdrive::cache::coordinator::CacheCoordinator;
121 /// # use warpdrive::cache::invalidation::InvalidationListener;
122 /// # async fn example(cache: CacheCoordinator) {
123 /// let listener = InvalidationListener::spawn(
124 /// "postgresql://localhost/warpdrive".to_string(),
125 /// "warpdrive:cache:invalidate".to_string(),
126 /// cache,
127 /// );
128 /// # }
129 /// ```
130 pub fn spawn(
131 database_url: String,
132 channel: String,
133 cache: CacheCoordinator,
134 ) -> InvalidationListener {
135 let handle =
136 tokio::spawn(
137 async move { Self::listen_loop(database_url, channel, Arc::new(cache)).await },
138 );
139
140 InvalidationListener { _handle: handle }
141 }
142
143 /// Main listen loop
144 ///
145 /// Connects to PostgreSQL, subscribes to the channel, and processes messages.
146 /// Automatically reconnects on connection loss.
147 async fn listen_loop(database_url: String, channel: String, cache: Arc<CacheCoordinator>) {
148 info!(
149 database_url = %database_url,
150 channel = %channel,
151 "Starting cache invalidation listener"
152 );
153
154 loop {
155 match Self::try_listen(&database_url, &channel, Arc::clone(&cache)).await {
156 Ok(()) => {
157 warn!("Cache invalidation listener exited cleanly, restarting...");
158 }
159 Err(e) => {
160 error!(
161 error = %e,
162 "Cache invalidation listener error, retrying in 5s..."
163 );
164 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
165 }
166 }
167 }
168 }
169
170 /// Attempt to listen for invalidation messages
171 ///
172 /// Returns when connection is lost or an unrecoverable error occurs.
173 async fn try_listen(
174 database_url: &str,
175 channel: &str,
176 cache: Arc<CacheCoordinator>,
177 ) -> anyhow::Result<()> {
178 // Create listener
179 let mut listener = PgListener::from_url(database_url, vec![channel.to_string()]).await?;
180
181 info!(channel = %channel, "Cache invalidation listener connected");
182
183 // Process messages
184 let mut stream = listener.stream();
185 while let Some(notification) = stream.next().await {
186 if let Err(e) = Self::handle_notification(notification, Arc::clone(&cache)).await {
187 error!(
188 error = %e,
189 "Failed to handle cache invalidation notification"
190 );
191 // Continue processing other notifications
192 }
193 }
194
195 Ok(())
196 }
197
198 /// Handle a single invalidation notification
199 async fn handle_notification(
200 notification: crate::postgres::PgNotification,
201 cache: Arc<CacheCoordinator>,
202 ) -> anyhow::Result<()> {
203 // Parse message
204 let message: InvalidationMessage = notification.parse_payload()?;
205
206 info!(
207 key = %message.key,
208 source = ?message.source,
209 reason = ?message.reason,
210 "Received cache invalidation notification"
211 );
212
213 // Invalidate cache entry
214 cache.delete(&message.key).await?;
215
216 // Update metrics
217 crate::metrics::CACHE_INVALIDATIONS
218 .with_label_values(&["pg_notify"])
219 .inc();
220
221 Ok(())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_invalidation_message_new() {
231 let msg = InvalidationMessage::new("test_key".to_string());
232 assert_eq!(msg.key, "test_key");
233 assert!(msg.source.is_none());
234 assert!(msg.reason.is_none());
235 }
236
237 #[test]
238 fn test_invalidation_message_with_metadata() {
239 let msg = InvalidationMessage::with_metadata(
240 "test_key".to_string(),
241 "instance-1".to_string(),
242 "expired".to_string(),
243 );
244 assert_eq!(msg.key, "test_key");
245 assert_eq!(msg.source, Some("instance-1".to_string()));
246 assert_eq!(msg.reason, Some("expired".to_string()));
247 }
248
249 #[test]
250 fn test_invalidation_message_serialization() {
251 let msg = InvalidationMessage::new("test_key".to_string());
252 let json = serde_json::to_string(&msg).unwrap();
253 let parsed: InvalidationMessage = serde_json::from_str(&json).unwrap();
254 assert_eq!(parsed.key, "test_key");
255 }
256
257 #[test]
258 fn test_invalidation_message_with_metadata_serialization() {
259 let msg = InvalidationMessage::with_metadata(
260 "test_key".to_string(),
261 "instance-1".to_string(),
262 "expired".to_string(),
263 );
264 let json = serde_json::to_string(&msg).unwrap();
265 let parsed: InvalidationMessage = serde_json::from_str(&json).unwrap();
266 assert_eq!(parsed.key, "test_key");
267 assert_eq!(parsed.source, Some("instance-1".to_string()));
268 assert_eq!(parsed.reason, Some("expired".to_string()));
269 }
270}