1use std::{
2 borrow::Cow,
3 env,
4 path::{Path, PathBuf},
5 str::FromStr,
6 time::Duration,
7};
8
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{Rng, distributions::Alphanumeric};
14use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote};
15use sqlx::{
16 Result, Row,
17 sqlite::{
18 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
19 SqliteSynchronous,
20 },
21};
22use time::OffsetDateTime;
23use uuid::Uuid;
24
25use crate::{
26 history::{HistoryId, HistoryStats},
27 utils::get_host_user,
28};
29
30use super::{
31 history::History,
32 ordering,
33 settings::{FilterMode, SearchMode, Settings},
34};
35
36pub struct Context {
37 pub session: String,
38 pub cwd: String,
39 pub hostname: String,
40 pub host_id: String,
41 pub git_root: Option<PathBuf>,
42}
43
44#[derive(Default, Clone)]
45pub struct OptFilters {
46 pub exit: Option<i64>,
47 pub exclude_exit: Option<i64>,
48 pub cwd: Option<String>,
49 pub exclude_cwd: Option<String>,
50 pub before: Option<String>,
51 pub after: Option<String>,
52 pub limit: Option<i64>,
53 pub offset: Option<i64>,
54 pub reverse: bool,
55 pub include_duplicates: bool,
56}
57
58pub fn current_context() -> Context {
59 let Ok(session) = env::var("ATUIN_SESSION") else {
60 eprintln!(
61 "ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."
62 );
63 std::process::exit(1);
64 };
65 let hostname = get_host_user();
66 let cwd = utils::get_current_dir();
67 let host_id = Settings::host_id().expect("failed to load host ID");
68 let git_root = utils::in_git_repo(cwd.as_str());
69
70 Context {
71 session,
72 hostname,
73 cwd,
74 git_root,
75 host_id: host_id.0.as_simple().to_string(),
76 }
77}
78
79fn get_session_start_time(session_id: &str) -> Option<i64> {
80 if let Ok(uuid) = Uuid::parse_str(session_id)
81 && let Some(timestamp) = uuid.get_timestamp()
82 {
83 let (seconds, nanos) = timestamp.to_unix();
84 return Some(seconds as i64 * 1_000_000_000 + nanos as i64);
85 }
86 None
87}
88
89#[async_trait]
90pub trait Database: Send + Sync + 'static {
91 async fn save(&self, h: &History) -> Result<()>;
92 async fn save_bulk(&self, h: &[History]) -> Result<()>;
93
94 async fn load(&self, id: &str) -> Result<Option<History>>;
95 async fn list(
96 &self,
97 filters: &[FilterMode],
98 context: &Context,
99 max: Option<usize>,
100 unique: bool,
101 include_deleted: bool,
102 ) -> Result<Vec<History>>;
103 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
104
105 async fn update(&self, h: &History) -> Result<()>;
106 async fn history_count(&self, include_deleted: bool) -> Result<i64>;
107
108 async fn last(&self) -> Result<Option<History>>;
109 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
110
111 async fn delete(&self, h: History) -> Result<()>;
112 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
113 async fn deleted(&self) -> Result<Vec<History>>;
114
115 #[allow(clippy::too_many_arguments)]
119 async fn search(
120 &self,
121 search_mode: SearchMode,
122 filter: FilterMode,
123 context: &Context,
124 query: &str,
125 filter_options: OptFilters,
126 ) -> Result<Vec<History>>;
127
128 async fn query_history(&self, query: &str) -> Result<Vec<History>>;
129
130 async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
131
132 async fn stats(&self, h: &History) -> Result<HistoryStats>;
133
134 async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>;
135}
136
137#[derive(Debug, Clone)]
140pub struct Sqlite {
141 pub pool: SqlitePool,
142}
143
144impl Sqlite {
145 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
146 let path = path.as_ref();
147 debug!("opening sqlite database at {path:?}");
148
149 if utils::broken_symlink(path) {
150 eprintln!(
151 "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
152 );
153 std::process::exit(1);
154 }
155
156 if !path.exists()
157 && let Some(dir) = path.parent()
158 {
159 fs::create_dir_all(dir)?;
160 }
161
162 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
163 .journal_mode(SqliteJournalMode::Wal)
164 .optimize_on_close(true, None)
165 .synchronous(SqliteSynchronous::Normal)
166 .with_regexp()
167 .create_if_missing(true);
168
169 let pool = SqlitePoolOptions::new()
170 .acquire_timeout(Duration::from_secs_f64(timeout))
171 .connect_with(opts)
172 .await?;
173
174 Self::setup_db(&pool).await?;
175 Ok(Self { pool })
176 }
177
178 pub async fn sqlite_version(&self) -> Result<String> {
179 sqlx::query_scalar("SELECT sqlite_version()")
180 .fetch_one(&self.pool)
181 .await
182 }
183
184 async fn setup_db(pool: &SqlitePool) -> Result<()> {
185 debug!("running sqlite database setup");
186
187 sqlx::migrate!("./migrations").run(pool).await?;
188
189 Ok(())
190 }
191
192 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
193 sqlx::query(
194 "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at)
195 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
196 )
197 .bind(h.id.0.as_str())
198 .bind(h.timestamp.unix_timestamp_nanos() as i64)
199 .bind(h.duration)
200 .bind(h.exit)
201 .bind(h.command.as_str())
202 .bind(h.cwd.as_str())
203 .bind(h.session.as_str())
204 .bind(h.hostname.as_str())
205 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
206 .execute(&mut **tx)
207 .await?;
208
209 Ok(())
210 }
211
212 async fn delete_row_raw(
213 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
214 id: HistoryId,
215 ) -> Result<()> {
216 sqlx::query("delete from history where id = ?1")
217 .bind(id.0.as_str())
218 .execute(&mut **tx)
219 .await?;
220
221 Ok(())
222 }
223
224 fn query_history(row: SqliteRow) -> History {
225 let deleted_at: Option<i64> = row.get("deleted_at");
226
227 History::from_db()
228 .id(row.get("id"))
229 .timestamp(
230 OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
231 .unwrap(),
232 )
233 .duration(row.get("duration"))
234 .exit(row.get("exit"))
235 .command(row.get("command"))
236 .cwd(row.get("cwd"))
237 .session(row.get("session"))
238 .hostname(row.get("hostname"))
239 .deleted_at(
240 deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
241 )
242 .build()
243 .into()
244 }
245}
246
247#[async_trait]
248impl Database for Sqlite {
249 async fn save(&self, h: &History) -> Result<()> {
250 debug!("saving history to sqlite");
251 let mut tx = self.pool.begin().await?;
252 Self::save_raw(&mut tx, h).await?;
253 tx.commit().await?;
254
255 Ok(())
256 }
257
258 async fn save_bulk(&self, h: &[History]) -> Result<()> {
259 debug!("saving history to sqlite");
260
261 let mut tx = self.pool.begin().await?;
262
263 for i in h {
264 Self::save_raw(&mut tx, i).await?;
265 }
266
267 tx.commit().await?;
268
269 Ok(())
270 }
271
272 async fn load(&self, id: &str) -> Result<Option<History>> {
273 debug!("loading history item {}", id);
274
275 let res = sqlx::query("select * from history where id = ?1")
276 .bind(id)
277 .map(Self::query_history)
278 .fetch_optional(&self.pool)
279 .await?;
280
281 Ok(res)
282 }
283
284 async fn update(&self, h: &History) -> Result<()> {
285 debug!("updating sqlite history");
286
287 sqlx::query(
288 "update history
289 set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9
290 where id = ?1",
291 )
292 .bind(h.id.0.as_str())
293 .bind(h.timestamp.unix_timestamp_nanos() as i64)
294 .bind(h.duration)
295 .bind(h.exit)
296 .bind(h.command.as_str())
297 .bind(h.cwd.as_str())
298 .bind(h.session.as_str())
299 .bind(h.hostname.as_str())
300 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
301 .execute(&self.pool)
302 .await?;
303
304 Ok(())
305 }
306
307 async fn list(
309 &self,
310 filters: &[FilterMode],
311 context: &Context,
312 max: Option<usize>,
313 unique: bool,
314 include_deleted: bool,
315 ) -> Result<Vec<History>> {
316 debug!("listing history");
317
318 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
319 query.field("*").order_desc("timestamp");
320 if !include_deleted {
321 query.and_where_is_null("deleted_at");
322 }
323
324 let git_root = if let Some(git_root) = context.git_root.clone() {
325 git_root.to_str().unwrap_or("/").to_string()
326 } else {
327 context.cwd.clone()
328 };
329
330 let session_start = get_session_start_time(&context.session);
331
332 for filter in filters {
333 match filter {
334 FilterMode::Global => &mut query,
335 FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
336 FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
337 FilterMode::SessionPreload => {
338 query.and_where_eq("session", quote(&context.session));
339 if let Some(session_start) = session_start {
340 query.or_where_lt("timestamp", session_start);
341 }
342 &mut query
343 }
344 FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
345 FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
346 };
347 }
348
349 if unique {
350 query.group_by("command").having("max(timestamp)");
351 }
352
353 if let Some(max) = max {
354 query.limit(max);
355 }
356
357 let query = query.sql().expect("bug in list query. please report");
358
359 let res = sqlx::query(&query)
360 .map(Self::query_history)
361 .fetch_all(&self.pool)
362 .await?;
363
364 Ok(res)
365 }
366
367 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
368 debug!("listing history from {:?} to {:?}", from, to);
369
370 let res = sqlx::query(
371 "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
372 )
373 .bind(from.unix_timestamp_nanos() as i64)
374 .bind(to.unix_timestamp_nanos() as i64)
375 .map(Self::query_history)
376 .fetch_all(&self.pool)
377 .await?;
378
379 Ok(res)
380 }
381
382 async fn last(&self) -> Result<Option<History>> {
383 let res = sqlx::query(
384 "select * from history where duration >= 0 order by timestamp desc limit 1",
385 )
386 .map(Self::query_history)
387 .fetch_optional(&self.pool)
388 .await?;
389
390 Ok(res)
391 }
392
393 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
394 let res = sqlx::query(
395 "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
396 )
397 .bind(timestamp.unix_timestamp_nanos() as i64)
398 .bind(count)
399 .map(Self::query_history)
400 .fetch_all(&self.pool)
401 .await?;
402
403 Ok(res)
404 }
405
406 async fn deleted(&self) -> Result<Vec<History>> {
407 let res = sqlx::query("select * from history where deleted_at is not null")
408 .map(Self::query_history)
409 .fetch_all(&self.pool)
410 .await?;
411
412 Ok(res)
413 }
414
415 async fn history_count(&self, include_deleted: bool) -> Result<i64> {
416 let query = if include_deleted {
417 "select count(1) from history"
418 } else {
419 "select count(1) from history where deleted_at is null"
420 };
421
422 let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
423 Ok(res.0)
424 }
425
426 async fn search(
427 &self,
428 search_mode: SearchMode,
429 filter: FilterMode,
430 context: &Context,
431 query: &str,
432 filter_options: OptFilters,
433 ) -> Result<Vec<History>> {
434 let mut sql = SqlBuilder::select_from("history");
435
436 if !filter_options.include_duplicates {
437 sql.group_by("command").having("max(timestamp)");
438 }
439
440 if let Some(limit) = filter_options.limit {
441 sql.limit(limit);
442 }
443
444 if let Some(offset) = filter_options.offset {
445 sql.offset(offset);
446 }
447
448 if filter_options.reverse {
449 sql.order_asc("timestamp");
450 } else {
451 sql.order_desc("timestamp");
452 }
453
454 let git_root = if let Some(git_root) = context.git_root.clone() {
455 git_root.to_str().unwrap_or("/").to_string()
456 } else {
457 context.cwd.clone()
458 };
459
460 let session_start = get_session_start_time(&context.session);
461
462 match filter {
463 FilterMode::Global => &mut sql,
464 FilterMode::Host => {
465 sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
466 }
467 FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
468 FilterMode::SessionPreload => {
469 sql.and_where_eq("session", quote(&context.session));
470 if let Some(session_start) = session_start {
471 sql.or_where_lt("timestamp", session_start);
472 }
473 &mut sql
474 }
475 FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
476 FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
477 };
478
479 let orig_query = query;
480
481 let mut regexes = Vec::new();
482 match search_mode {
483 SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
484 _ => {
485 let mut is_or = false;
486 let mut regex = None;
487 for part in query.split_inclusive(' ') {
488 let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
489 (None, false) => {
490 if part.trim_end().is_empty() {
491 continue;
492 }
493 Cow::Owned(part.trim_end().replace('*', "%")) }
495 (None, true) => {
496 if part[2..].trim_end().ends_with('/') {
497 let end_pos = part.trim_end().len() - 1;
498 regexes.push(String::from(&part[2..end_pos]));
499 } else {
500 regex = Some(String::from(&part[2..]));
501 }
502 continue;
503 }
504 (Some(r), _) => {
505 if part.trim_end().ends_with('/') {
506 let end_pos = part.trim_end().len() - 1;
507 r.push_str(&part.trim_end()[..end_pos]);
508 regexes.push(regex.take().unwrap());
509 } else {
510 r.push_str(part);
511 }
512 continue;
513 }
514 };
515
516 let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
518 (true, "*")
519 } else {
520 (false, "%")
521 };
522
523 let (is_inverse, query_part) = match query_part.strip_prefix('!') {
524 Some(stripped) => (true, Cow::Borrowed(stripped)),
525 None => (false, query_part),
526 };
527
528 #[allow(clippy::if_same_then_else)]
529 let param = if query_part == "|" {
530 if !is_or {
531 is_or = true;
532 continue;
533 } else {
534 format!("{glob}|{glob}")
535 }
536 } else if let Some(term) = query_part.strip_prefix('^') {
537 format!("{term}{glob}")
538 } else if let Some(term) = query_part.strip_suffix('$') {
539 format!("{glob}{term}")
540 } else if let Some(term) = query_part.strip_prefix('\'') {
541 format!("{glob}{term}{glob}")
542 } else if is_inverse {
543 format!("{glob}{query_part}{glob}")
544 } else if search_mode == SearchMode::FullText {
545 format!("{glob}{query_part}{glob}")
546 } else {
547 query_part.split("").join(glob)
548 };
549
550 sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
551 is_or = false;
552 }
553 if let Some(r) = regex {
554 regexes.push(r);
555 }
556
557 &mut sql
558 }
559 };
560
561 for regex in regexes {
562 sql.and_where("command regexp ?".bind(®ex));
563 }
564
565 filter_options
566 .exit
567 .map(|exit| sql.and_where_eq("exit", exit));
568
569 filter_options
570 .exclude_exit
571 .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
572
573 filter_options
574 .cwd
575 .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
576
577 filter_options
578 .exclude_cwd
579 .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
580
581 filter_options.before.map(|before| {
582 interim::parse_date_string(
583 before.as_str(),
584 OffsetDateTime::now_utc(),
585 interim::Dialect::Uk,
586 )
587 .map(|before| {
588 sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
589 })
590 });
591
592 filter_options.after.map(|after| {
593 interim::parse_date_string(
594 after.as_str(),
595 OffsetDateTime::now_utc(),
596 interim::Dialect::Uk,
597 )
598 .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
599 });
600
601 sql.and_where_is_null("deleted_at");
602
603 let query = sql.sql().expect("bug in search query. please report");
604
605 let res = sqlx::query(&query)
606 .map(Self::query_history)
607 .fetch_all(&self.pool)
608 .await?;
609
610 Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
611 }
612
613 async fn query_history(&self, query: &str) -> Result<Vec<History>> {
614 let res = sqlx::query(query)
615 .map(Self::query_history)
616 .fetch_all(&self.pool)
617 .await?;
618
619 Ok(res)
620 }
621
622 async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
623 debug!("listing history");
624
625 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
626
627 query
628 .fields(&[
629 "id",
630 "max(timestamp) as timestamp",
631 "max(duration) as duration",
632 "exit",
633 "command",
634 "deleted_at",
635 "group_concat(cwd, ':') as cwd",
636 "group_concat(session) as session",
637 "group_concat(hostname, ',') as hostname",
638 "count(*) as count",
639 ])
640 .group_by("command")
641 .group_by("exit")
642 .and_where("deleted_at is null")
643 .order_desc("timestamp");
644
645 let query = query.sql().expect("bug in list query. please report");
646
647 let res = sqlx::query(&query)
648 .map(|row: SqliteRow| {
649 let count: i32 = row.get("count");
650 (Self::query_history(row), count)
651 })
652 .fetch_all(&self.pool)
653 .await?;
654
655 Ok(res)
656 }
657
658 async fn delete(&self, mut h: History) -> Result<()> {
661 let now = OffsetDateTime::now_utc();
662 h.command = rand::thread_rng()
663 .sample_iter(&Alphanumeric)
664 .take(32)
665 .map(char::from)
666 .collect(); h.deleted_at = Some(now); self.update(&h).await?; Ok(())
672 }
673
674 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
675 let mut tx = self.pool.begin().await?;
676
677 for id in ids {
678 Self::delete_row_raw(&mut tx, id.clone()).await?;
679 }
680
681 tx.commit().await?;
682
683 Ok(())
684 }
685
686 async fn stats(&self, h: &History) -> Result<HistoryStats> {
687 let mut prev = SqlBuilder::select_from("history");
689 prev.field("*")
690 .and_where("timestamp < ?1")
691 .and_where("session = ?2")
692 .order_by("timestamp", true)
693 .limit(1);
694
695 let mut next = SqlBuilder::select_from("history");
696 next.field("*")
697 .and_where("timestamp > ?1")
698 .and_where("session = ?2")
699 .order_by("timestamp", false)
700 .limit(1);
701
702 let mut total = SqlBuilder::select_from("history");
703 total.field("count(1)").and_where("command = ?1");
704
705 let mut average = SqlBuilder::select_from("history");
706 average.field("avg(duration)").and_where("command = ?1");
707
708 let mut exits = SqlBuilder::select_from("history");
709 exits
710 .fields(&["exit", "count(1) as count"])
711 .and_where("command = ?1")
712 .group_by("exit");
713
714 let mut day_of_week = SqlBuilder::select_from("history");
716 day_of_week
717 .fields(&[
718 "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
719 "count(1) as count",
720 ])
721 .and_where("command = ?1")
722 .group_by("day_of_week");
723
724 let mut duration_over_time = SqlBuilder::select_from("history");
729 duration_over_time
730 .fields(&[
731 "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
732 "avg(duration) as duration",
733 ])
734 .and_where("command = ?1")
735 .group_by("month_year")
736 .having("duration > 0");
737
738 let prev = prev.sql().expect("issue in stats previous query");
739 let next = next.sql().expect("issue in stats next query");
740 let total = total.sql().expect("issue in stats average query");
741 let average = average.sql().expect("issue in stats previous query");
742 let exits = exits.sql().expect("issue in stats exits query");
743 let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
744 let duration_over_time = duration_over_time
745 .sql()
746 .expect("issue in stats duration over time query");
747
748 let prev = sqlx::query(&prev)
749 .bind(h.timestamp.unix_timestamp_nanos() as i64)
750 .bind(&h.session)
751 .map(Self::query_history)
752 .fetch_optional(&self.pool)
753 .await?;
754
755 let next = sqlx::query(&next)
756 .bind(h.timestamp.unix_timestamp_nanos() as i64)
757 .bind(&h.session)
758 .map(Self::query_history)
759 .fetch_optional(&self.pool)
760 .await?;
761
762 let total: (i64,) = sqlx::query_as(&total)
763 .bind(&h.command)
764 .fetch_one(&self.pool)
765 .await?;
766
767 let average: (f64,) = sqlx::query_as(&average)
768 .bind(&h.command)
769 .fetch_one(&self.pool)
770 .await?;
771
772 let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
773 .bind(&h.command)
774 .fetch_all(&self.pool)
775 .await?;
776
777 let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
778 .bind(&h.command)
779 .fetch_all(&self.pool)
780 .await?;
781
782 let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
783 .bind(&h.command)
784 .fetch_all(&self.pool)
785 .await?;
786
787 let duration_over_time = duration_over_time
788 .iter()
789 .map(|f| (f.0.clone(), f.1.round() as i64))
790 .collect();
791
792 Ok(HistoryStats {
793 next,
794 previous: prev,
795 total: total.0 as u64,
796 average_duration: average.0 as u64,
797 exits,
798 day_of_week,
799 duration_over_time,
800 })
801 }
802
803 async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> {
804 let res = sqlx::query(
805 "SELECT * FROM (
806 SELECT *, ROW_NUMBER()
807 OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC)
808 AS rn
809 FROM history
810 ) sub
811 WHERE rn > ?1 and timestamp < ?2;
812 ",
813 )
814 .bind(dupkeep)
815 .bind(before)
816 .map(Self::query_history)
817 .fetch_all(&self.pool)
818 .await?;
819
820 Ok(res)
821 }
822}
823
824trait SqlBuilderExt {
825 fn fuzzy_condition<S: ToString, T: ToString>(
826 &mut self,
827 field: S,
828 mask: T,
829 inverse: bool,
830 glob: bool,
831 is_or: bool,
832 ) -> &mut Self;
833}
834
835impl SqlBuilderExt for SqlBuilder {
836 fn fuzzy_condition<S: ToString, T: ToString>(
838 &mut self,
839 field: S,
840 mask: T,
841 inverse: bool,
842 glob: bool,
843 is_or: bool,
844 ) -> &mut Self {
845 let mut cond = field.to_string();
846 if inverse {
847 cond.push_str(" NOT");
848 }
849 if glob {
850 cond.push_str(" GLOB '");
851 } else {
852 cond.push_str(" LIKE '");
853 }
854 cond.push_str(&esc(mask.to_string()));
855 cond.push('\'');
856 if is_or {
857 self.or_where(cond)
858 } else {
859 self.and_where(cond)
860 }
861 }
862}
863
864#[cfg(test)]
865mod test {
866 use crate::settings::test_local_timeout;
867
868 use super::*;
869 use std::time::{Duration, Instant};
870
871 async fn assert_search_eq(
872 db: &impl Database,
873 mode: SearchMode,
874 filter_mode: FilterMode,
875 query: &str,
876 expected: usize,
877 ) -> Result<Vec<History>> {
878 let context = Context {
879 hostname: "test:host".to_string(),
880 session: "beepboopiamasession".to_string(),
881 cwd: "/home/ellie".to_string(),
882 host_id: "test-host".to_string(),
883 git_root: None,
884 };
885
886 let results = db
887 .search(
888 mode,
889 filter_mode,
890 &context,
891 query,
892 OptFilters {
893 ..Default::default()
894 },
895 )
896 .await?;
897
898 assert_eq!(
899 results.len(),
900 expected,
901 "query \"{}\", commands: {:?}",
902 query,
903 results.iter().map(|a| &a.command).collect::<Vec<&String>>()
904 );
905 Ok(results)
906 }
907
908 async fn assert_search_commands(
909 db: &impl Database,
910 mode: SearchMode,
911 filter_mode: FilterMode,
912 query: &str,
913 expected_commands: Vec<&str>,
914 ) {
915 let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
916 .await
917 .unwrap();
918 let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
919 assert_eq!(commands, expected_commands);
920 }
921
922 async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
923 let mut captured: History = History::capture()
924 .timestamp(OffsetDateTime::now_utc())
925 .command(cmd)
926 .cwd("/home/ellie")
927 .build()
928 .into();
929
930 captured.exit = 0;
931 captured.duration = 1;
932 captured.session = "beep boop".to_string();
933 captured.hostname = "booop".to_string();
934
935 db.save(&captured).await
936 }
937
938 #[tokio::test(flavor = "multi_thread")]
939 async fn test_search_prefix() {
940 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
941 .await
942 .unwrap();
943 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
944
945 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
946 .await
947 .unwrap();
948 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
949 .await
950 .unwrap();
951 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
952 .await
953 .unwrap();
954 }
955
956 #[tokio::test(flavor = "multi_thread")]
957 async fn test_search_fulltext() {
958 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
959 .await
960 .unwrap();
961 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
962
963 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
964 .await
965 .unwrap();
966 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
967 .await
968 .unwrap();
969 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
970 .await
971 .unwrap();
972 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
973 .await
974 .unwrap();
975
976 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
978 .await
979 .unwrap();
980 assert_search_eq(
981 &db,
982 SearchMode::FullText,
983 FilterMode::Global,
984 "r/ls / ie$",
985 1,
986 )
987 .await
988 .unwrap();
989 assert_search_eq(
990 &db,
991 SearchMode::FullText,
992 FilterMode::Global,
993 "r/ls / !ie",
994 0,
995 )
996 .await
997 .unwrap();
998 assert_search_eq(
999 &db,
1000 SearchMode::FullText,
1001 FilterMode::Global,
1002 "meow r/ls/",
1003 0,
1004 )
1005 .await
1006 .unwrap();
1007 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
1008 .await
1009 .unwrap();
1010 assert_search_eq(
1011 &db,
1012 SearchMode::FullText,
1013 FilterMode::Global,
1014 "r//home//",
1015 1,
1016 )
1017 .await
1018 .unwrap();
1019 assert_search_eq(
1020 &db,
1021 SearchMode::FullText,
1022 FilterMode::Global,
1023 "r//home///",
1024 0,
1025 )
1026 .await
1027 .unwrap();
1028 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
1029 .await
1030 .unwrap();
1031 assert_search_eq(
1032 &db,
1033 SearchMode::FullText,
1034 FilterMode::Global,
1035 "r/home.*e",
1036 1,
1037 )
1038 .await
1039 .unwrap();
1040 }
1041
1042 #[tokio::test(flavor = "multi_thread")]
1043 async fn test_search_fuzzy() {
1044 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1045 .await
1046 .unwrap();
1047 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
1048 new_history_item(&mut db, "ls /home/frank").await.unwrap();
1049 new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
1050 new_history_item(&mut db, "/home/ellie/.bin/rustup")
1051 .await
1052 .unwrap();
1053
1054 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
1055 .await
1056 .unwrap();
1057 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
1058 .await
1059 .unwrap();
1060 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1061 .await
1062 .unwrap();
1063 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1064 .await
1065 .unwrap();
1066 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1067 .await
1068 .unwrap();
1069 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1070 .await
1071 .unwrap();
1072 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1073 .await
1074 .unwrap();
1075 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1076 .await
1077 .unwrap();
1078
1079 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1081 .await
1082 .unwrap();
1083 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1084 .await
1085 .unwrap();
1086 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1087 .await
1088 .unwrap();
1089 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1090 .await
1091 .unwrap();
1092 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1093 .await
1094 .unwrap();
1095 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1096 .await
1097 .unwrap();
1098
1099 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1101 .await
1102 .unwrap();
1103 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1104 .await
1105 .unwrap();
1106 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1107 .await
1108 .unwrap();
1109 assert_search_eq(
1110 &db,
1111 SearchMode::Fuzzy,
1112 FilterMode::Global,
1113 "'frank | 'rustup",
1114 2,
1115 )
1116 .await
1117 .unwrap();
1118 assert_search_eq(
1119 &db,
1120 SearchMode::Fuzzy,
1121 FilterMode::Global,
1122 "'frank | 'rustup 'ls",
1123 1,
1124 )
1125 .await
1126 .unwrap();
1127
1128 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1130 .await
1131 .unwrap();
1132
1133 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1135 .await
1136 .unwrap();
1137 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1138 .await
1139 .unwrap();
1140 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1141 .await
1142 .unwrap();
1143 }
1144
1145 #[tokio::test(flavor = "multi_thread")]
1146 async fn test_search_reordered_fuzzy() {
1147 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1148 .await
1149 .unwrap();
1150 new_history_item(&mut db, "curl").await.unwrap();
1153 new_history_item(&mut db, "corburl").await.unwrap();
1154
1155 assert_search_commands(
1157 &db,
1158 SearchMode::Fuzzy,
1159 FilterMode::Global,
1160 "curl",
1161 vec!["curl", "corburl"],
1162 )
1163 .await;
1164
1165 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1166 .await
1167 .unwrap();
1168 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1169 .await
1170 .unwrap();
1171 }
1172
1173 #[tokio::test(flavor = "multi_thread")]
1174 async fn test_search_bench_dupes() {
1175 let context = Context {
1176 hostname: "test:host".to_string(),
1177 session: "beepboopiamasession".to_string(),
1178 cwd: "/home/ellie".to_string(),
1179 host_id: "test-host".to_string(),
1180 git_root: None,
1181 };
1182
1183 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1184 .await
1185 .unwrap();
1186 for _i in 1..10000 {
1187 new_history_item(&mut db, "i am a duplicated command")
1188 .await
1189 .unwrap();
1190 }
1191 let start = Instant::now();
1192 let _results = db
1193 .search(
1194 SearchMode::Fuzzy,
1195 FilterMode::Global,
1196 &context,
1197 "",
1198 OptFilters {
1199 ..Default::default()
1200 },
1201 )
1202 .await
1203 .unwrap();
1204 let duration = start.elapsed();
1205
1206 assert!(duration < Duration::from_secs(15));
1207 }
1208}