nt_agentdb_client/
queries.rs

1// AgentDB Query Templates with <1ms performance targets
2
3use crate::{
4    client::AgentDBClient,
5    errors::Result,
6    schema::{Observation, Order, ReflexionTrace, Signal},
7};
8use chrono::Utc;
9use serde::{Serialize, Deserialize};
10use uuid::Uuid;
11
12/// Query builder for vector similarity search
13#[derive(Debug, Clone, Serialize)]
14pub struct VectorQuery {
15    pub collection: String,
16    pub embedding: Vec<f32>,
17    pub k: usize,
18
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub filters: Option<Vec<Filter>>,
21
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub min_score: Option<f32>,
24}
25
26impl VectorQuery {
27    pub fn new(collection: String, embedding: Vec<f32>, k: usize) -> Self {
28        Self {
29            collection,
30            embedding,
31            k,
32            filters: None,
33            min_score: None,
34        }
35    }
36
37    pub fn with_filter(mut self, filter: Filter) -> Self {
38        self.filters.get_or_insert_with(Vec::new).push(filter);
39        self
40    }
41
42    pub fn with_min_score(mut self, score: f32) -> Self {
43        self.min_score = Some(score);
44        self
45    }
46}
47
48/// Filter for metadata queries
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "op", rename_all = "lowercase")]
51pub enum Filter {
52    Eq {
53        field: String,
54        value: serde_json::Value,
55    },
56    Ne {
57        field: String,
58        value: serde_json::Value,
59    },
60    Gt {
61        field: String,
62        value: serde_json::Value,
63    },
64    Gte {
65        field: String,
66        value: serde_json::Value,
67    },
68    Lt {
69        field: String,
70        value: serde_json::Value,
71    },
72    Lte {
73        field: String,
74        value: serde_json::Value,
75    },
76    In {
77        field: String,
78        values: Vec<serde_json::Value>,
79    },
80    And {
81        filters: Vec<Filter>,
82    },
83    Or {
84        filters: Vec<Filter>,
85    },
86}
87
88impl Filter {
89    pub fn eq(field: impl Into<String>, value: impl Serialize) -> Self {
90        Self::Eq {
91            field: field.into(),
92            value: serde_json::to_value(value).unwrap(),
93        }
94    }
95
96    pub fn gte(field: impl Into<String>, value: impl Serialize) -> Self {
97        Self::Gte {
98            field: field.into(),
99            value: serde_json::to_value(value).unwrap(),
100        }
101    }
102
103    pub fn lte(field: impl Into<String>, value: impl Serialize) -> Self {
104        Self::Lte {
105            field: field.into(),
106            value: serde_json::to_value(value).unwrap(),
107        }
108    }
109
110    pub fn and(filters: Vec<Filter>) -> Self {
111        Self::And { filters }
112    }
113
114    pub fn or(filters: Vec<Filter>) -> Self {
115        Self::Or { filters }
116    }
117}
118
119/// Query templates for common operations
120impl AgentDBClient {
121    /// Find similar market conditions
122    /// Target: <1ms for k=10
123    pub async fn find_similar_conditions(
124        &self,
125        current: &Observation,
126        k: usize,
127        time_window_hours: Option<i64>,
128    ) -> Result<Vec<Observation>> {
129        let mut query = VectorQuery::new("observations".to_string(), current.embedding.clone(), k)
130            .with_filter(Filter::eq("symbol", &current.symbol));
131
132        if let Some(hours) = time_window_hours {
133            let cutoff = current.timestamp_us - (hours * 3600 * 1_000_000);
134            query = query.with_filter(Filter::gte("timestamp_us", cutoff));
135        }
136
137        self.vector_search(query).await
138    }
139
140    /// Get signals by strategy
141    /// Target: <1ms
142    pub async fn get_signals_by_strategy(
143        &self,
144        strategy_id: &str,
145        min_confidence: f64,
146        limit: usize,
147    ) -> Result<Vec<Signal>> {
148        let query = MetadataQuery {
149            collection: "signals".to_string(),
150            filters: vec![
151                Filter::eq("strategy_id", strategy_id),
152                Filter::gte("confidence", min_confidence),
153            ],
154            limit: Some(limit),
155            sort_by: Some(SortBy {
156                field: "confidence".to_string(),
157                order: SortOrder::Desc,
158            }),
159        };
160
161        self.metadata_search(query).await
162    }
163
164    /// Find similar trading decisions
165    /// Target: <1ms for k=10
166    pub async fn find_similar_decisions(&self, signal: &Signal, k: usize) -> Result<Vec<Signal>> {
167        let query = VectorQuery::new("signals".to_string(), signal.embedding.clone(), k)
168            .with_filter(Filter::eq("symbol", &signal.symbol));
169
170        self.vector_search(query).await
171    }
172
173    /// Get top performing strategies
174    /// Target: <50ms for 1000 traces
175    pub async fn get_top_strategies(
176        &self,
177        min_score: f64,
178        limit: usize,
179    ) -> Result<Vec<(String, f64)>> {
180        let query = MetadataQuery {
181            collection: "reflexion_traces".to_string(),
182            filters: vec![Filter::gte("verdict.score", min_score)],
183            limit: Some(limit * 10), // Get more to aggregate
184            sort_by: Some(SortBy {
185                field: "verdict.sharpe".to_string(),
186                order: SortOrder::Desc,
187            }),
188        };
189
190        let traces: Vec<ReflexionTrace> = self.metadata_search(query).await?;
191
192        // Aggregate by strategy (simplified)
193        let mut strategy_scores = std::collections::HashMap::new();
194
195        for trace in traces {
196            // Extract strategy from decision
197            // This is simplified - in practice, we'd join with signals table
198            let score = trace.verdict.score;
199            strategy_scores
200                .entry("strategy_placeholder".to_string())
201                .or_insert_with(Vec::new)
202                .push(score);
203        }
204
205        let mut results: Vec<(String, f64)> = strategy_scores
206            .into_iter()
207            .map(|(strategy, scores)| {
208                let avg = scores.iter().sum::<f64>() / scores.len() as f64;
209                (strategy, avg)
210            })
211            .collect();
212
213        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
214        results.truncate(limit);
215
216        Ok(results)
217    }
218
219    /// Get observations in time range
220    /// Target: <5ms for 1-hour window
221    pub async fn get_observations_in_range(
222        &self,
223        symbol: &str,
224        start_us: i64,
225        end_us: i64,
226    ) -> Result<Vec<Observation>> {
227        let query = MetadataQuery {
228            collection: "observations".to_string(),
229            filters: vec![
230                Filter::eq("symbol", symbol),
231                Filter::gte("timestamp_us", start_us),
232                Filter::lte("timestamp_us", end_us),
233            ],
234            limit: Some(10000),
235            sort_by: Some(SortBy {
236                field: "timestamp_us".to_string(),
237                order: SortOrder::Asc,
238            }),
239        };
240
241        self.metadata_search(query).await
242    }
243
244    /// Get orders for signal
245    /// Target: <1ms
246    pub async fn get_orders_for_signal(&self, signal_id: Uuid) -> Result<Vec<Order>> {
247        let query = MetadataQuery {
248            collection: "orders".to_string(),
249            filters: vec![Filter::eq("signal_id", signal_id.to_string())],
250            limit: Some(100),
251            sort_by: Some(SortBy {
252                field: "timestamps.created_us".to_string(),
253                order: SortOrder::Asc,
254            }),
255        };
256
257        self.metadata_search(query).await
258    }
259
260    /// Get recent signals
261    /// Target: <1ms
262    pub async fn get_recent_signals(&self, symbol: &str, limit: usize) -> Result<Vec<Signal>> {
263        let cutoff = Utc::now().timestamp_micros() - (24 * 3600 * 1_000_000); // Last 24 hours
264
265        let query = MetadataQuery {
266            collection: "signals".to_string(),
267            filters: vec![
268                Filter::eq("symbol", symbol),
269                Filter::gte("timestamp_us", cutoff),
270            ],
271            limit: Some(limit),
272            sort_by: Some(SortBy {
273                field: "timestamp_us".to_string(),
274                order: SortOrder::Desc,
275            }),
276        };
277
278        self.metadata_search(query).await
279    }
280}
281
282/// Metadata-only query
283#[derive(Debug, Clone, Serialize)]
284pub struct MetadataQuery {
285    pub collection: String,
286    pub filters: Vec<Filter>,
287
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub limit: Option<usize>,
290
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub sort_by: Option<SortBy>,
293}
294
295#[derive(Debug, Clone, Serialize)]
296pub struct SortBy {
297    pub field: String,
298    pub order: SortOrder,
299}
300
301#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
302#[serde(rename_all = "lowercase")]
303pub enum SortOrder {
304    Asc,
305    Desc,
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_filter_builder() {
314        let filter = Filter::eq("symbol", "AAPL");
315
316        match filter {
317            Filter::Eq { field, value } => {
318                assert_eq!(field, "symbol");
319                assert_eq!(value, serde_json::json!("AAPL"));
320            }
321            _ => panic!("Wrong filter type"),
322        }
323    }
324
325    #[test]
326    fn test_and_filter() {
327        let filter = Filter::and(vec![
328            Filter::eq("symbol", "AAPL"),
329            Filter::gte("confidence", 0.8),
330        ]);
331
332        match filter {
333            Filter::And { filters } => {
334                assert_eq!(filters.len(), 2);
335            }
336            _ => panic!("Wrong filter type"),
337        }
338    }
339
340    #[test]
341    fn test_vector_query_builder() {
342        let query = VectorQuery::new("observations".to_string(), vec![0.1, 0.2, 0.3], 10)
343            .with_filter(Filter::eq("symbol", "AAPL"))
344            .with_min_score(0.8);
345
346        assert_eq!(query.collection, "observations");
347        assert_eq!(query.k, 10);
348        assert!(query.filters.is_some());
349        assert_eq!(query.min_score, Some(0.8));
350    }
351}