rusqlite/
hooks.rs

1//! Commit, Data Change and Rollback Notification Callbacks
2#![allow(non_camel_case_types)]
3
4use std::os::raw::{c_char, c_int, c_void};
5use std::panic::{catch_unwind, RefUnwindSafe};
6use std::ptr;
7
8use crate::ffi;
9
10use crate::{Connection, InnerConnection};
11
12/// Action Codes
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
14#[repr(i32)]
15#[non_exhaustive]
16#[allow(clippy::upper_case_acronyms)]
17pub enum Action {
18    /// Unsupported / unexpected action
19    UNKNOWN = -1,
20    /// DELETE command
21    SQLITE_DELETE = ffi::SQLITE_DELETE,
22    /// INSERT command
23    SQLITE_INSERT = ffi::SQLITE_INSERT,
24    /// UPDATE command
25    SQLITE_UPDATE = ffi::SQLITE_UPDATE,
26}
27
28impl From<i32> for Action {
29    #[inline]
30    fn from(code: i32) -> Action {
31        match code {
32            ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
33            ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
34            ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
35            _ => Action::UNKNOWN,
36        }
37    }
38}
39
40/// The context received by an authorizer hook.
41///
42/// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44pub struct AuthContext<'c> {
45    /// The action to be authorized.
46    pub action: AuthAction<'c>,
47
48    /// The database name, if applicable.
49    pub database_name: Option<&'c str>,
50
51    /// The inner-most trigger or view responsible for the access attempt.
52    /// `None` if the access attempt was made by top-level SQL code.
53    pub accessor: Option<&'c str>,
54}
55
56/// Actions and arguments found within a statement during
57/// preparation.
58///
59/// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
60#[derive(Clone, Copy, Debug, Eq, PartialEq)]
61#[non_exhaustive]
62#[allow(missing_docs)]
63pub enum AuthAction<'c> {
64    /// This variant is not normally produced by SQLite. You may encounter it
65    // if you're using a different version than what's supported by this library.
66    Unknown {
67        /// The unknown authorization action code.
68        code: i32,
69        /// The third arg to the authorizer callback.
70        arg1: Option<&'c str>,
71        /// The fourth arg to the authorizer callback.
72        arg2: Option<&'c str>,
73    },
74    CreateIndex {
75        index_name: &'c str,
76        table_name: &'c str,
77    },
78    CreateTable {
79        table_name: &'c str,
80    },
81    CreateTempIndex {
82        index_name: &'c str,
83        table_name: &'c str,
84    },
85    CreateTempTable {
86        table_name: &'c str,
87    },
88    CreateTempTrigger {
89        trigger_name: &'c str,
90        table_name: &'c str,
91    },
92    CreateTempView {
93        view_name: &'c str,
94    },
95    CreateTrigger {
96        trigger_name: &'c str,
97        table_name: &'c str,
98    },
99    CreateView {
100        view_name: &'c str,
101    },
102    Delete {
103        table_name: &'c str,
104    },
105    DropIndex {
106        index_name: &'c str,
107        table_name: &'c str,
108    },
109    DropTable {
110        table_name: &'c str,
111    },
112    DropTempIndex {
113        index_name: &'c str,
114        table_name: &'c str,
115    },
116    DropTempTable {
117        table_name: &'c str,
118    },
119    DropTempTrigger {
120        trigger_name: &'c str,
121        table_name: &'c str,
122    },
123    DropTempView {
124        view_name: &'c str,
125    },
126    DropTrigger {
127        trigger_name: &'c str,
128        table_name: &'c str,
129    },
130    DropView {
131        view_name: &'c str,
132    },
133    Insert {
134        table_name: &'c str,
135    },
136    Pragma {
137        pragma_name: &'c str,
138        /// The pragma value, if present (e.g., `PRAGMA name = value;`).
139        pragma_value: Option<&'c str>,
140    },
141    Read {
142        table_name: &'c str,
143        column_name: &'c str,
144    },
145    Select,
146    Transaction {
147        operation: TransactionOperation,
148    },
149    Update {
150        table_name: &'c str,
151        column_name: &'c str,
152    },
153    Attach {
154        filename: &'c str,
155    },
156    Detach {
157        database_name: &'c str,
158    },
159    AlterTable {
160        database_name: &'c str,
161        table_name: &'c str,
162    },
163    Reindex {
164        index_name: &'c str,
165    },
166    Analyze {
167        table_name: &'c str,
168    },
169    CreateVtable {
170        table_name: &'c str,
171        module_name: &'c str,
172    },
173    DropVtable {
174        table_name: &'c str,
175        module_name: &'c str,
176    },
177    Function {
178        function_name: &'c str,
179    },
180    Savepoint {
181        operation: TransactionOperation,
182        savepoint_name: &'c str,
183    },
184    Recursive,
185}
186
187impl<'c> AuthAction<'c> {
188    fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
189        match (code, arg1, arg2) {
190            (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
191                index_name,
192                table_name,
193            },
194            (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
195            (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
196                Self::CreateTempIndex {
197                    index_name,
198                    table_name,
199                }
200            }
201            (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
202                Self::CreateTempTable { table_name }
203            }
204            (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
205                Self::CreateTempTrigger {
206                    trigger_name,
207                    table_name,
208                }
209            }
210            (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
211                Self::CreateTempView { view_name }
212            }
213            (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
214                Self::CreateTrigger {
215                    trigger_name,
216                    table_name,
217                }
218            }
219            (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
220            (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
221            (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
222                index_name,
223                table_name,
224            },
225            (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
226            (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
227                Self::DropTempIndex {
228                    index_name,
229                    table_name,
230                }
231            }
232            (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
233                Self::DropTempTable { table_name }
234            }
235            (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
236                Self::DropTempTrigger {
237                    trigger_name,
238                    table_name,
239                }
240            }
241            (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
242            (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
243                trigger_name,
244                table_name,
245            },
246            (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
247            (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
248            (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
249                pragma_name,
250                pragma_value,
251            },
252            (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
253                table_name,
254                column_name,
255            },
256            (ffi::SQLITE_SELECT, ..) => Self::Select,
257            (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
258                operation: TransactionOperation::from_str(operation_str),
259            },
260            (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
261                table_name,
262                column_name,
263            },
264            (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
265            (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
266            (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
267                database_name,
268                table_name,
269            },
270            (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
271            (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
272            (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
273                Self::CreateVtable {
274                    table_name,
275                    module_name,
276                }
277            }
278            (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
279                table_name,
280                module_name,
281            },
282            (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
283            (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
284                operation: TransactionOperation::from_str(operation_str),
285                savepoint_name,
286            },
287            (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
288            (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
289        }
290    }
291}
292
293pub(crate) type BoxedAuthorizer =
294    Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
295
296/// A transaction operation.
297#[derive(Clone, Copy, Debug, Eq, PartialEq)]
298#[non_exhaustive]
299#[allow(missing_docs)]
300pub enum TransactionOperation {
301    Unknown,
302    Begin,
303    Release,
304    Rollback,
305}
306
307impl TransactionOperation {
308    fn from_str(op_str: &str) -> Self {
309        match op_str {
310            "BEGIN" => Self::Begin,
311            "RELEASE" => Self::Release,
312            "ROLLBACK" => Self::Rollback,
313            _ => Self::Unknown,
314        }
315    }
316}
317
318/// [`authorizer`](Connection::authorizer) return code
319#[derive(Clone, Copy, Debug, Eq, PartialEq)]
320#[non_exhaustive]
321pub enum Authorization {
322    /// Authorize the action.
323    Allow,
324    /// Don't allow access, but don't trigger an error either.
325    Ignore,
326    /// Trigger an error.
327    Deny,
328}
329
330impl Authorization {
331    fn into_raw(self) -> c_int {
332        match self {
333            Self::Allow => ffi::SQLITE_OK,
334            Self::Ignore => ffi::SQLITE_IGNORE,
335            Self::Deny => ffi::SQLITE_DENY,
336        }
337    }
338}
339
340impl Connection {
341    /// Register a callback function to be invoked whenever
342    /// a transaction is committed.
343    ///
344    /// The callback returns `true` to rollback.
345    #[inline]
346    pub fn commit_hook<F>(&self, hook: Option<F>)
347    where
348        F: FnMut() -> bool + Send + 'static,
349    {
350        self.db.borrow_mut().commit_hook(hook);
351    }
352
353    /// Register a callback function to be invoked whenever
354    /// a transaction is committed.
355    #[inline]
356    pub fn rollback_hook<F>(&self, hook: Option<F>)
357    where
358        F: FnMut() + Send + 'static,
359    {
360        self.db.borrow_mut().rollback_hook(hook);
361    }
362
363    /// Register a callback function to be invoked whenever
364    /// a row is updated, inserted or deleted in a rowid table.
365    ///
366    /// The callback parameters are:
367    ///
368    /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
369    /// `SQLITE_DELETE`),
370    /// - the name of the database ("main", "temp", ...),
371    /// - the name of the table that is updated,
372    /// - the ROWID of the row that is updated.
373    #[inline]
374    pub fn update_hook<F>(&self, hook: Option<F>)
375    where
376        F: FnMut(Action, &str, &str, i64) + Send + 'static,
377    {
378        self.db.borrow_mut().update_hook(hook);
379    }
380
381    /// Register a query progress callback.
382    ///
383    /// The parameter `num_ops` is the approximate number of virtual machine
384    /// instructions that are evaluated between successive invocations of the
385    /// `handler`. If `num_ops` is less than one then the progress handler
386    /// is disabled.
387    ///
388    /// If the progress callback returns `true`, the operation is interrupted.
389    pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
390    where
391        F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
392    {
393        self.db.borrow_mut().progress_handler(num_ops, handler);
394    }
395
396    /// Register an authorizer callback that's invoked
397    /// as a statement is being prepared.
398    #[inline]
399    pub fn authorizer<'c, F>(&self, hook: Option<F>)
400    where
401        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
402    {
403        self.db.borrow_mut().authorizer(hook);
404    }
405}
406
407impl InnerConnection {
408    #[inline]
409    pub fn remove_hooks(&mut self) {
410        self.update_hook(None::<fn(Action, &str, &str, i64)>);
411        self.commit_hook(None::<fn() -> bool>);
412        self.rollback_hook(None::<fn()>);
413        self.progress_handler(0, None::<fn() -> bool>);
414        self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
415    }
416
417    fn commit_hook<F>(&mut self, hook: Option<F>)
418    where
419        F: FnMut() -> bool + Send + 'static,
420    {
421        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
422        where
423            F: FnMut() -> bool,
424        {
425            let r = catch_unwind(|| {
426                let boxed_hook: *mut F = p_arg.cast::<F>();
427                (*boxed_hook)()
428            });
429            c_int::from(r.unwrap_or_default())
430        }
431
432        // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
433        // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
434        // `InnerConnection.free_boxed_hook`.
435        let free_commit_hook = if hook.is_some() {
436            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
437        } else {
438            None
439        };
440
441        let previous_hook = match hook {
442            Some(hook) => {
443                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
444                unsafe {
445                    ffi::sqlite3_commit_hook(
446                        self.db(),
447                        Some(call_boxed_closure::<F>),
448                        boxed_hook.cast(),
449                    )
450                }
451            }
452            _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
453        };
454        if !previous_hook.is_null() {
455            if let Some(free_boxed_hook) = self.free_commit_hook {
456                unsafe { free_boxed_hook(previous_hook) };
457            }
458        }
459        self.free_commit_hook = free_commit_hook;
460    }
461
462    fn rollback_hook<F>(&mut self, hook: Option<F>)
463    where
464        F: FnMut() + Send + 'static,
465    {
466        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
467        where
468            F: FnMut(),
469        {
470            drop(catch_unwind(|| {
471                let boxed_hook: *mut F = p_arg.cast::<F>();
472                (*boxed_hook)();
473            }));
474        }
475
476        let free_rollback_hook = if hook.is_some() {
477            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
478        } else {
479            None
480        };
481
482        let previous_hook = match hook {
483            Some(hook) => {
484                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
485                unsafe {
486                    ffi::sqlite3_rollback_hook(
487                        self.db(),
488                        Some(call_boxed_closure::<F>),
489                        boxed_hook.cast(),
490                    )
491                }
492            }
493            _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
494        };
495        if !previous_hook.is_null() {
496            if let Some(free_boxed_hook) = self.free_rollback_hook {
497                unsafe { free_boxed_hook(previous_hook) };
498            }
499        }
500        self.free_rollback_hook = free_rollback_hook;
501    }
502
503    fn update_hook<F>(&mut self, hook: Option<F>)
504    where
505        F: FnMut(Action, &str, &str, i64) + Send + 'static,
506    {
507        unsafe extern "C" fn call_boxed_closure<F>(
508            p_arg: *mut c_void,
509            action_code: c_int,
510            p_db_name: *const c_char,
511            p_table_name: *const c_char,
512            row_id: i64,
513        ) where
514            F: FnMut(Action, &str, &str, i64),
515        {
516            let action = Action::from(action_code);
517            drop(catch_unwind(|| {
518                let boxed_hook: *mut F = p_arg.cast::<F>();
519                (*boxed_hook)(
520                    action,
521                    expect_utf8(p_db_name, "database name"),
522                    expect_utf8(p_table_name, "table name"),
523                    row_id,
524                );
525            }));
526        }
527
528        let free_update_hook = if hook.is_some() {
529            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
530        } else {
531            None
532        };
533
534        let previous_hook = match hook {
535            Some(hook) => {
536                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
537                unsafe {
538                    ffi::sqlite3_update_hook(
539                        self.db(),
540                        Some(call_boxed_closure::<F>),
541                        boxed_hook.cast(),
542                    )
543                }
544            }
545            _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
546        };
547        if !previous_hook.is_null() {
548            if let Some(free_boxed_hook) = self.free_update_hook {
549                unsafe { free_boxed_hook(previous_hook) };
550            }
551        }
552        self.free_update_hook = free_update_hook;
553    }
554
555    fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
556    where
557        F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
558    {
559        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
560        where
561            F: FnMut() -> bool,
562        {
563            let r = catch_unwind(|| {
564                let boxed_handler: *mut F = p_arg.cast::<F>();
565                (*boxed_handler)()
566            });
567            c_int::from(r.unwrap_or_default())
568        }
569
570        if let Some(handler) = handler {
571            let boxed_handler = Box::new(handler);
572            unsafe {
573                ffi::sqlite3_progress_handler(
574                    self.db(),
575                    num_ops,
576                    Some(call_boxed_closure::<F>),
577                    &*boxed_handler as *const F as *mut _,
578                );
579            }
580            self.progress_handler = Some(boxed_handler);
581        } else {
582            unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) }
583            self.progress_handler = None;
584        };
585    }
586
587    fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
588    where
589        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
590    {
591        unsafe extern "C" fn call_boxed_closure<'c, F>(
592            p_arg: *mut c_void,
593            action_code: c_int,
594            param1: *const c_char,
595            param2: *const c_char,
596            db_name: *const c_char,
597            trigger_or_view_name: *const c_char,
598        ) -> c_int
599        where
600            F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
601        {
602            catch_unwind(|| {
603                let action = AuthAction::from_raw(
604                    action_code,
605                    expect_optional_utf8(param1, "authorizer param 1"),
606                    expect_optional_utf8(param2, "authorizer param 2"),
607                );
608                let auth_ctx = AuthContext {
609                    action,
610                    database_name: expect_optional_utf8(db_name, "database name"),
611                    accessor: expect_optional_utf8(
612                        trigger_or_view_name,
613                        "accessor (inner-most trigger or view)",
614                    ),
615                };
616                let boxed_hook: *mut F = p_arg.cast::<F>();
617                (*boxed_hook)(auth_ctx)
618            })
619            .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
620        }
621
622        let callback_fn = authorizer
623            .as_ref()
624            .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _);
625        let boxed_authorizer = authorizer.map(Box::new);
626
627        match unsafe {
628            ffi::sqlite3_set_authorizer(
629                self.db(),
630                callback_fn,
631                boxed_authorizer
632                    .as_ref()
633                    .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
634            )
635        } {
636            ffi::SQLITE_OK => {
637                self.authorizer = boxed_authorizer.map(|ba| ba as _);
638            }
639            err_code => {
640                // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
641                // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
642                // This library does not allow constructing a null db ptr, so if this branch
643                // is hit, something very bad has happened. Panicking instead of returning
644                // `Result` keeps this hook's API consistent with the others.
645                panic!("unexpectedly failed to set_authorizer: {}", unsafe {
646                    crate::error::error_from_handle(self.db(), err_code)
647                });
648            }
649        }
650    }
651}
652
653unsafe fn free_boxed_hook<F>(p: *mut c_void) {
654    drop(Box::from_raw(p.cast::<F>()));
655}
656
657unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
658    expect_optional_utf8(p_str, description)
659        .unwrap_or_else(|| panic!("received empty {}", description))
660}
661
662unsafe fn expect_optional_utf8<'a>(
663    p_str: *const c_char,
664    description: &'static str,
665) -> Option<&'a str> {
666    if p_str.is_null() {
667        return None;
668    }
669    std::str::from_utf8(std::ffi::CStr::from_ptr(p_str).to_bytes())
670        .unwrap_or_else(|_| panic!("received non-utf8 string as {}", description))
671        .into()
672}
673
674#[cfg(test)]
675mod test {
676    use super::Action;
677    use crate::{Connection, Result};
678    use std::sync::atomic::{AtomicBool, Ordering};
679
680    #[test]
681    fn test_commit_hook() -> Result<()> {
682        let db = Connection::open_in_memory()?;
683
684        static CALLED: AtomicBool = AtomicBool::new(false);
685        db.commit_hook(Some(|| {
686            CALLED.store(true, Ordering::Relaxed);
687            false
688        }));
689        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
690        assert!(CALLED.load(Ordering::Relaxed));
691        Ok(())
692    }
693
694    #[test]
695    fn test_fn_commit_hook() -> Result<()> {
696        let db = Connection::open_in_memory()?;
697
698        fn hook() -> bool {
699            true
700        }
701
702        db.commit_hook(Some(hook));
703        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
704            .unwrap_err();
705        Ok(())
706    }
707
708    #[test]
709    fn test_rollback_hook() -> Result<()> {
710        let db = Connection::open_in_memory()?;
711
712        static CALLED: AtomicBool = AtomicBool::new(false);
713        db.rollback_hook(Some(|| {
714            CALLED.store(true, Ordering::Relaxed);
715        }));
716        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
717        assert!(CALLED.load(Ordering::Relaxed));
718        Ok(())
719    }
720
721    #[test]
722    fn test_update_hook() -> Result<()> {
723        let db = Connection::open_in_memory()?;
724
725        static CALLED: AtomicBool = AtomicBool::new(false);
726        db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
727            assert_eq!(Action::SQLITE_INSERT, action);
728            assert_eq!("main", db);
729            assert_eq!("foo", tbl);
730            assert_eq!(1, row_id);
731            CALLED.store(true, Ordering::Relaxed);
732        }));
733        db.execute_batch("CREATE TABLE foo (t TEXT)")?;
734        db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
735        assert!(CALLED.load(Ordering::Relaxed));
736        Ok(())
737    }
738
739    #[test]
740    fn test_progress_handler() -> Result<()> {
741        let db = Connection::open_in_memory()?;
742
743        static CALLED: AtomicBool = AtomicBool::new(false);
744        db.progress_handler(
745            1,
746            Some(|| {
747                CALLED.store(true, Ordering::Relaxed);
748                false
749            }),
750        );
751        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
752        assert!(CALLED.load(Ordering::Relaxed));
753        Ok(())
754    }
755
756    #[test]
757    fn test_progress_handler_interrupt() -> Result<()> {
758        let db = Connection::open_in_memory()?;
759
760        fn handler() -> bool {
761            true
762        }
763
764        db.progress_handler(1, Some(handler));
765        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
766            .unwrap_err();
767        Ok(())
768    }
769
770    #[test]
771    fn test_authorizer() -> Result<()> {
772        use super::{AuthAction, AuthContext, Authorization};
773
774        let db = Connection::open_in_memory()?;
775        db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")
776            .unwrap();
777
778        let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
779            AuthAction::Read { column_name, .. } if column_name == "private" => {
780                Authorization::Ignore
781            }
782            AuthAction::DropTable { .. } => Authorization::Deny,
783            AuthAction::Pragma { .. } => panic!("shouldn't be called"),
784            _ => Authorization::Allow,
785        };
786
787        db.authorizer(Some(authorizer));
788        db.execute_batch(
789            "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
790        )
791        .unwrap();
792        db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
793            assert_eq!(row.get::<_, String>("public")?, "pub txt");
794            assert!(row.get::<_, Option<String>>("private")?.is_none());
795            Ok(())
796        })
797        .unwrap();
798        db.execute_batch("DROP TABLE foo").unwrap_err();
799
800        db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
801        db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed.
802
803        Ok(())
804    }
805}