1use casbin::{EventData, Watcher};
16use redis::{AsyncCommands, Client};
17use serde::{Deserialize, Serialize};
18use std::sync::{
19 atomic::{AtomicBool, Ordering},
20 Arc, Mutex,
21};
22use thiserror::Error;
23use tokio::sync::mpsc;
24use tokio::task::JoinHandle;
25use tokio_stream::StreamExt;
26
27#[derive(Error, Debug)]
30pub enum WatcherError {
31 #[error("Redis connection error: {0}")]
32 RedisConnection(#[from] redis::RedisError),
33
34 #[error("Serialization error: {0}")]
35 Serialization(#[from] serde_json::Error),
36
37 #[error("Callback not set")]
38 CallbackNotSet,
39
40 #[error("Watcher already closed")]
41 AlreadyClosed,
42
43 #[error("Configuration error: {0}")]
44 Configuration(String),
45
46 #[error("Runtime error: {0}")]
47 Runtime(String),
48}
49
50pub type Result<T> = std::result::Result<T, WatcherError>;
51
52type UpdateCallback = Box<dyn FnMut(String) + Send + Sync>;
54type CallbackArc = Arc<Mutex<Option<UpdateCallback>>>;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60#[serde(rename_all = "PascalCase")]
61pub enum UpdateType {
62 Update,
63 UpdateForAddPolicy,
64 UpdateForRemovePolicy,
65 UpdateForRemoveFilteredPolicy,
66 UpdateForSavePolicy,
67 UpdateForAddPolicies,
68 UpdateForRemovePolicies,
69 UpdateForUpdatePolicy,
70 UpdateForUpdatePolicies,
71}
72
73impl std::fmt::Display for UpdateType {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 UpdateType::Update => write!(f, "Update"),
77 UpdateType::UpdateForAddPolicy => write!(f, "UpdateForAddPolicy"),
78 UpdateType::UpdateForRemovePolicy => write!(f, "UpdateForRemovePolicy"),
79 UpdateType::UpdateForRemoveFilteredPolicy => write!(f, "UpdateForRemoveFilteredPolicy"),
80 UpdateType::UpdateForSavePolicy => write!(f, "UpdateForSavePolicy"),
81 UpdateType::UpdateForAddPolicies => write!(f, "UpdateForAddPolicies"),
82 UpdateType::UpdateForRemovePolicies => write!(f, "UpdateForRemovePolicies"),
83 UpdateType::UpdateForUpdatePolicy => write!(f, "UpdateForUpdatePolicy"),
84 UpdateType::UpdateForUpdatePolicies => write!(f, "UpdateForUpdatePolicies"),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(rename_all = "PascalCase")]
92pub struct Message {
93 pub method: UpdateType,
94 #[serde(rename = "ID")]
95 pub id: String,
96 #[serde(default, skip_serializing_if = "String::is_empty")]
97 pub sec: String,
98 #[serde(default, skip_serializing_if = "String::is_empty")]
99 pub ptype: String,
100 #[serde(default, skip_serializing_if = "Vec::is_empty")]
101 pub old_rule: Vec<String>,
102 #[serde(default, skip_serializing_if = "Vec::is_empty")]
103 pub old_rules: Vec<Vec<String>>,
104 #[serde(default, skip_serializing_if = "Vec::is_empty")]
105 pub new_rule: Vec<String>,
106 #[serde(default, skip_serializing_if = "Vec::is_empty")]
107 pub new_rules: Vec<Vec<String>>,
108 #[serde(default)]
109 pub field_index: i32,
110 #[serde(default, skip_serializing_if = "Vec::is_empty")]
111 pub field_values: Vec<String>,
112}
113
114impl Message {
115 pub fn new(method: UpdateType, id: String) -> Self {
116 Self {
117 method,
118 id,
119 sec: String::new(),
120 ptype: String::new(),
121 old_rule: Vec::new(),
122 old_rules: Vec::new(),
123 new_rule: Vec::new(),
124 new_rules: Vec::new(),
125 field_index: 0,
126 field_values: Vec::new(),
127 }
128 }
129
130 pub fn to_json(&self) -> Result<String> {
131 Ok(serde_json::to_string(self)?)
132 }
133
134 pub fn from_json(json: &str) -> Result<Self> {
135 Ok(serde_json::from_str(json)?)
136 }
137}
138
139fn event_data_to_message(event_data: &EventData, local_id: &str) -> Message {
143 match event_data {
144 EventData::AddPolicy(sec, ptype, rule) => {
145 let mut message = Message::new(UpdateType::UpdateForAddPolicy, local_id.to_string());
146 message.sec = sec.clone();
147 message.ptype = ptype.clone();
148 message.new_rule = rule.clone();
149 message
150 }
151 EventData::AddPolicies(sec, ptype, rules) => {
152 let mut message = Message::new(UpdateType::UpdateForAddPolicies, local_id.to_string());
153 message.sec = sec.clone();
154 message.ptype = ptype.clone();
155 message.new_rules = rules.clone();
156 message
157 }
158 EventData::RemovePolicy(sec, ptype, rule) => {
159 let mut message = Message::new(UpdateType::UpdateForRemovePolicy, local_id.to_string());
160 message.sec = sec.clone();
161 message.ptype = ptype.clone();
162 message.old_rule = rule.clone();
163 message
164 }
165 EventData::RemovePolicies(sec, ptype, rules) => {
166 let mut message =
167 Message::new(UpdateType::UpdateForRemovePolicies, local_id.to_string());
168 message.sec = sec.clone();
169 message.ptype = ptype.clone();
170 message.old_rules = rules.clone();
171 message
172 }
173 EventData::RemoveFilteredPolicy(sec, ptype, field_values) => {
174 let mut message = Message::new(
175 UpdateType::UpdateForRemoveFilteredPolicy,
176 local_id.to_string(),
177 );
178 message.sec = sec.clone();
179 message.ptype = ptype.clone();
180 if !field_values.is_empty() {
181 message.field_values = field_values[0].clone();
182 }
183 message
184 }
185 EventData::SavePolicy(_) => {
186 Message::new(UpdateType::UpdateForSavePolicy, local_id.to_string())
187 }
188 EventData::ClearPolicy => Message::new(UpdateType::Update, local_id.to_string()),
189 EventData::ClearCache => Message::new(UpdateType::Update, local_id.to_string()),
190 }
191}
192
193enum RedisClientWrapper {
197 Standalone(Client),
198 ClusterPubSub { pubsub_client: Client },
202}
203
204impl RedisClientWrapper {
205 async fn get_async_pubsub(&self) -> redis::RedisResult<redis::aio::PubSub> {
206 match self {
207 RedisClientWrapper::Standalone(client) => client.get_async_pubsub().await,
208 RedisClientWrapper::ClusterPubSub { pubsub_client } => {
209 pubsub_client.get_async_pubsub().await
211 }
212 }
213 }
214
215 async fn publish_message(&self, channel: &str, payload: String) -> redis::RedisResult<()> {
216 match self {
217 RedisClientWrapper::Standalone(client) => {
218 let mut conn = client.get_multiplexed_async_connection().await?;
219 let _: i32 = conn.publish(channel, payload).await?;
220 Ok(())
221 }
222 RedisClientWrapper::ClusterPubSub { pubsub_client } => {
223 let mut conn = pubsub_client.get_multiplexed_async_connection().await?;
227 let _: i32 = conn.publish(channel, payload).await?;
228 log::debug!("Published to cluster node via pubsub_client");
229 Ok(())
230 }
231 }
232 }
233}
234
235pub struct RedisWatcher {
238 client: Arc<RedisClientWrapper>,
239 options: crate::WatcherOptions,
240 callback: CallbackArc,
241 publish_tx: mpsc::UnboundedSender<Message>,
242 publish_task: Arc<Mutex<Option<JoinHandle<()>>>>,
243 subscription_task: Arc<Mutex<Option<JoinHandle<()>>>>,
244 is_closed: Arc<AtomicBool>,
245 subscription_ready: Arc<tokio::sync::Notify>,
246}
247
248impl RedisWatcher {
249 pub fn new(redis_url: &str, options: crate::WatcherOptions) -> Result<Self> {
251 let client = Arc::new(RedisClientWrapper::Standalone(Client::open(redis_url)?));
252
253 let (publish_tx, publish_rx) = mpsc::unbounded_channel::<Message>();
255
256 let is_closed = Arc::new(AtomicBool::new(false));
257 let subscription_ready = Arc::new(tokio::sync::Notify::new());
258
259 let publish_task = {
261 let client = client.clone();
262 let channel = options.channel.clone();
263 let is_closed = is_closed.clone();
264
265 tokio::spawn(async move {
266 Self::publish_worker(publish_rx, client, channel, is_closed).await
267 })
268 };
269
270 let watcher = Self {
271 client,
272 options,
273 callback: Arc::new(Mutex::new(None)),
274 publish_tx,
275 publish_task: Arc::new(Mutex::new(Some(publish_task))),
276 subscription_task: Arc::new(Mutex::new(None)),
277 is_closed,
278 subscription_ready,
279 };
280
281 watcher.start_subscription()?;
284
285 Ok(watcher)
286 }
287
288 pub fn new_cluster(cluster_urls: &str, options: crate::WatcherOptions) -> Result<Self> {
298 let urls: Vec<&str> = cluster_urls.split(',').map(|s| s.trim()).collect();
300 if urls.is_empty() {
301 return Err(WatcherError::Configuration(
302 "No cluster URLs provided".to_string(),
303 ));
304 }
305
306 let pubsub_url = urls[0];
310 let pubsub_client = Client::open(pubsub_url).map_err(|e| {
311 WatcherError::Configuration(format!("Failed to create pubsub client: {}", e))
312 })?;
313
314 log::warn!(
315 "Redis Cluster PubSub using fixed node: {} - ALL instances MUST use the SAME node!",
316 pubsub_url
317 );
318
319 let client = Arc::new(RedisClientWrapper::ClusterPubSub { pubsub_client });
320
321 let (publish_tx, publish_rx) = mpsc::unbounded_channel::<Message>();
323
324 let is_closed = Arc::new(AtomicBool::new(false));
325 let subscription_ready = Arc::new(tokio::sync::Notify::new());
326
327 let publish_task = {
329 let client = client.clone();
330 let channel = options.channel.clone();
331 let is_closed = is_closed.clone();
332
333 tokio::spawn(async move {
334 Self::publish_worker(publish_rx, client, channel, is_closed).await
335 })
336 };
337
338 let watcher = Self {
339 client,
340 options,
341 callback: Arc::new(Mutex::new(None)),
342 publish_tx,
343 publish_task: Arc::new(Mutex::new(Some(publish_task))),
344 subscription_task: Arc::new(Mutex::new(None)),
345 is_closed,
346 subscription_ready,
347 };
348
349 watcher.start_subscription()?;
352
353 Ok(watcher)
354 }
355
356 async fn publish_worker(
358 mut rx: mpsc::UnboundedReceiver<Message>,
359 client: Arc<RedisClientWrapper>,
360 channel: String,
361 is_closed: Arc<AtomicBool>,
362 ) {
363 while let Some(message) = rx.recv().await {
364 if is_closed.load(Ordering::Relaxed) {
365 break;
366 }
367
368 if let Ok(payload) = message.to_json() {
369 eprintln!(
370 "[RedisWatcher] Publishing message to channel {}: {}",
371 channel, payload
372 );
373
374 let mut retry_count = 0;
376 loop {
377 match client.publish_message(&channel, payload.clone()).await {
378 Ok(_) => {
379 eprintln!(
380 "[RedisWatcher] Successfully published message to channel: {}",
381 channel
382 );
383 break;
384 }
385 Err(e) => {
386 retry_count += 1;
387 eprintln!(
388 "[RedisWatcher] Failed to publish message (attempt {}): {}",
389 retry_count, e
390 );
391 if retry_count >= 3 {
392 eprintln!(
393 "[RedisWatcher] Failed to publish message after {} attempts: {}",
394 retry_count,
395 e
396 );
397 break;
398 }
399 tokio::time::sleep(tokio::time::Duration::from_millis(
400 100 * retry_count,
401 ))
402 .await;
403 }
404 }
405 }
406 } else {
407 eprintln!("[RedisWatcher] Failed to serialize message to JSON");
408 }
409 }
410 }
411
412 pub async fn wait_for_ready(&self) {
417 let timeout = tokio::time::Duration::from_secs(5);
419 let _ = tokio::time::timeout(timeout, self.subscription_ready.notified()).await;
420 }
421
422 fn publish_message(&self, message: &Message) -> Result<()> {
424 if self.is_closed.load(Ordering::Relaxed) {
425 return Err(WatcherError::AlreadyClosed);
426 }
427
428 self.publish_tx
429 .send(message.clone())
430 .map_err(|_| WatcherError::Runtime("Publish channel closed".to_string()))?;
431
432 Ok(())
433 }
434
435 fn start_subscription(&self) -> Result<()> {
437 if self.is_closed.load(Ordering::Relaxed) {
438 return Err(WatcherError::AlreadyClosed);
439 }
440
441 let callback = self.callback.clone();
442 let channel = self.options.channel.clone();
443 let local_id = self.options.local_id.clone();
444 let ignore_self = self.options.ignore_self;
445 let is_closed = self.is_closed.clone();
446 let client = self.client.clone();
447 let subscription_ready = self.subscription_ready.clone();
448
449 let handle = tokio::spawn(async move {
450 Self::subscription_worker(
451 client,
452 channel,
453 local_id,
454 ignore_self,
455 is_closed,
456 callback,
457 subscription_ready,
458 )
459 .await
460 });
461
462 *self.subscription_task.lock().unwrap() = Some(handle);
463 Ok(())
464 }
465
466 async fn subscription_worker(
468 client: Arc<RedisClientWrapper>,
469 channel: String,
470 local_id: String,
471 ignore_self: bool,
472 is_closed: Arc<AtomicBool>,
473 callback: CallbackArc,
474 subscription_ready: Arc<tokio::sync::Notify>,
475 ) {
476 let result = async {
477 let mut retry_count = 0;
479 let mut pubsub = loop {
480 if is_closed.load(Ordering::Relaxed) {
481 return Ok(());
482 }
483
484 match client.get_async_pubsub().await {
485 Ok(p) => break p,
486 Err(e) => {
487 retry_count += 1;
488 log::warn!(
489 "Failed to get async pubsub (attempt {}): {}",
490 retry_count,
491 e
492 );
493 if retry_count > 5 {
494 return Err(e);
495 }
496 tokio::time::sleep(tokio::time::Duration::from_millis(1000 * retry_count))
497 .await;
498 }
499 }
500 };
501
502 let mut subscribe_retry = 0;
504 loop {
505 if is_closed.load(Ordering::Relaxed) {
506 return Ok(());
507 }
508
509 match pubsub.subscribe(&channel).await {
510 Ok(_) => {
511 eprintln!(
512 "[RedisWatcher] Successfully subscribed to channel: {}",
513 channel
514 );
515 subscription_ready.notify_waiters();
517 break;
518 }
519 Err(e) => {
520 subscribe_retry += 1;
521 eprintln!(
522 "[RedisWatcher] Failed to subscribe to channel {} (attempt {}): {}",
523 channel, subscribe_retry, e
524 );
525 if subscribe_retry > 5 {
526 return Err(e);
527 }
528 tokio::time::sleep(tokio::time::Duration::from_millis(
529 500 * subscribe_retry,
530 ))
531 .await;
532 }
533 }
534 }
535
536 let mut stream = pubsub.on_message();
537
538 loop {
539 if is_closed.load(Ordering::Relaxed) {
541 break;
542 }
543
544 tokio::select! {
546 msg_opt = stream.next() => {
547 match msg_opt {
548 Some(msg) => {
549 let payload: String = msg.get_payload().unwrap_or_default();
550 eprintln!("[RedisWatcher] Received message on channel {}: {}", channel, payload);
551
552 if ignore_self {
554 if let Ok(parsed_msg) = Message::from_json(&payload) {
555 if parsed_msg.id == local_id {
556 eprintln!("[RedisWatcher] Ignoring self message from: {}", parsed_msg.id);
557 continue;
558 }
559 }
560 }
561
562 if let Ok(mut cb_guard) = callback.lock() {
564 if let Some(ref mut cb) = *cb_guard {
565 eprintln!("[RedisWatcher] Invoking callback for message");
566 cb(payload);
567 } else {
568 eprintln!("[RedisWatcher] Callback not set, message ignored");
569 }
570 } else {
571 eprintln!("[RedisWatcher] Failed to acquire callback lock");
572 }
573 }
574 None => {
575 eprintln!("[RedisWatcher] Pubsub stream ended");
577 break;
578 }
579 }
580 }
581 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
582 if is_closed.load(Ordering::Relaxed) {
584 break;
585 }
586 }
587 }
588 }
589
590 Ok::<(), redis::RedisError>(())
591 };
592
593 if let Err(e) = result.await {
594 log::error!("Subscription error: {}", e);
595 }
596 }
597}
598
599impl Watcher for RedisWatcher {
600 fn set_update_callback(&mut self, cb: Box<dyn FnMut(String) + Send + Sync>) {
601 eprintln!("[RedisWatcher] Setting update callback");
602 *self.callback.lock().unwrap() = Some(cb);
603
604 }
608
609 fn update(&mut self, d: EventData) {
610 let message = event_data_to_message(&d, &self.options.local_id);
611 eprintln!(
612 "[RedisWatcher] update() called with event: {:?}",
613 message.method
614 );
615 let _ = self.publish_message(&message);
616 }
617}
618
619impl Drop for RedisWatcher {
620 fn drop(&mut self) {
621 self.is_closed.store(true, Ordering::Relaxed);
623
624 if let Ok(mut handle_guard) = self.subscription_task.lock() {
626 if let Some(handle) = handle_guard.take() {
627 handle.abort();
628 }
629 }
630
631 if let Ok(mut handle_guard) = self.publish_task.lock() {
633 if let Some(handle) = handle_guard.take() {
634 handle.abort();
635 }
636 }
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_message_serialization() {
646 let message = Message::new(UpdateType::Update, "test-id".to_string());
647 let json = message.to_json().unwrap();
648 let parsed = Message::from_json(&json).unwrap();
649 assert_eq!(message.method, parsed.method);
650 assert_eq!(message.id, parsed.id);
651 }
652
653 #[test]
654 fn test_event_data_conversion() {
655 let event = EventData::AddPolicy(
656 "p".to_string(),
657 "p".to_string(),
658 vec!["alice".to_string(), "data1".to_string(), "read".to_string()],
659 );
660
661 let message = event_data_to_message(&event, "test-id");
662 assert_eq!(message.method, UpdateType::UpdateForAddPolicy);
663 assert_eq!(message.sec, "p");
664 assert_eq!(message.ptype, "p");
665 assert_eq!(message.new_rule, vec!["alice", "data1", "read"]);
666 }
667
668 }