atuin-daemon 18.13.6

The daemon crate for Atuin
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
//! Search component.
//!
//! Provides fuzzy search over command history using the Nucleo search library
//! with frecency-based ranking and dynamic filtering.

use std::{pin::Pin, sync::Arc};

use atuin_client::database::Database;
use eyre::Result;
use tokio::sync::RwLock;
use tokio_stream::Stream;
use tonic::{Request, Response, Status, Streaming};
use tracing::{Level, debug, info, instrument, span, trace};
use uuid::Uuid;

use crate::{
    daemon::{Component, DaemonHandle},
    events::DaemonEvent,
    search::{
        FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse,
        search_server::{Search as SearchSvc, SearchServer},
    },
};

const PAGE_SIZE: usize = 5000;
const RESULTS_LIMIT: u32 = 200;
/// How often to rebuild the frecency map (in seconds).
const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60;

/// Search component - provides fuzzy search over command history.
///
/// This component:
/// - Maintains a deduplicated search index with frecency ranking
/// - Loads history from the database on startup
/// - Updates the index when history events occur
/// - Provides the Search gRPC service
pub struct SearchComponent {
    index: Arc<RwLock<SearchIndex>>,
    handle: tokio::sync::RwLock<Option<DaemonHandle>>,
    loader_handle: Option<tokio::task::JoinHandle<()>>,
    frecency_handle: Option<tokio::task::JoinHandle<()>>,
}

impl SearchComponent {
    /// Create a new search component.
    pub fn new() -> Self {
        Self {
            index: Arc::new(RwLock::new(SearchIndex::new())),
            handle: tokio::sync::RwLock::new(None),
            loader_handle: None,
            frecency_handle: None,
        }
    }

    /// Get the gRPC service for this component.
    pub fn grpc_service(&self) -> SearchServer<SearchGrpcService> {
        SearchServer::new(SearchGrpcService {
            index: self.index.clone(),
        })
    }

    /// Rebuild the entire search index from the database.
    async fn rebuild_index(&self) -> Result<()> {
        let handle_guard = self.handle.read().await;
        let handle = handle_guard
            .as_ref()
            .ok_or_else(|| eyre::eyre!("component not initialized"))?;

        info!("Rebuilding search index from database");

        // Create a new index
        let new_index = SearchIndex::new();

        // Load all history into the new index
        let db = handle.history_db().clone();
        let mut pager = db.all_paged(PAGE_SIZE, false, true);
        loop {
            match pager.next().await {
                Ok(Some(histories)) => {
                    info!(
                        "Loading {} history entries into search index",
                        histories.len()
                    );
                    new_index.add_histories(&histories);
                }
                Ok(None) => break,
                Err(e) => {
                    tracing::error!("Failed to load history during rebuild: {}", e);
                    break;
                }
            }
        }

        info!(
            "Search index rebuild complete; {} unique commands",
            new_index.command_count()
        );

        // Replace the old index with the new one
        *self.index.write().await = new_index;
        Ok(())
    }
}

impl Default for SearchComponent {
    fn default() -> Self {
        Self::new()
    }
}

