1use casbin::{CoreApi, Enforcer, EventData, Watcher};
2use serde::{Deserialize, Serialize};
3use sqlx::PgPool;
4use sqlx::postgres::PgListener;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Clone)]
57pub struct SqlxWatcher {
58 db: PgPool,
59 tx: Arc<RwLock<tokio::sync::mpsc::Sender<Box<dyn FnMut() + Send + Sync>>>>,
61 rc: Arc<RwLock<tokio::sync::mpsc::Receiver<Box<dyn FnMut() + Send + Sync>>>>,
63 last_message: Arc<RwLock<PolicyChange>>,
67 instance_id: String,
69 _channel: String,
71}
72
73#[derive(thiserror::Error, Debug)]
74pub enum Error {
75 #[error("sqlx error: {0}")]
76 Sqlx(#[from] sqlx::Error),
77 #[error("serde error: {0}")]
78 Serde(#[from] serde_json::Error),
79 #[error("casbin error: {0}")]
80 Casbin(#[from] casbin::Error),
81 #[error("general error: {0}")]
82 General(String),
83}
84
85pub const DEFAULT_NOTIFY_CHANNEL: &str = "casbin_policy_change";
86const NOTIFY_MAX_BYTES: usize = 8000;
90
91pub type Result<T> = std::result::Result<T, Error>;
92
93impl SqlxWatcher {
94 pub fn new(db: PgPool) -> Self {
95 let (tx, rc) = tokio::sync::mpsc::channel(1);
96 Self {
97 db,
98 tx: Arc::new(RwLock::new(tx)),
99 rc: Arc::new(RwLock::new(rc)),
100 last_message: Arc::new(RwLock::new(PolicyChange::None)),
101 instance_id: uuid::Uuid::new_v4().to_string(),
102 _channel: DEFAULT_NOTIFY_CHANNEL.to_string(),
103 }
104 }
105
106 pub fn set_channel(&mut self, channel: &str) {
109 self._channel = channel.to_string();
110 }
111
112 pub fn channel(&self) -> String {
114 self._channel.clone()
115 }
116
117 fn is_own_message(&self, change: &PolicyChange) -> bool {
118 match change {
119 PolicyChange::AddPolicies(instance_id, _) => instance_id == &self.instance_id,
120 PolicyChange::RemovePolicies(instance_id, _) => instance_id == &self.instance_id,
121 PolicyChange::SavePolicy(instance_id, _) => instance_id == &self.instance_id,
122 PolicyChange::ClearPolicy(instance_id) => instance_id == &self.instance_id,
123 PolicyChange::ClearCache(instance_id) => instance_id == &self.instance_id,
124 PolicyChange::LoadPolicy(instance_id) => instance_id == &self.instance_id,
125 _ => false,
126 }
127 }
128
129 pub async fn listen(&mut self, enforcer: Arc<RwLock<Enforcer>>) -> Result<()> {
134 let mut listener = PgListener::connect_with(&self.db).await?;
135 listener.listen(&self._channel).await?;
136
137 {
138 enforcer.write().await.load_policy().await?;
140 }
141
142 let mut cb: Box<dyn FnMut() + Send + Sync> = Box::new(|| {
143 let cloned_enforcer = enforcer.clone();
144 tokio::task::spawn(async move {
145 if let Err(err) = cloned_enforcer.write().await.load_policy().await {
146 log::error!("failed to reload policy: {}", err);
147 }
148 });
149 });
150
151 log::info!("casbin sqlx watcher started");
152
153 loop {
154 let mut rc = self.rc.write().await;
155 tokio::select! {
156 n = listener.try_recv() => {
157 if let Ok(n) = n {
158 if let Some(notification) = n {
159
160 if notification.payload().is_empty() {
161 log::warn!("empty casbin policy change notification, doing full policy reload as fallback");
162 if let Err(e) = enforcer.write().await.load_policy().await {
163 log::error!("error while trying to reload whole policy: {}", e);
164 }
165 continue;
166 }
167
168 log::info!("received casbin policy change notification: {}", notification.payload());
169
170 let policy_change = serde_json::from_str::<PolicyChange>(notification.payload());
171
172 let result: Result<()> = match policy_change {
173 Ok(change) => {
174 match self.is_own_message(&change) {
175 false => {
176 *self.last_message.write().await = change;
177 cb();
178 Ok(())
179 },
180 true => Ok(())
181 }
182
183 },
184 Err(orig_error) => {
185 log::info!("doing full policy reload as fallback");
186 if let Err(subsequent_error) = enforcer.write().await.load_policy().await {
187 Err(Error::General(format!("failed to apply policy {}\n subsequent fallback reload error: {}", orig_error, subsequent_error)))
188 } else {
189 Err(orig_error.into())
190 }
191 }
192 };
193
194 if let Err(e) = result {
195 log::error!("error while applying casbin policy change: {}", e);
196 }
197
198
199 }
200
201 } else {
202 log::error!("casbin listener connection lost, auto reconnecting");
203 }
204 },
205 new_cb = rc.recv() => {
206 if let Some(new_cb) = new_cb {
207 log::info!("casbin watcher callback set");
208 cb = new_cb;
209 }
210 },
211 }
212 }
213 }
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217pub struct PolicyChangeData {
218 pub sec: String,
219 pub ptype: String,
220 pub vars: Vec<Vec<String>>,
221}
222
223impl PolicyChangeData {
224 #[allow(dead_code)]
225 fn flatten(self) -> Vec<Vec<String>> {
226 self.vars
227 .into_iter()
228 .map(|vars| [vec![self.sec.clone(), self.ptype.clone()], vars].concat())
229 .collect()
230 }
231}
232
233#[derive(Debug, Serialize, Deserialize)]
235pub enum PolicyChange {
236 None,
237 AddPolicies(String, PolicyChangeData),
238 RemovePolicies(String, PolicyChangeData),
239 SavePolicy(String, Vec<Vec<String>>),
240 ClearPolicy(String),
241 ClearCache(String),
242 LoadPolicy(String),
243}
244impl PolicyChange {
245 fn from(instance_id: String, value: EventData) -> Self {
246 match value {
247 EventData::AddPolicy(sec, ptype, vars) => PolicyChange::AddPolicies(
248 instance_id,
249 PolicyChangeData {
250 sec,
251 ptype,
252 vars: vec![vars],
253 },
254 ),
255 EventData::AddPolicies(sec, ptype, vars) => {
256 PolicyChange::AddPolicies(instance_id, PolicyChangeData { sec, ptype, vars })
257 }
258 EventData::RemovePolicy(sec, ptype, vars) => PolicyChange::RemovePolicies(
259 instance_id,
260 PolicyChangeData {
261 sec,
262 ptype,
263 vars: vec![vars],
264 },
265 ),
266 EventData::RemovePolicies(sec, ptype, vars) => {
267 PolicyChange::RemovePolicies(instance_id, PolicyChangeData { sec, ptype, vars })
268 }
269 EventData::RemoveFilteredPolicy(sec, ptype, vars) => {
270 PolicyChange::RemovePolicies(instance_id, PolicyChangeData { sec, ptype, vars })
271 }
272 EventData::SavePolicy(p) => PolicyChange::SavePolicy(instance_id, p),
273 EventData::ClearPolicy => PolicyChange::ClearPolicy(instance_id),
274 EventData::ClearCache => PolicyChange::ClearCache(instance_id),
275 }
276 }
277}
278
279impl Watcher for SqlxWatcher {
280 fn set_update_callback(&mut self, cb: Box<dyn FnMut() + Send + Sync>) {
281 let tx = self.tx.clone();
282 tokio::task::spawn(async move {
283 if let Err(e) = tx.write().await.send(cb).await {
284 log::error!("failed to send casbin watcher callback: {}", e);
285 }
286 });
287 }
288
289 fn update(&mut self, d: EventData) {
290 let db = self.db.clone();
291 let policy_change = PolicyChange::from(self.instance_id.clone(), d);
292 let serialized = serde_json::to_string(&policy_change).unwrap();
293
294 let serialized = if serialized.len() > NOTIFY_MAX_BYTES {
296 log::warn!("policy change too large, resorting to full reload");
297 serde_json::to_string(&PolicyChange::LoadPolicy(self.instance_id.clone())).unwrap()
298 } else {
299 serialized
300 };
301
302 let channel = self._channel.clone();
303
304 tokio::task::spawn(async move {
305 if let Err(e) = sqlx::query!(
306 r#"
307 SELECT pg_notify($1, $2)
308 "#,
309 &channel,
310 serialized
311 )
312 .execute(&db)
313 .await
314 {
315 log::error!("failed to notify casbin policy change: {}", e);
316 }
317 });
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use casbin::Enforcer;
325 use std::env;
326 use std::sync::Arc;
327 use tokio::sync::RwLock;
328 use tokio::task::JoinHandle;
329
330 async fn setup_listener(
331 cb: Box<dyn FnMut() + Send + Sync>,
332 ) -> (SqlxWatcher, JoinHandle<()>, PgPool) {
333 let db = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
334 .await
335 .unwrap();
336 let mut watcher = SqlxWatcher::new(db.clone());
337 watcher.set_update_callback(cb);
338 watcher.set_channel(&uuid::Uuid::new_v4().to_string());
339 let mut watcher_clone = watcher.clone();
340
341 let policy = sqlx_adapter::SqlxAdapter::new_with_pool(db.clone())
342 .await
343 .unwrap();
344 let model = casbin::DefaultModel::from_str(include_str!("./resources/rbac_model.conf"))
345 .await
346 .unwrap();
347 let enforcer = Arc::new(RwLock::new(Enforcer::new(model, policy).await.unwrap()));
348
349 let handle = tokio::task::spawn(async move {
350 if let Err(err) = watcher_clone.listen(enforcer).await {
351 eprintln!("casbin watcher failed: {}", err);
352 }
353 });
354 (watcher, handle, db)
355 }
356 #[sqlx::test(fixtures("base"))]
357 async fn test_should_notify_and_listen_basic(_: PgPool) {
358 let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
360
361 let (watcher, handle, db) = setup_listener(Box::new(move || {
362 println!("casbin policy changed");
363 let tx = tx_msg.clone();
364 tokio::task::spawn(async move {
365 tx.send(true).await.unwrap();
366 });
367 }))
368 .await;
369
370 let mut watcher2 = SqlxWatcher::new(db.clone());
371 watcher2.set_channel(&watcher.channel());
372 watcher2.update(EventData::SavePolicy(vec![]));
373
374 let found = tokio::time::timeout(tokio::time::Duration::from_secs(5), rx_msg.recv())
376 .await
377 .unwrap()
378 .unwrap();
379 handle.abort();
380 assert!(found);
381 }
382
383 #[sqlx::test(fixtures("base"))]
384 async fn test_should_ignore_own_messages(_: PgPool) {
385 let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
387
388 let (mut watcher, handle, _db) = setup_listener(Box::new(move || {
389 println!("casbin policy changed");
390 let tx = tx_msg.clone();
391 tokio::task::spawn(async move {
392 tx.send(true).await.unwrap();
393 });
394 }))
395 .await;
396
397 watcher.update(EventData::SavePolicy(vec![]));
398
399 let found = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx_msg.recv()).await;
401 handle.abort();
402 assert!(found.is_err());
403 }
404
405 #[sqlx::test(fixtures("base"))]
406 async fn test_should_notify_and_listen_large(_: PgPool) {
407 let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
409
410 let (watcher, handle, db) = setup_listener(Box::new(move || {
411 println!("casbin policy changed");
412 let tx = tx_msg.clone();
413 tokio::task::spawn(async move {
414 tx.send(true).await.unwrap();
415 });
416 }))
417 .await;
418
419 let mut watcher2 = SqlxWatcher::new(db.clone());
420 watcher2.set_channel(&watcher.channel());
421 watcher2.update(EventData::SavePolicy(vec![vec!["a".to_string(); 8000]]));
422
423 let found = tokio::time::timeout(tokio::time::Duration::from_secs(5), rx_msg.recv())
425 .await
426 .unwrap()
427 .unwrap();
428 handle.abort();
429 assert!(found);
430 }
431}