atuin_client/
database.rs

1use std::{
2    borrow::Cow,
3    env,
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{Rng, distributions::Alphanumeric};
14use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote};
15use sqlx::{
16    Result, Row,
17    sqlite::{
18        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
19        SqliteSynchronous,
20    },
21};
22use time::OffsetDateTime;
23
24use crate::{
25    history::{HistoryId, HistoryStats},
26    utils::get_host_user,
27};
28
29use super::{
30    history::History,
31    ordering,
32    settings::{FilterMode, SearchMode, Settings},
33};
34
35pub struct Context {
36    pub session: String,
37    pub cwd: String,
38    pub hostname: String,
39    pub host_id: String,
40    pub git_root: Option<PathBuf>,
41}
42
43#[derive(Default, Clone)]
44pub struct OptFilters {
45    pub exit: Option<i64>,
46    pub exclude_exit: Option<i64>,
47    pub cwd: Option<String>,
48    pub exclude_cwd: Option<String>,
49    pub before: Option<String>,
50    pub after: Option<String>,
51    pub limit: Option<i64>,
52    pub offset: Option<i64>,
53    pub reverse: bool,
54    pub include_duplicates: bool,
55}
56
57pub fn current_context() -> Context {
58    let Ok(session) = env::var("ATUIN_SESSION") else {
59        eprintln!(
60            "ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."
61        );
62        std::process::exit(1);
63    };
64    let hostname = get_host_user();
65    let cwd = utils::get_current_dir();
66    let host_id = Settings::host_id().expect("failed to load host ID");
67    let git_root = utils::in_git_repo(cwd.as_str());
68
69    Context {
70        session,
71        hostname,
72        cwd,
73        git_root,
74        host_id: host_id.0.as_simple().to_string(),
75    }
76}
77
78#[async_trait]
79pub trait Database: Send + Sync + 'static {
80    async fn save(&self, h: &History) -> Result<()>;
81    async fn save_bulk(&self, h: &[History]) -> Result<()>;
82
83    async fn load(&self, id: &str) -> Result<Option<History>>;
84    async fn list(
85        &self,
86        filters: &[FilterMode],
87        context: &Context,
88        max: Option<usize>,
89        unique: bool,
90        include_deleted: bool,
91    ) -> Result<Vec<History>>;
92    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
93
94    async fn update(&self, h: &History) -> Result<()>;
95    async fn history_count(&self, include_deleted: bool) -> Result<i64>;
96
97    async fn last(&self) -> Result<Option<History>>;
98    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
99
100    async fn delete(&self, h: History) -> Result<()>;
101    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
102    async fn deleted(&self) -> Result<Vec<History>>;
103
104    // Yes I know, it's a lot.
105    // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless.
106    // Been debating maybe a DSL for search? eg "before:time limit:1 the query"
107    #[allow(clippy::too_many_arguments)]
108    async fn search(
109        &self,
110        search_mode: SearchMode,
111        filter: FilterMode,
112        context: &Context,
113        query: &str,
114        filter_options: OptFilters,
115    ) -> Result<Vec<History>>;
116
117    async fn query_history(&self, query: &str) -> Result<Vec<History>>;
118
119    async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
120
121    async fn stats(&self, h: &History) -> Result<HistoryStats>;
122}
123
124// Intended for use on a developer machine and not a sync server.
125// TODO: implement IntoIterator
126#[derive(Debug, Clone)]
127pub struct Sqlite {
128    pub pool: SqlitePool,
129}
130
131impl Sqlite {
132    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
133        let path = path.as_ref();
134        debug!("opening sqlite database at {:?}", path);
135
136        if utils::broken_symlink(path) {
137            eprintln!(
138                "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
139            );
140            std::process::exit(1);
141        }
142
143        if !path.exists() {
144            if let Some(dir) = path.parent() {
145                fs::create_dir_all(dir)?;
146            }
147        }
148
149        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
150            .journal_mode(SqliteJournalMode::Wal)
151            .optimize_on_close(true, None)
152            .synchronous(SqliteSynchronous::Normal)
153            .with_regexp()
154            .create_if_missing(true);
155
156        let pool = SqlitePoolOptions::new()
157            .acquire_timeout(Duration::from_secs_f64(timeout))
158            .connect_with(opts)
159            .await?;
160
161        Self::setup_db(&pool).await?;
162        Ok(Self { pool })
163    }
164
165    pub async fn sqlite_version(&self) -> Result<String> {
166        sqlx::query_scalar("SELECT sqlite_version()")
167            .fetch_one(&self.pool)
168            .await
169    }
170
171    async fn setup_db(pool: &SqlitePool) -> Result<()> {
172        debug!("running sqlite database setup");
173
174        sqlx::migrate!("./migrations").run(pool).await?;
175
176        Ok(())
177    }
178
179    async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
180        sqlx::query(
181            "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
182                values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
183        )
184        .bind(h.id.0.as_str())
185        .bind(h.timestamp.unix_timestamp_nanos() as i64)
186        .bind(h.duration)
187        .bind(h.exit)
188        .bind(h.command.as_str())
189        .bind(h.cwd.as_str())
190        .bind(h.session.as_str())
191        .bind(h.hostname.as_str())
192        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
193        .execute(&mut **tx)
194        .await?;
195
196        Ok(())
197    }
198
199    async fn delete_row_raw(
200        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
201        id: HistoryId,
202    ) -> Result<()> {
203        sqlx::query("delete from history where id = ?1")
204            .bind(id.0.as_str())
205            .execute(&mut **tx)
206            .await?;
207
208        Ok(())
209    }
210
211    fn query_history(row: SqliteRow) -> History {
212        let deleted_at: Option<i64> = row.get("deleted_at");
213
214        History::from_db()
215            .id(row.get("id"))
216            .timestamp(
217                OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
218                    .unwrap(),
219            )
220            .duration(row.get("duration"))
221            .exit(row.get("exit"))
222            .command(row.get("command"))
223            .cwd(row.get("cwd"))
224            .session(row.get("session"))
225            .hostname(row.get("hostname"))
226            .deleted_at(
227                deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
228            )
229            .build()
230            .into()
231    }
232}
233
234#[async_trait]
235impl Database for Sqlite {
236    async fn save(&self, h: &History) -> Result<()> {
237        debug!("saving history to sqlite");
238        let mut tx = self.pool.begin().await?;
239        Self::save_raw(&mut tx, h).await?;
240        tx.commit().await?;
241
242        Ok(())
243    }
244
245    async fn save_bulk(&self, h: &[History]) -> Result<()> {
246        debug!("saving history to sqlite");
247
248        let mut tx = self.pool.begin().await?;
249
250        for i in h {
251            Self::save_raw(&mut tx, i).await?;
252        }
253
254        tx.commit().await?;
255
256        Ok(())
257    }
258
259    async fn load(&self, id: &str) -> Result<Option<History>> {
260        debug!("loading history item {}", id);
261
262        let res = sqlx::query("select * from history where id = ?1")
263            .bind(id)
264            .map(Self::query_history)
265            .fetch_optional(&self.pool)
266            .await?;
267
268        Ok(res)
269    }
270
271    async fn update(&self, h: &History) -> Result<()> {
272        debug!("updating sqlite history");
273
274        sqlx::query(
275            "update history
276                set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
277                where id = ?1",
278        )
279        .bind(h.id.0.as_str())
280        .bind(h.timestamp.unix_timestamp_nanos() as i64)
281        .bind(h.duration)
282        .bind(h.exit)
283        .bind(h.command.as_str())
284        .bind(h.cwd.as_str())
285        .bind(h.session.as_str())
286        .bind(h.hostname.as_str())
287        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
288        .execute(&self.pool)
289        .await?;
290
291        Ok(())
292    }
293
294    // make a unique list, that only shows the *newest* version of things
295    async fn list(
296        &self,
297        filters: &[FilterMode],
298        context: &Context,
299        max: Option<usize>,
300        unique: bool,
301        include_deleted: bool,
302    ) -> Result<Vec<History>> {
303        debug!("listing history");
304
305        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
306        query.field("*").order_desc("timestamp");
307        if !include_deleted {
308            query.and_where_is_null("deleted_at");
309        }
310
311        let git_root = if let Some(git_root) = context.git_root.clone() {
312            git_root.to_str().unwrap_or("/").to_string()
313        } else {
314            context.cwd.clone()
315        };
316
317        for filter in filters {
318            match filter {
319                FilterMode::Global => &mut query,
320                FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
321                FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
322                FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
323                FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
324            };
325        }
326
327        if unique {
328            query.group_by("command").having("max(timestamp)");
329        }
330
331        if let Some(max) = max {
332            query.limit(max);
333        }
334
335        let query = query.sql().expect("bug in list query. please report");
336
337        let res = sqlx::query(&query)
338            .map(Self::query_history)
339            .fetch_all(&self.pool)
340            .await?;
341
342        Ok(res)
343    }
344
345    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
346        debug!("listing history from {:?} to {:?}", from, to);
347
348        let res = sqlx::query(
349            "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
350        )
351        .bind(from.unix_timestamp_nanos() as i64)
352        .bind(to.unix_timestamp_nanos() as i64)
353            .map(Self::query_history)
354        .fetch_all(&self.pool)
355        .await?;
356
357        Ok(res)
358    }
359
360    async fn last(&self) -> Result<Option<History>> {
361        let res = sqlx::query(
362            "select * from history where duration >= 0 order by timestamp desc limit 1",
363        )
364        .map(Self::query_history)
365        .fetch_optional(&self.pool)
366        .await?;
367
368        Ok(res)
369    }
370
371    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
372        let res = sqlx::query(
373            "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
374        )
375        .bind(timestamp.unix_timestamp_nanos() as i64)
376        .bind(count)
377        .map(Self::query_history)
378        .fetch_all(&self.pool)
379        .await?;
380
381        Ok(res)
382    }
383
384    async fn deleted(&self) -> Result<Vec<History>> {
385        let res = sqlx::query("select * from history where deleted_at is not null")
386            .map(Self::query_history)
387            .fetch_all(&self.pool)
388            .await?;
389
390        Ok(res)
391    }
392
393    async fn history_count(&self, include_deleted: bool) -> Result<i64> {
394        let query = if include_deleted {
395            "select count(1) from history"
396        } else {
397            "select count(1) from history where deleted_at is null"
398        };
399
400        let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
401        Ok(res.0)
402    }
403
404    async fn search(
405        &self,
406        search_mode: SearchMode,
407        filter: FilterMode,
408        context: &Context,
409        query: &str,
410        filter_options: OptFilters,
411    ) -> Result<Vec<History>> {
412        let mut sql = SqlBuilder::select_from("history");
413
414        if !filter_options.include_duplicates {
415            sql.group_by("command").having("max(timestamp)");
416        }
417
418        if let Some(limit) = filter_options.limit {
419            sql.limit(limit);
420        }
421
422        if let Some(offset) = filter_options.offset {
423            sql.offset(offset);
424        }
425
426        if filter_options.reverse {
427            sql.order_asc("timestamp");
428        } else {
429            sql.order_desc("timestamp");
430        }
431
432        let git_root = if let Some(git_root) = context.git_root.clone() {
433            git_root.to_str().unwrap_or("/").to_string()
434        } else {
435            context.cwd.clone()
436        };
437
438        match filter {
439            FilterMode::Global => &mut sql,
440            FilterMode::Host => {
441                sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
442            }
443            FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
444            FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
445            FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
446        };
447
448        let orig_query = query;
449
450        let mut regexes = Vec::new();
451        match search_mode {
452            SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
453            _ => {
454                let mut is_or = false;
455                let mut regex = None;
456                for part in query.split_inclusive(' ') {
457                    let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
458                        (None, false) => {
459                            if part.trim_end().is_empty() {
460                                continue;
461                            }
462                            Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
463                        }
464                        (None, true) => {
465                            if part[2..].trim_end().ends_with('/') {
466                                let end_pos = part.trim_end().len() - 1;
467                                regexes.push(String::from(&part[2..end_pos]));
468                            } else {
469                                regex = Some(String::from(&part[2..]));
470                            }
471                            continue;
472                        }
473                        (Some(r), _) => {
474                            if part.trim_end().ends_with('/') {
475                                let end_pos = part.trim_end().len() - 1;
476                                r.push_str(&part.trim_end()[..end_pos]);
477                                regexes.push(regex.take().unwrap());
478                            } else {
479                                r.push_str(part);
480                            }
481                            continue;
482                        }
483                    };
484
485                    // TODO smart case mode could be made configurable like in fzf
486                    let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
487                        (true, "*")
488                    } else {
489                        (false, "%")
490                    };
491
492                    let (is_inverse, query_part) = match query_part.strip_prefix('!') {
493                        Some(stripped) => (true, Cow::Borrowed(stripped)),
494                        None => (false, query_part),
495                    };
496
497                    #[allow(clippy::if_same_then_else)]
498                    let param = if query_part == "|" {
499                        if !is_or {
500                            is_or = true;
501                            continue;
502                        } else {
503                            format!("{glob}|{glob}")
504                        }
505                    } else if let Some(term) = query_part.strip_prefix('^') {
506                        format!("{term}{glob}")
507                    } else if let Some(term) = query_part.strip_suffix('$') {
508                        format!("{glob}{term}")
509                    } else if let Some(term) = query_part.strip_prefix('\'') {
510                        format!("{glob}{term}{glob}")
511                    } else if is_inverse {
512                        format!("{glob}{query_part}{glob}")
513                    } else if search_mode == SearchMode::FullText {
514                        format!("{glob}{query_part}{glob}")
515                    } else {
516                        query_part.split("").join(glob)
517                    };
518
519                    sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
520                    is_or = false;
521                }
522                if let Some(r) = regex {
523                    regexes.push(r);
524                }
525
526                &mut sql
527            }
528        };
529
530        for regex in regexes {
531            sql.and_where("command regexp ?".bind(&regex));
532        }
533
534        filter_options
535            .exit
536            .map(|exit| sql.and_where_eq("exit", exit));
537
538        filter_options
539            .exclude_exit
540            .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
541
542        filter_options
543            .cwd
544            .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
545
546        filter_options
547            .exclude_cwd
548            .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
549
550        filter_options.before.map(|before| {
551            interim::parse_date_string(
552                before.as_str(),
553                OffsetDateTime::now_utc(),
554                interim::Dialect::Uk,
555            )
556            .map(|before| {
557                sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
558            })
559        });
560
561        filter_options.after.map(|after| {
562            interim::parse_date_string(
563                after.as_str(),
564                OffsetDateTime::now_utc(),
565                interim::Dialect::Uk,
566            )
567            .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
568        });
569
570        sql.and_where_is_null("deleted_at");
571
572        let query = sql.sql().expect("bug in search query. please report");
573
574        let res = sqlx::query(&query)
575            .map(Self::query_history)
576            .fetch_all(&self.pool)
577            .await?;
578
579        Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
580    }
581
582    async fn query_history(&self, query: &str) -> Result<Vec<History>> {
583        let res = sqlx::query(query)
584            .map(Self::query_history)
585            .fetch_all(&self.pool)
586            .await?;
587
588        Ok(res)
589    }
590
591    async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
592        debug!("listing history");
593
594        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
595
596        query
597            .fields(&[
598                "id",
599                "max(timestamp) as timestamp",
600                "max(duration) as duration",
601                "exit",
602                "command",
603                "deleted_at",
604                "group_concat(cwd, ':') as cwd",
605                "group_concat(session) as session",
606                "group_concat(hostname, ',') as hostname",
607                "count(*) as count",
608            ])
609            .group_by("command")
610            .group_by("exit")
611            .and_where("deleted_at is null")
612            .order_desc("timestamp");
613
614        let query = query.sql().expect("bug in list query. please report");
615
616        let res = sqlx::query(&query)
617            .map(|row: SqliteRow| {
618                let count: i32 = row.get("count");
619                (Self::query_history(row), count)
620            })
621            .fetch_all(&self.pool)
622            .await?;
623
624        Ok(res)
625    }
626
627    // deleted_at doesn't mean the actual time that the user deleted it,
628    // but the time that the system marks it as deleted
629    async fn delete(&self, mut h: History) -> Result<()> {
630        let now = OffsetDateTime::now_utc();
631        h.command = rand::thread_rng()
632            .sample_iter(&Alphanumeric)
633            .take(32)
634            .map(char::from)
635            .collect(); // overwrite with random string
636        h.deleted_at = Some(now); // delete it
637
638        self.update(&h).await?; // save it
639
640        Ok(())
641    }
642
643    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
644        let mut tx = self.pool.begin().await?;
645
646        for id in ids {
647            Self::delete_row_raw(&mut tx, id.clone()).await?;
648        }
649
650        tx.commit().await?;
651
652        Ok(())
653    }
654
655    async fn stats(&self, h: &History) -> Result<HistoryStats> {
656        // We select the previous in the session by time
657        let mut prev = SqlBuilder::select_from("history");
658        prev.field("*")
659            .and_where("timestamp < ?1")
660            .and_where("session = ?2")
661            .order_by("timestamp", true)
662            .limit(1);
663
664        let mut next = SqlBuilder::select_from("history");
665        next.field("*")
666            .and_where("timestamp > ?1")
667            .and_where("session = ?2")
668            .order_by("timestamp", false)
669            .limit(1);
670
671        let mut total = SqlBuilder::select_from("history");
672        total.field("count(1)").and_where("command = ?1");
673
674        let mut average = SqlBuilder::select_from("history");
675        average.field("avg(duration)").and_where("command = ?1");
676
677        let mut exits = SqlBuilder::select_from("history");
678        exits
679            .fields(&["exit", "count(1) as count"])
680            .and_where("command = ?1")
681            .group_by("exit");
682
683        // rewrite the following with sqlbuilder
684        let mut day_of_week = SqlBuilder::select_from("history");
685        day_of_week
686            .fields(&[
687                "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
688                "count(1) as count",
689            ])
690            .and_where("command = ?1")
691            .group_by("day_of_week");
692
693        // Intentionally format the string with 01 hardcoded. We want the average runtime for the
694        // _entire month_, but will later parse it as a datetime for sorting
695        // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a
696        // string sort, which won't be correct.
697        let mut duration_over_time = SqlBuilder::select_from("history");
698        duration_over_time
699            .fields(&[
700                "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
701                "avg(duration) as duration",
702            ])
703            .and_where("command = ?1")
704            .group_by("month_year")
705            .having("duration > 0");
706
707        let prev = prev.sql().expect("issue in stats previous query");
708        let next = next.sql().expect("issue in stats next query");
709        let total = total.sql().expect("issue in stats average query");
710        let average = average.sql().expect("issue in stats previous query");
711        let exits = exits.sql().expect("issue in stats exits query");
712        let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
713        let duration_over_time = duration_over_time
714            .sql()
715            .expect("issue in stats duration over time query");
716
717        let prev = sqlx::query(&prev)
718            .bind(h.timestamp.unix_timestamp_nanos() as i64)
719            .bind(&h.session)
720            .map(Self::query_history)
721            .fetch_optional(&self.pool)
722            .await?;
723
724        let next = sqlx::query(&next)
725            .bind(h.timestamp.unix_timestamp_nanos() as i64)
726            .bind(&h.session)
727            .map(Self::query_history)
728            .fetch_optional(&self.pool)
729            .await?;
730
731        let total: (i64,) = sqlx::query_as(&total)
732            .bind(&h.command)
733            .fetch_one(&self.pool)
734            .await?;
735
736        let average: (f64,) = sqlx::query_as(&average)
737            .bind(&h.command)
738            .fetch_one(&self.pool)
739            .await?;
740
741        let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
742            .bind(&h.command)
743            .fetch_all(&self.pool)
744            .await?;
745
746        let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
747            .bind(&h.command)
748            .fetch_all(&self.pool)
749            .await?;
750
751        let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
752            .bind(&h.command)
753            .fetch_all(&self.pool)
754            .await?;
755
756        let duration_over_time = duration_over_time
757            .iter()
758            .map(|f| (f.0.clone(), f.1.round() as i64))
759            .collect();
760
761        Ok(HistoryStats {
762            next,
763            previous: prev,
764            total: total.0 as u64,
765            average_duration: average.0 as u64,
766            exits,
767            day_of_week,
768            duration_over_time,
769        })
770    }
771}
772
773trait SqlBuilderExt {
774    fn fuzzy_condition<S: ToString, T: ToString>(
775        &mut self,
776        field: S,
777        mask: T,
778        inverse: bool,
779        glob: bool,
780        is_or: bool,
781    ) -> &mut Self;
782}
783
784impl SqlBuilderExt for SqlBuilder {
785    /// adapted from the sql-builder *like functions
786    fn fuzzy_condition<S: ToString, T: ToString>(
787        &mut self,
788        field: S,
789        mask: T,
790        inverse: bool,
791        glob: bool,
792        is_or: bool,
793    ) -> &mut Self {
794        let mut cond = field.to_string();
795        if inverse {
796            cond.push_str(" NOT");
797        }
798        if glob {
799            cond.push_str(" GLOB '");
800        } else {
801            cond.push_str(" LIKE '");
802        }
803        cond.push_str(&esc(mask.to_string()));
804        cond.push('\'');
805        if is_or {
806            self.or_where(cond)
807        } else {
808            self.and_where(cond)
809        }
810    }
811}
812
813#[cfg(test)]
814mod test {
815    use crate::settings::test_local_timeout;
816
817    use super::*;
818    use std::time::{Duration, Instant};
819
820    async fn assert_search_eq<'a>(
821        db: &impl Database,
822        mode: SearchMode,
823        filter_mode: FilterMode,
824        query: &str,
825        expected: usize,
826    ) -> Result<Vec<History>> {
827        let context = Context {
828            hostname: "test:host".to_string(),
829            session: "beepboopiamasession".to_string(),
830            cwd: "/home/ellie".to_string(),
831            host_id: "test-host".to_string(),
832            git_root: None,
833        };
834
835        let results = db
836            .search(
837                mode,
838                filter_mode,
839                &context,
840                query,
841                OptFilters {
842                    ..Default::default()
843                },
844            )
845            .await?;
846
847        assert_eq!(
848            results.len(),
849            expected,
850            "query \"{}\", commands: {:?}",
851            query,
852            results.iter().map(|a| &a.command).collect::<Vec<&String>>()
853        );
854        Ok(results)
855    }
856
857    async fn assert_search_commands(
858        db: &impl Database,
859        mode: SearchMode,
860        filter_mode: FilterMode,
861        query: &str,
862        expected_commands: Vec<&str>,
863    ) {
864        let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
865            .await
866            .unwrap();
867        let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
868        assert_eq!(commands, expected_commands);
869    }
870
871    async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
872        let mut captured: History = History::capture()
873            .timestamp(OffsetDateTime::now_utc())
874            .command(cmd)
875            .cwd("/home/ellie")
876            .build()
877            .into();
878
879        captured.exit = 0;
880        captured.duration = 1;
881        captured.session = "beep boop".to_string();
882        captured.hostname = "booop".to_string();
883
884        db.save(&captured).await
885    }
886
887    #[tokio::test(flavor = "multi_thread")]
888    async fn test_search_prefix() {
889        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
890            .await
891            .unwrap();
892        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
893
894        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
895            .await
896            .unwrap();
897        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
898            .await
899            .unwrap();
900        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls  ", 0)
901            .await
902            .unwrap();
903    }
904
905    #[tokio::test(flavor = "multi_thread")]
906    async fn test_search_fulltext() {
907        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
908            .await
909            .unwrap();
910        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
911
912        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
913            .await
914            .unwrap();
915        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
916            .await
917            .unwrap();
918        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
919            .await
920            .unwrap();
921        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
922            .await
923            .unwrap();
924
925        // regex
926        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
927            .await
928            .unwrap();
929        assert_search_eq(
930            &db,
931            SearchMode::FullText,
932            FilterMode::Global,
933            "r/ls / ie$",
934            1,
935        )
936        .await
937        .unwrap();
938        assert_search_eq(
939            &db,
940            SearchMode::FullText,
941            FilterMode::Global,
942            "r/ls / !ie",
943            0,
944        )
945        .await
946        .unwrap();
947        assert_search_eq(
948            &db,
949            SearchMode::FullText,
950            FilterMode::Global,
951            "meow r/ls/",
952            0,
953        )
954        .await
955        .unwrap();
956        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
957            .await
958            .unwrap();
959        assert_search_eq(
960            &db,
961            SearchMode::FullText,
962            FilterMode::Global,
963            "r//home//",
964            1,
965        )
966        .await
967        .unwrap();
968        assert_search_eq(
969            &db,
970            SearchMode::FullText,
971            FilterMode::Global,
972            "r//home///",
973            0,
974        )
975        .await
976        .unwrap();
977        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
978            .await
979            .unwrap();
980        assert_search_eq(
981            &db,
982            SearchMode::FullText,
983            FilterMode::Global,
984            "r/home.*e",
985            1,
986        )
987        .await
988        .unwrap();
989    }
990
991    #[tokio::test(flavor = "multi_thread")]
992    async fn test_search_fuzzy() {
993        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
994            .await
995            .unwrap();
996        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
997        new_history_item(&mut db, "ls /home/frank").await.unwrap();
998        new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
999        new_history_item(&mut db, "/home/ellie/.bin/rustup")
1000            .await
1001            .unwrap();
1002
1003        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
1004            .await
1005            .unwrap();
1006        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
1007            .await
1008            .unwrap();
1009        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1010            .await
1011            .unwrap();
1012        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1013            .await
1014            .unwrap();
1015        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1016            .await
1017            .unwrap();
1018        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1019            .await
1020            .unwrap();
1021        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1022            .await
1023            .unwrap();
1024        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1025            .await
1026            .unwrap();
1027
1028        // single term operators
1029        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1030            .await
1031            .unwrap();
1032        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1033            .await
1034            .unwrap();
1035        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1036            .await
1037            .unwrap();
1038        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1039            .await
1040            .unwrap();
1041        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1042            .await
1043            .unwrap();
1044        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1045            .await
1046            .unwrap();
1047
1048        // multiple terms
1049        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1050            .await
1051            .unwrap();
1052        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1053            .await
1054            .unwrap();
1055        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1056            .await
1057            .unwrap();
1058        assert_search_eq(
1059            &db,
1060            SearchMode::Fuzzy,
1061            FilterMode::Global,
1062            "'frank | 'rustup",
1063            2,
1064        )
1065        .await
1066        .unwrap();
1067        assert_search_eq(
1068            &db,
1069            SearchMode::Fuzzy,
1070            FilterMode::Global,
1071            "'frank | 'rustup 'ls",
1072            1,
1073        )
1074        .await
1075        .unwrap();
1076
1077        // case matching
1078        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1079            .await
1080            .unwrap();
1081
1082        // regex
1083        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1084            .await
1085            .unwrap();
1086        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1087            .await
1088            .unwrap();
1089        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1090            .await
1091            .unwrap();
1092    }
1093
1094    #[tokio::test(flavor = "multi_thread")]
1095    async fn test_search_reordered_fuzzy() {
1096        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1097            .await
1098            .unwrap();
1099        // test ordering of results: we should choose the first, even though it happened longer ago.
1100
1101        new_history_item(&mut db, "curl").await.unwrap();
1102        new_history_item(&mut db, "corburl").await.unwrap();
1103
1104        // if fuzzy reordering is on, it should come back in a more sensible order
1105        assert_search_commands(
1106            &db,
1107            SearchMode::Fuzzy,
1108            FilterMode::Global,
1109            "curl",
1110            vec!["curl", "corburl"],
1111        )
1112        .await;
1113
1114        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1115            .await
1116            .unwrap();
1117        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1118            .await
1119            .unwrap();
1120    }
1121
1122    #[tokio::test(flavor = "multi_thread")]
1123    async fn test_search_bench_dupes() {
1124        let context = Context {
1125            hostname: "test:host".to_string(),
1126            session: "beepboopiamasession".to_string(),
1127            cwd: "/home/ellie".to_string(),
1128            host_id: "test-host".to_string(),
1129            git_root: None,
1130        };
1131
1132        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1133            .await
1134            .unwrap();
1135        for _i in 1..10000 {
1136            new_history_item(&mut db, "i am a duplicated command")
1137                .await
1138                .unwrap();
1139        }
1140        let start = Instant::now();
1141        let _results = db
1142            .search(
1143                SearchMode::Fuzzy,
1144                FilterMode::Global,
1145                &context,
1146                "",
1147                OptFilters {
1148                    ..Default::default()
1149                },
1150            )
1151            .await
1152            .unwrap();
1153        let duration = start.elapsed();
1154
1155        assert!(duration < Duration::from_secs(15));
1156    }
1157}