1use std::{pin::Pin, sync::Arc};
7
8use atuin_client::database::Database;
9use eyre::Result;
10use tokio::sync::RwLock;
11use tokio_stream::Stream;
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{Level, debug, info, instrument, span, trace};
14use uuid::Uuid;
15
16use crate::{
17 daemon::{Component, DaemonHandle},
18 events::DaemonEvent,
19 search::{
20 FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse,
21 search_server::{Search as SearchSvc, SearchServer},
22 },
23};
24
25const PAGE_SIZE: usize = 5000;
26const RESULTS_LIMIT: u32 = 200;
27const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60;
29
30pub struct SearchComponent {
38 index: Arc<RwLock<SearchIndex>>,
39 handle: tokio::sync::RwLock<Option<DaemonHandle>>,
40 loader_handle: Option<tokio::task::JoinHandle<()>>,
41 frecency_handle: Option<tokio::task::JoinHandle<()>>,
42}
43
44impl SearchComponent {
45 pub fn new() -> Self {
47 Self {
48 index: Arc::new(RwLock::new(SearchIndex::new())),
49 handle: tokio::sync::RwLock::new(None),
50 loader_handle: None,
51 frecency_handle: None,
52 }
53 }
54
55 pub fn grpc_service(&self) -> SearchServer<SearchGrpcService> {
57 SearchServer::new(SearchGrpcService {
58 index: self.index.clone(),
59 })
60 }
61
62 async fn rebuild_index(&self) -> Result<()> {
64 let handle_guard = self.handle.read().await;
65 let handle = handle_guard
66 .as_ref()
67 .ok_or_else(|| eyre::eyre!("component not initialized"))?;
68
69 info!("Rebuilding search index from database");
70
71 let new_index = SearchIndex::new();
73
74 let db = handle.history_db().clone();
76 let mut pager = db.all_paged(PAGE_SIZE, false, true);
77 loop {
78 match pager.next().await {
79 Ok(Some(histories)) => {
80 info!(
81 "Loading {} history entries into search index",
82 histories.len()
83 );
84 new_index.add_histories(&histories);
85 }
86 Ok(None) => break,
87 Err(e) => {
88 tracing::error!("Failed to load history during rebuild: {}", e);
89 break;
90 }
91 }
92 }
93
94 info!(
95 "Search index rebuild complete; {} unique commands",
96 new_index.command_count()
97 );
98
99 *self.index.write().await = new_index;
101 Ok(())
102 }
103}
104
105impl Default for SearchComponent {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[tonic::async_trait]
112impl Component for SearchComponent {
113 fn name(&self) -> &'static str {
114 "search"
115 }
116
117 async fn start(&mut self, handle: DaemonHandle) -> Result<()> {
118 *self.handle.write().await = Some(handle.clone());
119
120 let index = self.index.clone();
122 let db = handle.history_db().clone();
123 let handle_for_loader = handle.clone();
124
125 self.loader_handle = Some(tokio::spawn(async move {
126 info!(
127 "Loading history into search index; page size = {}",
128 PAGE_SIZE
129 );
130 let mut pager = db.all_paged(PAGE_SIZE, false, true);
131 loop {
132 match pager.next().await {
133 Ok(Some(histories)) => {
134 info!(
135 "Loading {} history entries into search index",
136 histories.len()
137 );
138 index.read().await.add_histories(&histories);
139 }
140 Ok(None) => {
141 info!(
142 "Initial history load complete; {} unique commands indexed",
143 index.read().await.command_count()
144 );
145 let settings = handle_for_loader.settings().await;
147 index.read().await.rebuild_frecency(&settings.search).await;
148 info!("Initial frecency map built");
149 break;
150 }
151 Err(e) => {
152 tracing::error!("Failed to load history: {}", e);
153 break;
154 }
155 }
156 }
157 }));
158
159 let index_for_frecency = self.index.clone();
161 let handle_for_frecency = handle.clone();
162 self.frecency_handle = Some(tokio::spawn(async move {
163 let mut interval = tokio::time::interval(std::time::Duration::from_secs(
164 FRECENCY_REFRESH_INTERVAL_SECS,
165 ));
166 loop {
167 interval.tick().await;
168 trace!("Refreshing frecency map");
169 let settings = handle_for_frecency.settings().await;
170 index_for_frecency
171 .read()
172 .await
173 .rebuild_frecency(&settings.search)
174 .await;
175 }
176 }));
177
178 tracing::info!("search component started");
179 Ok(())
180 }
181
182 async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> {
183 match event {
184 DaemonEvent::RecordsAdded(records) => {
185 debug!(
186 count = records.len(),
187 "Processing added records for search index"
188 );
189
190 let handle_guard = self.handle.read().await;
191 if let Some(handle) = handle_guard.as_ref() {
192 let histories: Vec<_> = handle
193 .history_db()
194 .query_history(
195 format!(
196 "select * from history where id in ({})",
197 records
198 .iter()
199 .map(|record| record.0.to_string())
200 .collect::<Vec<_>>()
201 .join(",")
202 )
203 .as_str(),
204 )
205 .await
206 .unwrap_or_default();
207
208 span!(Level::TRACE, "inject_records", count = histories.len())
209 .in_scope(async || {
210 self.index.read().await.add_histories(&histories);
211 })
212 .await;
213 }
214 }
215 DaemonEvent::HistoryStarted(history) => {
216 debug!(id = %history.id, command = %history.command, "History started (no index action)");
217 }
218 DaemonEvent::HistoryEnded(history) => {
219 span!(Level::TRACE, "inject_history_ended")
220 .in_scope(async || {
221 self.index.read().await.add_history(history);
222 })
223 .await;
224 }
225 DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => {
226 info!("History store pruned or rebuilt, rebuilding search index");
227 if let Err(e) = self.rebuild_index().await {
228 tracing::error!("Failed to rebuild search index: {}", e);
229 }
230 }
231 DaemonEvent::HistoryDeleted { ids } => {
232 info!(
233 count = ids.len(),
234 "History deleted, rebuilding search index"
235 );
236 if let Err(e) = self.rebuild_index().await {
239 tracing::error!("Failed to rebuild search index: {}", e);
240 }
241 }
242 DaemonEvent::SettingsReloaded => {
243 info!("Settings reloaded, rebuilding frecency map with new multipliers");
244 let handle_guard = self.handle.read().await;
245 if let Some(handle) = handle_guard.as_ref() {
246 let settings = handle.settings().await;
247 self.index
248 .read()
249 .await
250 .rebuild_frecency(&settings.search)
251 .await;
252 }
253 }
254 DaemonEvent::SyncCompleted { .. }
256 | DaemonEvent::SyncFailed { .. }
257 | DaemonEvent::ForceSync
258 | DaemonEvent::ShutdownRequested => {}
259 }
260 Ok(())
261 }
262
263 async fn stop(&mut self) -> Result<()> {
264 if let Some(handle) = self.loader_handle.take() {
265 handle.abort();
266 }
267 if let Some(handle) = self.frecency_handle.take() {
268 handle.abort();
269 }
270 tracing::info!("search component stopped");
271 Ok(())
272 }
273}
274
275pub struct SearchGrpcService {
277 index: Arc<RwLock<SearchIndex>>,
278}
279
280#[tonic::async_trait]
281impl SearchSvc for SearchGrpcService {
282 type SearchStream = Pin<Box<dyn Stream<Item = Result<SearchResponse, Status>> + Send>>;
283
284 #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")]
285 async fn search(
286 &self,
287 request: Request<Streaming<SearchRequest>>,
288 ) -> Result<Response<Self::SearchStream>, Status> {
289 let mut in_stream = request.into_inner();
290 let index = self.index.clone();
291
292 let (tx, rx) = tokio::sync::mpsc::channel::<Result<SearchResponse, Status>>(128);
294
295 tokio::spawn(async move {
297 while let Some(req) = in_stream.message().await.transpose() {
298 match req {
299 Ok(search_req) => {
300 let query = search_req.query;
301 let query_id = search_req.query_id;
302 let filter_mode: FilterMode = search_req
303 .filter_mode
304 .try_into()
305 .unwrap_or(FilterMode::Global);
306 let proto_context = search_req.context;
307
308 debug!(
309 "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}",
310 query,
311 query_id,
312 filter_mode.as_str_name(),
313 proto_context
314 );
315
316 let index_filter = convert_filter_mode(filter_mode, &proto_context);
318
319 let query_context = proto_context
321 .map(|ctx| QueryContext {
322 cwd: Some(with_trailing_slash(&ctx.cwd)),
323 git_root: ctx.git_root.map(|s| with_trailing_slash(&s)),
324 hostname: Some(ctx.hostname),
325 session_id: Some(ctx.session_id),
326 })
327 .unwrap_or_default();
328
329 let history_ids =
331 span!(Level::TRACE, "daemon_search_query", %query, query_id)
332 .in_scope(|| async {
333 let index = index.read().await;
334 index
335 .search(&query, index_filter, &query_context, RESULTS_LIMIT)
336 .await
337 })
338 .await;
339
340 let ids: Vec<Vec<u8>> = history_ids
342 .iter()
343 .filter_map(|id| {
344 Uuid::parse_str(id)
345 .ok()
346 .map(|uuid| uuid.as_bytes().to_vec())
347 })
348 .collect();
349
350 if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() {
351 break; }
353 }
354 Err(e) => {
355 let _ = tx.send(Err(e)).await;
356 break;
357 }
358 }
359 }
360 });
361
362 let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
364 Ok(Response::new(Box::pin(out_stream)))
365 }
366}
367
368fn convert_filter_mode(
370 mode: FilterMode,
371 context: &Option<crate::search::SearchContext>,
372) -> IndexFilterMode {
373 match (mode, context) {
374 (FilterMode::Global, _) => IndexFilterMode::Global,
375 (FilterMode::Directory, Some(ctx)) => {
376 IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
377 }
378 (FilterMode::Workspace, Some(ctx)) => {
379 if let Some(ref git_root) = ctx.git_root {
380 IndexFilterMode::Workspace(with_trailing_slash(git_root))
381 } else {
382 IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd))
384 }
385 }
386 (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()),
387 (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()),
388 (FilterMode::SessionPreload, Some(ctx)) => {
389 IndexFilterMode::Session(ctx.session_id.clone())
391 }
392 _ => IndexFilterMode::Global,
394 }
395}
396
397#[cfg(windows)]
398pub fn with_trailing_slash(s: &str) -> String {
399 if s.ends_with('\\') {
400 s.to_string()
401 } else {
402 format!("{}\\", s)
403 }
404}
405
406#[cfg(not(windows))]
407pub fn with_trailing_slash(s: &str) -> String {
408 if s.ends_with('/') {
409 s.to_string()
410 } else {
411 format!("{}/", s)
412 }
413}