1use std::{
2 env,
3 path::{Path, PathBuf},
4 str::FromStr,
5 time::Duration,
6};
7
8use crate::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS};
9use async_trait::async_trait;
10use atuin_common::utils;
11use fs_err as fs;
12use itertools::Itertools;
13use rand::{Rng, distributions::Alphanumeric};
14use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote};
15use sqlx::{
16 Result, Row,
17 sqlite::{
18 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
19 SqliteSynchronous,
20 },
21};
22use time::OffsetDateTime;
23use uuid::Uuid;
24
25use crate::{
26 history::{HistoryId, HistoryStats},
27 utils::get_host_user,
28};
29
30use super::{
31 history::History,
32 ordering,
33 settings::{FilterMode, SearchMode, Settings},
34};
35
36#[derive(Clone)]
37pub struct Context {
38 pub session: String,
39 pub cwd: String,
40 pub hostname: String,
41 pub host_id: String,
42 pub git_root: Option<PathBuf>,
43}
44
45#[derive(Default, Clone)]
46pub struct OptFilters {
47 pub exit: Option<i64>,
48 pub exclude_exit: Option<i64>,
49 pub cwd: Option<String>,
50 pub exclude_cwd: Option<String>,
51 pub before: Option<String>,
52 pub after: Option<String>,
53 pub limit: Option<i64>,
54 pub offset: Option<i64>,
55 pub reverse: bool,
56 pub include_duplicates: bool,
57 pub authors: Vec<String>,
59}
60
61pub async fn current_context() -> eyre::Result<Context> {
62 let session = env::var("ATUIN_SESSION").map_err(|_| {
63 eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.")
64 })?;
65 let hostname = get_host_user();
66 let cwd = utils::get_current_dir();
67 let host_id = Settings::host_id().await?;
68 let git_root = utils::in_git_repo(cwd.as_str());
69
70 Ok(Context {
71 session,
72 hostname,
73 cwd,
74 git_root,
75 host_id: host_id.0.as_simple().to_string(),
76 })
77}
78
79impl Context {
80 pub fn from_history(entry: &History) -> Self {
81 Context {
82 session: entry.session.to_string(),
83 cwd: entry.cwd.to_string(),
84 hostname: entry.hostname.to_string(),
85 host_id: String::new(),
86 git_root: utils::in_git_repo(entry.cwd.as_str()),
87 }
88 }
89}
90
91fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) {
93 let mut conditions: Vec<String> = Vec::new();
94 let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", ");
95 let author_expr = "CASE \
96 WHEN author IS NULL OR trim(author) = '' THEN \
97 CASE \
98 WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \
99 ELSE hostname \
100 END \
101 ELSE author \
102 END";
103
104 for author in authors {
105 match author.as_str() {
106 AUTHOR_FILTER_ALL_USER => {
107 conditions.push(format!("{author_expr} NOT IN ({agent_list})"));
108 }
109 AUTHOR_FILTER_ALL_AGENT => {
110 conditions.push(format!("{author_expr} IN ({agent_list})"));
111 }
112 literal => {
113 conditions.push(format!("{author_expr} = {}", quote(literal)));
114 }
115 }
116 }
117
118 if !conditions.is_empty() {
119 sql.and_where(format!("({})", conditions.join(" OR ")));
120 }
121}
122
123fn get_session_start_time(session_id: &str) -> Option<i64> {
124 if let Ok(uuid) = Uuid::parse_str(session_id)
125 && let Some(timestamp) = uuid.get_timestamp()
126 {
127 let (seconds, nanos) = timestamp.to_unix();
128 return Some(seconds as i64 * 1_000_000_000 + nanos as i64);
129 }
130 None
131}
132
133#[async_trait]
134pub trait Database: Send + Sync + 'static {
135 async fn save(&self, h: &History) -> Result<()>;
136 async fn save_bulk(&self, h: &[History]) -> Result<()>;
137
138 async fn load(&self, id: &str) -> Result<Option<History>>;
139 async fn list(
140 &self,
141 filters: &[FilterMode],
142 context: &Context,
143 max: Option<usize>,
144 unique: bool,
145 include_deleted: bool,
146 ) -> Result<Vec<History>>;
147 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>;
148
149 async fn update(&self, h: &History) -> Result<()>;
150 async fn history_count(&self, include_deleted: bool) -> Result<i64>;
151
152 async fn last(&self) -> Result<Option<History>>;
153 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>;
154
155 async fn delete(&self, h: History) -> Result<()>;
156 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>;
157 async fn deleted(&self) -> Result<Vec<History>>;
158
159 #[allow(clippy::too_many_arguments)]
163 async fn search(
164 &self,
165 search_mode: SearchMode,
166 filter: FilterMode,
167 context: &Context,
168 query: &str,
169 filter_options: OptFilters,
170 ) -> Result<Vec<History>>;
171
172 async fn query_history(&self, query: &str) -> Result<Vec<History>>;
173
174 async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
175
176 fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged;
177
178 async fn stats(&self, h: &History) -> Result<HistoryStats>;
179
180 async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>;
181
182 fn clone_boxed(&self) -> Box<dyn Database + 'static>;
183}
184
185#[derive(Debug, Clone)]
188pub struct Sqlite {
189 pub pool: SqlitePool,
190}
191
192impl Sqlite {
193 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
194 let path = path.as_ref();
195 debug!("opening sqlite database at {path:?}");
196
197 if utils::broken_symlink(path) {
198 eprintln!(
199 "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
200 );
201 std::process::exit(1);
202 }
203
204 if !path.exists()
205 && let Some(dir) = path.parent()
206 {
207 fs::create_dir_all(dir)?;
208 }
209
210 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
211 .journal_mode(SqliteJournalMode::Wal)
212 .optimize_on_close(true, None)
213 .synchronous(SqliteSynchronous::Normal)
214 .with_regexp()
215 .create_if_missing(true);
216
217 let pool = SqlitePoolOptions::new()
218 .acquire_timeout(Duration::from_secs_f64(timeout))
219 .connect_with(opts)
220 .await?;
221
222 Self::setup_db(&pool).await?;
223 Ok(Self { pool })
224 }
225
226 pub async fn sqlite_version(&self) -> Result<String> {
227 sqlx::query_scalar("SELECT sqlite_version()")
228 .fetch_one(&self.pool)
229 .await
230 }
231
232 async fn setup_db(pool: &SqlitePool) -> Result<()> {
233 debug!("running sqlite database setup");
234
235 sqlx::migrate!("./migrations").run(pool).await?;
236
237 Ok(())
238 }
239
240 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
241 sqlx::query(
242 "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at)
243 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
244 )
245 .bind(h.id.0.as_str())
246 .bind(h.timestamp.unix_timestamp_nanos() as i64)
247 .bind(h.duration)
248 .bind(h.exit)
249 .bind(h.command.as_str())
250 .bind(h.cwd.as_str())
251 .bind(h.session.as_str())
252 .bind(h.hostname.as_str())
253 .bind(h.author.as_str())
254 .bind(h.intent.as_deref())
255 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
256 .execute(&mut **tx)
257 .await?;
258
259 Ok(())
260 }
261
262 async fn delete_row_raw(
263 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
264 id: HistoryId,
265 ) -> Result<()> {
266 sqlx::query("delete from history where id = ?1")
267 .bind(id.0.as_str())
268 .execute(&mut **tx)
269 .await?;
270
271 Ok(())
272 }
273
274 fn query_history(row: SqliteRow) -> History {
275 let deleted_at: Option<i64> = row.get("deleted_at");
276 let hostname: String = row.get("hostname");
277 let author: Option<String> = row.try_get("author").ok().flatten();
278 let author = author
279 .filter(|author| !author.trim().is_empty())
280 .unwrap_or_else(|| History::author_from_hostname(hostname.as_str()));
281 let intent: Option<String> = row.try_get("intent").ok().flatten();
282 let intent = intent.filter(|intent| !intent.trim().is_empty());
283
284 History::from_db()
285 .id(row.get("id"))
286 .timestamp(
287 OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128)
288 .unwrap(),
289 )
290 .duration(row.get("duration"))
291 .exit(row.get("exit"))
292 .command(row.get("command"))
293 .cwd(row.get("cwd"))
294 .session(row.get("session"))
295 .hostname(hostname)
296 .author(author)
297 .intent(intent)
298 .deleted_at(
299 deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()),
300 )
301 .build()
302 .into()
303 }
304}
305
306#[async_trait]
307impl Database for Sqlite {
308 async fn save(&self, h: &History) -> Result<()> {
309 debug!("saving history to sqlite");
310 let mut tx = self.pool.begin().await?;
311 Self::save_raw(&mut tx, h).await?;
312 tx.commit().await?;
313
314 Ok(())
315 }
316
317 async fn save_bulk(&self, h: &[History]) -> Result<()> {
318 debug!("saving history to sqlite");
319
320 let mut tx = self.pool.begin().await?;
321
322 for i in h {
323 Self::save_raw(&mut tx, i).await?;
324 }
325
326 tx.commit().await?;
327
328 Ok(())
329 }
330
331 async fn load(&self, id: &str) -> Result<Option<History>> {
332 debug!("loading history item {}", id);
333
334 let res = sqlx::query("select * from history where id = ?1")
335 .bind(id)
336 .map(Self::query_history)
337 .fetch_optional(&self.pool)
338 .await?;
339
340 Ok(res)
341 }
342
343 async fn update(&self, h: &History) -> Result<()> {
344 debug!("updating sqlite history");
345
346 sqlx::query(
347 "update history
348 set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11
349 where id = ?1",
350 )
351 .bind(h.id.0.as_str())
352 .bind(h.timestamp.unix_timestamp_nanos() as i64)
353 .bind(h.duration)
354 .bind(h.exit)
355 .bind(h.command.as_str())
356 .bind(h.cwd.as_str())
357 .bind(h.session.as_str())
358 .bind(h.hostname.as_str())
359 .bind(h.author.as_str())
360 .bind(h.intent.as_deref())
361 .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64))
362 .execute(&self.pool)
363 .await?;
364
365 Ok(())
366 }
367
368 async fn list(
370 &self,
371 filters: &[FilterMode],
372 context: &Context,
373 max: Option<usize>,
374 unique: bool,
375 include_deleted: bool,
376 ) -> Result<Vec<History>> {
377 debug!("listing history");
378
379 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
380 query.field("*").order_desc("timestamp");
381 if !include_deleted {
382 query.and_where_is_null("deleted_at");
383 }
384
385 let git_root = if let Some(git_root) = context.git_root.clone() {
386 git_root.to_str().unwrap_or("/").to_string()
387 } else {
388 context.cwd.clone()
389 };
390
391 let session_start = get_session_start_time(&context.session);
392
393 for filter in filters {
394 match filter {
395 FilterMode::Global => &mut query,
396 FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)),
397 FilterMode::Session => query.and_where_eq("session", quote(&context.session)),
398 FilterMode::SessionPreload => {
399 query.and_where_eq("session", quote(&context.session));
400 if let Some(session_start) = session_start {
401 query.or_where_lt("timestamp", session_start);
402 }
403 &mut query
404 }
405 FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)),
406 FilterMode::Workspace => query.and_where_like_left("cwd", &git_root),
407 };
408 }
409
410 if unique {
411 query.group_by("command").having("max(timestamp)");
412 }
413
414 if let Some(max) = max {
415 query.limit(max);
416 }
417
418 let query = query.sql().expect("bug in list query. please report");
419
420 let res = sqlx::query(&query)
421 .map(Self::query_history)
422 .fetch_all(&self.pool)
423 .await?;
424
425 Ok(res)
426 }
427
428 async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> {
429 debug!("listing history from {:?} to {:?}", from, to);
430
431 let res = sqlx::query(
432 "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
433 )
434 .bind(from.unix_timestamp_nanos() as i64)
435 .bind(to.unix_timestamp_nanos() as i64)
436 .map(Self::query_history)
437 .fetch_all(&self.pool)
438 .await?;
439
440 Ok(res)
441 }
442
443 async fn last(&self) -> Result<Option<History>> {
444 let res = sqlx::query(
445 "select * from history where duration >= 0 order by timestamp desc limit 1",
446 )
447 .map(Self::query_history)
448 .fetch_optional(&self.pool)
449 .await?;
450
451 Ok(res)
452 }
453
454 async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> {
455 let res = sqlx::query(
456 "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
457 )
458 .bind(timestamp.unix_timestamp_nanos() as i64)
459 .bind(count)
460 .map(Self::query_history)
461 .fetch_all(&self.pool)
462 .await?;
463
464 Ok(res)
465 }
466
467 async fn deleted(&self) -> Result<Vec<History>> {
468 let res = sqlx::query("select * from history where deleted_at is not null")
469 .map(Self::query_history)
470 .fetch_all(&self.pool)
471 .await?;
472
473 Ok(res)
474 }
475
476 async fn history_count(&self, include_deleted: bool) -> Result<i64> {
477 let query = if include_deleted {
478 "select count(1) from history"
479 } else {
480 "select count(1) from history where deleted_at is null"
481 };
482
483 let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?;
484 Ok(res.0)
485 }
486
487 async fn search(
488 &self,
489 search_mode: SearchMode,
490 filter: FilterMode,
491 context: &Context,
492 query: &str,
493 filter_options: OptFilters,
494 ) -> Result<Vec<History>> {
495 let mut sql = SqlBuilder::select_from("history");
496
497 if !filter_options.include_duplicates {
498 sql.group_by("command").having("max(timestamp)");
499 }
500
501 if let Some(limit) = filter_options.limit {
502 sql.limit(limit);
503 }
504
505 if let Some(offset) = filter_options.offset {
506 sql.offset(offset);
507 }
508
509 if filter_options.reverse {
510 sql.order_asc("timestamp");
511 } else {
512 sql.order_desc("timestamp");
513 }
514
515 let git_root = if let Some(git_root) = context.git_root.clone() {
516 git_root.to_str().unwrap_or("/").to_string()
517 } else {
518 context.cwd.clone()
519 };
520
521 let session_start = get_session_start_time(&context.session);
522
523 match filter {
524 FilterMode::Global => &mut sql,
525 FilterMode::Host => {
526 sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase()))
527 }
528 FilterMode::Session => sql.and_where_eq("session", quote(&context.session)),
529 FilterMode::SessionPreload => {
530 sql.and_where_eq("session", quote(&context.session));
531 if let Some(session_start) = session_start {
532 sql.or_where_lt("timestamp", session_start);
533 }
534 &mut sql
535 }
536 FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)),
537 FilterMode::Workspace => sql.and_where_like_left("cwd", git_root),
538 };
539
540 let orig_query = query;
541
542 let mut regexes = Vec::new();
543 match search_mode {
544 SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
545 _ => {
546 let mut is_or = false;
547 for token in QueryTokenizer::new(query) {
548 let (is_glob, glob) = if token.has_uppercase() {
550 (true, "*")
551 } else {
552 (false, "%")
553 };
554 let param = match token {
555 QueryToken::Regex(r) => {
556 regexes.push(String::from(r));
557 continue;
558 }
559 QueryToken::Or => {
560 if !is_or {
561 is_or = true;
562 continue;
563 } else {
564 format!("{glob}|{glob}")
565 }
566 }
567 QueryToken::MatchStart(term, _) => {
568 format!("{term}{glob}")
569 }
570 QueryToken::MatchEnd(term, _) => {
571 format!("{glob}{term}")
572 }
573 QueryToken::MatchFull(term, _) => {
574 format!("{glob}{term}{glob}")
575 }
576 QueryToken::Match(term, _) => {
577 if search_mode == SearchMode::FullText {
578 format!("{glob}{term}{glob}")
579 } else {
580 term.split("").join(glob)
581 }
582 }
583 };
584
585 sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or);
586 is_or = false;
587 }
588
589 &mut sql
590 }
591 };
592
593 for regex in regexes {
594 sql.and_where("command regexp ?".bind(®ex));
595 }
596
597 filter_options
598 .exit
599 .map(|exit| sql.and_where_eq("exit", exit));
600
601 filter_options
602 .exclude_exit
603 .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit));
604
605 filter_options
606 .cwd
607 .map(|cwd| sql.and_where_eq("cwd", quote(cwd)));
608
609 filter_options
610 .exclude_cwd
611 .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd)));
612
613 filter_options.before.map(|before| {
614 interim::parse_date_string(
615 before.as_str(),
616 OffsetDateTime::now_utc(),
617 interim::Dialect::Uk,
618 )
619 .map(|before| {
620 sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64))
621 })
622 });
623
624 filter_options.after.map(|after| {
625 interim::parse_date_string(
626 after.as_str(),
627 OffsetDateTime::now_utc(),
628 interim::Dialect::Uk,
629 )
630 .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64)))
631 });
632
633 if !filter_options.authors.is_empty() {
634 apply_author_filter(&mut sql, &filter_options.authors);
635 }
636
637 sql.and_where_is_null("deleted_at");
638
639 let query = sql.sql().expect("bug in search query. please report");
640
641 let res = sqlx::query(&query)
642 .map(Self::query_history)
643 .fetch_all(&self.pool)
644 .await?;
645
646 Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
647 }
648
649 async fn query_history(&self, query: &str) -> Result<Vec<History>> {
650 let res = sqlx::query(query)
651 .map(Self::query_history)
652 .fetch_all(&self.pool)
653 .await?;
654
655 Ok(res)
656 }
657
658 async fn all_with_count(&self) -> Result<Vec<(History, i32)>> {
659 debug!("listing history");
660
661 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
662
663 query
664 .fields(&[
665 "id",
666 "max(timestamp) as timestamp",
667 "max(duration) as duration",
668 "exit",
669 "command",
670 "deleted_at",
671 "null as author",
672 "null as intent",
673 "group_concat(cwd, ':') as cwd",
674 "group_concat(session) as session",
675 "group_concat(hostname, ',') as hostname",
676 "count(*) as count",
677 ])
678 .group_by("command")
679 .group_by("exit")
680 .and_where("deleted_at is null")
681 .order_desc("timestamp");
682
683 let query = query.sql().expect("bug in list query. please report");
684
685 let res = sqlx::query(&query)
686 .map(|row: SqliteRow| {
687 let count: i32 = row.get("count");
688 (Self::query_history(row), count)
689 })
690 .fetch_all(&self.pool)
691 .await?;
692
693 Ok(res)
694 }
695
696 fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged {
697 Paged::new(Box::new(self.clone()), page_size, include_deleted, unique)
698 }
699
700 async fn delete(&self, mut h: History) -> Result<()> {
703 let now = OffsetDateTime::now_utc();
704 h.command = rand::thread_rng()
705 .sample_iter(&Alphanumeric)
706 .take(32)
707 .map(char::from)
708 .collect(); h.deleted_at = Some(now); self.update(&h).await?; Ok(())
714 }
715
716 async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> {
717 let mut tx = self.pool.begin().await?;
718
719 for id in ids {
720 Self::delete_row_raw(&mut tx, id.clone()).await?;
721 }
722
723 tx.commit().await?;
724
725 Ok(())
726 }
727
728 async fn stats(&self, h: &History) -> Result<HistoryStats> {
729 let mut prev = SqlBuilder::select_from("history");
731 prev.field("*")
732 .and_where("timestamp < ?1")
733 .and_where("session = ?2")
734 .order_by("timestamp", true)
735 .limit(1);
736
737 let mut next = SqlBuilder::select_from("history");
738 next.field("*")
739 .and_where("timestamp > ?1")
740 .and_where("session = ?2")
741 .order_by("timestamp", false)
742 .limit(1);
743
744 let mut total = SqlBuilder::select_from("history");
745 total.field("count(1)").and_where("command = ?1");
746
747 let mut average = SqlBuilder::select_from("history");
748 average.field("avg(duration)").and_where("command = ?1");
749
750 let mut exits = SqlBuilder::select_from("history");
751 exits
752 .fields(&["exit", "count(1) as count"])
753 .and_where("command = ?1")
754 .group_by("exit");
755
756 let mut day_of_week = SqlBuilder::select_from("history");
758 day_of_week
759 .fields(&[
760 "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week",
761 "count(1) as count",
762 ])
763 .and_where("command = ?1")
764 .group_by("day_of_week");
765
766 let mut duration_over_time = SqlBuilder::select_from("history");
771 duration_over_time
772 .fields(&[
773 "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year",
774 "avg(duration) as duration",
775 ])
776 .and_where("command = ?1")
777 .group_by("month_year")
778 .having("duration > 0");
779
780 let prev = prev.sql().expect("issue in stats previous query");
781 let next = next.sql().expect("issue in stats next query");
782 let total = total.sql().expect("issue in stats average query");
783 let average = average.sql().expect("issue in stats previous query");
784 let exits = exits.sql().expect("issue in stats exits query");
785 let day_of_week = day_of_week.sql().expect("issue in stats day of week query");
786 let duration_over_time = duration_over_time
787 .sql()
788 .expect("issue in stats duration over time query");
789
790 let prev = sqlx::query(&prev)
791 .bind(h.timestamp.unix_timestamp_nanos() as i64)
792 .bind(&h.session)
793 .map(Self::query_history)
794 .fetch_optional(&self.pool)
795 .await?;
796
797 let next = sqlx::query(&next)
798 .bind(h.timestamp.unix_timestamp_nanos() as i64)
799 .bind(&h.session)
800 .map(Self::query_history)
801 .fetch_optional(&self.pool)
802 .await?;
803
804 let total: (i64,) = sqlx::query_as(&total)
805 .bind(&h.command)
806 .fetch_one(&self.pool)
807 .await?;
808
809 let average: (f64,) = sqlx::query_as(&average)
810 .bind(&h.command)
811 .fetch_one(&self.pool)
812 .await?;
813
814 let exits: Vec<(i64, i64)> = sqlx::query_as(&exits)
815 .bind(&h.command)
816 .fetch_all(&self.pool)
817 .await?;
818
819 let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week)
820 .bind(&h.command)
821 .fetch_all(&self.pool)
822 .await?;
823
824 let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time)
825 .bind(&h.command)
826 .fetch_all(&self.pool)
827 .await?;
828
829 let duration_over_time = duration_over_time
830 .iter()
831 .map(|f| (f.0.clone(), f.1.round() as i64))
832 .collect();
833
834 Ok(HistoryStats {
835 next,
836 previous: prev,
837 total: total.0 as u64,
838 average_duration: average.0 as u64,
839 exits,
840 day_of_week,
841 duration_over_time,
842 })
843 }
844
845 async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> {
846 let res = sqlx::query(
847 "SELECT * FROM (
848 SELECT *, ROW_NUMBER()
849 OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC)
850 AS rn
851 FROM history
852 ) sub
853 WHERE rn > ?1 and timestamp < ?2;
854 ",
855 )
856 .bind(dupkeep)
857 .bind(before)
858 .map(Self::query_history)
859 .fetch_all(&self.pool)
860 .await?;
861
862 Ok(res)
863 }
864
865 fn clone_boxed(&self) -> Box<dyn Database + 'static> {
866 Box::new(self.clone())
867 }
868}
869
870pub struct Paged {
871 database: Box<dyn Database + 'static>,
872 page_size: usize,
873 last_id: Option<String>,
874 include_deleted: bool,
875 unique: bool,
876}
877
878impl Paged {
879 pub fn new(
880 database: Box<dyn Database + 'static>,
881 page_size: usize,
882 include_deleted: bool,
883 unique: bool,
884 ) -> Self {
885 Self {
886 database,
887 page_size,
888 last_id: None,
889 include_deleted,
890 unique,
891 }
892 }
893
894 pub async fn next(&mut self) -> Result<Option<Vec<History>>> {
895 let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
896
897 query.field("*").order_desc("id");
898
899 if !self.include_deleted {
900 query.and_where_is_null("deleted_at");
901 }
902
903 if self.unique {
904 query
908 .group_by("command, cwd, hostname, session")
909 .having("max(timestamp)");
910 }
911
912 query.limit(self.page_size);
913
914 if let Some(last_id) = &self.last_id {
915 query.and_where_lt("id", quote(last_id));
916 }
917
918 let query = query.sql().expect("bug in list query. please report");
919 let res = self.database.query_history(&query).await?;
920
921 if res.is_empty() {
922 Ok(None)
923 } else {
924 self.last_id = Some(res.last().unwrap().id.0.clone());
925 Ok(Some(res))
926 }
927 }
928}
929
930trait SqlBuilderExt {
931 fn fuzzy_condition<S: ToString, T: ToString>(
932 &mut self,
933 field: S,
934 mask: T,
935 inverse: bool,
936 glob: bool,
937 is_or: bool,
938 ) -> &mut Self;
939}
940
941impl SqlBuilderExt for SqlBuilder {
942 fn fuzzy_condition<S: ToString, T: ToString>(
944 &mut self,
945 field: S,
946 mask: T,
947 inverse: bool,
948 glob: bool,
949 is_or: bool,
950 ) -> &mut Self {
951 let mut cond = field.to_string();
952 if inverse {
953 cond.push_str(" NOT");
954 }
955 if glob {
956 cond.push_str(" GLOB '");
957 } else {
958 cond.push_str(" LIKE '");
959 }
960 cond.push_str(&esc(mask.to_string()));
961 cond.push('\'');
962 if is_or {
963 self.or_where(cond)
964 } else {
965 self.and_where(cond)
966 }
967 }
968}
969
970#[cfg(test)]
971mod test {
972 use crate::settings::test_local_timeout;
973
974 use super::*;
975 use std::time::{Duration, Instant};
976
977 async fn assert_search_eq(
978 db: &impl Database,
979 mode: SearchMode,
980 filter_mode: FilterMode,
981 query: &str,
982 expected: usize,
983 ) -> Result<Vec<History>> {
984 let context = Context {
985 hostname: "test:host".to_string(),
986 session: "beepboopiamasession".to_string(),
987 cwd: "/home/ellie".to_string(),
988 host_id: "test-host".to_string(),
989 git_root: None,
990 };
991
992 let results = db
993 .search(
994 mode,
995 filter_mode,
996 &context,
997 query,
998 OptFilters {
999 ..Default::default()
1000 },
1001 )
1002 .await?;
1003
1004 assert_eq!(
1005 results.len(),
1006 expected,
1007 "query \"{}\", commands: {:?}",
1008 query,
1009 results.iter().map(|a| &a.command).collect::<Vec<&String>>()
1010 );
1011 Ok(results)
1012 }
1013
1014 async fn assert_search_commands(
1015 db: &impl Database,
1016 mode: SearchMode,
1017 filter_mode: FilterMode,
1018 query: &str,
1019 expected_commands: Vec<&str>,
1020 ) {
1021 let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
1022 .await
1023 .unwrap();
1024 let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
1025 assert_eq!(commands, expected_commands);
1026 }
1027
1028 async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
1029 let mut captured: History = History::capture()
1030 .timestamp(OffsetDateTime::now_utc())
1031 .command(cmd)
1032 .cwd("/home/ellie")
1033 .build()
1034 .into();
1035
1036 captured.exit = 0;
1037 captured.duration = 1;
1038 captured.session = "beep boop".to_string();
1039 captured.hostname = "booop".to_string();
1040
1041 db.save(&captured).await
1042 }
1043
1044 #[tokio::test(flavor = "multi_thread")]
1045 async fn test_search_prefix() {
1046 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1047 .await
1048 .unwrap();
1049 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
1050
1051 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
1052 .await
1053 .unwrap();
1054 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
1055 .await
1056 .unwrap();
1057 assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
1058 .await
1059 .unwrap();
1060 }
1061
1062 #[tokio::test(flavor = "multi_thread")]
1063 async fn test_search_fulltext() {
1064 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1065 .await
1066 .unwrap();
1067 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
1068
1069 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
1070 .await
1071 .unwrap();
1072 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
1073 .await
1074 .unwrap();
1075 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1)
1076 .await
1077 .unwrap();
1078 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
1079 .await
1080 .unwrap();
1081
1082 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
1084 .await
1085 .unwrap();
1086 assert_search_eq(
1087 &db,
1088 SearchMode::FullText,
1089 FilterMode::Global,
1090 "r/ls / ie$",
1091 1,
1092 )
1093 .await
1094 .unwrap();
1095 assert_search_eq(
1096 &db,
1097 SearchMode::FullText,
1098 FilterMode::Global,
1099 "r/ls / !ie",
1100 0,
1101 )
1102 .await
1103 .unwrap();
1104 assert_search_eq(
1105 &db,
1106 SearchMode::FullText,
1107 FilterMode::Global,
1108 "meow r/ls/",
1109 0,
1110 )
1111 .await
1112 .unwrap();
1113 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
1114 .await
1115 .unwrap();
1116 assert_search_eq(
1117 &db,
1118 SearchMode::FullText,
1119 FilterMode::Global,
1120 "r//home//",
1121 1,
1122 )
1123 .await
1124 .unwrap();
1125 assert_search_eq(
1126 &db,
1127 SearchMode::FullText,
1128 FilterMode::Global,
1129 "r//home///",
1130 0,
1131 )
1132 .await
1133 .unwrap();
1134 assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
1135 .await
1136 .unwrap();
1137 assert_search_eq(
1138 &db,
1139 SearchMode::FullText,
1140 FilterMode::Global,
1141 "r/home.*e",
1142 1,
1143 )
1144 .await
1145 .unwrap();
1146 }
1147
1148 #[tokio::test(flavor = "multi_thread")]
1149 async fn test_search_fuzzy() {
1150 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1151 .await
1152 .unwrap();
1153 new_history_item(&mut db, "ls /home/ellie").await.unwrap();
1154 new_history_item(&mut db, "ls /home/frank").await.unwrap();
1155 new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
1156 new_history_item(&mut db, "/home/ellie/.bin/rustup")
1157 .await
1158 .unwrap();
1159
1160 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
1161 .await
1162 .unwrap();
1163 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
1164 .await
1165 .unwrap();
1166 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
1167 .await
1168 .unwrap();
1169 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
1170 .await
1171 .unwrap();
1172 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
1173 .await
1174 .unwrap();
1175 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
1176 .await
1177 .unwrap();
1178 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
1179 .await
1180 .unwrap();
1181 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
1182 .await
1183 .unwrap();
1184
1185 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
1187 .await
1188 .unwrap();
1189 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
1190 .await
1191 .unwrap();
1192 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
1193 .await
1194 .unwrap();
1195 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
1196 .await
1197 .unwrap();
1198 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
1199 .await
1200 .unwrap();
1201 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
1202 .await
1203 .unwrap();
1204
1205 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
1207 .await
1208 .unwrap();
1209 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
1210 .await
1211 .unwrap();
1212 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
1213 .await
1214 .unwrap();
1215 assert_search_eq(
1216 &db,
1217 SearchMode::Fuzzy,
1218 FilterMode::Global,
1219 "'frank | 'rustup",
1220 2,
1221 )
1222 .await
1223 .unwrap();
1224 assert_search_eq(
1225 &db,
1226 SearchMode::Fuzzy,
1227 FilterMode::Global,
1228 "'frank | 'rustup 'ls",
1229 1,
1230 )
1231 .await
1232 .unwrap();
1233
1234 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
1236 .await
1237 .unwrap();
1238
1239 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
1241 .await
1242 .unwrap();
1243 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
1244 .await
1245 .unwrap();
1246 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
1247 .await
1248 .unwrap();
1249 }
1250
1251 #[tokio::test(flavor = "multi_thread")]
1252 async fn test_search_reordered_fuzzy() {
1253 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1254 .await
1255 .unwrap();
1256 new_history_item(&mut db, "curl").await.unwrap();
1259 new_history_item(&mut db, "corburl").await.unwrap();
1260
1261 assert_search_commands(
1263 &db,
1264 SearchMode::Fuzzy,
1265 FilterMode::Global,
1266 "curl",
1267 vec!["curl", "corburl"],
1268 )
1269 .await;
1270
1271 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
1272 .await
1273 .unwrap();
1274 assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
1275 .await
1276 .unwrap();
1277 }
1278
1279 #[tokio::test(flavor = "multi_thread")]
1280 async fn test_paged_basic() {
1281 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1282 .await
1283 .unwrap();
1284
1285 for i in 0..5 {
1287 new_history_item(&mut db, &format!("command{}", i))
1288 .await
1289 .unwrap();
1290 }
1291
1292 let mut paged = db.all_paged(2, false, false);
1294
1295 let page1 = paged.next().await.unwrap();
1297 assert!(page1.is_some());
1298 assert_eq!(page1.unwrap().len(), 2);
1299
1300 let page2 = paged.next().await.unwrap();
1302 assert!(page2.is_some());
1303 assert_eq!(page2.unwrap().len(), 2);
1304
1305 let page3 = paged.next().await.unwrap();
1307 assert!(page3.is_some());
1308 assert_eq!(page3.unwrap().len(), 1);
1309
1310 let page4 = paged.next().await.unwrap();
1312 assert!(page4.is_none());
1313 }
1314
1315 #[tokio::test(flavor = "multi_thread")]
1316 async fn test_paged_empty() {
1317 let db = Sqlite::new("sqlite::memory:", test_local_timeout())
1318 .await
1319 .unwrap();
1320
1321 let mut paged = db.all_paged(10, false, false);
1323
1324 let page = paged.next().await.unwrap();
1326 assert!(page.is_none());
1327 }
1328
1329 #[tokio::test(flavor = "multi_thread")]
1330 async fn test_paged_unique() {
1331 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1332 .await
1333 .unwrap();
1334
1335 new_history_item(&mut db, "duplicate").await.unwrap();
1337 new_history_item(&mut db, "duplicate").await.unwrap();
1338 new_history_item(&mut db, "unique1").await.unwrap();
1339 new_history_item(&mut db, "unique2").await.unwrap();
1340
1341 let mut paged = db.all_paged(10, false, false);
1343 let page = paged.next().await.unwrap().unwrap();
1344 assert_eq!(page.len(), 4);
1345
1346 let mut paged_unique = db.all_paged(10, false, true);
1348 let page_unique = paged_unique.next().await.unwrap().unwrap();
1349 assert_eq!(page_unique.len(), 3);
1350 }
1351
1352 #[tokio::test(flavor = "multi_thread")]
1353 async fn test_paged_include_deleted() {
1354 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1355 .await
1356 .unwrap();
1357
1358 new_history_item(&mut db, "keep1").await.unwrap();
1360 new_history_item(&mut db, "keep2").await.unwrap();
1361 new_history_item(&mut db, "delete_me").await.unwrap();
1362
1363 let all = db
1365 .list(
1366 &[],
1367 &Context {
1368 hostname: "".to_string(),
1369 session: "".to_string(),
1370 cwd: "".to_string(),
1371 host_id: "".to_string(),
1372 git_root: None,
1373 },
1374 None,
1375 false,
1376 false,
1377 )
1378 .await
1379 .unwrap();
1380
1381 let to_delete = all
1382 .iter()
1383 .find(|h| h.command == "delete_me")
1384 .unwrap()
1385 .clone();
1386 db.delete(to_delete).await.unwrap();
1387
1388 let mut paged = db.all_paged(10, false, false);
1390 let page = paged.next().await.unwrap().unwrap();
1391 assert_eq!(page.len(), 2);
1392
1393 let mut paged_deleted = db.all_paged(10, true, false);
1395 let page_deleted = paged_deleted.next().await.unwrap().unwrap();
1396 assert_eq!(page_deleted.len(), 3);
1397 }
1398
1399 #[tokio::test(flavor = "multi_thread")]
1400 async fn test_search_bench_dupes() {
1401 let context = Context {
1402 hostname: "test:host".to_string(),
1403 session: "beepboopiamasession".to_string(),
1404 cwd: "/home/ellie".to_string(),
1405 host_id: "test-host".to_string(),
1406 git_root: None,
1407 };
1408
1409 let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
1410 .await
1411 .unwrap();
1412 for _i in 1..10000 {
1413 new_history_item(&mut db, "i am a duplicated command")
1414 .await
1415 .unwrap();
1416 }
1417 let start = Instant::now();
1418 let _results = db
1419 .search(
1420 SearchMode::Fuzzy,
1421 FilterMode::Global,
1422 &context,
1423 "",
1424 OptFilters {
1425 ..Default::default()
1426 },
1427 )
1428 .await
1429 .unwrap();
1430 let duration = start.elapsed();
1431
1432 assert!(duration < Duration::from_secs(15));
1433 }
1434}
1435
1436pub struct QueryTokenizer<'a> {
1437 query: &'a str,
1438 last_pos: usize,
1439}
1440
1441pub enum QueryToken<'a> {
1442 Match(&'a str, bool),
1443 MatchStart(&'a str, bool),
1444 MatchEnd(&'a str, bool),
1445 MatchFull(&'a str, bool),
1446 Or,
1447 Regex(&'a str),
1448}
1449
1450impl<'a> QueryToken<'a> {
1451 pub fn has_uppercase(&self) -> bool {
1452 match self {
1453 Self::Match(term, _)
1454 | Self::MatchStart(term, _)
1455 | Self::MatchEnd(term, _)
1456 | Self::MatchFull(term, _) => term.contains(char::is_uppercase),
1457 _ => false,
1458 }
1459 }
1460
1461 pub fn is_inverse(&self) -> bool {
1462 match self {
1463 Self::Match(_, inv)
1464 | Self::MatchStart(_, inv)
1465 | Self::MatchEnd(_, inv)
1466 | Self::MatchFull(_, inv) => *inv,
1467 _ => false,
1468 }
1469 }
1470}
1471
1472impl<'a> QueryTokenizer<'a> {
1473 pub fn new(query: &'a str) -> Self {
1474 Self { query, last_pos: 0 }
1475 }
1476}
1477
1478impl<'a> Iterator for QueryTokenizer<'a> {
1479 type Item = QueryToken<'a>;
1480 fn next(&mut self) -> Option<Self::Item> {
1481 let remaining = &self.query[self.last_pos..];
1482 if remaining.is_empty() {
1483 return None;
1484 }
1485
1486 if let Some(remaining) = remaining.strip_prefix("r/") {
1487 let (regex, next_pos) = if let Some(end) = remaining.find("/ ") {
1488 (&remaining[..end], self.last_pos + 2 + end + 2)
1489 } else if let Some(remaining) = remaining.strip_suffix('/') {
1490 (remaining, self.query.len())
1491 } else {
1492 (remaining, self.query.len())
1493 };
1494 self.last_pos = next_pos;
1495 Some(QueryToken::Regex(regex))
1496 } else {
1497 let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') {
1498 (&remaining[..sp], self.last_pos + sp + 1)
1499 } else {
1500 (remaining, self.query.len())
1501 };
1502 self.last_pos = next_pos;
1503
1504 if part == "|" {
1505 return Some(QueryToken::Or);
1506 }
1507
1508 let mut is_inverse = false;
1509 if let Some(s) = part.strip_prefix('!') {
1510 part = s;
1511 is_inverse = true;
1512 }
1513 let token = if let Some(s) = part.strip_prefix('^') {
1514 QueryToken::MatchStart(s, is_inverse)
1515 } else if let Some(s) = part.strip_suffix('$') {
1516 QueryToken::MatchEnd(s, is_inverse)
1517 } else if let Some(s) = part.strip_prefix('\'') {
1518 QueryToken::MatchFull(s, is_inverse)
1519 } else {
1520 QueryToken::Match(part, is_inverse)
1521 };
1522 Some(token)
1523 }
1524 }
1525}