casbin_sqlx_watcher/
watcher.rs

1use casbin::{CoreApi, Enforcer, EventData, Watcher};
2use serde::{Deserialize, Serialize};
3use sqlx::PgPool;
4use sqlx::postgres::PgListener;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// A sqlx based Watcher for casbin policy changes.
9///
10/// The watcher is responsible for both notifying and listening to casbin policy changes.
11/// By default load_policy is called when any changes are received.
12/// The user can alter this behaviour via set_update_callback.
13///
14/// Since the Watcher trait doesn't supply the payload to the callback, you can access that
15/// via the last_message field.
16///
17/// Example:
18///
19/// ```rust
20/// use casbin::Watcher;
21/// use casbin_sqlx_watcher::SqlxWatcher;
22/// use sqlx::PgPool;
23///
24/// #[tokio::main]
25/// async fn main() {
26///    use std::sync::Arc;
27///    use casbin::{CoreApi, Enforcer};
28///    use tokio::sync::RwLock;
29///    let db = PgPool::connect(std::env::var("DATABASE_URL").unwrap_or_default().as_str()).await.unwrap();
30///    let mut watcher = SqlxWatcher::new(db.clone());
31///    let mut watcher_clone = watcher.clone();
32///
33///    let policy = sqlx_adapter::SqlxAdapter::new_with_pool(db.clone()).await.unwrap();
34///    let model = casbin::DefaultModel::from_str(include_str!("./resources/rbac_model.conf")).await.unwrap();
35///
36///    let enforcer = Arc::new(RwLock::new(Enforcer::new(model, policy).await.unwrap()));
37///
38///    tokio::task::spawn(async move {
39///       if let Err(err) = watcher_clone.listen(enforcer).await {
40///          eprintln!("casbin watcher failed: {}", err);
41///      }
42///    });
43///
44///
45///    watcher.set_update_callback(Box::new(|| {
46///       println!("casbin policy changed");
47///    }));
48///
49///    // This is not the recommended way to trigger changes, casbin will do that automatically.
50///    // But for illustration purposes, we can manually trigger a change.
51///    sqlx::query("NOTIFY casbin_policy_change").execute(&db).await.unwrap();
52///
53///    tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
54///    // output: casbin policy changed
55/// }
56#[derive(Clone)]
57pub struct SqlxWatcher {
58    db: PgPool,
59    /// The channel sender to send the callback to.
60    tx: Arc<RwLock<tokio::sync::mpsc::Sender<Box<dyn FnMut() + Send + Sync>>>>,
61    /// The channel receiver to read the callback from.
62    rc: Arc<RwLock<tokio::sync::mpsc::Receiver<Box<dyn FnMut() + Send + Sync>>>>,
63    /// The last message that was received.
64    /// This is in order to work around the limitation of the Watcher trait not providing the
65    /// payload to the callback.
66    last_message: Arc<RwLock<PolicyChange>>,
67    /// The instance id of this watcher. Used to ignore our own messages.
68    instance_id: String,
69    /// The channel to listen and notify on.
70    _channel: String,
71}
72
73#[derive(thiserror::Error, Debug)]
74pub enum Error {
75    #[error("sqlx error: {0}")]
76    Sqlx(#[from] sqlx::Error),
77    #[error("serde error: {0}")]
78    Serde(#[from] serde_json::Error),
79    #[error("casbin error: {0}")]
80    Casbin(#[from] casbin::Error),
81    #[error("general error: {0}")]
82    General(String),
83}
84
85pub const DEFAULT_NOTIFY_CHANNEL: &str = "casbin_policy_change";
86/// The maximum number of bytes that can be sent in a notification in postgres
87/// Not configurable, see
88/// https://www.postgresql.org/docs/current/sql-notify.html#:~:text=In%20the%20default%20configuration%20it,the%20key%20of%20the%20record.)
89const NOTIFY_MAX_BYTES: usize = 8000;
90
91pub type Result<T> = std::result::Result<T, Error>;
92
93impl SqlxWatcher {
94    pub fn new(db: PgPool) -> Self {
95        let (tx, rc) = tokio::sync::mpsc::channel(1);
96        Self {
97            db,
98            tx: Arc::new(RwLock::new(tx)),
99            rc: Arc::new(RwLock::new(rc)),
100            last_message: Arc::new(RwLock::new(PolicyChange::None)),
101            instance_id: uuid::Uuid::new_v4().to_string(),
102            _channel: DEFAULT_NOTIFY_CHANNEL.to_string(),
103        }
104    }
105
106    /// Set the channel to listen and notify on.
107    /// By default, the value of DEFAULT_NOTIFY_CHANNEL is used, which is "casbin_policy_change".
108    pub fn set_channel(&mut self, channel: &str) {
109        self._channel = channel.to_string();
110    }
111
112    /// Get the channel that is listened to and notified on.
113    pub fn channel(&self) -> String {
114        self._channel.clone()
115    }
116
117    fn is_own_message(&self, change: &PolicyChange) -> bool {
118        match change {
119            PolicyChange::AddPolicies(instance_id, _) => instance_id == &self.instance_id,
120            PolicyChange::RemovePolicies(instance_id, _) => instance_id == &self.instance_id,
121            PolicyChange::SavePolicy(instance_id, _) => instance_id == &self.instance_id,
122            PolicyChange::ClearPolicy(instance_id) => instance_id == &self.instance_id,
123            PolicyChange::ClearCache(instance_id) => instance_id == &self.instance_id,
124            PolicyChange::LoadPolicy(instance_id) => instance_id == &self.instance_id,
125            _ => false,
126        }
127    }
128
129    /// The main listen loop
130    ///
131    /// This listens to the postgres notification channel for casbin policy changes.
132    /// It also listens for updates to the callback function.
133    pub async fn listen(&mut self, enforcer: Arc<RwLock<Enforcer>>) -> Result<()> {
134        let mut listener = PgListener::connect_with(&self.db).await?;
135        listener.listen(&self._channel).await?;
136
137        {
138            // load the policy in case we missed anything during startup
139            enforcer.write().await.load_policy().await?;
140        }
141
142        let mut cb: Box<dyn FnMut() + Send + Sync> = Box::new(|| {
143            let cloned_enforcer = enforcer.clone();
144            tokio::task::spawn(async move {
145                if let Err(err) = cloned_enforcer.write().await.load_policy().await {
146                    log::error!("failed to reload policy: {}", err);
147                }
148            });
149        });
150
151        log::info!("casbin sqlx watcher started");
152
153        loop {
154            let mut rc = self.rc.write().await;
155            tokio::select! {
156                        n = listener.try_recv() => {
157                            if let Ok(n) = n {
158                                if let Some(notification) = n {
159
160                                    if notification.payload().is_empty() {
161                                        log::warn!("empty casbin policy change notification, doing full policy reload as fallback");
162                                        if let Err(e) = enforcer.write().await.load_policy().await {
163                                            log::error!("error while trying to reload whole policy: {}", e);
164                                        }
165                                        continue;
166                                    }
167
168                                    log::info!("received casbin policy change notification: {}", notification.payload());
169
170                                    let policy_change = serde_json::from_str::<PolicyChange>(notification.payload());
171
172                                    let result: Result<()> = match policy_change {
173                                        Ok(change) => {
174                                            match self.is_own_message(&change) {
175                                                false => {
176                                                    *self.last_message.write().await = change;
177                                                    cb();
178                                                    Ok(())
179                                                },
180                                                true => Ok(())
181                                            }
182
183                                        },
184                                        Err(orig_error) => {
185                                            log::info!("doing full policy reload as fallback");
186                                            if let Err(subsequent_error) = enforcer.write().await.load_policy().await {
187                                                Err(Error::General(format!("failed to apply policy {}\n    subsequent fallback reload error: {}", orig_error, subsequent_error)))
188                                            } else {
189                                                Err(orig_error.into())
190                                            }
191                                        }
192                                    };
193
194                                    if let Err(e) = result {
195                                        log::error!("error while applying casbin policy change: {}", e);
196                                    }
197
198
199                                }
200
201                            } else {
202                                log::error!("casbin listener connection lost, auto reconnecting");
203                            }
204                        },
205                new_cb = rc.recv() => {
206                    if let Some(new_cb) = new_cb {
207                        log::info!("casbin watcher callback set");
208                        cb = new_cb;
209                    }
210                },
211            }
212        }
213    }
214}
215
216#[derive(Debug, Serialize, Deserialize)]
217pub struct PolicyChangeData {
218    pub sec: String,
219    pub ptype: String,
220    pub vars: Vec<Vec<String>>,
221}
222
223impl PolicyChangeData {
224    #[allow(dead_code)]
225    fn flatten(self) -> Vec<Vec<String>> {
226        self.vars
227            .into_iter()
228            .map(|vars| [vec![self.sec.clone(), self.ptype.clone()], vars].concat())
229            .collect()
230    }
231}
232
233/// A serde friendly enum to represent the various policy changes that can be made.
234#[derive(Debug, Serialize, Deserialize)]
235pub enum PolicyChange {
236    None,
237    AddPolicies(String, PolicyChangeData),
238    RemovePolicies(String, PolicyChangeData),
239    SavePolicy(String, Vec<Vec<String>>),
240    ClearPolicy(String),
241    ClearCache(String),
242    LoadPolicy(String),
243}
244impl PolicyChange {
245    fn from(instance_id: String, value: EventData) -> Self {
246        match value {
247            EventData::AddPolicy(sec, ptype, vars) => PolicyChange::AddPolicies(
248                instance_id,
249                PolicyChangeData {
250                    sec,
251                    ptype,
252                    vars: vec![vars],
253                },
254            ),
255            EventData::AddPolicies(sec, ptype, vars) => {
256                PolicyChange::AddPolicies(instance_id, PolicyChangeData { sec, ptype, vars })
257            }
258            EventData::RemovePolicy(sec, ptype, vars) => PolicyChange::RemovePolicies(
259                instance_id,
260                PolicyChangeData {
261                    sec,
262                    ptype,
263                    vars: vec![vars],
264                },
265            ),
266            EventData::RemovePolicies(sec, ptype, vars) => {
267                PolicyChange::RemovePolicies(instance_id, PolicyChangeData { sec, ptype, vars })
268            }
269            EventData::RemoveFilteredPolicy(sec, ptype, vars) => {
270                PolicyChange::RemovePolicies(instance_id, PolicyChangeData { sec, ptype, vars })
271            }
272            EventData::SavePolicy(p) => PolicyChange::SavePolicy(instance_id, p),
273            EventData::ClearPolicy => PolicyChange::ClearPolicy(instance_id),
274            EventData::ClearCache => PolicyChange::ClearCache(instance_id),
275        }
276    }
277}
278
279impl Watcher for SqlxWatcher {
280    fn set_update_callback(&mut self, cb: Box<dyn FnMut() + Send + Sync>) {
281        let tx = self.tx.clone();
282        tokio::task::spawn(async move {
283            if let Err(e) = tx.write().await.send(cb).await {
284                log::error!("failed to send casbin watcher callback: {}", e);
285            }
286        });
287    }
288
289    fn update(&mut self, d: EventData) {
290        let db = self.db.clone();
291        let policy_change = PolicyChange::from(self.instance_id.clone(), d);
292        let serialized = serde_json::to_string(&policy_change).unwrap();
293
294        // if > 8000 bytes we resort to a full reload
295        let serialized = if serialized.len() > NOTIFY_MAX_BYTES {
296            log::warn!("policy change too large, resorting to full reload");
297            serde_json::to_string(&PolicyChange::LoadPolicy(self.instance_id.clone())).unwrap()
298        } else {
299            serialized
300        };
301
302        let channel = self._channel.clone();
303
304        tokio::task::spawn(async move {
305            if let Err(e) = sqlx::query!(
306                r#"
307                SELECT pg_notify($1, $2)
308            "#,
309                &channel,
310                serialized
311            )
312            .execute(&db)
313            .await
314            {
315                log::error!("failed to notify casbin policy change: {}", e);
316            }
317        });
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use casbin::Enforcer;
325    use std::env;
326    use std::sync::Arc;
327    use tokio::sync::RwLock;
328    use tokio::task::JoinHandle;
329
330    async fn setup_listener(
331        cb: Box<dyn FnMut() + Send + Sync>,
332    ) -> (SqlxWatcher, JoinHandle<()>, PgPool) {
333        let db = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
334            .await
335            .unwrap();
336        let mut watcher = SqlxWatcher::new(db.clone());
337        watcher.set_update_callback(cb);
338        watcher.set_channel(&uuid::Uuid::new_v4().to_string());
339        let mut watcher_clone = watcher.clone();
340
341        let policy = sqlx_adapter::SqlxAdapter::new_with_pool(db.clone())
342            .await
343            .unwrap();
344        let model = casbin::DefaultModel::from_str(include_str!("./resources/rbac_model.conf"))
345            .await
346            .unwrap();
347        let enforcer = Arc::new(RwLock::new(Enforcer::new(model, policy).await.unwrap()));
348
349        let handle = tokio::task::spawn(async move {
350            if let Err(err) = watcher_clone.listen(enforcer).await {
351                eprintln!("casbin watcher failed: {}", err);
352            }
353        });
354        (watcher, handle, db)
355    }
356    #[sqlx::test(fixtures("base"))]
357    async fn test_should_notify_and_listen_basic(_: PgPool) {
358        // create a channel to notify on messages
359        let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
360
361        let (watcher, handle, db) = setup_listener(Box::new(move || {
362            println!("casbin policy changed");
363            let tx = tx_msg.clone();
364            tokio::task::spawn(async move {
365                tx.send(true).await.unwrap();
366            });
367        }))
368        .await;
369
370        let mut watcher2 = SqlxWatcher::new(db.clone());
371        watcher2.set_channel(&watcher.channel());
372        watcher2.update(EventData::SavePolicy(vec![]));
373
374        // wait up to 5 seconds for a reply
375        let found = tokio::time::timeout(tokio::time::Duration::from_secs(5), rx_msg.recv())
376            .await
377            .unwrap()
378            .unwrap();
379        handle.abort();
380        assert!(found);
381    }
382
383    #[sqlx::test(fixtures("base"))]
384    async fn test_should_ignore_own_messages(_: PgPool) {
385        // create a channel to notify on messages
386        let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
387
388        let (mut watcher, handle, _db) = setup_listener(Box::new(move || {
389            println!("casbin policy changed");
390            let tx = tx_msg.clone();
391            tokio::task::spawn(async move {
392                tx.send(true).await.unwrap();
393            });
394        }))
395        .await;
396
397        watcher.update(EventData::SavePolicy(vec![]));
398
399        // wait up to 5 seconds for a reply
400        let found = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx_msg.recv()).await;
401        handle.abort();
402        assert!(found.is_err());
403    }
404
405    #[sqlx::test(fixtures("base"))]
406    async fn test_should_notify_and_listen_large(_: PgPool) {
407        // create a channel to notify on messages
408        let (tx_msg, mut rx_msg) = tokio::sync::mpsc::channel::<bool>(5);
409
410        let (watcher, handle, db) = setup_listener(Box::new(move || {
411            println!("casbin policy changed");
412            let tx = tx_msg.clone();
413            tokio::task::spawn(async move {
414                tx.send(true).await.unwrap();
415            });
416        }))
417        .await;
418
419        let mut watcher2 = SqlxWatcher::new(db.clone());
420        watcher2.set_channel(&watcher.channel());
421        watcher2.update(EventData::SavePolicy(vec![vec!["a".to_string(); 8000]]));
422
423        // wait up to 5 seconds for a reply
424        let found = tokio::time::timeout(tokio::time::Duration::from_secs(5), rx_msg.recv())
425            .await
426            .unwrap()
427            .unwrap();
428        handle.abort();
429        assert!(found);
430    }
431}