1use terraphim_types::{Document, SearchQuery};
2
3pub trait HaystackProvider {
4 type Error: std::fmt::Display + std::fmt::Debug + Send + Sync + 'static;
5
6 #[allow(async_fn_in_trait)]
7 async fn search(&self, query: &SearchQuery) -> Result<Vec<Document>, Self::Error>;
8}
9
10#[cfg(test)]
11mod tests {
12 use super::*;
13 use terraphim_types::NormalizedTermValue;
14
15 struct TestProvider {
17 documents: Vec<Document>,
18 }
19
20 impl TestProvider {
21 fn with_docs(documents: Vec<Document>) -> Self {
22 Self { documents }
23 }
24
25 fn empty() -> Self {
26 Self {
27 documents: Vec::new(),
28 }
29 }
30 }
31
32 #[derive(Debug)]
34 struct TestProviderError(String);
35
36 impl std::fmt::Display for TestProviderError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "TestProviderError: {}", self.0)
39 }
40 }
41
42 impl HaystackProvider for TestProvider {
43 type Error = TestProviderError;
44
45 async fn search(&self, _query: &SearchQuery) -> Result<Vec<Document>, Self::Error> {
46 Ok(self.documents.clone())
47 }
48 }
49
50 struct FailingProvider;
52
53 impl HaystackProvider for FailingProvider {
54 type Error = TestProviderError;
55
56 async fn search(&self, _query: &SearchQuery) -> Result<Vec<Document>, Self::Error> {
57 Err(TestProviderError("search failed".to_string()))
58 }
59 }
60
61 fn make_query(term: &str) -> SearchQuery {
62 SearchQuery {
63 search_term: NormalizedTermValue::from(term),
64 ..Default::default()
65 }
66 }
67
68 fn make_document(id: &str, title: &str) -> Document {
69 Document {
70 id: id.to_string(),
71 title: title.to_string(),
72 ..Default::default()
73 }
74 }
75
76 #[tokio::test]
77 async fn test_provider_returns_documents() {
78 let provider = TestProvider::with_docs(vec![
79 make_document("1", "First Result"),
80 make_document("2", "Second Result"),
81 ]);
82 let results = provider.search(&make_query("test")).await.unwrap();
83 assert_eq!(results.len(), 2);
84 assert_eq!(results[0].title, "First Result");
85 assert_eq!(results[1].title, "Second Result");
86 }
87
88 #[tokio::test]
89 async fn test_provider_returns_empty_results() {
90 let provider = TestProvider::empty();
91 let results = provider.search(&make_query("nothing")).await.unwrap();
92 assert!(results.is_empty());
93 }
94
95 #[tokio::test]
96 async fn test_provider_error_propagation() {
97 let provider = FailingProvider;
98 let result = provider.search(&make_query("test")).await;
99 assert!(result.is_err());
100 let err = result.unwrap_err();
101 assert!(err.to_string().contains("search failed"));
102 }
103
104 #[tokio::test]
105 async fn test_error_type_is_send_sync() {
106 fn assert_send_sync<T: Send + Sync + 'static>() {}
107 assert_send_sync::<TestProviderError>();
108 }
109
110 #[tokio::test]
111 async fn test_provider_with_empty_search_term() {
112 let provider = TestProvider::with_docs(vec![make_document("1", "Doc")]);
113 let results = provider.search(&make_query("")).await.unwrap();
114 assert_eq!(results.len(), 1);
115 }
116
117 #[tokio::test]
118 async fn test_provider_with_special_characters_in_query() {
119 let provider = TestProvider::with_docs(vec![make_document("1", "Doc")]);
120 let results = provider
121 .search(&make_query("test & <script>alert(1)</script>"))
122 .await
123 .unwrap();
124 assert_eq!(results.len(), 1);
125 }
126
127 #[tokio::test]
128 async fn test_concurrent_searches() {
129 let provider =
130 std::sync::Arc::new(TestProvider::with_docs(vec![make_document("1", "Result")]));
131
132 let mut handles = Vec::new();
133 for _ in 0..10 {
134 let p = provider.clone();
135 handles.push(tokio::spawn(async move {
136 p.search(&make_query("concurrent")).await.unwrap()
137 }));
138 }
139
140 for handle in handles {
141 let results = handle.await.unwrap();
142 assert_eq!(results.len(), 1);
143 }
144 }
145}