1use crate::context::ContextId;
7use crate::errors::SisterResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Query {
15 pub query_type: String,
17
18 #[serde(default)]
20 pub params: HashMap<String, serde_json::Value>,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub limit: Option<usize>,
25
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub offset: Option<usize>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
32 pub context_id: Option<ContextId>,
33
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub context_ids: Option<Vec<ContextId>>,
37
38 #[serde(default)]
40 pub merge_results: bool,
41}
42
43impl Query {
44 pub fn new(query_type: impl Into<String>) -> Self {
46 Self {
47 query_type: query_type.into(),
48 params: HashMap::new(),
49 limit: None,
50 offset: None,
51 context_id: None,
52 context_ids: None,
53 merge_results: false,
54 }
55 }
56
57 pub fn param(mut self, key: impl Into<String>, value: impl Serialize) -> Self {
59 if let Ok(v) = serde_json::to_value(value) {
60 self.params.insert(key.into(), v);
61 }
62 self
63 }
64
65 pub fn limit(mut self, limit: usize) -> Self {
67 self.limit = Some(limit);
68 self
69 }
70
71 pub fn offset(mut self, offset: usize) -> Self {
73 self.offset = Some(offset);
74 self
75 }
76
77 pub fn in_context(mut self, context_id: ContextId) -> Self {
79 self.context_id = Some(context_id);
80 self
81 }
82
83 pub fn in_contexts(mut self, context_ids: Vec<ContextId>) -> Self {
85 self.context_ids = Some(context_ids);
86 self.merge_results = true;
87 self
88 }
89
90 pub fn get_param<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
92 self.params
93 .get(key)
94 .and_then(|v| serde_json::from_value(v.clone()).ok())
95 }
96
97 pub fn get_string(&self, key: &str) -> Option<String> {
99 self.get_param(key)
100 }
101
102 pub fn get_int(&self, key: &str) -> Option<i64> {
104 self.get_param(key)
105 }
106
107 pub fn get_bool(&self, key: &str) -> Option<bool> {
109 self.get_param(key)
110 }
111}
112
113impl Query {
115 pub fn list() -> Self {
117 Self::new("list")
118 }
119
120 pub fn search(text: impl Into<String>) -> Self {
122 Self::new("search").param("text", text.into())
123 }
124
125 pub fn recent(count: usize) -> Self {
127 Self::new("recent").limit(count)
128 }
129
130 pub fn related(item_id: impl Into<String>) -> Self {
132 Self::new("related").param("item_id", item_id.into())
133 }
134
135 pub fn temporal() -> Self {
137 Self::new("temporal")
138 }
139
140 pub fn get(item_id: impl Into<String>) -> Self {
142 Self::new("get").param("id", item_id.into())
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct QueryResult {
149 pub query: Query,
151
152 pub results: Vec<serde_json::Value>,
154
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub total_count: Option<usize>,
158
159 pub has_more: bool,
161
162 #[serde(with = "duration_millis")]
164 pub query_time: Duration,
165
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub queried_contexts: Option<Vec<ContextId>>,
169}
170
171impl QueryResult {
172 pub fn new(query: Query, results: Vec<serde_json::Value>, query_time: Duration) -> Self {
174 Self {
175 query,
176 total_count: Some(results.len()),
177 has_more: false,
178 results,
179 query_time,
180 queried_contexts: None,
181 }
182 }
183
184 pub fn empty(query: Query) -> Self {
186 Self {
187 query,
188 results: vec![],
189 total_count: Some(0),
190 has_more: false,
191 query_time: Duration::ZERO,
192 queried_contexts: None,
193 }
194 }
195
196 pub fn with_pagination(mut self, total: usize, has_more: bool) -> Self {
198 self.total_count = Some(total);
199 self.has_more = has_more;
200 self
201 }
202
203 pub fn with_contexts(mut self, contexts: Vec<ContextId>) -> Self {
205 self.queried_contexts = Some(contexts);
206 self
207 }
208
209 pub fn results_as<T: for<'de> Deserialize<'de>>(&self) -> Vec<T> {
211 self.results
212 .iter()
213 .filter_map(|v| serde_json::from_value(v.clone()).ok())
214 .collect()
215 }
216
217 pub fn is_empty(&self) -> bool {
219 self.results.is_empty()
220 }
221
222 pub fn len(&self) -> usize {
224 self.results.len()
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct QueryTypeInfo {
231 pub name: String,
233
234 pub description: String,
236
237 pub required_params: Vec<String>,
239
240 pub optional_params: Vec<String>,
242
243 #[serde(skip_serializing_if = "Option::is_none")]
245 pub example: Option<serde_json::Value>,
246}
247
248impl QueryTypeInfo {
249 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
250 Self {
251 name: name.into(),
252 description: description.into(),
253 required_params: vec![],
254 optional_params: vec![],
255 example: None,
256 }
257 }
258
259 pub fn required(mut self, params: Vec<&str>) -> Self {
260 self.required_params = params.into_iter().map(String::from).collect();
261 self
262 }
263
264 pub fn optional(mut self, params: Vec<&str>) -> Self {
265 self.optional_params = params.into_iter().map(String::from).collect();
266 self
267 }
268
269 pub fn example(mut self, example: impl Serialize) -> Self {
270 self.example = serde_json::to_value(example).ok();
271 self
272 }
273}
274
275pub trait Queryable {
277 fn query(&self, query: Query) -> SisterResult<QueryResult>;
279
280 fn supports_query(&self, query_type: &str) -> bool;
282
283 fn query_types(&self) -> Vec<QueryTypeInfo>;
285
286 fn search(&self, text: &str) -> SisterResult<QueryResult> {
288 self.query(Query::search(text))
289 }
290
291 fn recent(&self, count: usize) -> SisterResult<QueryResult> {
293 self.query(Query::recent(count))
294 }
295
296 fn list(&self, limit: usize, offset: usize) -> SisterResult<QueryResult> {
298 self.query(Query::list().limit(limit).offset(offset))
299 }
300}
301
302mod duration_millis {
304 use serde::{Deserialize, Deserializer, Serializer};
305 use std::time::Duration;
306
307 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
308 where
309 S: Serializer,
310 {
311 serializer.serialize_u64(duration.as_millis() as u64)
312 }
313
314 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
315 where
316 D: Deserializer<'de>,
317 {
318 let ms = u64::deserialize(deserializer)?;
319 Ok(Duration::from_millis(ms))
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_query_builder() {
329 let query = Query::search("hello")
330 .limit(10)
331 .offset(5)
332 .param("extra", "value");
333
334 assert_eq!(query.query_type, "search");
335 assert_eq!(query.limit, Some(10));
336 assert_eq!(query.offset, Some(5));
337 assert_eq!(query.get_string("text"), Some("hello".to_string()));
338 assert_eq!(query.get_string("extra"), Some("value".to_string()));
339 }
340
341 #[test]
342 fn test_common_queries() {
343 let list = Query::list();
344 assert_eq!(list.query_type, "list");
345
346 let recent = Query::recent(5);
347 assert_eq!(recent.query_type, "recent");
348 assert_eq!(recent.limit, Some(5));
349
350 let search = Query::search("test");
351 assert_eq!(search.get_string("text"), Some("test".to_string()));
352 }
353
354 #[test]
355 fn test_query_result() {
356 let query = Query::list();
357 let results = vec![
358 serde_json::json!({"id": "1"}),
359 serde_json::json!({"id": "2"}),
360 ];
361
362 let result =
363 QueryResult::new(query, results, Duration::from_millis(10)).with_pagination(100, true);
364
365 assert_eq!(result.len(), 2);
366 assert!(result.has_more);
367 assert_eq!(result.total_count, Some(100));
368 }
369}