Skip to main content

rusqlite/hooks/
mod.rs

1//! Commit, Data Change and Rollback Notification Callbacks
2#![expect(non_camel_case_types)]
3
4use std::ffi::{c_char, c_int, c_void, CStr};
5use std::panic::catch_unwind;
6use std::ptr;
7
8use crate::ffi;
9
10use crate::{error::decode_result_raw, Connection, InnerConnection, Result};
11
12#[cfg(feature = "preupdate_hook")]
13pub use preupdate_hook::*;
14
15#[cfg(feature = "preupdate_hook")]
16mod preupdate_hook;
17
18/// Action Codes
19#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20#[repr(i32)]
21#[non_exhaustive]
22pub enum Action {
23    /// Unsupported / unexpected action
24    UNKNOWN = -1,
25    /// DELETE command
26    SQLITE_DELETE = ffi::SQLITE_DELETE,
27    /// INSERT command
28    SQLITE_INSERT = ffi::SQLITE_INSERT,
29    /// UPDATE command
30    SQLITE_UPDATE = ffi::SQLITE_UPDATE,
31}
32
33impl From<i32> for Action {
34    #[inline]
35    fn from(code: i32) -> Self {
36        match code {
37            ffi::SQLITE_DELETE => Self::SQLITE_DELETE,
38            ffi::SQLITE_INSERT => Self::SQLITE_INSERT,
39            ffi::SQLITE_UPDATE => Self::SQLITE_UPDATE,
40            _ => Self::UNKNOWN,
41        }
42    }
43}
44
45/// The context received by an authorizer hook.
46///
47/// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
49pub struct AuthContext<'c> {
50    /// The action to be authorized.
51    pub action: AuthAction<'c>,
52
53    /// The database name, if applicable.
54    pub database_name: Option<&'c str>,
55
56    /// The inner-most trigger or view responsible for the access attempt.
57    /// `None` if the access attempt was made by top-level SQL code.
58    pub accessor: Option<&'c str>,
59}
60
61/// Actions and arguments found within a statement during
62/// preparation.
63///
64/// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
65#[derive(Clone, Copy, Debug, Eq, PartialEq)]
66#[non_exhaustive]
67#[allow(missing_docs)]
68pub enum AuthAction<'c> {
69    /// This variant is not normally produced by SQLite. You may encounter it
70    // if you're using a different version than what's supported by this library.
71    Unknown {
72        /// The unknown authorization action code.
73        code: i32,
74        /// The third arg to the authorizer callback.
75        arg1: Option<&'c str>,
76        /// The fourth arg to the authorizer callback.
77        arg2: Option<&'c str>,
78    },
79    CreateIndex {
80        index_name: &'c str,
81        table_name: &'c str,
82    },
83    CreateTable {
84        table_name: &'c str,
85    },
86    CreateTempIndex {
87        index_name: &'c str,
88        table_name: &'c str,
89    },
90    CreateTempTable {
91        table_name: &'c str,
92    },
93    CreateTempTrigger {
94        trigger_name: &'c str,
95        table_name: &'c str,
96    },
97    CreateTempView {
98        view_name: &'c str,
99    },
100    CreateTrigger {
101        trigger_name: &'c str,
102        table_name: &'c str,
103    },
104    CreateView {
105        view_name: &'c str,
106    },
107    Delete {
108        table_name: &'c str,
109    },
110    DropIndex {
111        index_name: &'c str,
112        table_name: &'c str,
113    },
114    DropTable {
115        table_name: &'c str,
116    },
117    DropTempIndex {
118        index_name: &'c str,
119        table_name: &'c str,
120    },
121    DropTempTable {
122        table_name: &'c str,
123    },
124    DropTempTrigger {
125        trigger_name: &'c str,
126        table_name: &'c str,
127    },
128    DropTempView {
129        view_name: &'c str,
130    },
131    DropTrigger {
132        trigger_name: &'c str,
133        table_name: &'c str,
134    },
135    DropView {
136        view_name: &'c str,
137    },
138    Insert {
139        table_name: &'c str,
140    },
141    Pragma {
142        pragma_name: &'c str,
143        /// The pragma value, if present (e.g., `PRAGMA name = value;`).
144        pragma_value: Option<&'c str>,
145    },
146    Read {
147        table_name: &'c str,
148        column_name: &'c str,
149    },
150    Select,
151    Transaction {
152        operation: TransactionOperation,
153    },
154    Update {
155        table_name: &'c str,
156        column_name: &'c str,
157    },
158    Attach {
159        filename: &'c str,
160    },
161    Detach {
162        database_name: &'c str,
163    },
164    AlterTable {
165        database_name: &'c str,
166        table_name: &'c str,
167    },
168    Reindex {
169        index_name: &'c str,
170    },
171    Analyze {
172        table_name: &'c str,
173    },
174    CreateVtable {
175        table_name: &'c str,
176        module_name: &'c str,
177    },
178    DropVtable {
179        table_name: &'c str,
180        module_name: &'c str,
181    },
182    Function {
183        function_name: &'c str,
184    },
185    Savepoint {
186        operation: TransactionOperation,
187        savepoint_name: &'c str,
188    },
189    Recursive,
190}
191
192impl<'c> AuthAction<'c> {
193    fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
194        match (code, arg1, arg2) {
195            (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
196                index_name,
197                table_name,
198            },
199            (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
200            (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
201                Self::CreateTempIndex {
202                    index_name,
203                    table_name,
204                }
205            }
206            (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
207                Self::CreateTempTable { table_name }
208            }
209            (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
210                Self::CreateTempTrigger {
211                    trigger_name,
212                    table_name,
213                }
214            }
215            (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
216                Self::CreateTempView { view_name }
217            }
218            (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
219                Self::CreateTrigger {
220                    trigger_name,
221                    table_name,
222                }
223            }
224            (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
225            (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
226            (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
227                index_name,
228                table_name,
229            },
230            (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
231            (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
232                Self::DropTempIndex {
233                    index_name,
234                    table_name,
235                }
236            }
237            (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
238                Self::DropTempTable { table_name }
239            }
240            (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
241                Self::DropTempTrigger {
242                    trigger_name,
243                    table_name,
244                }
245            }
246            (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
247            (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
248                trigger_name,
249                table_name,
250            },
251            (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
252            (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
253            (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
254                pragma_name,
255                pragma_value,
256            },
257            (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
258                table_name,
259                column_name,
260            },
261            (ffi::SQLITE_SELECT, ..) => Self::Select,
262            (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
263                operation: TransactionOperation::from_str(operation_str),
264            },
265            (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
266                table_name,
267                column_name,
268            },
269            (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
270            (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
271            (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
272                database_name,
273                table_name,
274            },
275            (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
276            (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
277            (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
278                Self::CreateVtable {
279                    table_name,
280                    module_name,
281                }
282            }
283            (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
284                table_name,
285                module_name,
286            },
287            (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
288            (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
289                operation: TransactionOperation::from_str(operation_str),
290                savepoint_name,
291            },
292            (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
293            (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
294        }
295    }
296}
297
298pub(crate) type BoxedAuthorizer =
299    Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
300
301/// A transaction operation.
302#[derive(Clone, Copy, Debug, Eq, PartialEq)]
303#[non_exhaustive]
304#[allow(missing_docs)]
305pub enum TransactionOperation {
306    Unknown,
307    Begin,
308    Release,
309    Rollback,
310}
311
312impl TransactionOperation {
313    fn from_str(op_str: &str) -> Self {
314        match op_str {
315            "BEGIN" => Self::Begin,
316            "RELEASE" => Self::Release,
317            "ROLLBACK" => Self::Rollback,
318            _ => Self::Unknown,
319        }
320    }
321}
322
323/// [`authorizer`](Connection::authorizer) return code
324#[derive(Clone, Copy, Debug, Eq, PartialEq)]
325#[non_exhaustive]
326pub enum Authorization {
327    /// Authorize the action.
328    Allow,
329    /// Don't allow access, but don't trigger an error either.
330    Ignore,
331    /// Trigger an error.
332    Deny,
333}
334
335impl Authorization {
336    fn into_raw(self) -> c_int {
337        match self {
338            Self::Allow => ffi::SQLITE_OK,
339            Self::Ignore => ffi::SQLITE_IGNORE,
340            Self::Deny => ffi::SQLITE_DENY,
341        }
342    }
343}
344
345impl Connection {
346    /// Register a callback function to be invoked whenever
347    /// a transaction is committed.
348    ///
349    /// The callback returns `true` to rollback.
350    #[inline]
351    pub fn commit_hook<F>(&self, hook: Option<F>) -> Result<()>
352    where
353        F: FnMut() -> bool + Send + 'static,
354    {
355        self.db.borrow().check_owned()?;
356        self.db.borrow_mut().commit_hook(hook);
357        Ok(())
358    }
359
360    /// Register a callback function to be invoked whenever
361    /// a transaction is rolled back.
362    #[inline]
363    pub fn rollback_hook<F>(&self, hook: Option<F>) -> Result<()>
364    where
365        F: FnMut() + Send + 'static,
366    {
367        self.db.borrow().check_owned()?;
368        self.db.borrow_mut().rollback_hook(hook);
369        Ok(())
370    }
371
372    /// Register a callback function to be invoked whenever
373    /// a row is updated, inserted or deleted in a rowid table.
374    ///
375    /// The callback parameters are:
376    ///
377    /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
378    ///   `SQLITE_DELETE`),
379    /// - the name of the database ("main", "temp", ...),
380    /// - the name of the table that is updated,
381    /// - the ROWID of the row that is updated.
382    #[inline]
383    pub fn update_hook<F>(&self, hook: Option<F>) -> Result<()>
384    where
385        F: FnMut(Action, &str, &str, i64) + Send + 'static,
386    {
387        self.db.borrow().check_owned()?;
388        self.db.borrow_mut().update_hook(hook);
389        Ok(())
390    }
391
392    /// Register a callback that is invoked each time data is committed to a database in wal mode.
393    ///
394    /// A single database handle may have at most a single write-ahead log callback registered at one time.
395    /// Calling `wal_hook` replaces any previously registered write-ahead log callback.
396    /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma
397    /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings.
398    pub fn wal_hook(&self, hook: Option<fn(&Wal, c_int) -> Result<()>>) {
399        unsafe extern "C" fn wal_hook_callback(
400            client_data: *mut c_void,
401            db: *mut ffi::sqlite3,
402            db_name: *const c_char,
403            pages: c_int,
404        ) -> c_int {
405            let hook_fn: fn(&Wal, c_int) -> Result<()> = std::mem::transmute(client_data);
406            let wal = Wal { db, db_name };
407            catch_unwind(|| match hook_fn(&wal, pages) {
408                Ok(_) => ffi::SQLITE_OK,
409                Err(e) => e
410                    .sqlite_error()
411                    .map_or(ffi::SQLITE_ERROR, |x| x.extended_code),
412            })
413            .unwrap_or_default()
414        }
415        let c = self.db.borrow_mut();
416        unsafe {
417            ffi::sqlite3_wal_hook(
418                c.db(),
419                hook.as_ref().map(|_| wal_hook_callback as _),
420                hook.map_or_else(ptr::null_mut, |f| f as *mut c_void),
421            );
422        }
423    }
424
425    /// Register a query progress callback.
426    ///
427    /// The parameter `num_ops` is the approximate number of virtual machine
428    /// instructions that are evaluated between successive invocations of the
429    /// `handler`. If `num_ops` is less than one then the progress handler
430    /// is disabled.
431    ///
432    /// If the progress callback returns `true`, the operation is interrupted.
433    pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) -> Result<()>
434    where
435        F: FnMut() -> bool + Send + 'static,
436    {
437        self.db.borrow().check_owned()?;
438        self.db.borrow_mut().progress_handler(num_ops, handler);
439        Ok(())
440    }
441
442    /// Register an authorizer callback that's invoked
443    /// as a statement is being prepared.
444    #[inline]
445    pub fn authorizer<'c, F>(&self, hook: Option<F>) -> Result<()>
446    where
447        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
448    {
449        self.db.borrow().check_owned()?;
450        self.db.borrow_mut().authorizer(hook);
451        Ok(())
452    }
453}
454
455/// Checkpoint mode
456#[derive(Clone, Copy)]
457#[repr(i32)]
458#[non_exhaustive]
459pub enum CheckpointMode {
460    /// Do as much as possible w/o blocking
461    PASSIVE = ffi::SQLITE_CHECKPOINT_PASSIVE,
462    /// Wait for writers, then checkpoint
463    FULL = ffi::SQLITE_CHECKPOINT_FULL,
464    /// Like FULL but wait for readers
465    RESTART = ffi::SQLITE_CHECKPOINT_RESTART,
466    /// Like RESTART but also truncate WAL
467    TRUNCATE = ffi::SQLITE_CHECKPOINT_TRUNCATE,
468    /// Do no work at all
469    #[cfg(feature = "modern_sqlite")] // 3.51.0
470    NOOP = -1, //ffi::SQLITE_CHECKPOINT_NOOP,
471}
472
473/// Write-Ahead Log
474pub struct Wal {
475    db: *mut ffi::sqlite3,
476    db_name: *const c_char,
477}
478
479impl Wal {
480    /// Checkpoint a database
481    pub fn checkpoint(&self) -> Result<()> {
482        unsafe { decode_result_raw(self.db, ffi::sqlite3_wal_checkpoint(self.db, self.db_name)) }
483    }
484    /// Checkpoint a database
485    pub fn checkpoint_v2(&self, mode: CheckpointMode) -> Result<(c_int, c_int)> {
486        let mut n_log = 0;
487        let mut n_ckpt = 0;
488        unsafe {
489            decode_result_raw(
490                self.db,
491                ffi::sqlite3_wal_checkpoint_v2(
492                    self.db,
493                    self.db_name,
494                    mode as c_int,
495                    &mut n_log,
496                    &mut n_ckpt,
497                ),
498            )?
499        };
500        Ok((n_log, n_ckpt))
501    }
502
503    /// Name of the database that was written to
504    pub fn name(&self) -> &CStr {
505        unsafe { CStr::from_ptr(self.db_name) }
506    }
507}
508
509impl InnerConnection {
510    #[inline]
511    pub fn remove_hooks(&mut self) {
512        self.update_hook(None::<fn(Action, &str, &str, i64)>);
513        self.commit_hook(None::<fn() -> bool>);
514        self.rollback_hook(None::<fn()>);
515        self.progress_handler(0, None::<fn() -> bool>);
516        self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
517    }
518
519    /// ```compile_fail
520    /// use rusqlite::{Connection, Result};
521    /// fn main() -> Result<()> {
522    ///     let db = Connection::open_in_memory()?;
523    ///     {
524    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
525    ///         db.commit_hook(Some(|| {
526    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
527    ///             true
528    ///         }));
529    ///     }
530    ///     assert!(db
531    ///         .execute_batch(
532    ///             "BEGIN;
533    ///         CREATE TABLE foo (t TEXT);
534    ///         COMMIT;",
535    ///         )
536    ///         .is_err());
537    ///     Ok(())
538    /// }
539    /// ```
540    fn commit_hook<F>(&mut self, hook: Option<F>)
541    where
542        F: FnMut() -> bool + Send + 'static,
543    {
544        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
545        where
546            F: FnMut() -> bool,
547        {
548            let r = catch_unwind(|| {
549                let boxed_hook: *mut F = p_arg.cast::<F>();
550                (*boxed_hook)()
551            });
552            c_int::from(r.unwrap_or_default())
553        }
554        let boxed_hook = hook.map(Box::new);
555        unsafe {
556            ffi::sqlite3_commit_hook(
557                self.db(),
558                boxed_hook.as_ref().map(|_| call_boxed_closure::<F> as _),
559                boxed_hook
560                    .as_ref()
561                    .map_or_else(ptr::null_mut, |h| &**h as *const F as *mut _),
562            )
563        };
564        self.commit_hook = boxed_hook.map(|bh| bh as _);
565    }
566
567    /// ```compile_fail
568    /// use rusqlite::{Connection, Result};
569    /// fn main() -> Result<()> {
570    ///     let db = Connection::open_in_memory()?;
571    ///     {
572    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
573    ///         db.rollback_hook(Some(|| {
574    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
575    ///         }));
576    ///     }
577    ///     assert!(db
578    ///         .execute_batch(
579    ///             "BEGIN;
580    ///         CREATE TABLE foo (t TEXT);
581    ///         ROLLBACK;",
582    ///         )
583    ///         .is_err());
584    ///     Ok(())
585    /// }
586    /// ```
587    fn rollback_hook<F>(&mut self, hook: Option<F>)
588    where
589        F: FnMut() + Send + 'static,
590    {
591        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
592        where
593            F: FnMut(),
594        {
595            drop(catch_unwind(|| {
596                let boxed_hook: *mut F = p_arg.cast::<F>();
597                (*boxed_hook)();
598            }));
599        }
600
601        let boxed_hook = hook.map(Box::new);
602        unsafe {
603            ffi::sqlite3_rollback_hook(
604                self.db(),
605                boxed_hook.as_ref().map(|_| call_boxed_closure::<F> as _),
606                boxed_hook
607                    .as_ref()
608                    .map_or_else(ptr::null_mut, |h| &**h as *const F as *mut _),
609            )
610        };
611        self.rollback_hook = boxed_hook.map(|bh| bh as _);
612    }
613
614    /// ```compile_fail
615    /// use rusqlite::{Connection, Result};
616    /// fn main() -> Result<()> {
617    ///     let db = Connection::open_in_memory()?;
618    ///     {
619    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
620    ///         db.update_hook(Some(|_, _: &str, _: &str, _| {
621    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
622    ///         }));
623    ///     }
624    ///     db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
625    /// }
626    /// ```
627    fn update_hook<F>(&mut self, hook: Option<F>)
628    where
629        F: FnMut(Action, &str, &str, i64) + Send + 'static,
630    {
631        unsafe extern "C" fn call_boxed_closure<F>(
632            p_arg: *mut c_void,
633            action_code: c_int,
634            p_db_name: *const c_char,
635            p_table_name: *const c_char,
636            row_id: i64,
637        ) where
638            F: FnMut(Action, &str, &str, i64),
639        {
640            let action = Action::from(action_code);
641            drop(catch_unwind(|| {
642                let boxed_hook: *mut F = p_arg.cast::<F>();
643                (*boxed_hook)(
644                    action,
645                    expect_utf8(p_db_name, "database name"),
646                    expect_utf8(p_table_name, "table name"),
647                    row_id,
648                );
649            }));
650        }
651
652        let boxed_hook = hook.map(Box::new);
653        unsafe {
654            ffi::sqlite3_update_hook(
655                self.db(),
656                boxed_hook.as_ref().map(|_| call_boxed_closure::<F> as _),
657                boxed_hook
658                    .as_ref()
659                    .map_or_else(ptr::null_mut, |h| &**h as *const F as *mut _),
660            )
661        };
662        self.update_hook = boxed_hook.map(|bh| bh as _);
663    }
664
665    /// ```compile_fail
666    /// use rusqlite::{Connection, Result};
667    /// fn main() -> Result<()> {
668    ///     let db = Connection::open_in_memory()?;
669    ///     {
670    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
671    ///         db.progress_handler(
672    ///             1,
673    ///             Some(|| {
674    ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
675    ///                 true
676    ///             }),
677    ///         );
678    ///     }
679    ///     assert!(db
680    ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
681    ///         .is_err());
682    ///     Ok(())
683    /// }
684    /// ```
685    fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
686    where
687        F: FnMut() -> bool + Send + 'static,
688    {
689        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
690        where
691            F: FnMut() -> bool,
692        {
693            let r = catch_unwind(|| {
694                let boxed_handler: *mut F = p_arg.cast::<F>();
695                (*boxed_handler)()
696            });
697            c_int::from(r.unwrap_or_default())
698        }
699
700        let boxed_handler = handler.map(Box::new);
701        unsafe {
702            ffi::sqlite3_progress_handler(
703                self.db(),
704                num_ops,
705                boxed_handler.as_ref().map(|_| call_boxed_closure::<F> as _),
706                boxed_handler
707                    .as_ref()
708                    .map_or_else(ptr::null_mut, |h| &**h as *const F as *mut _),
709            )
710        };
711        self.progress_handler = boxed_handler.map(|bh| bh as _);
712    }
713
714    /// ```compile_fail
715    /// use rusqlite::{Connection, Result};
716    /// fn main() -> Result<()> {
717    ///     let db = Connection::open_in_memory()?;
718    ///     {
719    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
720    ///         db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| {
721    ///             called.store(true, std::sync::atomic::Ordering::Relaxed);
722    ///             rusqlite::hooks::Authorization::Deny
723    ///         }));
724    ///     }
725    ///     assert!(db
726    ///         .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
727    ///         .is_err());
728    ///     Ok(())
729    /// }
730    /// ```
731    fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
732    where
733        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
734    {
735        unsafe extern "C" fn call_boxed_closure<'c, F>(
736            p_arg: *mut c_void,
737            action_code: c_int,
738            param1: *const c_char,
739            param2: *const c_char,
740            db_name: *const c_char,
741            trigger_or_view_name: *const c_char,
742        ) -> c_int
743        where
744            F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
745        {
746            catch_unwind(|| {
747                let action = AuthAction::from_raw(
748                    action_code,
749                    expect_optional_utf8(param1, "authorizer param 1"),
750                    expect_optional_utf8(param2, "authorizer param 2"),
751                );
752                let auth_ctx = AuthContext {
753                    action,
754                    database_name: expect_optional_utf8(db_name, "database name"),
755                    accessor: expect_optional_utf8(
756                        trigger_or_view_name,
757                        "accessor (inner-most trigger or view)",
758                    ),
759                };
760                let boxed_hook: *mut F = p_arg.cast::<F>();
761                (*boxed_hook)(auth_ctx)
762            })
763            .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
764        }
765
766        let boxed_authorizer = authorizer.map(Box::new);
767
768        match unsafe {
769            ffi::sqlite3_set_authorizer(
770                self.db(),
771                boxed_authorizer
772                    .as_ref()
773                    .map(|_| call_boxed_closure::<'c, F> as _),
774                boxed_authorizer
775                    .as_ref()
776                    .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
777            )
778        } {
779            ffi::SQLITE_OK => {
780                self.authorizer = boxed_authorizer.map(|ba| ba as _);
781            }
782            err_code => {
783                // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
784                // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
785                // This library does not allow constructing a null db ptr, so if this branch
786                // is hit, something very bad has happened. Panicking instead of returning
787                // `Result` keeps this hook's API consistent with the others.
788                panic!("unexpectedly failed to set_authorizer: {}", unsafe {
789                    crate::error::error_from_handle(self.db(), err_code)
790                });
791            }
792        }
793    }
794}
795
796unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
797    expect_optional_utf8(p_str, description)
798        .unwrap_or_else(|| panic!("received empty {description}"))
799}
800
801unsafe fn expect_optional_utf8<'a>(
802    p_str: *const c_char,
803    description: &'static str,
804) -> Option<&'a str> {
805    if p_str.is_null() {
806        return None;
807    }
808    CStr::from_ptr(p_str)
809        .to_str()
810        .unwrap_or_else(|_| panic!("received non-utf8 string as {description}"))
811        .into()
812}
813
814#[cfg(test)]
815mod test {
816    #[cfg(all(target_family = "wasm", target_os = "unknown"))]
817    use wasm_bindgen_test::wasm_bindgen_test as test;
818
819    use super::Action;
820    use crate::{Connection, Result, MAIN_DB};
821    use std::sync::atomic::{AtomicBool, Ordering};
822
823    #[test]
824    fn test_commit_hook() -> Result<()> {
825        let db = Connection::open_in_memory()?;
826
827        static CALLED: AtomicBool = AtomicBool::new(false);
828        db.commit_hook(Some(|| {
829            CALLED.store(true, Ordering::Relaxed);
830            false
831        }))?;
832        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
833        assert!(CALLED.load(Ordering::Relaxed));
834        Ok(())
835    }
836
837    #[test]
838    fn test_fn_commit_hook() -> Result<()> {
839        let db = Connection::open_in_memory()?;
840
841        fn hook() -> bool {
842            true
843        }
844
845        db.commit_hook(Some(hook))?;
846        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
847            .unwrap_err();
848        Ok(())
849    }
850
851    #[test]
852    fn test_rollback_hook() -> Result<()> {
853        let db = Connection::open_in_memory()?;
854
855        static CALLED: AtomicBool = AtomicBool::new(false);
856        db.rollback_hook(Some(|| {
857            CALLED.store(true, Ordering::Relaxed);
858        }))?;
859        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
860        assert!(CALLED.load(Ordering::Relaxed));
861        Ok(())
862    }
863
864    #[test]
865    fn test_update_hook() -> Result<()> {
866        let db = Connection::open_in_memory()?;
867
868        static CALLED: AtomicBool = AtomicBool::new(false);
869        db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
870            assert_eq!(Action::SQLITE_INSERT, action);
871            assert_eq!("main", db);
872            assert_eq!("foo", tbl);
873            assert_eq!(1, row_id);
874            CALLED.store(true, Ordering::Relaxed);
875        }))?;
876        db.execute_batch("CREATE TABLE foo (t TEXT)")?;
877        db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
878        assert!(CALLED.load(Ordering::Relaxed));
879        Ok(())
880    }
881
882    #[test]
883    fn test_progress_handler() -> Result<()> {
884        let db = Connection::open_in_memory()?;
885
886        static CALLED: AtomicBool = AtomicBool::new(false);
887        db.progress_handler(
888            1,
889            Some(|| {
890                CALLED.store(true, Ordering::Relaxed);
891                false
892            }),
893        )?;
894        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
895        assert!(CALLED.load(Ordering::Relaxed));
896        Ok(())
897    }
898
899    #[test]
900    fn test_progress_handler_interrupt() -> Result<()> {
901        let db = Connection::open_in_memory()?;
902
903        fn handler() -> bool {
904            true
905        }
906
907        db.progress_handler(1, Some(handler))?;
908        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
909            .unwrap_err();
910        Ok(())
911    }
912
913    #[test]
914    fn test_authorizer() -> Result<()> {
915        use super::{AuthAction, AuthContext, Authorization};
916
917        let db = Connection::open_in_memory()?;
918        db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")?;
919
920        let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
921            AuthAction::Read {
922                column_name: "private",
923                ..
924            } => Authorization::Ignore,
925            AuthAction::DropTable { .. } => Authorization::Deny,
926            AuthAction::Pragma { .. } => panic!("shouldn't be called"),
927            _ => Authorization::Allow,
928        };
929
930        db.authorizer(Some(authorizer))?;
931        db.execute_batch(
932            "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
933        )?;
934        db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
935            assert_eq!(row.get::<_, String>("public")?, "pub txt");
936            assert!(row.get::<_, Option<String>>("private")?.is_none());
937            Ok(())
938        })?;
939        db.execute_batch("DROP TABLE foo").unwrap_err();
940
941        db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>)?;
942        db.execute_batch("PRAGMA user_version=1")?; // Disallowed by first authorizer, but it's now removed.
943
944        Ok(())
945    }
946
947    #[cfg_attr(
948        all(target_family = "wasm", target_os = "unknown"),
949        ignore = "no filesystem on this platform"
950    )]
951    #[test]
952    fn wal_hook() -> Result<()> {
953        let temp_dir = tempfile::tempdir().unwrap();
954        let path = temp_dir.path().join("wal-hook.db3");
955
956        let db = Connection::open(&path)?;
957        let journal_mode: String =
958            db.pragma_update_and_check(None, "journal_mode", "wal", |row| row.get(0))?;
959        assert_eq!(journal_mode, "wal");
960
961        static CALLED: AtomicBool = AtomicBool::new(false);
962        db.wal_hook(Some(|wal, pages| {
963            assert_eq!(wal.name(), MAIN_DB);
964            assert!(pages > 0);
965            CALLED.swap(true, Ordering::Relaxed);
966            wal.checkpoint()
967        }));
968        db.execute_batch("CREATE TABLE x(c);")?;
969        assert!(CALLED.load(Ordering::Relaxed));
970
971        db.wal_hook(Some(|wal, pages| {
972            assert!(pages > 0);
973            let (log, ckpt) = wal.checkpoint_v2(super::CheckpointMode::TRUNCATE)?;
974            assert_eq!(log, 0);
975            assert_eq!(ckpt, 0);
976            Ok(())
977        }));
978        db.execute_batch("CREATE TABLE y(c);")?;
979
980        db.wal_hook(None);
981        Ok(())
982    }
983
984    #[test]
985    fn test_non_owning_hooks_cleanup() -> Result<()> {
986        let conn = Connection::open_in_memory()?;
987
988        static CALLED: AtomicBool = AtomicBool::new(false);
989        CALLED.store(false, Ordering::Relaxed);
990        conn.commit_hook(Some(|| {
991            CALLED.store(true, Ordering::Relaxed);
992            false
993        }))?;
994
995        let non_owning_conn = unsafe { Connection::from_handle(conn.handle()) }?;
996        drop(non_owning_conn);
997
998        conn.execute_batch("CREATE TABLE test(value)")?;
999        assert!(CALLED.load(Ordering::Relaxed));
1000        Ok(())
1001    }
1002}