1use std::{
2 borrow::Cow,
3 env,
4 path::{Path, PathBuf},
5 str::FromStr,
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{Rng, distributions::Alphanumeric};
14use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote};
15use sqlx::{
16 Result, Row,
17 sqlite::{
18 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
19 SqliteSynchronous,
20 },
21};
22use time::OffsetDateTime;
23
24use crate::{
25 history::{HistoryId, HistoryStats},
26 utils::get_host_user,
27};
28
29use super::{
30 history::History,
31 ordering,
32 settings::{FilterMode, SearchMode, Settings},
33};
34
35pub struct Context {
36 pub session: String,
37 pub cwd: String,
38 pub hostname: String,
39 pub host_id: String,
40 pub git_root: Option<PathBuf>,
41}
42
43#[derive(Default, Clone)]
44pub struct OptFilters {
45 pub exit: Option<i64>,
46 pub exclude_exit: Option<i64>,
47 pub cwd: Option<String>,
48 pub exclude_cwd: Option<String>,
49 pub before: Option<String>,
50 pub after: Option<String>,
51 pub limit: Option<i64>,
52 pub offset: Option<i64>,
53 pub reverse: bool,
54 pub include_duplicates: bool,
55}
56
57pub fn current_context() -> Context {
58 let Ok(session) = env::var("ATUIN_SESSION") else {
59 eprintln!(
60 "ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."
61 );
62 std::process::exit(1);
63 };
64 let hostname = get_host_user();
65 let cwd = utils::get_current_dir();
66 let host_id = Settings::host_id().expect("failed to load host ID");
67 let git_root = utils::in_git_repo(cwd.as_str());
68
69 Context {
70 session,
71 hostname,
72 cwd,
73 git_root,
74 host_id: host_id.0.as_simple().to_string(),
75 }
76}
77
78#[async_trait]
79pub trait Database: Send + Sync + 'static {
80 async fn save(&self, h: &History) -> Result<()>;
81 async fn save_bulk(&self, h: &[History]) -> Result<()>;
82
83 async fn load(&self, id: &str) -> Result<Option<History>>;
84 async fn list(
85 &self,
86 filters: &[FilterMode],
87 context: &Context,
88 max: Option<usize>,
89 unique: bool,
90 include_deleted: bool,
91 ) -> Result<Vec<History>>;
92 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
93
94 async fn update(&self, h: &History) -> Result<()>;
95 async fn history_count(&self, include_deleted: bool) -> Result<i64>;
96
97 async fn last(&self) -> Result<Option<History>>;
98 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
99
100 async fn delete(&self, h: History) -> Result<()>;
101 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
102 async fn deleted(&self) -> Result<Vec<History>>;
103
104 #[allow(clippy::too_many_arguments)]
108 async fn search(
109 &self,
110 search_mode: SearchMode,
111 filter: FilterMode,
112 context: &Context,
113 query: &str,
114 filter_options: OptFilters,
115 ) -> Result<Vec<History>>;
116
117 async fn query_history(&self, query: &str) -> Result<Vec<History>>;
118
119 async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
120
121 async fn stats(&self, h: &History) -> Result<HistoryStats>;
122}
123
124#[derive(Debug, Clone)]
127pub struct Sqlite {
128 pub pool: SqlitePool,
129}
130
131impl Sqlite {
132 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
133 let path = path.as_ref();
134 debug!("opening sqlite database at {:?}", path);
135
136 if utils::broken_symlink(path) {
137 eprintln!(
138 "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
139 );
140 std::process::exit(1);
141 }
142
143 if !path.exists() {
144 if let Some(dir) = path.parent() {
145 fs::create_dir_all(dir)?;
146 }
147 }
148
149 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
150 .journal_mode(SqliteJournalMode::Wal)
151 .optimize_on_close(true, None)
152 .synchronous(SqliteSynchronous::Normal)
153 .with_regexp()
154 .create_if_missing(true);
155
156 let pool = SqlitePoolOptions::new()
157 .acquire_timeout(Duration::from_secs_f64(timeout))
158 .connect_with(opts)
159 .await?;
160
161 Self::setup_db(&pool).await?;
162 Ok(Self { pool })
163 }
164
165 pub async fn sqlite_version(&self) -> Result<String> {
166 sqlx::query_scalar("SELECT sqlite_version()")
167 .fetch_one(&self.pool)
168 .await
169 }
170
171 async fn setup_db(pool: &SqlitePool) -> Result<()> {
172 debug!("running sqlite database setup");
173
174 sqlx::migrate!("./migrations").run(pool).await?;
175
176 Ok(())
177 }
178
179 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
180 sqlx::query(
181 "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
182 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
183 )
184 .bind(h.id.0.as_str())
185 .bind(h.timestamp.unix_timestamp_nanos() as i64)
186 .bind(h.duration)
187 .bind(h.exit)
188 .bind(h.command.as_str())
189 .bind(h.cwd.as_str())
190 .bind(h.session.as_str())
191 .bind(h.hostname.as_str())
192 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
193 .execute(&mut **tx)
194 .await?;
195
196 Ok(())
197 }
198
199 async fn delete_row_raw(
200 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
201 id: HistoryId,
202 ) -> Result<()> {
203 sqlx::query("delete from history where id = ?1")
204 .bind(id.0.as_str())
205 .execute(&mut **tx)
206 .await?;
207
208 Ok(())
209 }
210
211 fn query_history(row: SqliteRow) -> History {
212 let deleted_at: Option<i64> = row.get("deleted_at");
213
214 History::from_db()
215 .id(row.get("id"))
216 .timestamp(
217 OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
218 .unwrap(),
219 )
220 .duration(row.get("duration"))
221 .exit(row.get("exit"))
222 .command(row.get("command"))
223 .cwd(row.get("cwd"))
224 .session(row.get("session"))
225 .hostname(row.get("hostname"))
226 .deleted_at(
227 deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
228 )
229 .build()
230 .into()
231 }
232}
233
234#[async_trait]
235impl Database for Sqlite {
236 async fn save(&self, h: &History) -> Result<()> {
237 debug!("saving history to sqlite");
238 let mut tx = self.pool.begin().await?;
239 Self::save_raw(&mut tx, h).await?;
240 tx.commit().await?;
241
242 Ok(())
243 }
244
245 async fn save_bulk(&self, h: &[History]) -> Result<()> {
246 debug!("saving history to sqlite");
247
248 let mut tx = self.pool.begin().await?;
249
250 for i in h {
251 Self::save_raw(&mut tx, i).await?;
252 }
253
254 tx.commit().await?;
255
256 Ok(())
257 }
258
259 async fn load(&self, id: &str) -> Result<Option<History>> {
260 debug!("loading history item {}", id);
261
262 let res = sqlx::query("select * from history where id = ?1")
263 .bind(id)
264 .map(Self::query_history)
265 .fetch_optional(&self.pool)
266 .await?;
267
268 Ok(res)
269 }
270
271 async fn update(&self, h: &History) -> Result<()> {
272 debug!("updating sqlite history");
273
274 sqlx::query(
275 "update history
276 set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
277 where id = ?1",
278 )
279 .bind(h.id.0.as_str())
280 .bind(h.timestamp.unix_timestamp_nanos() as i64)
281 .bind(h.duration)
282 .bind(h.exit)
283 .bind(h.command.as_str())
284 .bind(h.cwd.as_str())
285 .bind(h.session.as_str())
286 .bind(h.hostname.as_str())
287 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
288 .execute(&self.pool)
289 .await?;
290
291 Ok(())
292 }
293
294 async fn list(
296 &self,
297 filters: &[FilterMode],
298 context: &Context,
299 max: Option<usize>,
300 unique: bool,
301 include_deleted: bool,
302 ) -> Result<Vec<History>> {
303 debug!("listing history");
304
305 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
306 query.field("*").order_desc("timestamp");
307 if !include_deleted {
308 query.and_where_is_null("deleted_at");
309 }
310
311 let git_root = if let Some(git_root) = context.git_root.clone() {
312 git_root.to_str().unwrap_or("/").to_string()
313 } else {
314 context.cwd.clone()
315 };
316
317 for filter in filters {
318 match filter {
319 FilterMode::Global => &mut query,
320 FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
321 FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
322 FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
323 FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
324 };
325 }
326
327 if unique {
328 query.group_by("command").having("max(timestamp)");
329 }
330
331 if let Some(max) = max {
332 query.limit(max);
333 }
334
335 let query = query.sql().expect("bug in list query. please report");
336
337 let res = sqlx::query(&query)
338 .map(Self::query_history)
339 .fetch_all(&self.pool)
340 .await?;
341
342 Ok(res)
343 }
344
345 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
346 debug!("listing history from {:?} to {:?}", from, to);
347
348 let res = sqlx::query(
349 "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
350 )
351 .bind(from.unix_timestamp_nanos() as i64)
352 .bind(to.unix_timestamp_nanos() as i64)
353 .map(Self::query_history)
354 .fetch_all(&self.pool)
355 .await?;
356
357 Ok(res)
358 }
359
360 async fn last(&self) -> Result<Option<History>> {
361 let res = sqlx::query(
362 "select * from history where duration >= 0 order by timestamp desc limit 1",
363 )
364 .map(Self::query_history)
365 .fetch_optional(&self.pool)
366 .await?;
367
368 Ok(res)
369 }
370
371 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
372 let res = sqlx::query(
373 "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
374 )
375 .bind(timestamp.unix_timestamp_nanos() as i64)
376 .bind(count)
377 .map(Self::query_history)
378 .fetch_all(&self.pool)
379 .await?;
380
381 Ok(res)
382 }
383
384 async fn deleted(&self) -> Result<Vec<History>> {
385 let res = sqlx::query("select * from history where deleted_at is not null")
386 .map(Self::query_history)
387 .fetch_all(&self.pool)
388 .await?;
389
390 Ok(res)
391 }
392
393 async fn history_count(&self, include_deleted: bool) -> Result<i64> {
394 let query = if include_deleted {
395 "select count(1) from history"
396 } else {
397 "select count(1) from history where deleted_at is null"
398 };
399
400 let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
401 Ok(res.0)
402 }
403
404 async fn search(
405 &self,
406 search_mode: SearchMode,
407 filter: FilterMode,
408 context: &Context,
409 query: &str,
410 filter_options: OptFilters,
411 ) -> Result<Vec<History>> {
412 let mut sql = SqlBuilder::select_from("history");
413
414 if !filter_options.include_duplicates {
415 sql.group_by("command").having("max(timestamp)");
416 }
417
418 if let Some(limit) = filter_options.limit {
419 sql.limit(limit);
420 }
421
422 if let Some(offset) = filter_options.offset {
423 sql.offset(offset);
424 }
425
426 if filter_options.reverse {
427 sql.order_asc("timestamp");
428 } else {
429 sql.order_desc("timestamp");
430 }
431
432 let git_root = if let Some(git_root) = context.git_root.clone() {
433 git_root.to_str().unwrap_or("/").to_string()
434 } else {
435 context.cwd.clone()
436 };
437
438 match filter {
439 FilterMode::Global => &mut sql,
440 FilterMode::Host => {
441 sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
442 }
443 FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
444 FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
445 FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
446 };
447
448 let orig_query = query;
449
450 let mut regexes = Vec::new();
451 match search_mode {
452 SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
453 _ => {
454 let mut is_or = false;
455 let mut regex = None;
456 for part in query.split_inclusive(' ') {
457 let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
458 (None, false) => {
459 if part.trim_end().is_empty() {
460 continue;
461 }
462 Cow::Owned(part.trim_end().replace('*', "%")) }
464 (None, true) => {
465 if part[2..].trim_end().ends_with('/') {
466 let end_pos = part.trim_end().len() - 1;
467 regexes.push(String::from(&part[2..end_pos]));
468 } else {
469 regex = Some(String::from(&part[2..]));
470 }
471 continue;
472 }
473 (Some(r), _) => {
474 if part.trim_end().ends_with('/') {
475 let end_pos = part.trim_end().len() - 1;
476 r.push_str(&part.trim_end()[..end_pos]);
477 regexes.push(regex.take().unwrap());
478 } else {
479 r.push_str(part);
480 }
481 continue;
482 }
483 };
484
485 let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
487 (true, "*")
488 } else {
489 (false, "%")
490 };
491
492 let (is_inverse, query_part) = match query_part.strip_prefix('!') {
493 Some(stripped) => (true, Cow::Borrowed(stripped)),
494 None => (false, query_part),
495 };
496
497 #[allow(clippy::if_same_then_else)]
498 let param = if query_part == "|" {
499 if !is_or {
500 is_or = true;
501 continue;
502 } else {
503 format!("{glob}|{glob}")
504 }
505 } else if let Some(term) = query_part.strip_prefix('^') {
506 format!("{term}{glob}")
507 } else if let Some(term) = query_part.strip_suffix('$') {
508 format!("{glob}{term}")
509 } else if let Some(term) = query_part.strip_prefix('\'') {
510 format!("{glob}{term}{glob}")
511 } else if is_inverse {
512 format!("{glob}{query_part}{glob}")
513 } else if search_mode == SearchMode::FullText {
514 format!("{glob}{query_part}{glob}")
515 } else {
516 query_part.split("").join(glob)
517 };
518
519 sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
520 is_or = false;
521 }
522 if let Some(r) = regex {
523 regexes.push(r);
524 }
525
526 &mut sql
527 }
528 };
529
530 for regex in regexes {
531 sql.and_where("command regexp ?".bind(®ex));
532 }
533
534 filter_options
535 .exit
536 .map(|exit| sql.and_where_eq("exit", exit));
537
538 filter_options
539 .exclude_exit
540 .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
541
542 filter_options
543 .cwd
544 .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
545
546 filter_options
547 .exclude_cwd
548 .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
549
550 filter_options.before.map(|before| {
551 interim::parse_date_string(
552 before.as_str(),
553 OffsetDateTime::now_utc(),
554 interim::Dialect::Uk,
555 )
556 .map(|before| {
557 sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
558 })
559 });
560
561 filter_options.after.map(|after| {
562 interim::parse_date_string(
563 after.as_str(),
564 OffsetDateTime::now_utc(),
565 interim::Dialect::Uk,
566 )
567 .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
568 });
569
570 sql.and_where_is_null("deleted_at");
571
572 let query = sql.sql().expect("bug in search query. please report");
573
574 let res = sqlx::query(&query)
575 .map(Self::query_history)
576 .fetch_all(&self.pool)
577 .await?;
578
579 Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
580 }
581
582 async fn query_history(&self, query: &str) -> Result<Vec<History>> {
583 let res = sqlx::query(query)
584 .map(Self::query_history)
585 .fetch_all(&self.pool)
586 .await?;
587
588 Ok(res)
589 }
590
591 async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
592 debug!("listing history");
593
594 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
595
596 query
597 .fields(&[
598 "id",
599 "max(timestamp) as timestamp",
600 "max(duration) as duration",
601 "exit",
602 "command",
603 "deleted_at",
604 "group_concat(cwd, ':') as cwd",
605 "group_concat(session) as session",
606 "group_concat(hostname, ',') as hostname",
607 "count(*) as count",
608 ])
609 .group_by("command")
610 .group_by("exit")
611 .and_where("deleted_at is null")
612 .order_desc("timestamp");
613
614 let query = query.sql().expect("bug in list query. please report");
615
616 let res = sqlx::query(&query)
617 .map(|row: SqliteRow| {
618 let count: i32 = row.get("count");
619 (Self::query_history(row), count)
620 })
621 .fetch_all(&self.pool)
622 .await?;
623
624 Ok(res)
625 }
626
627 async fn delete(&self, mut h: History) -> Result<()> {
630 let now = OffsetDateTime::now_utc();
631 h.command = rand::thread_rng()
632 .sample_iter(&Alphanumeric)
633 .take(32)
634 .map(char::from)
635 .collect(); h.deleted_at = Some(now); self.update(&h).await?; Ok(())
641 }
642
643 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
644 let mut tx = self.pool.begin().await?;
645
646 for id in ids {
647 Self::delete_row_raw(&mut tx, id.clone()).await?;
648 }
649
650 tx.commit().await?;
651
652 Ok(())
653 }
654
655 async fn stats(&self, h: &History) -> Result<HistoryStats> {
656 let mut prev = SqlBuilder::select_from("history");
658 prev.field("*")
659 .and_where("timestamp < ?1")
660 .and_where("session = ?2")
661 .order_by("timestamp", true)
662 .limit(1);
663
664 let mut next = SqlBuilder::select_from("history");
665 next.field("*")
666 .and_where("timestamp > ?1")
667 .and_where("session = ?2")
668 .order_by("timestamp", false)
669 .limit(1);
670
671 let mut total = SqlBuilder::select_from("history");
672 total.field("count(1)").and_where("command = ?1");
673
674 let mut average = SqlBuilder::select_from("history");
675 average.field("avg(duration)").and_where("command = ?1");
676
677 let mut exits = SqlBuilder::select_from("history");
678 exits
679 .fields(&["exit", "count(1) as count"])
680 .and_where("command = ?1")
681 .group_by("exit");
682
683 let mut day_of_week = SqlBuilder::select_from("history");
685 day_of_week
686 .fields(&[
687 "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
688 "count(1) as count",
689 ])
690 .and_where("command = ?1")
691 .group_by("day_of_week");
692
693 let mut duration_over_time = SqlBuilder::select_from("history");
698 duration_over_time
699 .fields(&[
700 "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
701 "avg(duration) as duration",
702 ])
703 .and_where("command = ?1")
704 .group_by("month_year")
705 .having("duration > 0");
706
707 let prev = prev.sql().expect("issue in stats previous query");
708 let next = next.sql().expect("issue in stats next query");
709 let total = total.sql().expect("issue in stats average query");
710 let average = average.sql().expect("issue in stats previous query");
711 let exits = exits.sql().expect("issue in stats exits query");
712 let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
713 let duration_over_time = duration_over_time
714 .sql()
715 .expect("issue in stats duration over time query");
716
717 let prev = sqlx::query(&prev)
718 .bind(h.timestamp.unix_timestamp_nanos() as i64)
719 .bind(&h.session)
720 .map(Self::query_history)
721 .fetch_optional(&self.pool)
722 .await?;
723
724 let next = sqlx::query(&next)
725 .bind(h.timestamp.unix_timestamp_nanos() as i64)
726 .bind(&h.session)
727 .map(Self::query_history)
728 .fetch_optional(&self.pool)
729 .await?;
730
731 let total: (i64,) = sqlx::query_as(&total)
732 .bind(&h.command)
733 .fetch_one(&self.pool)
734 .await?;
735
736 let average: (f64,) = sqlx::query_as(&average)
737 .bind(&h.command)
738 .fetch_one(&self.pool)
739 .await?;
740
741 let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
742 .bind(&h.command)
743 .fetch_all(&self.pool)
744 .await?;
745
746 let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
747 .bind(&h.command)
748 .fetch_all(&self.pool)
749 .await?;
750
751 let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
752 .bind(&h.command)
753 .fetch_all(&self.pool)
754 .await?;
755
756 let duration_over_time = duration_over_time
757 .iter()
758 .map(|f| (f.0.clone(), f.1.round() as i64))
759 .collect();
760
761 Ok(HistoryStats {
762 next,
763 previous: prev,
764 total: total.0 as u64,
765 average_duration: average.0 as u64,
766 exits,
767 day_of_week,
768 duration_over_time,
769 })
770 }
771}
772
773trait SqlBuilderExt {
774 fn fuzzy_condition<S: ToString, T: ToString>(
775 &mut self,
776 field: S,
777 mask: T,
778 inverse: bool,
779 glob: bool,
780 is_or: bool,
781 ) -> &mut Self;
782}
783
784impl SqlBuilderExt for SqlBuilder {
785 fn fuzzy_condition<S: ToString, T: ToString>(
787 &mut self,
788 field: S,
789 mask: T,
790 inverse: bool,
791 glob: bool,
792 is_or: bool,
793 ) -> &mut Self {
794 let mut cond = field.to_string();
795 if inverse {
796 cond.push_str(" NOT");
797 }
798 if glob {
799 cond.push_str(" GLOB '");
800 } else {
801 cond.push_str(" LIKE '");
802 }
803 cond.push_str(&esc(mask.to_string()));
804 cond.push('\'');
805 if is_or {
806 self.or_where(cond)
807 } else {
808 self.and_where(cond)
809 }
810 }
811}
812
813#[cfg(test)]
814mod test {
815 use crate::settings::test_local_timeout;
816
817 use super::*;
818 use std::time::{Duration, Instant};
819
820 async fn assert_search_eq<'a>(
821 db: &impl Database,
822 mode: SearchMode,
823 filter_mode: FilterMode,
824 query: &str,
825 expected: usize,
826 ) -> Result<Vec<History>> {
827 let context = Context {
828 hostname: "test:host".to_string(),
829 session: "beepboopiamasession".to_string(),
830 cwd: "/home/ellie".to_string(),
831 host_id: "test-host".to_string(),
832 git_root: None,
833 };
834
835 let results = db
836 .search(
837 mode,
838 filter_mode,
839 &context,
840 query,
841 OptFilters {
842 ..Default::default()
843 },
844 )
845 .await?;
846
847 assert_eq!(
848 results.len(),
849 expected,
850 "query \"{}\", commands: {:?}",
851 query,
852 results.iter().map(|a| &a.command).collect::<Vec<&String>>()
853 );
854 Ok(results)
855 }
856
857 async fn assert_search_commands(
858 db: &impl Database,
859 mode: SearchMode,
860 filter_mode: FilterMode,
861 query: &str,
862 expected_commands: Vec<&str>,
863 ) {
864 let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
865 .await
866 .unwrap();
867 let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
868 assert_eq!(commands, expected_commands);
869 }
870
871 async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
872 let mut captured: History = History::capture()
873 .timestamp(OffsetDateTime::now_utc())
874 .command(cmd)
875 .cwd("/home/ellie")
876 .build()
877 .into();
878
879 captured.exit = 0;
880 captured.duration = 1;
881 captured.session = "beep boop".to_string();
882 captured.hostname = "booop".to_string();
883
884 db.save(&captured).await
885 }
886
887 #[tokio::test(flavor = "multi_thread")]
888 async fn test_search_prefix() {
889 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
890 .await
891 .unwrap();
892 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
893
894 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
895 .await
896 .unwrap();
897 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
898 .await
899 .unwrap();
900 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
901 .await
902 .unwrap();
903 }
904
905 #[tokio::test(flavor = "multi_thread")]
906 async fn test_search_fulltext() {
907 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
908 .await
909 .unwrap();
910 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
911
912 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
913 .await
914 .unwrap();
915 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
916 .await
917 .unwrap();
918 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
919 .await
920 .unwrap();
921 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
922 .await
923 .unwrap();
924
925 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
927 .await
928 .unwrap();
929 assert_search_eq(
930 &db,
931 SearchMode::FullText,
932 FilterMode::Global,
933 "r/ls / ie$",
934 1,
935 )
936 .await
937 .unwrap();
938 assert_search_eq(
939 &db,
940 SearchMode::FullText,
941 FilterMode::Global,
942 "r/ls / !ie",
943 0,
944 )
945 .await
946 .unwrap();
947 assert_search_eq(
948 &db,
949 SearchMode::FullText,
950 FilterMode::Global,
951 "meow r/ls/",
952 0,
953 )
954 .await
955 .unwrap();
956 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
957 .await
958 .unwrap();
959 assert_search_eq(
960 &db,
961 SearchMode::FullText,
962 FilterMode::Global,
963 "r//home//",
964 1,
965 )
966 .await
967 .unwrap();
968 assert_search_eq(
969 &db,
970 SearchMode::FullText,
971 FilterMode::Global,
972 "r//home///",
973 0,
974 )
975 .await
976 .unwrap();
977 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
978 .await
979 .unwrap();
980 assert_search_eq(
981 &db,
982 SearchMode::FullText,
983 FilterMode::Global,
984 "r/home.*e",
985 1,
986 )
987 .await
988 .unwrap();
989 }
990
991 #[tokio::test(flavor = "multi_thread")]
992 async fn test_search_fuzzy() {
993 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
994 .await
995 .unwrap();
996 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
997 new_history_item(&mut db, "ls /home/frank").await.unwrap();
998 new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
999 new_history_item(&mut db, "/home/ellie/.bin/rustup")
1000 .await
1001 .unwrap();
1002
1003 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
1004 .await
1005 .unwrap();
1006 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
1007 .await
1008 .unwrap();
1009 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1010 .await
1011 .unwrap();
1012 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1013 .await
1014 .unwrap();
1015 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1016 .await
1017 .unwrap();
1018 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1019 .await
1020 .unwrap();
1021 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1022 .await
1023 .unwrap();
1024 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1025 .await
1026 .unwrap();
1027
1028 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1030 .await
1031 .unwrap();
1032 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1033 .await
1034 .unwrap();
1035 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1036 .await
1037 .unwrap();
1038 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1039 .await
1040 .unwrap();
1041 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1042 .await
1043 .unwrap();
1044 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1045 .await
1046 .unwrap();
1047
1048 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1050 .await
1051 .unwrap();
1052 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1053 .await
1054 .unwrap();
1055 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1056 .await
1057 .unwrap();
1058 assert_search_eq(
1059 &db,
1060 SearchMode::Fuzzy,
1061 FilterMode::Global,
1062 "'frank | 'rustup",
1063 2,
1064 )
1065 .await
1066 .unwrap();
1067 assert_search_eq(
1068 &db,
1069 SearchMode::Fuzzy,
1070 FilterMode::Global,
1071 "'frank | 'rustup 'ls",
1072 1,
1073 )
1074 .await
1075 .unwrap();
1076
1077 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1079 .await
1080 .unwrap();
1081
1082 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1084 .await
1085 .unwrap();
1086 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1087 .await
1088 .unwrap();
1089 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1090 .await
1091 .unwrap();
1092 }
1093
1094 #[tokio::test(flavor = "multi_thread")]
1095 async fn test_search_reordered_fuzzy() {
1096 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1097 .await
1098 .unwrap();
1099 new_history_item(&mut db, "curl").await.unwrap();
1102 new_history_item(&mut db, "corburl").await.unwrap();
1103
1104 assert_search_commands(
1106 &db,
1107 SearchMode::Fuzzy,
1108 FilterMode::Global,
1109 "curl",
1110 vec!["curl", "corburl"],
1111 )
1112 .await;
1113
1114 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1115 .await
1116 .unwrap();
1117 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1118 .await
1119 .unwrap();
1120 }
1121
1122 #[tokio::test(flavor = "multi_thread")]
1123 async fn test_search_bench_dupes() {
1124 let context = Context {
1125 hostname: "test:host".to_string(),
1126 session: "beepboopiamasession".to_string(),
1127 cwd: "/home/ellie".to_string(),
1128 host_id: "test-host".to_string(),
1129 git_root: None,
1130 };
1131
1132 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1133 .await
1134 .unwrap();
1135 for _i in 1..10000 {
1136 new_history_item(&mut db, "i am a duplicated command")
1137 .await
1138 .unwrap();
1139 }
1140 let start = Instant::now();
1141 let _results = db
1142 .search(
1143 SearchMode::Fuzzy,
1144 FilterMode::Global,
1145 &context,
1146 "",
1147 OptFilters {
1148 ..Default::default()
1149 },
1150 )
1151 .await
1152 .unwrap();
1153 let duration = start.elapsed();
1154
1155 assert!(duration < Duration::from_secs(15));
1156 }
1157}