use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use atuin_client::history::{History, is_known_agent};
use atuin_client::settings::Search;
use atuin_nucleo::{Injector, Nucleo, pattern};
use dashmap::DashMap;
use lasso::{Spur, ThreadedRodeo};
use time::OffsetDateTime;
use tokio::sync::RwLock;
use tracing::{Level, instrument};
use uuid::Uuid;
use crate::components::search::with_trailing_slash;
fn parse_uuid_bytes(s: &str) -> Option<[u8; 16]> {
Uuid::parse_str(s).ok().map(|u| *u.as_bytes())
}
fn format_uuid_bytes(bytes: &[u8; 16]) -> String {
Uuid::from_bytes(*bytes).to_string()
}
#[derive(Debug, Clone, Default)]
pub struct FrecencyData {
pub count: u32,
pub last_used: i64,
}
impl FrecencyData {
pub fn record_use(&mut self, timestamp: i64) {
self.count += 1;
if timestamp > self.last_used {
self.last_used = timestamp;
}
}
#[instrument(level = tracing::Level::TRACE, name = "index_frecency_compute")]
pub fn compute(&self, now: i64, recency_mul: f64, frequency_mul: f64) -> u32 {
if self.count == 0 {
return 0;
}
let age_seconds = (now - self.last_used).max(0) as u64;
let age_hours = age_seconds / 3600;
let recency_score: f64 = match age_hours {
0 => 100.0,
1..=6 => 90.0,
7..=24 => 70.0,
25..=72 => 50.0,
73..=168 => 30.0,
169..=720 => 15.0,
_ => 5.0,
};
let frequency_score = ((self.count as f64).ln() * 20.0).min(100.0);
((recency_score * recency_mul) + (frequency_score * frequency_mul)).round() as u32
}
}
pub struct CommandData {
most_recent_id: [u8; 16],
most_recent_timestamp: i64,
pub global_frecency: FrecencyData,
directories: HashSet<Spur>,
hosts: HashSet<Spur>,
sessions: HashSet<[u8; 16]>,
}
impl CommandData {
pub fn new(history: &History, interner: &ThreadedRodeo) -> Option<Self> {
let history_id = parse_uuid_bytes(&history.id.0)?;
let session = parse_uuid_bytes(&history.session)?;
let timestamp = history.timestamp.unix_timestamp();
let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd));
let host_key = interner.get_or_intern(&history.hostname);
let mut directories = HashSet::new();
directories.insert(dir_key);
let mut hosts = HashSet::new();
hosts.insert(host_key);
let mut sessions = HashSet::new();
sessions.insert(session);
let mut global_frecency = FrecencyData::default();
global_frecency.record_use(timestamp);
Some(Self {
most_recent_id: history_id,
most_recent_timestamp: timestamp,
global_frecency,
directories,
hosts,
sessions,
})
}
pub fn add_invocation(&mut self, history: &History, interner: &ThreadedRodeo) -> bool {
let Some(history_id) = parse_uuid_bytes(&history.id.0) else {
return false;
};
let Some(session) = parse_uuid_bytes(&history.session) else {
return false;
};
let timestamp = history.timestamp.unix_timestamp();
self.global_frecency.record_use(timestamp);
let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd));
self.directories.insert(dir_key);
self.hosts.insert(interner.get_or_intern(&history.hostname));
self.sessions.insert(session);
if timestamp > self.most_recent_timestamp {
self.most_recent_id = history_id;
self.most_recent_timestamp = timestamp;
}
true
}
pub fn most_recent_id(&self) -> String {
format_uuid_bytes(&self.most_recent_id)
}
pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool {
interner
.get(dir)
.is_some_and(|spur| self.directories.contains(&spur))
}
pub fn has_invocation_in_workspace(&self, prefix: &str, interner: &ThreadedRodeo) -> bool {
self.directories
.iter()
.any(|&spur| interner.resolve(&spur).starts_with(prefix))
}
pub fn has_invocation_on_host(&self, hostname: &str, interner: &ThreadedRodeo) -> bool {
interner
.get(hostname)
.is_some_and(|spur| self.hosts.contains(&spur))
}
pub fn has_invocation_in_session(&self, session: &str) -> bool {
parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IndexFilterMode {
Global,
Directory(String),
Workspace(String),
Host(String),
Session(String),
}
#[derive(Debug, Clone, Default)]
pub struct QueryContext {
pub cwd: Option<String>,
pub git_root: Option<String>,
pub hostname: Option<String>,
pub session_id: Option<String>,
}
type FrecencyMap = Arc<HashMap<Arc<str>, u32>>;
pub struct SearchIndex {
commands: Arc<DashMap<Arc<str>, CommandData>>,
nucleo: RwLock<Nucleo<String>>,
injector: Injector<String>,
frecency_map: RwLock<Option<FrecencyMap>>,
interner: Arc<ThreadedRodeo>,
}
impl SearchIndex {
pub fn new() -> Self {
let nucleo_config = atuin_nucleo::Config::DEFAULT;
let nucleo = Nucleo::<String>::new(nucleo_config, Arc::new(|| {}), None, 1);
let injector = nucleo.injector();
Self {
commands: Arc::new(DashMap::new()),
nucleo: RwLock::new(nucleo),
injector,
frecency_map: RwLock::new(None),
interner: Arc::new(ThreadedRodeo::new()),
}
}
pub fn add_history(&self, history: &History) {
if is_known_agent(&history.author) {
return;
}
let command = history.command.as_str();
if let Some(mut entry) = self.commands.get_mut(command) {
entry.add_invocation(history, &self.interner);
} else {
let Some(data) = CommandData::new(history, &self.interner) else {
return; };
let command_arc: Arc<str> = command.into();
self.commands.insert(Arc::clone(&command_arc), data);
self.injector.push(command_arc.to_string(), |cmd, cols| {
cols[0] = cmd.clone().into();
});
}
}
pub fn add_histories(&self, histories: &[History]) {
for history in histories {
self.add_history(history);
}
}
pub fn command_count(&self) -> usize {
self.commands.len()
}
pub async fn nucleo_item_count(&self) -> u32 {
self.nucleo.read().await.snapshot().item_count()
}
#[instrument(skip_all, level = tracing::Level::TRACE, name = "index_search", fields(query = %query))]
pub async fn search(
&self,
query: &str,
filter_mode: IndexFilterMode,
_context: &QueryContext,
limit: u32,
) -> Vec<String> {
let mut nucleo = self.nucleo.write().await;
let frecency_map = self.frecency_map.read().await.clone();
let filter = self.build_filter(&filter_mode);
nucleo.set_filter(filter);
let scorer = Self::build_scorer(frecency_map);
nucleo.set_scorer(scorer);
nucleo.pattern.reparse(
0,
query,
pattern::CaseMatching::Smart,
pattern::Normalization::Smart,
false,
);
tracing::span!(Level::TRACE, "index_search_tick").in_scope(|| {
while nucleo.tick(10).running {}
});
let snapshot = nucleo.snapshot();
let matched_count = snapshot.matched_item_count().min(limit);
tracing::span!(Level::TRACE, "index_search_results").in_scope(|| {
snapshot
.matched_items(..matched_count)
.filter_map(|item| {
let cmd = item.data;
self.commands
.get(cmd.as_str())
.map(|data| data.most_recent_id())
})
.collect()
})
}
#[instrument(skip_all, level = tracing::Level::DEBUG, name = "rebuild_frecency")]
pub async fn rebuild_frecency(&self, search_settings: &Search) {
let now = OffsetDateTime::now_utc().unix_timestamp();
let mut frecency_map: HashMap<Arc<str>, u32> = HashMap::new();
let recency_mul = search_settings.recency_score_multiplier.max(0.0);
let frequency_mul = search_settings.frequency_score_multiplier.max(0.0);
let frecency_mul = search_settings.frecency_score_multiplier.max(0.0);
for entry in self.commands.iter() {
let frecency = entry
.global_frecency
.compute(now, recency_mul, frequency_mul);
let frecency = (frecency as f64 * frecency_mul).round() as u32;
frecency_map.insert(Arc::clone(entry.key()), frecency);
}
*self.frecency_map.write().await = Some(Arc::new(frecency_map));
}
fn build_filter(&self, mode: &IndexFilterMode) -> Option<atuin_nucleo::Filter<String>> {
if matches!(mode, IndexFilterMode::Global) {
return None;
}
let passing_commands: Arc<HashSet<String>> = {
let mut set = HashSet::new();
for entry in self.commands.iter() {
let passes = match mode {
IndexFilterMode::Global => unreachable!(),
IndexFilterMode::Directory(dir) => {
entry.has_invocation_in_dir(dir, &self.interner)
}
IndexFilterMode::Workspace(prefix) => {
entry.has_invocation_in_workspace(prefix, &self.interner)
}
IndexFilterMode::Host(hostname) => {
entry.has_invocation_on_host(hostname, &self.interner)
}
IndexFilterMode::Session(session) => entry.has_invocation_in_session(session),
};
if passes {
set.insert(entry.key().to_string());
}
}
Arc::new(set)
};
Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd)))
}
fn build_scorer(frecency_map: Option<FrecencyMap>) -> Option<atuin_nucleo::Scorer<String>> {
let map = frecency_map?;
Some(Arc::new(move |cmd: &String, fuzzy_score: u32| {
let frecency = map.get(cmd.as_str()).copied().unwrap_or(0);
fuzzy_score + frecency
}))
}
}
impl Default for SearchIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use time::macros::datetime;
fn make_history(command: &str, cwd: &str, timestamp: OffsetDateTime) -> History {
History::import()
.timestamp(timestamp)
.command(command)
.cwd(cwd)
.build()
.into()
}
#[test]
fn frecency_data_compute() {
let now = 1000000i64;
let recent = FrecencyData {
count: 5,
last_used: now - 60, };
assert!(recent.compute(now, 1.0, 1.0) > 100);
let old = FrecencyData {
count: 5,
last_used: now - 86400 * 30, };
assert!(old.compute(now, 1.0, 1.0) < recent.compute(now, 1.0, 1.0));
let frequent_old = FrecencyData {
count: 100,
last_used: now - 86400 * 7, };
assert!(frequent_old.compute(now, 1.0, 1.0) > 50);
}
#[test]
fn frecency_data_compute_with_multipliers() {
let now = 1000000i64;
let data = FrecencyData {
count: 5,
last_used: now - 60, };
let default_score = data.compute(now, 1.0, 1.0);
let double_recency = data.compute(now, 2.0, 1.0);
assert!(double_recency > default_score);
let double_frequency = data.compute(now, 1.0, 2.0);
assert!(double_frequency > default_score);
let no_recency = data.compute(now, 0.0, 1.0);
assert!(no_recency < default_score);
let no_frequency = data.compute(now, 1.0, 0.0);
assert!(no_frequency < default_score);
let no_score = data.compute(now, 0.0, 0.0);
assert_eq!(no_score, 0);
let half_recency = data.compute(now, 0.5, 1.0);
assert!(half_recency < default_score);
assert!(half_recency > no_recency);
let boost_recency = data.compute(now, 1.5, 1.0);
assert!(boost_recency > default_score);
assert!(boost_recency < double_recency);
}
#[test]
fn command_data_add_invocation() {
let interner = ThreadedRodeo::new();
let (dir1, dir2) = if cfg!(windows) {
("C:\\Users\\User\\project", "C:\\Users\\User\\other")
} else {
("/home/user/project", "/home/user/other")
};
let history1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC));
let history2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC));
let mut data = CommandData::new(&history1, &interner).unwrap();
assert_eq!(data.global_frecency.count, 1);
let id1 = data.most_recent_id();
data.add_invocation(&history2, &interner);
assert_eq!(data.global_frecency.count, 2);
let id2 = data.most_recent_id();
assert_ne!(id1, id2);
}
#[test]
fn command_data_filters() {
let interner = ThreadedRodeo::new();
let (dir1, dir2) = if cfg!(windows) {
("C:\\Users\\User\\project", "C:\\Users\\User\\other")
} else {
("/home/user/project", "/home/user/other")
};
let h1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC));
let h2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC));
let mut data = CommandData::new(&h1, &interner).unwrap();
data.add_invocation(&h2, &interner);
let (check1, check2, check3) = if cfg!(windows) {
(
with_trailing_slash("C:\\Users\\User\\project"),
with_trailing_slash("C:\\Users\\User\\other"),
with_trailing_slash("C:\\Users\\User\\missing"),
)
} else {
(
with_trailing_slash("/home/user/project"),
with_trailing_slash("/home/user/other"),
with_trailing_slash("/home/user/missing"),
)
};
assert!(data.has_invocation_in_dir(&check1, &interner));
assert!(data.has_invocation_in_dir(&check2, &interner));
assert!(!data.has_invocation_in_dir(&check3, &interner));
let (check1, check2, check3) = if cfg!(windows) {
(
with_trailing_slash("C:\\Users\\User"),
with_trailing_slash("C:\\Users"),
with_trailing_slash("C:\\Users\\User\\var"),
)
} else {
(
with_trailing_slash("/home/user"),
with_trailing_slash("/home"),
with_trailing_slash("/var"),
)
};
assert!(data.has_invocation_in_workspace(&check1, &interner));
assert!(data.has_invocation_in_workspace(&check2, &interner));
assert!(!data.has_invocation_in_workspace(&check3, &interner));
}
#[tokio::test]
async fn search_index_add_and_search() {
let index = SearchIndex::new();
let h1 = make_history(
"git status",
"/home/user/project",
datetime!(2024-01-01 10:00 UTC),
);
let h2 = make_history(
"git commit -m 'test'",
"/home/user/project",
datetime!(2024-01-01 10:05 UTC),
);
let h3 = make_history(
"ls -la",
"/home/user/other",
datetime!(2024-01-01 10:10 UTC),
);
index.add_history(&h1);
index.add_history(&h2);
index.add_history(&h3);
assert_eq!(index.command_count(), 3);
let results = index
.search("git", IndexFilterMode::Global, &QueryContext::default(), 10)
.await;
assert_eq!(results.len(), 2);
let results = index
.search(
"",
IndexFilterMode::Directory(with_trailing_slash("/home/user/project")),
&QueryContext::default(),
10,
)
.await;
assert_eq!(results.len(), 2); }
}