atuin_client/
database.rs

1use std::{
2    borrow::Cow,
3    env,
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{Rng, distributions::Alphanumeric};
14use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote};
15use sqlx::{
16    Result, Row,
17    sqlite::{
18        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
19        SqliteSynchronous,
20    },
21};
22use time::OffsetDateTime;
23use uuid::Uuid;
24
25use crate::{
26    history::{HistoryId, HistoryStats},
27    utils::get_host_user,
28};
29
30use super::{
31    history::History,
32    ordering,
33    settings::{FilterMode, SearchMode, Settings},
34};
35
36pub struct Context {
37    pub session: String,
38    pub cwd: String,
39    pub hostname: String,
40    pub host_id: String,
41    pub git_root: Option<PathBuf>,
42}
43
44#[derive(Default, Clone)]
45pub struct OptFilters {
46    pub exit: Option<i64>,
47    pub exclude_exit: Option<i64>,
48    pub cwd: Option<String>,
49    pub exclude_cwd: Option<String>,
50    pub before: Option<String>,
51    pub after: Option<String>,
52    pub limit: Option<i64>,
53    pub offset: Option<i64>,
54    pub reverse: bool,
55    pub include_duplicates: bool,
56}
57
58pub fn current_context() -> Context {
59    let Ok(session) = env::var("ATUIN_SESSION") else {
60        eprintln!(
61            "ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."
62        );
63        std::process::exit(1);
64    };
65    let hostname = get_host_user();
66    let cwd = utils::get_current_dir();
67    let host_id = Settings::host_id().expect("failed to load host ID");
68    let git_root = utils::in_git_repo(cwd.as_str());
69
70    Context {
71        session,
72        hostname,
73        cwd,
74        git_root,
75        host_id: host_id.0.as_simple().to_string(),
76    }
77}
78
79fn get_session_start_time(session_id: &str) -> Option<i64> {
80    if let Ok(uuid) = Uuid::parse_str(session_id)
81        && let Some(timestamp) = uuid.get_timestamp()
82    {
83        let (seconds, nanos) = timestamp.to_unix();
84        return Some(seconds as i64 * 1_000_000_000 + nanos as i64);
85    }
86    None
87}
88
89#[async_trait]
90pub trait Database: Send + Sync + 'static {
91    async fn save(&self, h: &History) -> Result<()>;
92    async fn save_bulk(&self, h: &[History]) -> Result<()>;
93
94    async fn load(&self, id: &str) -> Result<Option<History>>;
95    async fn list(
96        &self,
97        filters: &[FilterMode],
98        context: &Context,
99        max: Option<usize>,
100        unique: bool,
101        include_deleted: bool,
102    ) -> Result<Vec<History>>;
103    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
104
105    async fn update(&self, h: &History) -> Result<()>;
106    async fn history_count(&self, include_deleted: bool) -> Result<i64>;
107
108    async fn last(&self) -> Result<Option<History>>;
109    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
110
111    async fn delete(&self, h: History) -> Result<()>;
112    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
113    async fn deleted(&self) -> Result<Vec<History>>;
114
115    // Yes I know, it's a lot.
116    // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless.
117    // Been debating maybe a DSL for search? eg "before:time limit:1 the query"
118    #[allow(clippy::too_many_arguments)]
119    async fn search(
120        &self,
121        search_mode: SearchMode,
122        filter: FilterMode,
123        context: &Context,
124        query: &str,
125        filter_options: OptFilters,
126    ) -> Result<Vec<History>>;
127
128    async fn query_history(&self, query: &str) -> Result<Vec<History>>;
129
130    async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
131
132    async fn stats(&self, h: &History) -> Result<HistoryStats>;
133
134    async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>;
135}
136
137// Intended for use on a developer machine and not a sync server.
138// TODO: implement IntoIterator
139#[derive(Debug, Clone)]
140pub struct Sqlite {
141    pub pool: SqlitePool,
142}
143
144impl Sqlite {
145    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
146        let path = path.as_ref();
147        debug!("opening sqlite database at {path:?}");
148
149        if utils::broken_symlink(path) {
150            eprintln!(
151                "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
152            );
153            std::process::exit(1);
154        }
155
156        if !path.exists()
157            && let Some(dir) = path.parent()
158        {
159            fs::create_dir_all(dir)?;
160        }
161
162        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
163            .journal_mode(SqliteJournalMode::Wal)
164            .optimize_on_close(true, None)
165            .synchronous(SqliteSynchronous::Normal)
166            .with_regexp()
167            .create_if_missing(true);
168
169        let pool = SqlitePoolOptions::new()
170            .acquire_timeout(Duration::from_secs_f64(timeout))
171            .connect_with(opts)
172            .await?;
173
174        Self::setup_db(&pool).await?;
175        Ok(Self { pool })
176    }
177
178    pub async fn sqlite_version(&self) -> Result<String> {
179        sqlx::query_scalar("SELECT sqlite_version()")
180            .fetch_one(&self.pool)
181            .await
182    }
183
184    async fn setup_db(pool: &SqlitePool) -> Result<()> {
185        debug!("running sqlite database setup");
186
187        sqlx::migrate!("./migrations").run(pool).await?;
188
189        Ok(())
190    }
191
192    async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
193        sqlx::query(
194            "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
195                values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
196        )
197        .bind(h.id.0.as_str())
198        .bind(h.timestamp.unix_timestamp_nanos() as i64)
199        .bind(h.duration)
200        .bind(h.exit)
201        .bind(h.command.as_str())
202        .bind(h.cwd.as_str())
203        .bind(h.session.as_str())
204        .bind(h.hostname.as_str())
205        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
206        .execute(&mut **tx)
207        .await?;
208
209        Ok(())
210    }
211
212    async fn delete_row_raw(
213        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
214        id: HistoryId,
215    ) -> Result<()> {
216        sqlx::query("delete from history where id = ?1")
217            .bind(id.0.as_str())
218            .execute(&mut **tx)
219            .await?;
220
221        Ok(())
222    }
223
224    fn query_history(row: SqliteRow) -> History {
225        let deleted_at: Option<i64> = row.get("deleted_at");
226
227        History::from_db()
228            .id(row.get("id"))
229            .timestamp(
230                OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
231                    .unwrap(),
232            )
233            .duration(row.get("duration"))
234            .exit(row.get("exit"))
235            .command(row.get("command"))
236            .cwd(row.get("cwd"))
237            .session(row.get("session"))
238            .hostname(row.get("hostname"))
239            .deleted_at(
240                deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
241            )
242            .build()
243            .into()
244    }
245}
246
247#[async_trait]
248impl Database for Sqlite {
249    async fn save(&self, h: &History) -> Result<()> {
250        debug!("saving history to sqlite");
251        let mut tx = self.pool.begin().await?;
252        Self::save_raw(&mut tx, h).await?;
253        tx.commit().await?;
254
255        Ok(())
256    }
257
258    async fn save_bulk(&self, h: &[History]) -> Result<()> {
259        debug!("saving history to sqlite");
260
261        let mut tx = self.pool.begin().await?;
262
263        for i in h {
264            Self::save_raw(&mut tx, i).await?;
265        }
266
267        tx.commit().await?;
268
269        Ok(())
270    }
271
272    async fn load(&self, id: &str) -> Result<Option<History>> {
273        debug!("loading history item {}", id);
274
275        let res = sqlx::query("select * from history where id = ?1")
276            .bind(id)
277            .map(Self::query_history)
278            .fetch_optional(&self.pool)
279            .await?;
280
281        Ok(res)
282    }
283
284    async fn update(&self, h: &History) -> Result<()> {
285        debug!("updating sqlite history");
286
287        sqlx::query(
288            "update history
289                set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
290                where id = ?1",
291        )
292        .bind(h.id.0.as_str())
293        .bind(h.timestamp.unix_timestamp_nanos() as i64)
294        .bind(h.duration)
295        .bind(h.exit)
296        .bind(h.command.as_str())
297        .bind(h.cwd.as_str())
298        .bind(h.session.as_str())
299        .bind(h.hostname.as_str())
300        .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
301        .execute(&self.pool)
302        .await?;
303
304        Ok(())
305    }
306
307    // make a unique list, that only shows the *newest* version of things
308    async fn list(
309        &self,
310        filters: &[FilterMode],
311        context: &Context,
312        max: Option<usize>,
313        unique: bool,
314        include_deleted: bool,
315    ) -> Result<Vec<History>> {
316        debug!("listing history");
317
318        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
319        query.field("*").order_desc("timestamp");
320        if !include_deleted {
321            query.and_where_is_null("deleted_at");
322        }
323
324        let git_root = if let Some(git_root) = context.git_root.clone() {
325            git_root.to_str().unwrap_or("/").to_string()
326        } else {
327            context.cwd.clone()
328        };
329
330        let session_start = get_session_start_time(&context.session);
331
332        for filter in filters {
333            match filter {
334                FilterMode::Global => &mut query,
335                FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
336                FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
337                FilterMode::SessionPreload => {
338                    query.and_where_eq("session", quote(&context.session));
339                    if let Some(session_start) = session_start {
340                        query.or_where_lt("timestamp", session_start);
341                    }
342                    &mut query
343                }
344                FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
345                FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
346            };
347        }
348
349        if unique {
350            query.group_by("command").having("max(timestamp)");
351        }
352
353        if let Some(max) = max {
354            query.limit(max);
355        }
356
357        let query = query.sql().expect("bug in list query. please report");
358
359        let res = sqlx::query(&query)
360            .map(Self::query_history)
361            .fetch_all(&self.pool)
362            .await?;
363
364        Ok(res)
365    }
366
367    async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
368        debug!("listing history from {:?} to {:?}", from, to);
369
370        let res = sqlx::query(
371            "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
372        )
373        .bind(from.unix_timestamp_nanos() as i64)
374        .bind(to.unix_timestamp_nanos() as i64)
375            .map(Self::query_history)
376        .fetch_all(&self.pool)
377        .await?;
378
379        Ok(res)
380    }
381
382    async fn last(&self) -> Result<Option<History>> {
383        let res = sqlx::query(
384            "select * from history where duration >= 0 order by timestamp desc limit 1",
385        )
386        .map(Self::query_history)
387        .fetch_optional(&self.pool)
388        .await?;
389
390        Ok(res)
391    }
392
393    async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
394        let res = sqlx::query(
395            "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
396        )
397        .bind(timestamp.unix_timestamp_nanos() as i64)
398        .bind(count)
399        .map(Self::query_history)
400        .fetch_all(&self.pool)
401        .await?;
402
403        Ok(res)
404    }
405
406    async fn deleted(&self) -> Result<Vec<History>> {
407        let res = sqlx::query("select * from history where deleted_at is not null")
408            .map(Self::query_history)
409            .fetch_all(&self.pool)
410            .await?;
411
412        Ok(res)
413    }
414
415    async fn history_count(&self, include_deleted: bool) -> Result<i64> {
416        let query = if include_deleted {
417            "select count(1) from history"
418        } else {
419            "select count(1) from history where deleted_at is null"
420        };
421
422        let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
423        Ok(res.0)
424    }
425
426    async fn search(
427        &self,
428        search_mode: SearchMode,
429        filter: FilterMode,
430        context: &Context,
431        query: &str,
432        filter_options: OptFilters,
433    ) -> Result<Vec<History>> {
434        let mut sql = SqlBuilder::select_from("history");
435
436        if !filter_options.include_duplicates {
437            sql.group_by("command").having("max(timestamp)");
438        }
439
440        if let Some(limit) = filter_options.limit {
441            sql.limit(limit);
442        }
443
444        if let Some(offset) = filter_options.offset {
445            sql.offset(offset);
446        }
447
448        if filter_options.reverse {
449            sql.order_asc("timestamp");
450        } else {
451            sql.order_desc("timestamp");
452        }
453
454        let git_root = if let Some(git_root) = context.git_root.clone() {
455            git_root.to_str().unwrap_or("/").to_string()
456        } else {
457            context.cwd.clone()
458        };
459
460        let session_start = get_session_start_time(&context.session);
461
462        match filter {
463            FilterMode::Global => &mut sql,
464            FilterMode::Host => {
465                sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
466            }
467            FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
468            FilterMode::SessionPreload => {
469                sql.and_where_eq("session", quote(&context.session));
470                if let Some(session_start) = session_start {
471                    sql.or_where_lt("timestamp", session_start);
472                }
473                &mut sql
474            }
475            FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
476            FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
477        };
478
479        let orig_query = query;
480
481        let mut regexes = Vec::new();
482        match search_mode {
483            SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
484            _ => {
485                let mut is_or = false;
486                let mut regex = None;
487                for part in query.split_inclusive(' ') {
488                    let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
489                        (None, false) => {
490                            if part.trim_end().is_empty() {
491                                continue;
492                            }
493                            Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
494                        }
495                        (None, true) => {
496                            if part[2..].trim_end().ends_with('/') {
497                                let end_pos = part.trim_end().len() - 1;
498                                regexes.push(String::from(&part[2..end_pos]));
499                            } else {
500                                regex = Some(String::from(&part[2..]));
501                            }
502                            continue;
503                        }
504                        (Some(r), _) => {
505                            if part.trim_end().ends_with('/') {
506                                let end_pos = part.trim_end().len() - 1;
507                                r.push_str(&part.trim_end()[..end_pos]);
508                                regexes.push(regex.take().unwrap());
509                            } else {
510                                r.push_str(part);
511                            }
512                            continue;
513                        }
514                    };
515
516                    // TODO smart case mode could be made configurable like in fzf
517                    let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
518                        (true, "*")
519                    } else {
520                        (false, "%")
521                    };
522
523                    let (is_inverse, query_part) = match query_part.strip_prefix('!') {
524                        Some(stripped) => (true, Cow::Borrowed(stripped)),
525                        None => (false, query_part),
526                    };
527
528                    #[allow(clippy::if_same_then_else)]
529                    let param = if query_part == "|" {
530                        if !is_or {
531                            is_or = true;
532                            continue;
533                        } else {
534                            format!("{glob}|{glob}")
535                        }
536                    } else if let Some(term) = query_part.strip_prefix('^') {
537                        format!("{term}{glob}")
538                    } else if let Some(term) = query_part.strip_suffix('$') {
539                        format!("{glob}{term}")
540                    } else if let Some(term) = query_part.strip_prefix('\'') {
541                        format!("{glob}{term}{glob}")
542                    } else if is_inverse {
543                        format!("{glob}{query_part}{glob}")
544                    } else if search_mode == SearchMode::FullText {
545                        format!("{glob}{query_part}{glob}")
546                    } else {
547                        query_part.split("").join(glob)
548                    };
549
550                    sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
551                    is_or = false;
552                }
553                if let Some(r) = regex {
554                    regexes.push(r);
555                }
556
557                &mut sql
558            }
559        };
560
561        for regex in regexes {
562            sql.and_where("command regexp ?".bind(&regex));
563        }
564
565        filter_options
566            .exit
567            .map(|exit| sql.and_where_eq("exit", exit));
568
569        filter_options
570            .exclude_exit
571            .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
572
573        filter_options
574            .cwd
575            .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
576
577        filter_options
578            .exclude_cwd
579            .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
580
581        filter_options.before.map(|before| {
582            interim::parse_date_string(
583                before.as_str(),
584                OffsetDateTime::now_utc(),
585                interim::Dialect::Uk,
586            )
587            .map(|before| {
588                sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
589            })
590        });
591
592        filter_options.after.map(|after| {
593            interim::parse_date_string(
594                after.as_str(),
595                OffsetDateTime::now_utc(),
596                interim::Dialect::Uk,
597            )
598            .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
599        });
600
601        sql.and_where_is_null("deleted_at");
602
603        let query = sql.sql().expect("bug in search query. please report");
604
605        let res = sqlx::query(&query)
606            .map(Self::query_history)
607            .fetch_all(&self.pool)
608            .await?;
609
610        Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
611    }
612
613    async fn query_history(&self, query: &str) -> Result<Vec<History>> {
614        let res = sqlx::query(query)
615            .map(Self::query_history)
616            .fetch_all(&self.pool)
617            .await?;
618
619        Ok(res)
620    }
621
622    async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
623        debug!("listing history");
624
625        let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
626
627        query
628            .fields(&[
629                "id",
630                "max(timestamp) as timestamp",
631                "max(duration) as duration",
632                "exit",
633                "command",
634                "deleted_at",
635                "group_concat(cwd, ':') as cwd",
636                "group_concat(session) as session",
637                "group_concat(hostname, ',') as hostname",
638                "count(*) as count",
639            ])
640            .group_by("command")
641            .group_by("exit")
642            .and_where("deleted_at is null")
643            .order_desc("timestamp");
644
645        let query = query.sql().expect("bug in list query. please report");
646
647        let res = sqlx::query(&query)
648            .map(|row: SqliteRow| {
649                let count: i32 = row.get("count");
650                (Self::query_history(row), count)
651            })
652            .fetch_all(&self.pool)
653            .await?;
654
655        Ok(res)
656    }
657
658    // deleted_at doesn't mean the actual time that the user deleted it,
659    // but the time that the system marks it as deleted
660    async fn delete(&self, mut h: History) -> Result<()> {
661        let now = OffsetDateTime::now_utc();
662        h.command = rand::thread_rng()
663            .sample_iter(&Alphanumeric)
664            .take(32)
665            .map(char::from)
666            .collect(); // overwrite with random string
667        h.deleted_at = Some(now); // delete it
668
669        self.update(&h).await?; // save it
670
671        Ok(())
672    }
673
674    async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
675        let mut tx = self.pool.begin().await?;
676
677        for id in ids {
678            Self::delete_row_raw(&mut tx, id.clone()).await?;
679        }
680
681        tx.commit().await?;
682
683        Ok(())
684    }
685
686    async fn stats(&self, h: &History) -> Result<HistoryStats> {
687        // We select the previous in the session by time
688        let mut prev = SqlBuilder::select_from("history");
689        prev.field("*")
690            .and_where("timestamp < ?1")
691            .and_where("session = ?2")
692            .order_by("timestamp", true)
693            .limit(1);
694
695        let mut next = SqlBuilder::select_from("history");
696        next.field("*")
697            .and_where("timestamp > ?1")
698            .and_where("session = ?2")
699            .order_by("timestamp", false)
700            .limit(1);
701
702        let mut total = SqlBuilder::select_from("history");
703        total.field("count(1)").and_where("command = ?1");
704
705        let mut average = SqlBuilder::select_from("history");
706        average.field("avg(duration)").and_where("command = ?1");
707
708        let mut exits = SqlBuilder::select_from("history");
709        exits
710            .fields(&["exit", "count(1) as count"])
711            .and_where("command = ?1")
712            .group_by("exit");
713
714        // rewrite the following with sqlbuilder
715        let mut day_of_week = SqlBuilder::select_from("history");
716        day_of_week
717            .fields(&[
718                "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
719                "count(1) as count",
720            ])
721            .and_where("command = ?1")
722            .group_by("day_of_week");
723
724        // Intentionally format the string with 01 hardcoded. We want the average runtime for the
725        // _entire month_, but will later parse it as a datetime for sorting
726        // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a
727        // string sort, which won't be correct.
728        let mut duration_over_time = SqlBuilder::select_from("history");
729        duration_over_time
730            .fields(&[
731                "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
732                "avg(duration) as duration",
733            ])
734            .and_where("command = ?1")
735            .group_by("month_year")
736            .having("duration > 0");
737
738        let prev = prev.sql().expect("issue in stats previous query");
739        let next = next.sql().expect("issue in stats next query");
740        let total = total.sql().expect("issue in stats average query");
741        let average = average.sql().expect("issue in stats previous query");
742        let exits = exits.sql().expect("issue in stats exits query");
743        let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
744        let duration_over_time = duration_over_time
745            .sql()
746            .expect("issue in stats duration over time query");
747
748        let prev = sqlx::query(&prev)
749            .bind(h.timestamp.unix_timestamp_nanos() as i64)
750            .bind(&h.session)
751            .map(Self::query_history)
752            .fetch_optional(&self.pool)
753            .await?;
754
755        let next = sqlx::query(&next)
756            .bind(h.timestamp.unix_timestamp_nanos() as i64)
757            .bind(&h.session)
758            .map(Self::query_history)
759            .fetch_optional(&self.pool)
760            .await?;
761
762        let total: (i64,) = sqlx::query_as(&total)
763            .bind(&h.command)
764            .fetch_one(&self.pool)
765            .await?;
766
767        let average: (f64,) = sqlx::query_as(&average)
768            .bind(&h.command)
769            .fetch_one(&self.pool)
770            .await?;
771
772        let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
773            .bind(&h.command)
774            .fetch_all(&self.pool)
775            .await?;
776
777        let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
778            .bind(&h.command)
779            .fetch_all(&self.pool)
780            .await?;
781
782        let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
783            .bind(&h.command)
784            .fetch_all(&self.pool)
785            .await?;
786
787        let duration_over_time = duration_over_time
788            .iter()
789            .map(|f| (f.0.clone(), f.1.round() as i64))
790            .collect();
791
792        Ok(HistoryStats {
793            next,
794            previous: prev,
795            total: total.0 as u64,
796            average_duration: average.0 as u64,
797            exits,
798            day_of_week,
799            duration_over_time,
800        })
801    }
802
803    async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> {
804        let res = sqlx::query(
805            "SELECT * FROM (
806                SELECT *, ROW_NUMBER()
807                  OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC)
808                  AS rn
809                  FROM history
810                ) sub
811              WHERE rn > ?1 and timestamp < ?2;
812            ",
813        )
814        .bind(dupkeep)
815        .bind(before)
816        .map(Self::query_history)
817        .fetch_all(&self.pool)
818        .await?;
819
820        Ok(res)
821    }
822}
823
824trait SqlBuilderExt {
825    fn fuzzy_condition<S: ToString, T: ToString>(
826        &mut self,
827        field: S,
828        mask: T,
829        inverse: bool,
830        glob: bool,
831        is_or: bool,
832    ) -> &mut Self;
833}
834
835impl SqlBuilderExt for SqlBuilder {
836    /// adapted from the sql-builder *like functions
837    fn fuzzy_condition<S: ToString, T: ToString>(
838        &mut self,
839        field: S,
840        mask: T,
841        inverse: bool,
842        glob: bool,
843        is_or: bool,
844    ) -> &mut Self {
845        let mut cond = field.to_string();
846        if inverse {
847            cond.push_str(" NOT");
848        }
849        if glob {
850            cond.push_str(" GLOB '");
851        } else {
852            cond.push_str(" LIKE '");
853        }
854        cond.push_str(&esc(mask.to_string()));
855        cond.push('\'');
856        if is_or {
857            self.or_where(cond)
858        } else {
859            self.and_where(cond)
860        }
861    }
862}
863
864#[cfg(test)]
865mod test {
866    use crate::settings::test_local_timeout;
867
868    use super::*;
869    use std::time::{Duration, Instant};
870
871    async fn assert_search_eq(
872        db: &impl Database,
873        mode: SearchMode,
874        filter_mode: FilterMode,
875        query: &str,
876        expected: usize,
877    ) -> Result<Vec<History>> {
878        let context = Context {
879            hostname: "test:host".to_string(),
880            session: "beepboopiamasession".to_string(),
881            cwd: "/home/ellie".to_string(),
882            host_id: "test-host".to_string(),
883            git_root: None,
884        };
885
886        let results = db
887            .search(
888                mode,
889                filter_mode,
890                &context,
891                query,
892                OptFilters {
893                    ..Default::default()
894                },
895            )
896            .await?;
897
898        assert_eq!(
899            results.len(),
900            expected,
901            "query \"{}\", commands: {:?}",
902            query,
903            results.iter().map(|a| &a.command).collect::<Vec<&String>>()
904        );
905        Ok(results)
906    }
907
908    async fn assert_search_commands(
909        db: &impl Database,
910        mode: SearchMode,
911        filter_mode: FilterMode,
912        query: &str,
913        expected_commands: Vec<&str>,
914    ) {
915        let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
916            .await
917            .unwrap();
918        let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
919        assert_eq!(commands, expected_commands);
920    }
921
922    async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
923        let mut captured: History = History::capture()
924            .timestamp(OffsetDateTime::now_utc())
925            .command(cmd)
926            .cwd("/home/ellie")
927            .build()
928            .into();
929
930        captured.exit = 0;
931        captured.duration = 1;
932        captured.session = "beep boop".to_string();
933        captured.hostname = "booop".to_string();
934
935        db.save(&captured).await
936    }
937
938    #[tokio::test(flavor = "multi_thread")]
939    async fn test_search_prefix() {
940        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
941            .await
942            .unwrap();
943        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
944
945        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
946            .await
947            .unwrap();
948        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
949            .await
950            .unwrap();
951        assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls  ", 0)
952            .await
953            .unwrap();
954    }
955
956    #[tokio::test(flavor = "multi_thread")]
957    async fn test_search_fulltext() {
958        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
959            .await
960            .unwrap();
961        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
962
963        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
964            .await
965            .unwrap();
966        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
967            .await
968            .unwrap();
969        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
970            .await
971            .unwrap();
972        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
973            .await
974            .unwrap();
975
976        // regex
977        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
978            .await
979            .unwrap();
980        assert_search_eq(
981            &db,
982            SearchMode::FullText,
983            FilterMode::Global,
984            "r/ls / ie$",
985            1,
986        )
987        .await
988        .unwrap();
989        assert_search_eq(
990            &db,
991            SearchMode::FullText,
992            FilterMode::Global,
993            "r/ls / !ie",
994            0,
995        )
996        .await
997        .unwrap();
998        assert_search_eq(
999            &db,
1000            SearchMode::FullText,
1001            FilterMode::Global,
1002            "meow r/ls/",
1003            0,
1004        )
1005        .await
1006        .unwrap();
1007        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
1008            .await
1009            .unwrap();
1010        assert_search_eq(
1011            &db,
1012            SearchMode::FullText,
1013            FilterMode::Global,
1014            "r//home//",
1015            1,
1016        )
1017        .await
1018        .unwrap();
1019        assert_search_eq(
1020            &db,
1021            SearchMode::FullText,
1022            FilterMode::Global,
1023            "r//home///",
1024            0,
1025        )
1026        .await
1027        .unwrap();
1028        assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
1029            .await
1030            .unwrap();
1031        assert_search_eq(
1032            &db,
1033            SearchMode::FullText,
1034            FilterMode::Global,
1035            "r/home.*e",
1036            1,
1037        )
1038        .await
1039        .unwrap();
1040    }
1041
1042    #[tokio::test(flavor = "multi_thread")]
1043    async fn test_search_fuzzy() {
1044        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1045            .await
1046            .unwrap();
1047        new_history_item(&mut db, "ls /home/ellie").await.unwrap();
1048        new_history_item(&mut db, "ls /home/frank").await.unwrap();
1049        new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
1050        new_history_item(&mut db, "/home/ellie/.bin/rustup")
1051            .await
1052            .unwrap();
1053
1054        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
1055            .await
1056            .unwrap();
1057        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
1058            .await
1059            .unwrap();
1060        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1061            .await
1062            .unwrap();
1063        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1064            .await
1065            .unwrap();
1066        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1067            .await
1068            .unwrap();
1069        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1070            .await
1071            .unwrap();
1072        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1073            .await
1074            .unwrap();
1075        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1076            .await
1077            .unwrap();
1078
1079        // single term operators
1080        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1081            .await
1082            .unwrap();
1083        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1084            .await
1085            .unwrap();
1086        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1087            .await
1088            .unwrap();
1089        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1090            .await
1091            .unwrap();
1092        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1093            .await
1094            .unwrap();
1095        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1096            .await
1097            .unwrap();
1098
1099        // multiple terms
1100        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1101            .await
1102            .unwrap();
1103        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1104            .await
1105            .unwrap();
1106        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1107            .await
1108            .unwrap();
1109        assert_search_eq(
1110            &db,
1111            SearchMode::Fuzzy,
1112            FilterMode::Global,
1113            "'frank | 'rustup",
1114            2,
1115        )
1116        .await
1117        .unwrap();
1118        assert_search_eq(
1119            &db,
1120            SearchMode::Fuzzy,
1121            FilterMode::Global,
1122            "'frank | 'rustup 'ls",
1123            1,
1124        )
1125        .await
1126        .unwrap();
1127
1128        // case matching
1129        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1130            .await
1131            .unwrap();
1132
1133        // regex
1134        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1135            .await
1136            .unwrap();
1137        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1138            .await
1139            .unwrap();
1140        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1141            .await
1142            .unwrap();
1143    }
1144
1145    #[tokio::test(flavor = "multi_thread")]
1146    async fn test_search_reordered_fuzzy() {
1147        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1148            .await
1149            .unwrap();
1150        // test ordering of results: we should choose the first, even though it happened longer ago.
1151
1152        new_history_item(&mut db, "curl").await.unwrap();
1153        new_history_item(&mut db, "corburl").await.unwrap();
1154
1155        // if fuzzy reordering is on, it should come back in a more sensible order
1156        assert_search_commands(
1157            &db,
1158            SearchMode::Fuzzy,
1159            FilterMode::Global,
1160            "curl",
1161            vec!["curl", "corburl"],
1162        )
1163        .await;
1164
1165        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1166            .await
1167            .unwrap();
1168        assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1169            .await
1170            .unwrap();
1171    }
1172
1173    #[tokio::test(flavor = "multi_thread")]
1174    async fn test_search_bench_dupes() {
1175        let context = Context {
1176            hostname: "test:host".to_string(),
1177            session: "beepboopiamasession".to_string(),
1178            cwd: "/home/ellie".to_string(),
1179            host_id: "test-host".to_string(),
1180            git_root: None,
1181        };
1182
1183        let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1184            .await
1185            .unwrap();
1186        for _i in 1..10000 {
1187            new_history_item(&mut db, "i am a duplicated command")
1188                .await
1189                .unwrap();
1190        }
1191        let start = Instant::now();
1192        let _results = db
1193            .search(
1194                SearchMode::Fuzzy,
1195                FilterMode::Global,
1196                &context,
1197                "",
1198                OptFilters {
1199                    ..Default::default()
1200                },
1201            )
1202            .await
1203            .unwrap();
1204        let duration = start.elapsed();
1205
1206        assert!(duration < Duration::from_secs(15));
1207    }
1208}