#[tonic::async_trait]
impl Component for SearchComponent {
    fn name(&self) -> &'static str {
        "search"
    }

    async fn start(&mut self, handle: DaemonHandle) -> Result<()> {
        *self.handle.write().await = Some(handle.clone());

        // Spawn background task to load history into index
        let index = self.index.clone();
        let db = handle.history_db().clone();
        let handle_for_loader = handle.clone();

        self.loader_handle = Some(tokio::spawn(async move {
            info!(
                "Loading history into search index; page size = {}",
                PAGE_SIZE
            );
            let mut pager = db.all_paged(PAGE_SIZE, false, true);
            loop {
                match pager.next().await {
                    Ok(Some(histories)) => {
                        info!(
                            "Loading {} history entries into search index",
                            histories.len()
                        );
                        index.read().await.add_histories(&histories);
                    }
                    Ok(None) => {
                        info!(
                            "Initial history load complete; {} unique commands indexed",
                            index.read().await.command_count()
                        );
                        // Build initial frecency map with current settings
                        let settings = handle_for_loader.settings().await;
                        index.read().await.rebuild_frecency(&settings.search).await;
                        info!("Initial frecency map built");
                        break;
                    }
                    Err(e) => {
                        tracing::error!("Failed to load history: {}", e);
                        break;
                    }
                }
            }
        }));

        // Spawn background task to periodically refresh frecency
        let index_for_frecency = self.index.clone();
        let handle_for_frecency = handle.clone();
        self.frecency_handle = Some(tokio::spawn(async move {
            let mut interval = tokio::time::interval(std::time::Duration::from_secs(
                FRECENCY_REFRESH_INTERVAL_SECS,
            ));
            loop {
                interval.tick().await;
                trace!("Refreshing frecency map");
                let settings = handle_for_frecency.settings().await;
                index_for_frecency
                    .read()
                    .await
                    .rebuild_frecency(&settings.search)
                    .await;
            }
        }));

        tracing::info!("search component started");
        Ok(())
    }

    async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> {
        match event {
            DaemonEvent::RecordsAdded(records) => {
                debug!(
                    count = records.len(),
                    "Processing added records for search index"
                );

                let handle_guard = self.handle.read().await;
                if let Some(handle) = handle_guard.as_ref() {
                    let histories: Vec<_> = handle
                        .history_db()
                        .query_history(
                            format!(
                                "select * from history where id in ({})",
                                records
                                    .iter()
                                    .map(|record| record.0.to_string())
                                    .collect::<Vec<_>>()
                                    .join(",")
                            )
                            .as_str(),
                        )
                        .await
                        .unwrap_or_default();

                    span!(Level::TRACE, "inject_records", count = histories.len())
                        .in_scope(async || {
                            self.index.read().await.add_histories(&histories);
                        })
                        .await;
                }
            }
            DaemonEvent::HistoryStarted(history) => {
                debug!(id = %history.id, command = %history.command, "History started (no index action)");
            }
            DaemonEvent::HistoryEnded(history) => {
                span!(Level::TRACE, "inject_history_ended")
                    .in_scope(async || {
                        self.index.read().await.add_history(history);
                    })
                    .await;
            }
            DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => {
                info!("History store pruned or rebuilt, rebuilding search index");
                if let Err(e) = self.rebuild_index().await {
                    tracing::error!("Failed to rebuild search index: {}", e);
                }
            }
            DaemonEvent::HistoryDeleted { ids } => {
                info!(
                    count = ids.len(),
                    "History deleted, rebuilding search index"
                );
                // For now, just rebuild the entire index. A more efficient implementation
                // would remove specific items from the index.
                if let Err(e) = self.rebuild_index().await {
                    tracing::error!("Failed to rebuild search index: {}", e);
                }
            }
            DaemonEvent::SettingsReloaded => {
                info!("Settings reloaded, rebuilding frecency map with new multipliers");
                let handle_guard = self.handle.read().await;
                if let Some(handle) = handle_guard.as_ref() {
                    let settings = handle.settings().await;
                    self.index
                        .read()
                        .await
                        .rebuild_frecency(&settings.search)
                        .await;
                }
            }
            // Events we don't care about
            DaemonEvent::SyncCompleted { .. }
            | DaemonEvent::SyncFailed { .. }
            | DaemonEvent::ForceSync
            | DaemonEvent::ShutdownRequested => {}
        }
        Ok(())
    }

    async fn stop(&mut self) -> Result<()> {
        if let Some(handle) = self.loader_handle.take() {
            handle.abort();
        }
        if let Some(handle) = self.frecency_handle.take() {
            handle.abort();
        }
        tracing::info!("search component stopped");
        Ok(())
    }
}

/// The gRPC service implementation.
pub struct SearchGrpcService {
    index: Arc<RwLock<SearchIndex>>,
}

