Skip to main content

atuin_daemon/components/
search.rs

1//! Search component.
2//!
3//! Provides fuzzy search over command history using the Nucleo search library
4//! with frecency-based ranking and dynamic filtering.
5
6use std::{pin::Pin, sync::Arc};
7
8use atuin_client::database::Database;
9use eyre::Result;
10use tokio::sync::RwLock;
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{Level, debug, info, instrument, span, trace};
14use uuid::Uuid;
15
16use crate::{
17    daemon::{Component, DaemonHandle},
18    events::DaemonEvent,
19    search::{
20        FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse,
21        search_server::{Search as SearchSvc, SearchServer},
22    },
23};
24
25const PAGE_SIZE: usize = 5000;
26const RESULTS_LIMIT: u32 = 200;
27/// How often to rebuild the frecency map (in seconds).
28const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60;
29
30/// Search component - provides fuzzy search over command history.
31///
32/// This component:
33/// - Maintains a deduplicated search index with frecency ranking
34/// - Loads history from the database on startup
35/// - Updates the index when history events occur
36/// - Provides the Search gRPC service
37pub struct SearchComponent {
38    index: Arc<RwLock<SearchIndex>>,
39    handle: tokio::sync::RwLock<Option<DaemonHandle>>,
40    loader_handle: Option<tokio::task::JoinHandle<()>>,
41    frecency_handle: Option<tokio::task::JoinHandle<()>>,
42}
43
44impl SearchComponent {
45    /// Create a new search component.
46    pub fn new() -> Self {
47        Self {
48            index: Arc::new(RwLock::new(SearchIndex::new())),
49            handle: tokio::sync::RwLock::new(None),
50            loader_handle: None,
51            frecency_handle: None,
52        }
53    }
54
55    /// Get the gRPC service for this component.
56    pub fn grpc_service(&self) -> SearchServer<SearchGrpcService> {
57        SearchServer::new(SearchGrpcService {
58            index: self.index.clone(),
59        })
60    }
61
62    /// Rebuild the entire search index from the database.
63    async fn rebuild_index(&self) -> Result<()> {
64        let handle_guard = self.handle.read().await;
65        let handle = handle_guard
66            .as_ref()
67            .ok_or_else(|| eyre::eyre!("component not initialized"))?;
68
69        info!("Rebuilding search index from database");
70
71        // Create a new index
72        let new_index = SearchIndex::new();
73
74        // Load all history into the new index
75        let db = handle.history_db().clone();
76        let mut pager = db.all_paged(PAGE_SIZE, false, true);
77        loop {
78            match pager.next().await {
79                Ok(Some(histories)) => {
80                    info!(
81                        "Loading {} history entries into search index",
82                        histories.len()
83                    );
84                    new_index.add_histories(&histories);
85                }
86                Ok(None) => break,
87                Err(e) => {
88                    tracing::error!("Failed to load history during rebuild: {}", e);
89                    break;
90                }
91            }
92        }
93
94        info!(
95            "Search index rebuild complete; {} unique commands",
96            new_index.command_count()
97        );
98
99        // Replace the old index with the new one
100        *self.index.write().await = new_index;
101        Ok(())
102    }
103}
104
105impl Default for SearchComponent {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[tonic::async_trait]
112impl Component for SearchComponent {
113    fn name(&self) -> &'static str {
114        "search"
115    }
116
117    async fn start(&mut self, handle: DaemonHandle) -> Result<()> {
118        *self.handle.write().await = Some(handle.clone());
119
120        // Spawn background task to load history into index
121        let index = self.index.clone();
122        let db = handle.history_db().clone();
123        let handle_for_loader = handle.clone();
124
125        self.loader_handle = Some(tokio::spawn(async move {
126            info!(
127                "Loading history into search index; page size = {}",
128                PAGE_SIZE
129            );
130            let mut pager = db.all_paged(PAGE_SIZE, false, true);
131            loop {
132                match pager.next().await {
133                    Ok(Some(histories)) => {
134                        info!(
135                            "Loading {} history entries into search index",
136                            histories.len()
137                        );
138                        index.read().await.add_histories(&histories);
139                    }
140                    Ok(None) => {
141                        info!(
142                            "Initial history load complete; {} unique commands indexed",
143                            index.read().await.command_count()
144                        );
145                        // Build initial frecency map with current settings
146                        let settings = handle_for_loader.settings().await;
147                        index.read().await.rebuild_frecency(&settings.search).await;
148                        info!("Initial frecency map built");
149                        break;
150                    }
151                    Err(e) => {
152                        tracing::error!("Failed to load history: {}", e);
153                        break;
154                    }
155                }
156            }
157        }));
158
159        // Spawn background task to periodically refresh frecency
160        let index_for_frecency = self.index.clone();
161        let handle_for_frecency = handle.clone();
162        self.frecency_handle = Some(tokio::spawn(async move {
163            let mut interval = tokio::time::interval(std::time::Duration::from_secs(
164                FRECENCY_REFRESH_INTERVAL_SECS,
165            ));
166            loop {
167                interval.tick().await;
168                trace!("Refreshing frecency map");
169                let settings = handle_for_frecency.settings().await;
170                index_for_frecency
171                    .read()
172                    .await
173                    .rebuild_frecency(&settings.search)
174                    .await;
175            }
176        }));
177
178        tracing::info!("search component started");
179        Ok(())
180    }
181
182    async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> {
183        match event {
184            DaemonEvent::RecordsAdded(records) => {
185                debug!(
186                    count = records.len(),
187                    "Processing added records for search index"
188                );
189
190                let handle_guard = self.handle.read().await;
191                if let Some(handle) = handle_guard.as_ref() {
192                    let histories: Vec<_> = handle
193                        .history_db()
194                        .query_history(
195                            format!(
196                                "select * from history where id in ({})",
197                                records
198                                    .iter()
199                                    .map(|record| record.0.to_string())
200                                    .collect::<Vec<_>>()
201                                    .join(",")
202                            )
203                            .as_str(),
204                        )
205                        .await
206                        .unwrap_or_default();
207
208                    span!(Level::TRACE, "inject_records", count = histories.len())
209                        .in_scope(async || {
210                            self.index.read().await.add_histories(&histories);
211                        })
212                        .await;
213                }
214            }
215            DaemonEvent::HistoryStarted(history) => {
216                debug!(id = %history.id, command = %history.command, "History started (no index action)");
217            }
218            DaemonEvent::HistoryEnded(history) => {
219                span!(Level::TRACE, "inject_history_ended")
220                    .in_scope(async || {
221                        self.index.read().await.add_history(history);
222                    })
223                    .await;
224            }
225            DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => {
226                info!("History store pruned or rebuilt, rebuilding search index");
227                if let Err(e) = self.rebuild_index().await {
228                    tracing::error!("Failed to rebuild search index: {}", e);
229                }
230            }
231            DaemonEvent::HistoryDeleted { ids } => {
232                info!(
233                    count = ids.len(),
234                    "History deleted, rebuilding search index"
235                );
236                // For now, just rebuild the entire index. A more efficient implementation
237                // would remove specific items from the index.
238                if let Err(e) = self.rebuild_index().await {
239                    tracing::error!("Failed to rebuild search index: {}", e);
240                }
241            }
242            DaemonEvent::SettingsReloaded => {
243                info!("Settings reloaded, rebuilding frecency map with new multipliers");
244                let handle_guard = self.handle.read().await;
245                if let Some(handle) = handle_guard.as_ref() {
246                    let settings = handle.settings().await;
247                    self.index
248                        .read()
249                        .await
250                        .rebuild_frecency(&settings.search)
251                        .await;
252                }
253            }
254            // Events we don't care about
255            DaemonEvent::SyncCompleted { .. }
256            | DaemonEvent::SyncFailed { .. }
257            | DaemonEvent::ForceSync
258            | DaemonEvent::ShutdownRequested => {}
259        }
260        Ok(())
261    }
262
263    async fn stop(&mut self) -> Result<()> {
264        if let Some(handle) = self.loader_handle.take() {
265            handle.abort();
266        }
267        if let Some(handle) = self.frecency_handle.take() {
268            handle.abort();
269        }
270        tracing::info!("search component stopped");
271        Ok(())
272    }
273}
274
275/// The gRPC service implementation.
276pub struct SearchGrpcService {
277    index: Arc<RwLock<SearchIndex>>,
278}
279
280#[tonic::async_trait]
281impl SearchSvc for SearchGrpcService {
282    type SearchStream = Pin<Box<dyn Stream<Item = Result<SearchResponse, Status>> + Send>>;
283
284    #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")]
285    async fn search(
286        &self,
287        request: Request<Streaming<SearchRequest>>,
288    ) -> Result<Response<Self::SearchStream>, Status> {
289        let mut in_stream = request.into_inner();
290        let index = self.index.clone();
291
292        // Create output channel
293        let (tx, rx) = tokio::sync::mpsc::channel::<Result<SearchResponse, Status>>(128);
294
295        // Spawn task to handle incoming requests and send responses
296        tokio::spawn(async move {
297            while let Some(req) = in_stream.message().await.transpose() {
298                match req {
299                    Ok(search_req) => {
300                        let query = search_req.query;
301                        let query_id = search_req.query_id;
302                        let filter_mode: FilterMode = search_req
303                            .filter_mode
304                            .try_into()
305                            .unwrap_or(FilterMode::Global);
306                        let proto_context = search_req.context;
307
308                        debug!(
309                            "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}",
310                            query,
311                            query_id,
312                            filter_mode.as_str_name(),
313                            proto_context
314                        );
315
316                        // Convert proto FilterMode + context to IndexFilterMode
317                        let index_filter = convert_filter_mode(filter_mode, &proto_context);
318
319                        // Build QueryContext from proto context
320                        let query_context = proto_context
321                            .map(|ctx| QueryContext {
322                                cwd: Some(with_trailing_slash(&ctx.cwd)),
323                                git_root: ctx.git_root.map(|s| with_trailing_slash(&s)),
324                                hostname: Some(ctx.hostname),
325                                session_id: Some(ctx.session_id),
326                            })
327                            .unwrap_or_default();
328
329                        // Perform the search
330                        let history_ids =
331                            span!(Level::TRACE, "daemon_search_query", %query, query_id)
332                                .in_scope(|| async {
333                                    let index = index.read().await;
334                                    index
335                                        .search(&query, index_filter, &query_context, RESULTS_LIMIT)
336                                        .await
337                                })
338                                .await;
339
340                        // Convert history IDs to bytes
341                        let ids: Vec<Vec<u8>> = history_ids
342                            .iter()
343                            .filter_map(|id| {
344                                Uuid::parse_str(id)
345                                    .ok()
346                                    .map(|uuid| uuid.as_bytes().to_vec())
347                            })
348                            .collect();
349
350                        if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() {
351                            break; // Client disconnected
352                        }
353                    }
354                    Err(e) => {
355                        let _ = tx.send(Err(e)).await;
356                        break;
357                    }
358                }
359            }
360        });
361
362        // Convert receiver to stream
363        let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
364        Ok(Response::new(Box::pin(out_stream)))
365    }
366}
367
368/// Convert proto FilterMode and context to IndexFilterMode.
369fn convert_filter_mode(
370    mode: FilterMode,
371    context: &Option<crate::search::SearchContext>,
372) -> IndexFilterMode {
373    match (mode, context) {
374        (FilterMode::Global, _) => IndexFilterMode::Global,
375        (FilterMode::Directory, Some(ctx)) => {
376            IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
377        }
378        (FilterMode::Workspace, Some(ctx)) => {
379            if let Some(ref git_root) = ctx.git_root {
380                IndexFilterMode::Workspace(with_trailing_slash(git_root))
381            } else {
382                // Fall back to directory if no git root
383                IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
384            }
385        }
386        (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()),
387        (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()),
388        (FilterMode::SessionPreload, Some(ctx)) => {
389            // SessionPreload is similar to Session - filter by session
390            IndexFilterMode::Session(ctx.session_id.clone())
391        }
392        // If no context provided, fall back to global
393        _ => IndexFilterMode::Global,
394    }
395}
396
397#[cfg(windows)]
398pub fn with_trailing_slash(s: &str) -> String {
399    if s.ends_with('\\') {
400        s.to_string()
401    } else {
402        format!("{}\\", s)
403    }
404}
405
406#[cfg(not(windows))]
407pub fn with_trailing_slash(s: &str) -> String {
408    if s.ends_with('/') {
409        s.to_string()
410    } else {
411        format!("{}/", s)
412    }
413}