use casbin::{EventData, Watcher};
use redis::{AsyncCommands, Client};
use serde::{Deserialize, Serialize};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use std::thread::{self, JoinHandle};
use thiserror::Error;
use tokio::runtime::Runtime;
use tokio_stream::StreamExt;
#[derive(Error, Debug)]
pub enum WatcherError {
#[error("Redis connection error: {0}")]
RedisConnection(#[from] redis::RedisError),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Callback not set")]
CallbackNotSet,
#[error("Watcher already closed")]
AlreadyClosed,
#[error("Configuration error: {0}")]
Configuration(String),
#[error("Runtime error: {0}")]
Runtime(String),
}
pub type Result<T> = std::result::Result<T, WatcherError>;
type UpdateCallback = Box<dyn FnMut(String) + Send + Sync>;
type CallbackArc = Arc<Mutex<Option<UpdateCallback>>>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "PascalCase")]
pub enum UpdateType {
Update,
UpdateForAddPolicy,
UpdateForRemovePolicy,
UpdateForRemoveFilteredPolicy,
UpdateForSavePolicy,
UpdateForAddPolicies,
UpdateForRemovePolicies,
UpdateForUpdatePolicy,
UpdateForUpdatePolicies,
}
impl std::fmt::Display for UpdateType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
UpdateType::Update => write!(f, "Update"),
UpdateType::UpdateForAddPolicy => write!(f, "UpdateForAddPolicy"),
UpdateType::UpdateForRemovePolicy => write!(f, "UpdateForRemovePolicy"),
UpdateType::UpdateForRemoveFilteredPolicy => write!(f, "UpdateForRemoveFilteredPolicy"),
UpdateType::UpdateForSavePolicy => write!(f, "UpdateForSavePolicy"),
UpdateType::UpdateForAddPolicies => write!(f, "UpdateForAddPolicies"),
UpdateType::UpdateForRemovePolicies => write!(f, "UpdateForRemovePolicies"),
UpdateType::UpdateForUpdatePolicy => write!(f, "UpdateForUpdatePolicy"),
UpdateType::UpdateForUpdatePolicies => write!(f, "UpdateForUpdatePolicies"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct Message {
pub method: UpdateType,
#[serde(rename = "ID")]
pub id: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub sec: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub ptype: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub old_rule: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub old_rules: Vec<Vec<String>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub new_rule: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub new_rules: Vec<Vec<String>>,
#[serde(default)]
pub field_index: i32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub field_values: Vec<String>,
}
impl Message {
pub fn new(method: UpdateType, id: String) -> Self {
Self {
method,
id,
sec: String::new(),
ptype: String::new(),
old_rule: Vec::new(),
old_rules: Vec::new(),
new_rule: Vec::new(),
new_rules: Vec::new(),
field_index: 0,
field_values: Vec::new(),
}
}
pub fn to_json(&self) -> Result<String> {
Ok(serde_json::to_string(self)?)
}
pub fn from_json(json: &str) -> Result<Self> {
Ok(serde_json::from_str(json)?)
}
}
fn event_data_to_message(event_data: &EventData, local_id: &str) -> Message {
match event_data {
EventData::AddPolicy(sec, ptype, rule) => {
let mut message = Message::new(UpdateType::UpdateForAddPolicy, local_id.to_string());
message.sec = sec.clone();
message.ptype = ptype.clone();
message.new_rule = rule.clone();
message
}
EventData::AddPolicies(sec, ptype, rules) => {
let mut message = Message::new(UpdateType::UpdateForAddPolicies, local_id.to_string());
message.sec = sec.clone();
message.ptype = ptype.clone();
message.new_rules = rules.clone();
message
}
EventData::RemovePolicy(sec, ptype, rule) => {
let mut message = Message::new(UpdateType::UpdateForRemovePolicy, local_id.to_string());
message.sec = sec.clone();
message.ptype = ptype.clone();
message.old_rule = rule.clone();
message
}
EventData::RemovePolicies(sec, ptype, rules) => {
let mut message =
Message::new(UpdateType::UpdateForRemovePolicies, local_id.to_string());
message.sec = sec.clone();
message.ptype = ptype.clone();
message.old_rules = rules.clone();
message
}
EventData::RemoveFilteredPolicy(sec, ptype, field_values) => {
let mut message = Message::new(
UpdateType::UpdateForRemoveFilteredPolicy,
local_id.to_string(),
);
message.sec = sec.clone();
message.ptype = ptype.clone();
if !field_values.is_empty() {
message.field_values = field_values[0].clone();
}
message
}
EventData::SavePolicy(_) => {
Message::new(UpdateType::UpdateForSavePolicy, local_id.to_string())
}
EventData::ClearPolicy => Message::new(UpdateType::Update, local_id.to_string()),
EventData::ClearCache => Message::new(UpdateType::Update, local_id.to_string()),
}
}
enum RedisClientWrapper {
Standalone(Client),
ClusterWithPubSub {
cluster: Box<redis::cluster::ClusterClient>,
pubsub_client: Client,
},
}
impl RedisClientWrapper {
async fn get_async_pubsub(&self) -> redis::RedisResult<redis::aio::PubSub> {
match self {
RedisClientWrapper::Standalone(client) => client.get_async_pubsub().await,
RedisClientWrapper::ClusterWithPubSub { pubsub_client, .. } => {
pubsub_client.get_async_pubsub().await
}
}
}
async fn publish_message(&self, channel: &str, payload: String) -> redis::RedisResult<()> {
match self {
RedisClientWrapper::Standalone(client) => {
let mut conn = client.get_multiplexed_async_connection().await?;
let _: i32 = conn.publish(channel, payload).await?;
Ok(())
}
RedisClientWrapper::ClusterWithPubSub { cluster, .. } => {
let mut conn = cluster.get_async_connection().await?;
let _: i32 = conn.publish(channel, payload).await?;
Ok(())
}
}
}
}
pub struct RedisWatcher {
runtime: Arc<Runtime>,
client: Arc<RedisClientWrapper>,
options: crate::WatcherOptions,
callback: CallbackArc,
subscription_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
is_closed: Arc<AtomicBool>,
}
impl RedisWatcher {
pub fn new(redis_url: &str, options: crate::WatcherOptions) -> Result<Self> {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| WatcherError::Runtime(e.to_string()))?,
);
let client = Arc::new(RedisClientWrapper::Standalone(Client::open(redis_url)?));
let client_clone = client.clone();
runtime.block_on(async move {
match &*client_clone {
RedisClientWrapper::Standalone(c) => {
let mut conn = c.get_multiplexed_async_connection().await?;
let _: String = redis::cmd("PING").query_async(&mut conn).await?;
}
RedisClientWrapper::ClusterWithPubSub { .. } => unreachable!(),
}
Ok::<(), WatcherError>(())
})?;
Ok(Self {
runtime,
client,
options,
callback: Arc::new(Mutex::new(None)),
subscription_handle: Arc::new(Mutex::new(None)),
is_closed: Arc::new(AtomicBool::new(false)),
})
}
pub fn new_cluster(cluster_urls: &str, options: crate::WatcherOptions) -> Result<Self> {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| WatcherError::Runtime(e.to_string()))?,
);
let urls: Vec<&str> = cluster_urls.split(',').map(|s| s.trim()).collect();
if urls.is_empty() {
return Err(WatcherError::Configuration(
"No cluster URLs provided".to_string(),
));
}
let cluster_client = redis::cluster::ClusterClient::builder(urls.clone())
.build()
.map_err(|e| {
WatcherError::Configuration(format!("Failed to build cluster client: {}", e))
})?;
let pubsub_client = Client::open(urls[0]).map_err(|e| {
WatcherError::Configuration(format!("Failed to create pubsub client: {}", e))
})?;
let client = Arc::new(RedisClientWrapper::ClusterWithPubSub {
cluster: Box::new(cluster_client),
pubsub_client,
});
let client_clone = client.clone();
runtime.block_on(async move {
match &*client_clone {
RedisClientWrapper::Standalone(_) => unreachable!(),
RedisClientWrapper::ClusterWithPubSub {
cluster,
pubsub_client,
} => {
let mut conn = cluster
.get_async_connection()
.await
.map_err(WatcherError::RedisConnection)?;
let _: String = redis::cmd("PING").query_async(&mut conn).await?;
let mut pubsub_conn = pubsub_client.get_multiplexed_async_connection().await?;
let _: String = redis::cmd("PING").query_async(&mut pubsub_conn).await?;
}
}
Ok::<(), WatcherError>(())
})?;
Ok(Self {
runtime,
client,
options,
callback: Arc::new(Mutex::new(None)),
subscription_handle: Arc::new(Mutex::new(None)),
is_closed: Arc::new(AtomicBool::new(false)),
})
}
fn publish_message(&self, message: &Message) -> Result<()> {
if self.is_closed.load(Ordering::Relaxed) {
return Err(WatcherError::AlreadyClosed);
}
let payload = message.to_json()?;
let client = self.client.clone();
let channel = self.options.channel.clone();
self.runtime.block_on(async move {
client.publish_message(&channel, payload).await?;
Ok::<(), WatcherError>(())
})?;
Ok(())
}
fn start_subscription(&self) -> Result<()> {
if self.is_closed.load(Ordering::Relaxed) {
return Err(WatcherError::AlreadyClosed);
}
let callback = self.callback.clone();
let channel = self.options.channel.clone();
let local_id = self.options.local_id.clone();
let ignore_self = self.options.ignore_self;
let is_closed = self.is_closed.clone();
let client = self.client.clone();
let handle = thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let result = async {
let mut pubsub = match client.get_async_pubsub().await {
Ok(p) => p,
Err(e) => {
log::error!("Failed to get async pubsub: {}", e);
return Err(e);
}
};
if let Err(e) = pubsub.subscribe(&channel).await {
log::error!("Failed to subscribe to channel {}: {}", channel, e);
return Err(e);
}
log::debug!("Successfully subscribed to channel: {}", channel);
let mut stream = pubsub.on_message();
loop {
if is_closed.load(Ordering::Relaxed) {
break;
}
tokio::select! {
msg_opt = stream.next() => {
match msg_opt {
Some(msg) => {
let payload: String = msg.get_payload().unwrap_or_default();
if ignore_self {
if let Ok(parsed_msg) = Message::from_json(&payload) {
if parsed_msg.id == local_id {
continue;
}
}
}
if let Ok(mut cb_guard) = callback.lock() {
if let Some(ref mut cb) = *cb_guard {
cb(payload);
}
}
}
None => {
log::debug!("Pubsub stream ended");
break;
}
}
}
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
if is_closed.load(Ordering::Relaxed) {
break;
}
}
}
}
Ok::<(), redis::RedisError>(())
};
if let Err(e) = result.await {
log::error!("Subscription error: {}", e);
}
});
});
*self.subscription_handle.lock().unwrap() = Some(handle);
Ok(())
}
}
impl Watcher for RedisWatcher {
fn set_update_callback(&mut self, cb: Box<dyn FnMut(String) + Send + Sync>) {
*self.callback.lock().unwrap() = Some(cb);
let _ = self.start_subscription();
}
fn update(&mut self, d: EventData) {
let message = event_data_to_message(&d, &self.options.local_id);
let _ = self.publish_message(&message);
}
}
impl Drop for RedisWatcher {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::Relaxed);
if let Ok(mut handle_guard) = self.subscription_handle.lock() {
if let Some(handle) = handle_guard.take() {
let _join_handle = std::thread::spawn(move || handle.join());
std::thread::sleep(std::time::Duration::from_millis(1000));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization() {
let message = Message::new(UpdateType::Update, "test-id".to_string());
let json = message.to_json().unwrap();
let parsed = Message::from_json(&json).unwrap();
assert_eq!(message.method, parsed.method);
assert_eq!(message.id, parsed.id);
}
#[test]
fn test_event_data_conversion() {
let event = EventData::AddPolicy(
"p".to_string(),
"p".to_string(),
vec!["alice".to_string(), "data1".to_string(), "read".to_string()],
);
let message = event_data_to_message(&event, "test-id");
assert_eq!(message.method, UpdateType::UpdateForAddPolicy);
assert_eq!(message.sec, "p");
assert_eq!(message.ptype, "p");
assert_eq!(message.new_rule, vec!["alice", "data1", "read"]);
}
}