1use std::{fmt, io};
2
3use distant_net::common::Request;
4use log::*;
5use tokio::sync::mpsc;
6use tokio::task::JoinHandle;
7
8use crate::client::{DistantChannel, DistantChannelExt};
9use crate::constants::CLIENT_SEARCHER_CAPACITY;
10use crate::protocol::{self, SearchId, SearchQuery, SearchQueryMatch};
11
12pub struct Searcher {
14 channel: DistantChannel,
15 id: SearchId,
16 query: SearchQuery,
17 task: JoinHandle<()>,
18 rx: mpsc::Receiver<SearchQueryMatch>,
19}
20
21impl fmt::Debug for Searcher {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("Searcher")
24 .field("id", &self.id)
25 .field("query", &self.query)
26 .finish()
27 }
28}
29
30impl Searcher {
31 pub async fn search(mut channel: DistantChannel, query: SearchQuery) -> io::Result<Self> {
33 trace!("Searching using {query:?}",);
34
35 let mut mailbox = channel
37 .mail(Request::new(protocol::Msg::Single(
38 protocol::Request::Search {
39 query: query.clone(),
40 },
41 )))
42 .await?;
43
44 let (tx, rx) = mpsc::channel(CLIENT_SEARCHER_CAPACITY);
45
46 let mut queue: Vec<SearchQueryMatch> = Vec::new();
48 let mut search_id = None;
49 while let Some(res) = mailbox.next().await {
50 for data in res.payload.into_vec() {
51 match data {
52 protocol::Response::SearchResults { matches, .. } => {
54 queue.extend(matches);
55 }
56
57 protocol::Response::SearchStarted { id } => {
59 trace!("[Query {id}] Searcher has started");
60 search_id = Some(id);
61 }
62
63 protocol::Response::Error(x) => return Err(io::Error::from(x)),
65
66 x => {
68 return Err(io::Error::new(
69 io::ErrorKind::Other,
70 format!("Unexpected response: {x:?}"),
71 ))
72 }
73 }
74 }
75
76 if search_id.is_some() {
80 break;
81 }
82 }
83
84 let search_id = match search_id {
85 Some(id) => {
87 trace!("[Query {id}] Forwarding {} queued matches", queue.len());
88 for r#match in queue.drain(..) {
89 if tx.send(r#match).await.is_err() {
90 return Err(io::Error::new(
91 io::ErrorKind::Other,
92 format!("[Query {id}] Queue search match dropped"),
93 ));
94 }
95 }
96 id
97 }
98
99 None => {
102 return Err(io::Error::new(
103 io::ErrorKind::Other,
104 "Search query missing started confirmation",
105 ))
106 }
107 };
108
109 let task = tokio::spawn({
112 async move {
113 while let Some(res) = mailbox.next().await {
114 let mut done = false;
115
116 for data in res.payload.into_vec() {
117 match data {
118 protocol::Response::SearchResults { matches, .. } => {
119 if tx.is_closed() {
122 break;
123 }
124
125 for r#match in matches {
127 if let Err(x) = tx.send(r#match).await {
128 error!(
129 "[Query {search_id}] Searcher failed to send match {:?}",
130 x.0
131 );
132 break;
133 }
134 }
135 }
136
137 protocol::Response::SearchDone { .. } => {
139 trace!("[Query {search_id}] Searcher has finished");
140 done = true;
141 break;
142 }
143
144 _ => continue,
145 }
146 }
147
148 if done {
149 break;
150 }
151 }
152 }
153 });
154
155 Ok(Self {
156 id: search_id,
157 query,
158 channel,
159 task,
160 rx,
161 })
162 }
163
164 pub fn query(&self) -> &SearchQuery {
166 &self.query
167 }
168
169 pub fn is_active(&self) -> bool {
171 !self.task.is_finished()
172 }
173
174 pub async fn next(&mut self) -> Option<SearchQueryMatch> {
176 self.rx.recv().await
177 }
178
179 pub async fn cancel(&mut self) -> io::Result<()> {
181 trace!("[Query {}] Cancelling search", self.id);
182 self.channel.cancel_search(self.id).await?;
183
184 self.task.abort();
186
187 Ok(())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::path::PathBuf;
194 use std::sync::Arc;
195
196 use distant_net::common::{FramedTransport, InmemoryTransport, Response};
197 use distant_net::Client;
198 use test_log::test;
199 use tokio::sync::Mutex;
200
201 use super::*;
202 use crate::protocol::{
203 SearchQueryCondition, SearchQueryMatchData, SearchQueryOptions, SearchQueryPathMatch,
204 SearchQuerySubmatch, SearchQueryTarget,
205 };
206 use crate::DistantClient;
207
208 fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
209 let (t1, t2) = FramedTransport::pair(100);
210 (t1, Client::spawn_inmemory(t2, Default::default()))
211 }
212
213 #[test(tokio::test)]
214 async fn searcher_should_have_query_reflect_ongoing_query() {
215 let (mut transport, session) = make_session();
216 let test_query = SearchQuery {
217 paths: vec![PathBuf::from("/some/test/path")],
218 target: SearchQueryTarget::Path,
219 condition: SearchQueryCondition::Regex {
220 value: String::from("."),
221 },
222 options: SearchQueryOptions::default(),
223 };
224
225 let search_task = {
228 let test_query = test_query.clone();
229 tokio::spawn(async move { Searcher::search(session.clone_channel(), test_query).await })
230 };
231
232 let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
234
235 transport
237 .write_frame_for(&Response::new(
238 req.id,
239 protocol::Response::SearchStarted { id: rand::random() },
240 ))
241 .await
242 .unwrap();
243
244 let searcher = search_task.await.unwrap().unwrap();
246 assert_eq!(searcher.query(), &test_query);
247 }
248
249 #[test(tokio::test)]
250 async fn searcher_should_support_getting_next_match() {
251 let (mut transport, session) = make_session();
252 let test_query = SearchQuery {
253 paths: vec![PathBuf::from("/some/test/path")],
254 target: SearchQueryTarget::Path,
255 condition: SearchQueryCondition::Regex {
256 value: String::from("."),
257 },
258 options: SearchQueryOptions::default(),
259 };
260
261 let search_task =
264 tokio::spawn(
265 async move { Searcher::search(session.clone_channel(), test_query).await },
266 );
267
268 let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
270
271 let id = rand::random::<SearchId>();
273 transport
274 .write_frame_for(&Response::new(
275 req.id.clone(),
276 protocol::Response::SearchStarted { id },
277 ))
278 .await
279 .unwrap();
280
281 let mut searcher = search_task.await.unwrap().unwrap();
283
284 transport
286 .write_frame_for(&Response::new(
287 req.id,
288 vec![
289 protocol::Response::SearchResults {
290 id,
291 matches: vec![
292 SearchQueryMatch::Path(SearchQueryPathMatch {
293 path: PathBuf::from("/some/path/1"),
294 submatches: vec![SearchQuerySubmatch {
295 r#match: SearchQueryMatchData::Text("test match".to_string()),
296 start: 3,
297 end: 7,
298 }],
299 }),
300 SearchQueryMatch::Path(SearchQueryPathMatch {
301 path: PathBuf::from("/some/path/2"),
302 submatches: vec![SearchQuerySubmatch {
303 r#match: SearchQueryMatchData::Text("test match 2".to_string()),
304 start: 88,
305 end: 99,
306 }],
307 }),
308 ],
309 },
310 protocol::Response::SearchResults {
311 id,
312 matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
313 path: PathBuf::from("/some/path/3"),
314 submatches: vec![SearchQuerySubmatch {
315 r#match: SearchQueryMatchData::Text("test match 3".to_string()),
316 start: 5,
317 end: 9,
318 }],
319 })],
320 },
321 ],
322 ))
323 .await
324 .unwrap();
325
326 let m = searcher.next().await.expect("Searcher closed unexpectedly");
328 assert_eq!(
329 m,
330 SearchQueryMatch::Path(SearchQueryPathMatch {
331 path: PathBuf::from("/some/path/1"),
332 submatches: vec![SearchQuerySubmatch {
333 r#match: SearchQueryMatchData::Text("test match".to_string()),
334 start: 3,
335 end: 7,
336 }],
337 })
338 );
339
340 let m = searcher.next().await.expect("Searcher closed unexpectedly");
341 assert_eq!(
342 m,
343 SearchQueryMatch::Path(SearchQueryPathMatch {
344 path: PathBuf::from("/some/path/2"),
345 submatches: vec![SearchQuerySubmatch {
346 r#match: SearchQueryMatchData::Text("test match 2".to_string()),
347 start: 88,
348 end: 99,
349 }],
350 }),
351 );
352
353 let m = searcher.next().await.expect("Searcher closed unexpectedly");
354 assert_eq!(
355 m,
356 SearchQueryMatch::Path(SearchQueryPathMatch {
357 path: PathBuf::from("/some/path/3"),
358 submatches: vec![SearchQuerySubmatch {
359 r#match: SearchQueryMatchData::Text("test match 3".to_string()),
360 start: 5,
361 end: 9,
362 }],
363 })
364 );
365 }
366
367 #[test(tokio::test)]
368 async fn searcher_should_distinguish_match_events_and_only_receive_matches_for_itself() {
369 let (mut transport, session) = make_session();
370
371 let test_query = SearchQuery {
372 paths: vec![PathBuf::from("/some/test/path")],
373 target: SearchQueryTarget::Path,
374 condition: SearchQueryCondition::Regex {
375 value: String::from("."),
376 },
377 options: SearchQueryOptions::default(),
378 };
379
380 let search_task =
383 tokio::spawn(
384 async move { Searcher::search(session.clone_channel(), test_query).await },
385 );
386
387 let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
389
390 let id = rand::random();
392 transport
393 .write_frame_for(&Response::new(
394 req.id.clone(),
395 protocol::Response::SearchStarted { id },
396 ))
397 .await
398 .unwrap();
399
400 let mut searcher = search_task.await.unwrap().unwrap();
402
403 transport
405 .write_frame_for(&Response::new(
406 req.id.clone(),
407 protocol::Response::SearchResults {
408 id,
409 matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
410 path: PathBuf::from("/some/path/1"),
411 submatches: vec![SearchQuerySubmatch {
412 r#match: SearchQueryMatchData::Text("test match".to_string()),
413 start: 3,
414 end: 7,
415 }],
416 })],
417 },
418 ))
419 .await
420 .unwrap();
421
422 transport
424 .write_frame_for(&Response::new(
425 req.id.clone() + "1",
426 protocol::Response::SearchResults {
427 id,
428 matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
429 path: PathBuf::from("/some/path/2"),
430 submatches: vec![SearchQuerySubmatch {
431 r#match: SearchQueryMatchData::Text("test match 2".to_string()),
432 start: 88,
433 end: 99,
434 }],
435 })],
436 },
437 ))
438 .await
439 .unwrap();
440
441 transport
443 .write_frame_for(&Response::new(
444 req.id,
445 protocol::Response::SearchResults {
446 id,
447 matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
448 path: PathBuf::from("/some/path/3"),
449 submatches: vec![SearchQuerySubmatch {
450 r#match: SearchQueryMatchData::Text("test match 3".to_string()),
451 start: 5,
452 end: 9,
453 }],
454 })],
455 },
456 ))
457 .await
458 .unwrap();
459
460 let m = searcher.next().await.expect("Searcher closed unexpectedly");
462 assert_eq!(
463 m,
464 SearchQueryMatch::Path(SearchQueryPathMatch {
465 path: PathBuf::from("/some/path/1"),
466 submatches: vec![SearchQuerySubmatch {
467 r#match: SearchQueryMatchData::Text("test match".to_string()),
468 start: 3,
469 end: 7,
470 }],
471 })
472 );
473
474 let m = searcher.next().await.expect("Watcher closed unexpectedly");
475 assert_eq!(
476 m,
477 SearchQueryMatch::Path(SearchQueryPathMatch {
478 path: PathBuf::from("/some/path/3"),
479 submatches: vec![SearchQuerySubmatch {
480 r#match: SearchQueryMatchData::Text("test match 3".to_string()),
481 start: 5,
482 end: 9,
483 }],
484 })
485 );
486 }
487
488 #[test(tokio::test)]
489 async fn searcher_should_stop_receiving_events_if_cancelled() {
490 let (mut transport, session) = make_session();
491
492 let test_query = SearchQuery {
493 paths: vec![PathBuf::from("/some/test/path")],
494 target: SearchQueryTarget::Path,
495 condition: SearchQueryCondition::Regex {
496 value: String::from("."),
497 },
498 options: SearchQueryOptions::default(),
499 };
500
501 let search_task =
504 tokio::spawn(
505 async move { Searcher::search(session.clone_channel(), test_query).await },
506 );
507
508 let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
510
511 let id = rand::random::<SearchId>();
513 transport
514 .write_frame_for(&Response::new(
515 req.id.clone(),
516 protocol::Response::SearchStarted { id },
517 ))
518 .await
519 .unwrap();
520
521 transport
523 .write_frame_for(&Response::new(
524 req.id,
525 protocol::Response::SearchResults {
526 id,
527 matches: vec![
528 SearchQueryMatch::Path(SearchQueryPathMatch {
529 path: PathBuf::from("/some/path/1"),
530 submatches: vec![SearchQuerySubmatch {
531 r#match: SearchQueryMatchData::Text("test match".to_string()),
532 start: 3,
533 end: 7,
534 }],
535 }),
536 SearchQueryMatch::Path(SearchQueryPathMatch {
537 path: PathBuf::from("/some/path/2"),
538 submatches: vec![SearchQuerySubmatch {
539 r#match: SearchQueryMatchData::Text("test match 2".to_string()),
540 start: 88,
541 end: 99,
542 }],
543 }),
544 ],
545 },
546 ))
547 .await
548 .unwrap();
549
550 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
552
553 let searcher = Arc::new(Mutex::new(search_task.await.unwrap().unwrap()));
556
557 let m = searcher
559 .lock()
560 .await
561 .next()
562 .await
563 .expect("Searcher closed unexpectedly");
564 assert_eq!(
565 m,
566 SearchQueryMatch::Path(SearchQueryPathMatch {
567 path: PathBuf::from("/some/path/1"),
568 submatches: vec![SearchQuerySubmatch {
569 r#match: SearchQueryMatchData::Text("test match".to_string()),
570 start: 3,
571 end: 7,
572 }],
573 }),
574 );
575
576 let searcher_2 = Arc::clone(&searcher);
578 let cancel_task = tokio::spawn(async move { searcher_2.lock().await.cancel().await });
579
580 let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
581
582 transport
583 .write_frame_for(&Response::new(req.id.clone(), protocol::Response::Ok))
584 .await
585 .unwrap();
586
587 cancel_task.await.unwrap().unwrap();
589
590 transport
592 .write_frame_for(&Response::new(
593 req.id,
594 protocol::Response::SearchResults {
595 id,
596 matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
597 path: PathBuf::from("/some/path/3"),
598 submatches: vec![SearchQuerySubmatch {
599 r#match: SearchQueryMatchData::Text("test match 3".to_string()),
600 start: 5,
601 end: 9,
602 }],
603 })],
604 },
605 ))
606 .await
607 .unwrap();
608
609 assert_eq!(
612 searcher.lock().await.next().await,
613 Some(SearchQueryMatch::Path(SearchQueryPathMatch {
614 path: PathBuf::from("/some/path/2"),
615 submatches: vec![SearchQuerySubmatch {
616 r#match: SearchQueryMatchData::Text("test match 2".to_string()),
617 start: 88,
618 end: 99,
619 }],
620 }))
621 );
622 assert_eq!(searcher.lock().await.next().await, None);
623 }
624}