1use crate::{
4 client::AgentDBClient,
5 errors::Result,
6 schema::{Observation, Order, ReflexionTrace, Signal},
7};
8use chrono::Utc;
9use serde::{Serialize, Deserialize};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Serialize)]
14pub struct VectorQuery {
15 pub collection: String,
16 pub embedding: Vec<f32>,
17 pub k: usize,
18
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub filters: Option<Vec<Filter>>,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub min_score: Option<f32>,
24}
25
26impl VectorQuery {
27 pub fn new(collection: String, embedding: Vec<f32>, k: usize) -> Self {
28 Self {
29 collection,
30 embedding,
31 k,
32 filters: None,
33 min_score: None,
34 }
35 }
36
37 pub fn with_filter(mut self, filter: Filter) -> Self {
38 self.filters.get_or_insert_with(Vec::new).push(filter);
39 self
40 }
41
42 pub fn with_min_score(mut self, score: f32) -> Self {
43 self.min_score = Some(score);
44 self
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "op", rename_all = "lowercase")]
51pub enum Filter {
52 Eq {
53 field: String,
54 value: serde_json::Value,
55 },
56 Ne {
57 field: String,
58 value: serde_json::Value,
59 },
60 Gt {
61 field: String,
62 value: serde_json::Value,
63 },
64 Gte {
65 field: String,
66 value: serde_json::Value,
67 },
68 Lt {
69 field: String,
70 value: serde_json::Value,
71 },
72 Lte {
73 field: String,
74 value: serde_json::Value,
75 },
76 In {
77 field: String,
78 values: Vec<serde_json::Value>,
79 },
80 And {
81 filters: Vec<Filter>,
82 },
83 Or {
84 filters: Vec<Filter>,
85 },
86}
87
88impl Filter {
89 pub fn eq(field: impl Into<String>, value: impl Serialize) -> Self {
90 Self::Eq {
91 field: field.into(),
92 value: serde_json::to_value(value).unwrap(),
93 }
94 }
95
96 pub fn gte(field: impl Into<String>, value: impl Serialize) -> Self {
97 Self::Gte {
98 field: field.into(),
99 value: serde_json::to_value(value).unwrap(),
100 }
101 }
102
103 pub fn lte(field: impl Into<String>, value: impl Serialize) -> Self {
104 Self::Lte {
105 field: field.into(),
106 value: serde_json::to_value(value).unwrap(),
107 }
108 }
109
110 pub fn and(filters: Vec<Filter>) -> Self {
111 Self::And { filters }
112 }
113
114 pub fn or(filters: Vec<Filter>) -> Self {
115 Self::Or { filters }
116 }
117}
118
119impl AgentDBClient {
121 pub async fn find_similar_conditions(
124 &self,
125 current: &Observation,
126 k: usize,
127 time_window_hours: Option<i64>,
128 ) -> Result<Vec<Observation>> {
129 let mut query = VectorQuery::new("observations".to_string(), current.embedding.clone(), k)
130 .with_filter(Filter::eq("symbol", ¤t.symbol));
131
132 if let Some(hours) = time_window_hours {
133 let cutoff = current.timestamp_us - (hours * 3600 * 1_000_000);
134 query = query.with_filter(Filter::gte("timestamp_us", cutoff));
135 }
136
137 self.vector_search(query).await
138 }
139
140 pub async fn get_signals_by_strategy(
143 &self,
144 strategy_id: &str,
145 min_confidence: f64,
146 limit: usize,
147 ) -> Result<Vec<Signal>> {
148 let query = MetadataQuery {
149 collection: "signals".to_string(),
150 filters: vec![
151 Filter::eq("strategy_id", strategy_id),
152 Filter::gte("confidence", min_confidence),
153 ],
154 limit: Some(limit),
155 sort_by: Some(SortBy {
156 field: "confidence".to_string(),
157 order: SortOrder::Desc,
158 }),
159 };
160
161 self.metadata_search(query).await
162 }
163
164 pub async fn find_similar_decisions(&self, signal: &Signal, k: usize) -> Result<Vec<Signal>> {
167 let query = VectorQuery::new("signals".to_string(), signal.embedding.clone(), k)
168 .with_filter(Filter::eq("symbol", &signal.symbol));
169
170 self.vector_search(query).await
171 }
172
173 pub async fn get_top_strategies(
176 &self,
177 min_score: f64,
178 limit: usize,
179 ) -> Result<Vec<(String, f64)>> {
180 let query = MetadataQuery {
181 collection: "reflexion_traces".to_string(),
182 filters: vec![Filter::gte("verdict.score", min_score)],
183 limit: Some(limit * 10), sort_by: Some(SortBy {
185 field: "verdict.sharpe".to_string(),
186 order: SortOrder::Desc,
187 }),
188 };
189
190 let traces: Vec<ReflexionTrace> = self.metadata_search(query).await?;
191
192 let mut strategy_scores = std::collections::HashMap::new();
194
195 for trace in traces {
196 let score = trace.verdict.score;
199 strategy_scores
200 .entry("strategy_placeholder".to_string())
201 .or_insert_with(Vec::new)
202 .push(score);
203 }
204
205 let mut results: Vec<(String, f64)> = strategy_scores
206 .into_iter()
207 .map(|(strategy, scores)| {
208 let avg = scores.iter().sum::<f64>() / scores.len() as f64;
209 (strategy, avg)
210 })
211 .collect();
212
213 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
214 results.truncate(limit);
215
216 Ok(results)
217 }
218
219 pub async fn get_observations_in_range(
222 &self,
223 symbol: &str,
224 start_us: i64,
225 end_us: i64,
226 ) -> Result<Vec<Observation>> {
227 let query = MetadataQuery {
228 collection: "observations".to_string(),
229 filters: vec![
230 Filter::eq("symbol", symbol),
231 Filter::gte("timestamp_us", start_us),
232 Filter::lte("timestamp_us", end_us),
233 ],
234 limit: Some(10000),
235 sort_by: Some(SortBy {
236 field: "timestamp_us".to_string(),
237 order: SortOrder::Asc,
238 }),
239 };
240
241 self.metadata_search(query).await
242 }
243
244 pub async fn get_orders_for_signal(&self, signal_id: Uuid) -> Result<Vec<Order>> {
247 let query = MetadataQuery {
248 collection: "orders".to_string(),
249 filters: vec![Filter::eq("signal_id", signal_id.to_string())],
250 limit: Some(100),
251 sort_by: Some(SortBy {
252 field: "timestamps.created_us".to_string(),
253 order: SortOrder::Asc,
254 }),
255 };
256
257 self.metadata_search(query).await
258 }
259
260 pub async fn get_recent_signals(&self, symbol: &str, limit: usize) -> Result<Vec<Signal>> {
263 let cutoff = Utc::now().timestamp_micros() - (24 * 3600 * 1_000_000); let query = MetadataQuery {
266 collection: "signals".to_string(),
267 filters: vec![
268 Filter::eq("symbol", symbol),
269 Filter::gte("timestamp_us", cutoff),
270 ],
271 limit: Some(limit),
272 sort_by: Some(SortBy {
273 field: "timestamp_us".to_string(),
274 order: SortOrder::Desc,
275 }),
276 };
277
278 self.metadata_search(query).await
279 }
280}
281
282#[derive(Debug, Clone, Serialize)]
284pub struct MetadataQuery {
285 pub collection: String,
286 pub filters: Vec<Filter>,
287
288 #[serde(skip_serializing_if = "Option::is_none")]
289 pub limit: Option<usize>,
290
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub sort_by: Option<SortBy>,
293}
294
295#[derive(Debug, Clone, Serialize)]
296pub struct SortBy {
297 pub field: String,
298 pub order: SortOrder,
299}
300
301#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
302#[serde(rename_all = "lowercase")]
303pub enum SortOrder {
304 Asc,
305 Desc,
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_filter_builder() {
314 let filter = Filter::eq("symbol", "AAPL");
315
316 match filter {
317 Filter::Eq { field, value } => {
318 assert_eq!(field, "symbol");
319 assert_eq!(value, serde_json::json!("AAPL"));
320 }
321 _ => panic!("Wrong filter type"),
322 }
323 }
324
325 #[test]
326 fn test_and_filter() {
327 let filter = Filter::and(vec![
328 Filter::eq("symbol", "AAPL"),
329 Filter::gte("confidence", 0.8),
330 ]);
331
332 match filter {
333 Filter::And { filters } => {
334 assert_eq!(filters.len(), 2);
335 }
336 _ => panic!("Wrong filter type"),
337 }
338 }
339
340 #[test]
341 fn test_vector_query_builder() {
342 let query = VectorQuery::new("observations".to_string(), vec![0.1, 0.2, 0.3], 10)
343 .with_filter(Filter::eq("symbol", "AAPL"))
344 .with_min_score(0.8);
345
346 assert_eq!(query.collection, "observations");
347 assert_eq!(query.k, 10);
348 assert!(query.filters.is_some());
349 assert_eq!(query.min_score, Some(0.8));
350 }
351}