#[tonic::async_trait]
impl SearchSvc for SearchGrpcService {
    type SearchStream = Pin<Box<dyn Stream<Item = Result<SearchResponse, Status>> + Send>>;

    #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")]
    async fn search(
        &self,
        request: Request<Streaming<SearchRequest>>,
    ) -> Result<Response<Self::SearchStream>, Status> {
        let mut in_stream = request.into_inner();
        let index = self.index.clone();

        // Create output channel
        let (tx, rx) = tokio::sync::mpsc::channel::<Result<SearchResponse, Status>>(128);

        // Spawn task to handle incoming requests and send responses
        tokio::spawn(async move {
            while let Some(req) = in_stream.message().await.transpose() {
                match req {
                    Ok(search_req) => {
                        let query = search_req.query;
                        let query_id = search_req.query_id;
                        let filter_mode: FilterMode = search_req
                            .filter_mode
                            .try_into()
                            .unwrap_or(FilterMode::Global);
                        let proto_context = search_req.context;

                        debug!(
                            "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}",
                            query,
                            query_id,
                            filter_mode.as_str_name(),
                            proto_context
                        );

                        // Convert proto FilterMode + context to IndexFilterMode
                        let index_filter = convert_filter_mode(filter_mode, &proto_context);

                        // Build QueryContext from proto context
                        let query_context = proto_context
                            .map(|ctx| QueryContext {
                                cwd: Some(with_trailing_slash(&ctx.cwd)),
                                git_root: ctx.git_root.map(|s| with_trailing_slash(&s)),
                                hostname: Some(ctx.hostname),
                                session_id: Some(ctx.session_id),
                            })
                            .unwrap_or_default();

                        // Perform the search
                        let history_ids =
                            span!(Level::TRACE, "daemon_search_query", %query, query_id)
                                .in_scope(|| async {
                                    let index = index.read().await;
                                    index
                                        .search(&query, index_filter, &query_context, RESULTS_LIMIT)
                                        .await
                                })
                                .await;

                        // Convert history IDs to bytes
                        let ids: Vec<Vec<u8>> = history_ids
                            .iter()
                            .filter_map(|id| {
                                Uuid::parse_str(id)
                                    .ok()
                                    .map(|uuid| uuid.as_bytes().to_vec())
                            })
                            .collect();

                        if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() {
                            break; // Client disconnected
                        }
                    }
                    Err(e) => {
                        let _ = tx.send(Err(e)).await;
                        break;
                    }
                }
            }
        });

        // Convert receiver to stream
        let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
        Ok(Response::new(Box::pin(out_stream)))
    }
}

/// Convert proto FilterMode and context to IndexFilterMode.
fn convert_filter_mode(
    mode: FilterMode,
    context: &Option<crate::search::SearchContext>,
) -> IndexFilterMode {
    match (mode, context) {
        (FilterMode::Global, _) => IndexFilterMode::Global,
        (FilterMode::Directory, Some(ctx)) => {
            IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
        }
        (FilterMode::Workspace, Some(ctx)) => {
            if let Some(ref git_root) = ctx.git_root {
                IndexFilterMode::Workspace(with_trailing_slash(git_root))
            } else {
                // Fall back to directory if no git root
                IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
            }
        }
        (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()),
        (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()),
        (FilterMode::SessionPreload, Some(ctx)) => {
            // SessionPreload is similar to Session - filter by session
            IndexFilterMode::Session(ctx.session_id.clone())
        }
        // If no context provided, fall back to global
        _ => IndexFilterMode::Global,
    }
}

#[cfg(windows)]
pub fn with_trailing_slash(s: &str) -> String {
    if s.ends_with('\\') {
        s.to_string()
    } else {
        format!("{}\\", s)
    }
}

#[cfg(not(windows))]
pub fn with_trailing_slash(s: &str) -> String {
    if s.ends_with('/') {
        s.to_string()
    } else {
        format!("{}/", s)
    }
}