Skip to main content

haystack_core/
lib.rs

1use terraphim_types::{Document, SearchQuery};
2
3pub trait HaystackProvider {
4    type Error: std::fmt::Display + std::fmt::Debug + Send + Sync + 'static;
5
6    #[allow(async_fn_in_trait)]
7    async fn search(&self, query: &SearchQuery) -> Result<Vec<Document>, Self::Error>;
8}
9
10#[cfg(test)]
11mod tests {
12    use super::*;
13    use terraphim_types::NormalizedTermValue;
14
15    /// A concrete test provider that returns pre-configured documents.
16    struct TestProvider {
17        documents: Vec<Document>,
18    }
19
20    impl TestProvider {
21        fn with_docs(documents: Vec<Document>) -> Self {
22            Self { documents }
23        }
24
25        fn empty() -> Self {
26            Self {
27                documents: Vec::new(),
28            }
29        }
30    }
31
32    /// Error type for the test provider.
33    #[derive(Debug)]
34    struct TestProviderError(String);
35
36    impl std::fmt::Display for TestProviderError {
37        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38            write!(f, "TestProviderError: {}", self.0)
39        }
40    }
41
42    impl HaystackProvider for TestProvider {
43        type Error = TestProviderError;
44
45        async fn search(&self, _query: &SearchQuery) -> Result<Vec<Document>, Self::Error> {
46            Ok(self.documents.clone())
47        }
48    }
49
50    /// A provider that always returns an error.
51    struct FailingProvider;
52
53    impl HaystackProvider for FailingProvider {
54        type Error = TestProviderError;
55
56        async fn search(&self, _query: &SearchQuery) -> Result<Vec<Document>, Self::Error> {
57            Err(TestProviderError("search failed".to_string()))
58        }
59    }
60
61    fn make_query(term: &str) -> SearchQuery {
62        SearchQuery {
63            search_term: NormalizedTermValue::from(term),
64            ..Default::default()
65        }
66    }
67
68    fn make_document(id: &str, title: &str) -> Document {
69        Document {
70            id: id.to_string(),
71            title: title.to_string(),
72            ..Default::default()
73        }
74    }
75
76    #[tokio::test]
77    async fn test_provider_returns_documents() {
78        let provider = TestProvider::with_docs(vec![
79            make_document("1", "First Result"),
80            make_document("2", "Second Result"),
81        ]);
82        let results = provider.search(&make_query("test")).await.unwrap();
83        assert_eq!(results.len(), 2);
84        assert_eq!(results[0].title, "First Result");
85        assert_eq!(results[1].title, "Second Result");
86    }
87
88    #[tokio::test]
89    async fn test_provider_returns_empty_results() {
90        let provider = TestProvider::empty();
91        let results = provider.search(&make_query("nothing")).await.unwrap();
92        assert!(results.is_empty());
93    }
94
95    #[tokio::test]
96    async fn test_provider_error_propagation() {
97        let provider = FailingProvider;
98        let result = provider.search(&make_query("test")).await;
99        assert!(result.is_err());
100        let err = result.unwrap_err();
101        assert!(err.to_string().contains("search failed"));
102    }
103
104    #[tokio::test]
105    async fn test_error_type_is_send_sync() {
106        fn assert_send_sync<T: Send + Sync + 'static>() {}
107        assert_send_sync::<TestProviderError>();
108    }
109
110    #[tokio::test]
111    async fn test_provider_with_empty_search_term() {
112        let provider = TestProvider::with_docs(vec![make_document("1", "Doc")]);
113        let results = provider.search(&make_query("")).await.unwrap();
114        assert_eq!(results.len(), 1);
115    }
116
117    #[tokio::test]
118    async fn test_provider_with_special_characters_in_query() {
119        let provider = TestProvider::with_docs(vec![make_document("1", "Doc")]);
120        let results = provider
121            .search(&make_query("test & <script>alert(1)</script>"))
122            .await
123            .unwrap();
124        assert_eq!(results.len(), 1);
125    }
126
127    #[tokio::test]
128    async fn test_concurrent_searches() {
129        let provider =
130            std::sync::Arc::new(TestProvider::with_docs(vec![make_document("1", "Result")]));
131
132        let mut handles = Vec::new();
133        for _ in 0..10 {
134            let p = provider.clone();
135            handles.push(tokio::spawn(async move {
136                p.search(&make_query("concurrent")).await.unwrap()
137            }));
138        }
139
140        for handle in handles {
141            let results = handle.await.unwrap();
142            assert_eq!(results.len(), 1);
143        }
144    }
145}