redis_watcher/
watcher.rs

1// Copyright 2025 The Casbin Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use casbin::{EventData, Watcher};
16use redis::{AsyncCommands, Client};
17use serde::{Deserialize, Serialize};
18use std::sync::{
19    atomic::{AtomicBool, Ordering},
20    Arc, Mutex,
21};
22use thiserror::Error;
23use tokio::sync::mpsc;
24use tokio::task::JoinHandle;
25use tokio_stream::StreamExt;
26
27// ========== Error Types ==========
28
29#[derive(Error, Debug)]
30pub enum WatcherError {
31    #[error("Redis connection error: {0}")]
32    RedisConnection(#[from] redis::RedisError),
33
34    #[error("Serialization error: {0}")]
35    Serialization(#[from] serde_json::Error),
36
37    #[error("Callback not set")]
38    CallbackNotSet,
39
40    #[error("Watcher already closed")]
41    AlreadyClosed,
42
43    #[error("Configuration error: {0}")]
44    Configuration(String),
45
46    #[error("Runtime error: {0}")]
47    Runtime(String),
48}
49
50pub type Result<T> = std::result::Result<T, WatcherError>;
51
52// Type aliases to reduce complexity
53type UpdateCallback = Box<dyn FnMut(String) + Send + Sync>;
54type CallbackArc = Arc<Mutex<Option<UpdateCallback>>>;
55
56// ========== Message Types ==========
57
58/// Message types for communication between watcher instances
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60#[serde(rename_all = "PascalCase")]
61pub enum UpdateType {
62    Update,
63    UpdateForAddPolicy,
64    UpdateForRemovePolicy,
65    UpdateForRemoveFilteredPolicy,
66    UpdateForSavePolicy,
67    UpdateForAddPolicies,
68    UpdateForRemovePolicies,
69    UpdateForUpdatePolicy,
70    UpdateForUpdatePolicies,
71}
72
73impl std::fmt::Display for UpdateType {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            UpdateType::Update => write!(f, "Update"),
77            UpdateType::UpdateForAddPolicy => write!(f, "UpdateForAddPolicy"),
78            UpdateType::UpdateForRemovePolicy => write!(f, "UpdateForRemovePolicy"),
79            UpdateType::UpdateForRemoveFilteredPolicy => write!(f, "UpdateForRemoveFilteredPolicy"),
80            UpdateType::UpdateForSavePolicy => write!(f, "UpdateForSavePolicy"),
81            UpdateType::UpdateForAddPolicies => write!(f, "UpdateForAddPolicies"),
82            UpdateType::UpdateForRemovePolicies => write!(f, "UpdateForRemovePolicies"),
83            UpdateType::UpdateForUpdatePolicy => write!(f, "UpdateForUpdatePolicy"),
84            UpdateType::UpdateForUpdatePolicies => write!(f, "UpdateForUpdatePolicies"),
85        }
86    }
87}
88
89/// Message structure for Redis pub/sub communication
90#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(rename_all = "PascalCase")]
92pub struct Message {
93    pub method: UpdateType,
94    #[serde(rename = "ID")]
95    pub id: String,
96    #[serde(default, skip_serializing_if = "String::is_empty")]
97    pub sec: String,
98    #[serde(default, skip_serializing_if = "String::is_empty")]
99    pub ptype: String,
100    #[serde(default, skip_serializing_if = "Vec::is_empty")]
101    pub old_rule: Vec<String>,
102    #[serde(default, skip_serializing_if = "Vec::is_empty")]
103    pub old_rules: Vec<Vec<String>>,
104    #[serde(default, skip_serializing_if = "Vec::is_empty")]
105    pub new_rule: Vec<String>,
106    #[serde(default, skip_serializing_if = "Vec::is_empty")]
107    pub new_rules: Vec<Vec<String>>,
108    #[serde(default)]
109    pub field_index: i32,
110    #[serde(default, skip_serializing_if = "Vec::is_empty")]
111    pub field_values: Vec<String>,
112}
113
114impl Message {
115    pub fn new(method: UpdateType, id: String) -> Self {
116        Self {
117            method,
118            id,
119            sec: String::new(),
120            ptype: String::new(),
121            old_rule: Vec::new(),
122            old_rules: Vec::new(),
123            new_rule: Vec::new(),
124            new_rules: Vec::new(),
125            field_index: 0,
126            field_values: Vec::new(),
127        }
128    }
129
130    pub fn to_json(&self) -> Result<String> {
131        Ok(serde_json::to_string(self)?)
132    }
133
134    pub fn from_json(json: &str) -> Result<Self> {
135        Ok(serde_json::from_str(json)?)
136    }
137}
138
139// ========== Helper Functions ==========
140
141/// Convert EventData to Message for publishing
142fn event_data_to_message(event_data: &EventData, local_id: &str) -> Message {
143    match event_data {
144        EventData::AddPolicy(sec, ptype, rule) => {
145            let mut message = Message::new(UpdateType::UpdateForAddPolicy, local_id.to_string());
146            message.sec = sec.clone();
147            message.ptype = ptype.clone();
148            message.new_rule = rule.clone();
149            message
150        }
151        EventData::AddPolicies(sec, ptype, rules) => {
152            let mut message = Message::new(UpdateType::UpdateForAddPolicies, local_id.to_string());
153            message.sec = sec.clone();
154            message.ptype = ptype.clone();
155            message.new_rules = rules.clone();
156            message
157        }
158        EventData::RemovePolicy(sec, ptype, rule) => {
159            let mut message = Message::new(UpdateType::UpdateForRemovePolicy, local_id.to_string());
160            message.sec = sec.clone();
161            message.ptype = ptype.clone();
162            message.old_rule = rule.clone();
163            message
164        }
165        EventData::RemovePolicies(sec, ptype, rules) => {
166            let mut message =
167                Message::new(UpdateType::UpdateForRemovePolicies, local_id.to_string());
168            message.sec = sec.clone();
169            message.ptype = ptype.clone();
170            message.old_rules = rules.clone();
171            message
172        }
173        EventData::RemoveFilteredPolicy(sec, ptype, field_values) => {
174            let mut message = Message::new(
175                UpdateType::UpdateForRemoveFilteredPolicy,
176                local_id.to_string(),
177            );
178            message.sec = sec.clone();
179            message.ptype = ptype.clone();
180            if !field_values.is_empty() {
181                message.field_values = field_values[0].clone();
182            }
183            message
184        }
185        EventData::SavePolicy(_) => {
186            Message::new(UpdateType::UpdateForSavePolicy, local_id.to_string())
187        }
188        EventData::ClearPolicy => Message::new(UpdateType::Update, local_id.to_string()),
189        EventData::ClearCache => Message::new(UpdateType::Update, local_id.to_string()),
190    }
191}
192
193// ========== Redis Client Wrapper ==========
194
195/// Wrapper to support both standalone and cluster Redis
196enum RedisClientWrapper {
197    Standalone(Client),
198    // For Cluster mode, we use a single node connection for pubsub
199    // Redis Cluster PubSub messages don't propagate across nodes,
200    // so all instances must connect to the same node for pub/sub
201    ClusterPubSub { pubsub_client: Client },
202}
203
204impl RedisClientWrapper {
205    async fn get_async_pubsub(&self) -> redis::RedisResult<redis::aio::PubSub> {
206        match self {
207            RedisClientWrapper::Standalone(client) => client.get_async_pubsub().await,
208            RedisClientWrapper::ClusterPubSub { pubsub_client } => {
209                // Use the dedicated pubsub client for cluster mode
210                pubsub_client.get_async_pubsub().await
211            }
212        }
213    }
214
215    async fn publish_message(&self, channel: &str, payload: String) -> redis::RedisResult<()> {
216        match self {
217            RedisClientWrapper::Standalone(client) => {
218                let mut conn = client.get_multiplexed_async_connection().await?;
219                let _: i32 = conn.publish(channel, payload).await?;
220                Ok(())
221            }
222            RedisClientWrapper::ClusterPubSub { pubsub_client } => {
223                // For Redis Cluster, we need to publish to the same node where PubSub is subscribed
224                // because PubSub messages don't propagate across cluster nodes
225                // Use the pubsub_client (single node) for both publishing and subscribing
226                let mut conn = pubsub_client.get_multiplexed_async_connection().await?;
227                let _: i32 = conn.publish(channel, payload).await?;
228                log::debug!("Published to cluster node via pubsub_client");
229                Ok(())
230            }
231        }
232    }
233}
234
235// ========== Redis Watcher Implementation ==========
236
237pub struct RedisWatcher {
238    client: Arc<RedisClientWrapper>,
239    options: crate::WatcherOptions,
240    callback: CallbackArc,
241    publish_tx: mpsc::UnboundedSender<Message>,
242    publish_task: Arc<Mutex<Option<JoinHandle<()>>>>,
243    subscription_task: Arc<Mutex<Option<JoinHandle<()>>>>,
244    is_closed: Arc<AtomicBool>,
245    subscription_ready: Arc<tokio::sync::Notify>,
246}
247
248impl RedisWatcher {
249    /// Create a new Redis watcher for standalone Redis
250    pub fn new(redis_url: &str, options: crate::WatcherOptions) -> Result<Self> {
251        let client = Arc::new(RedisClientWrapper::Standalone(Client::open(redis_url)?));
252
253        // Create publish channel
254        let (publish_tx, publish_rx) = mpsc::unbounded_channel::<Message>();
255
256        let is_closed = Arc::new(AtomicBool::new(false));
257        let subscription_ready = Arc::new(tokio::sync::Notify::new());
258
259        // Spawn publish task
260        let publish_task = {
261            let client = client.clone();
262            let channel = options.channel.clone();
263            let is_closed = is_closed.clone();
264
265            tokio::spawn(async move {
266                Self::publish_worker(publish_rx, client, channel, is_closed).await
267            })
268        };
269
270        let watcher = Self {
271            client,
272            options,
273            callback: Arc::new(Mutex::new(None)),
274            publish_tx,
275            publish_task: Arc::new(Mutex::new(Some(publish_task))),
276            subscription_task: Arc::new(Mutex::new(None)),
277            is_closed,
278            subscription_ready,
279        };
280
281        // Start subscription immediately like Go version does
282        // This ensures the watcher is ready to receive messages before any publishes happen
283        watcher.start_subscription()?;
284
285        Ok(watcher)
286    }
287
288    /// Create a new Redis watcher for Redis Cluster
289    ///
290    /// Note: Redis Cluster PubSub messages don't propagate between nodes.
291    /// All instances MUST connect to the SAME node for pub/sub to work.
292    /// This method uses the first URL as the fixed PubSub node.
293    ///
294    /// # Arguments
295    /// * `cluster_urls` - Comma-separated Redis URLs (first URL used for PubSub)
296    /// * `options` - Watcher configuration options
297    pub fn new_cluster(cluster_urls: &str, options: crate::WatcherOptions) -> Result<Self> {
298        // Parse cluster URLs
299        let urls: Vec<&str> = cluster_urls.split(',').map(|s| s.trim()).collect();
300        if urls.is_empty() {
301            return Err(WatcherError::Configuration(
302                "No cluster URLs provided".to_string(),
303            ));
304        }
305
306        // For Redis Cluster PubSub: use the first node for both publish and subscribe
307        // This ensures messages are sent and received on the same node
308        // since PubSub messages don't propagate across cluster nodes
309        let pubsub_url = urls[0];
310        let pubsub_client = Client::open(pubsub_url).map_err(|e| {
311            WatcherError::Configuration(format!("Failed to create pubsub client: {}", e))
312        })?;
313
314        log::warn!(
315            "Redis Cluster PubSub using fixed node: {} - ALL instances MUST use the SAME node!",
316            pubsub_url
317        );
318
319        let client = Arc::new(RedisClientWrapper::ClusterPubSub { pubsub_client });
320
321        // Create publish channel
322        let (publish_tx, publish_rx) = mpsc::unbounded_channel::<Message>();
323
324        let is_closed = Arc::new(AtomicBool::new(false));
325        let subscription_ready = Arc::new(tokio::sync::Notify::new());
326
327        // Spawn publish task
328        let publish_task = {
329            let client = client.clone();
330            let channel = options.channel.clone();
331            let is_closed = is_closed.clone();
332
333            tokio::spawn(async move {
334                Self::publish_worker(publish_rx, client, channel, is_closed).await
335            })
336        };
337
338        let watcher = Self {
339            client,
340            options,
341            callback: Arc::new(Mutex::new(None)),
342            publish_tx,
343            publish_task: Arc::new(Mutex::new(Some(publish_task))),
344            subscription_task: Arc::new(Mutex::new(None)),
345            is_closed,
346            subscription_ready,
347        };
348
349        // Start subscription immediately like Go version does
350        // This ensures the watcher is ready to receive messages before any publishes happen
351        watcher.start_subscription()?;
352
353        Ok(watcher)
354    }
355
356    /// Background worker for publishing messages
357    async fn publish_worker(
358        mut rx: mpsc::UnboundedReceiver<Message>,
359        client: Arc<RedisClientWrapper>,
360        channel: String,
361        is_closed: Arc<AtomicBool>,
362    ) {
363        while let Some(message) = rx.recv().await {
364            if is_closed.load(Ordering::Relaxed) {
365                break;
366            }
367
368            if let Ok(payload) = message.to_json() {
369                eprintln!(
370                    "[RedisWatcher] Publishing message to channel {}: {}",
371                    channel, payload
372                );
373
374                // Retry publishing with exponential backoff
375                let mut retry_count = 0;
376                loop {
377                    match client.publish_message(&channel, payload.clone()).await {
378                        Ok(_) => {
379                            eprintln!(
380                                "[RedisWatcher] Successfully published message to channel: {}",
381                                channel
382                            );
383                            break;
384                        }
385                        Err(e) => {
386                            retry_count += 1;
387                            eprintln!(
388                                "[RedisWatcher] Failed to publish message (attempt {}): {}",
389                                retry_count, e
390                            );
391                            if retry_count >= 3 {
392                                eprintln!(
393                                    "[RedisWatcher] Failed to publish message after {} attempts: {}",
394                                    retry_count,
395                                    e
396                                );
397                                break;
398                            }
399                            tokio::time::sleep(tokio::time::Duration::from_millis(
400                                100 * retry_count,
401                            ))
402                            .await;
403                        }
404                    }
405                }
406            } else {
407                eprintln!("[RedisWatcher] Failed to serialize message to JSON");
408            }
409        }
410    }
411
412    /// Wait for subscription to be ready (similar to Go's WaitGroup.Wait())
413    ///
414    /// This ensures that the watcher is fully subscribed before publishing messages.
415    /// Recommended to call this after creating the watcher and before any policy operations.
416    pub async fn wait_for_ready(&self) {
417        // Wait with timeout
418        let timeout = tokio::time::Duration::from_secs(5);
419        let _ = tokio::time::timeout(timeout, self.subscription_ready.notified()).await;
420    }
421
422    /// Publish message to Redis channel
423    fn publish_message(&self, message: &Message) -> Result<()> {
424        if self.is_closed.load(Ordering::Relaxed) {
425            return Err(WatcherError::AlreadyClosed);
426        }
427
428        self.publish_tx
429            .send(message.clone())
430            .map_err(|_| WatcherError::Runtime("Publish channel closed".to_string()))?;
431
432        Ok(())
433    }
434
435    /// Start subscription to Redis channel
436    fn start_subscription(&self) -> Result<()> {
437        if self.is_closed.load(Ordering::Relaxed) {
438            return Err(WatcherError::AlreadyClosed);
439        }
440
441        let callback = self.callback.clone();
442        let channel = self.options.channel.clone();
443        let local_id = self.options.local_id.clone();
444        let ignore_self = self.options.ignore_self;
445        let is_closed = self.is_closed.clone();
446        let client = self.client.clone();
447        let subscription_ready = self.subscription_ready.clone();
448
449        let handle = tokio::spawn(async move {
450            Self::subscription_worker(
451                client,
452                channel,
453                local_id,
454                ignore_self,
455                is_closed,
456                callback,
457                subscription_ready,
458            )
459            .await
460        });
461
462        *self.subscription_task.lock().unwrap() = Some(handle);
463        Ok(())
464    }
465
466    /// Background worker for subscription
467    async fn subscription_worker(
468        client: Arc<RedisClientWrapper>,
469        channel: String,
470        local_id: String,
471        ignore_self: bool,
472        is_closed: Arc<AtomicBool>,
473        callback: CallbackArc,
474        subscription_ready: Arc<tokio::sync::Notify>,
475    ) {
476        let result = async {
477            // Retry connection with backoff
478            let mut retry_count = 0;
479            let mut pubsub = loop {
480                if is_closed.load(Ordering::Relaxed) {
481                    return Ok(());
482                }
483
484                match client.get_async_pubsub().await {
485                    Ok(p) => break p,
486                    Err(e) => {
487                        retry_count += 1;
488                        log::warn!(
489                            "Failed to get async pubsub (attempt {}): {}",
490                            retry_count,
491                            e
492                        );
493                        if retry_count > 5 {
494                            return Err(e);
495                        }
496                        tokio::time::sleep(tokio::time::Duration::from_millis(1000 * retry_count))
497                            .await;
498                    }
499                }
500            };
501
502            // Subscribe with retry
503            let mut subscribe_retry = 0;
504            loop {
505                if is_closed.load(Ordering::Relaxed) {
506                    return Ok(());
507                }
508
509                match pubsub.subscribe(&channel).await {
510                    Ok(_) => {
511                        eprintln!(
512                            "[RedisWatcher] Successfully subscribed to channel: {}",
513                            channel
514                        );
515                        // Notify that subscription is ready (similar to Go's WaitGroup.Done())
516                        subscription_ready.notify_waiters();
517                        break;
518                    }
519                    Err(e) => {
520                        subscribe_retry += 1;
521                        eprintln!(
522                            "[RedisWatcher] Failed to subscribe to channel {} (attempt {}): {}",
523                            channel, subscribe_retry, e
524                        );
525                        if subscribe_retry > 5 {
526                            return Err(e);
527                        }
528                        tokio::time::sleep(tokio::time::Duration::from_millis(
529                            500 * subscribe_retry,
530                        ))
531                        .await;
532                    }
533                }
534            }
535
536            let mut stream = pubsub.on_message();
537
538            loop {
539                // Check if closed before waiting for next message
540                if is_closed.load(Ordering::Relaxed) {
541                    break;
542                }
543
544                // Use tokio::select! to check for shutdown while waiting
545                tokio::select! {
546                    msg_opt = stream.next() => {
547                        match msg_opt {
548                            Some(msg) => {
549                                let payload: String = msg.get_payload().unwrap_or_default();
550                                eprintln!("[RedisWatcher] Received message on channel {}: {}", channel, payload);
551
552                                // Parse message and check if we should ignore it
553                                if ignore_self {
554                                    if let Ok(parsed_msg) = Message::from_json(&payload) {
555                                        if parsed_msg.id == local_id {
556                                            eprintln!("[RedisWatcher] Ignoring self message from: {}", parsed_msg.id);
557                                            continue;
558                                        }
559                                    }
560                                }
561
562                                // Call callback
563                                if let Ok(mut cb_guard) = callback.lock() {
564                                    if let Some(ref mut cb) = *cb_guard {
565                                        eprintln!("[RedisWatcher] Invoking callback for message");
566                                        cb(payload);
567                                    } else {
568                                        eprintln!("[RedisWatcher] Callback not set, message ignored");
569                                    }
570                                } else {
571                                    eprintln!("[RedisWatcher] Failed to acquire callback lock");
572                                }
573                            }
574                            None => {
575                                // Stream ended
576                                eprintln!("[RedisWatcher] Pubsub stream ended");
577                                break;
578                            }
579                        }
580                    }
581                    _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
582                        // Periodic check for shutdown
583                        if is_closed.load(Ordering::Relaxed) {
584                            break;
585                        }
586                    }
587                }
588            }
589
590            Ok::<(), redis::RedisError>(())
591        };
592
593        if let Err(e) = result.await {
594            log::error!("Subscription error: {}", e);
595        }
596    }
597}
598
599impl Watcher for RedisWatcher {
600    fn set_update_callback(&mut self, cb: Box<dyn FnMut(String) + Send + Sync>) {
601        eprintln!("[RedisWatcher] Setting update callback");
602        *self.callback.lock().unwrap() = Some(cb);
603
604        // Note: Unlike the old implementation, we don't restart subscription here
605        // because subscription is already started in new()/new_cluster()
606        // This matches the Go implementation where subscribe() is called in NewWatcher()
607    }
608
609    fn update(&mut self, d: EventData) {
610        let message = event_data_to_message(&d, &self.options.local_id);
611        eprintln!(
612            "[RedisWatcher] update() called with event: {:?}",
613            message.method
614        );
615        let _ = self.publish_message(&message);
616    }
617}
618
619impl Drop for RedisWatcher {
620    fn drop(&mut self) {
621        // Signal closure first
622        self.is_closed.store(true, Ordering::Relaxed);
623
624        // Abort subscription task
625        if let Ok(mut handle_guard) = self.subscription_task.lock() {
626            if let Some(handle) = handle_guard.take() {
627                handle.abort();
628            }
629        }
630
631        // Abort publish task
632        if let Ok(mut handle_guard) = self.publish_task.lock() {
633            if let Some(handle) = handle_guard.take() {
634                handle.abort();
635            }
636        }
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[test]
645    fn test_message_serialization() {
646        let message = Message::new(UpdateType::Update, "test-id".to_string());
647        let json = message.to_json().unwrap();
648        let parsed = Message::from_json(&json).unwrap();
649        assert_eq!(message.method, parsed.method);
650        assert_eq!(message.id, parsed.id);
651    }
652
653    #[test]
654    fn test_event_data_conversion() {
655        let event = EventData::AddPolicy(
656            "p".to_string(),
657            "p".to_string(),
658            vec!["alice".to_string(), "data1".to_string(), "read".to_string()],
659        );
660
661        let message = event_data_to_message(&event, "test-id");
662        assert_eq!(message.method, UpdateType::UpdateForAddPolicy);
663        assert_eq!(message.sec, "p");
664        assert_eq!(message.ptype, "p");
665        assert_eq!(message.new_rule, vec!["alice", "data1", "read"]);
666    }
667
668    // Note: Integration tests that require Redis are in watcher_test.rs
669}