1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11pub struct UpstashVectorIndex {
13 client: Client,
14 base_url: String,
15 token: String,
16 namespace: String,
17}
18
19impl UpstashVectorIndex {
20 pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21 Self {
22 client: Client::new(),
23 base_url: base_url.trim_end_matches('/').to_string(),
24 token: token.to_string(),
25 namespace: namespace.to_string(),
26 }
27 }
28
29 fn url(&self, endpoint: &str) -> String {
30 if self.namespace.is_empty() {
31 format!("{}/{endpoint}", self.base_url)
32 } else {
33 format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34 }
35 }
36
37 async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38 &self,
39 endpoint: &str,
40 body: &T,
41 ) -> Result<R> {
42 let url = self.url(endpoint);
43 debug!(url = %url, "upstash request");
44
45 let resp = self
46 .client
47 .post(&url)
48 .header("Authorization", format!("Bearer {}", self.token))
49 .json(body)
50 .send()
51 .await?;
52
53 if !resp.status().is_success() {
54 let status = resp.status();
55 let text = resp.text().await.unwrap_or_default();
56 return Err(Mem7Error::VectorStore(format!(
57 "Upstash HTTP {status}: {text}"
58 )));
59 }
60
61 resp.json()
62 .await
63 .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64 }
65}
66
67#[derive(Serialize)]
70struct UpsertEntry {
71 id: String,
72 vector: Vec<f32>,
73 metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78 vector: Vec<f32>,
79 #[serde(rename = "topK")]
80 top_k: usize,
81 #[serde(rename = "includeMetadata")]
82 include_metadata: bool,
83 #[serde(rename = "includeVectors")]
84 include_vectors: bool,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct FetchRequest {
91 ids: Vec<String>,
92 #[serde(rename = "includeMetadata")]
93 include_metadata: bool,
94 #[serde(rename = "includeVectors")]
95 include_vectors: bool,
96}
97
98#[derive(Serialize)]
99struct RangeRequest {
100 cursor: String,
101 limit: usize,
102 #[serde(rename = "includeMetadata")]
103 include_metadata: bool,
104 #[serde(rename = "includeVectors")]
105 include_vectors: bool,
106}
107
108#[derive(Deserialize)]
109struct UpstashResponse<T> {
110 result: T,
111}
112
113#[derive(Deserialize)]
114struct QueryResultEntry {
115 id: String,
116 score: f32,
117 metadata: Option<serde_json::Value>,
118}
119
120#[derive(Deserialize)]
121struct FetchResultEntry {
122 id: String,
123 vector: Option<Vec<f32>>,
124 metadata: Option<serde_json::Value>,
125}
126
127#[derive(Deserialize)]
128struct RangeResult {
129 #[serde(rename = "nextCursor")]
130 next_cursor: String,
131 vectors: Vec<FetchResultEntry>,
132}
133
134fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
135 match val {
136 serde_json::Value::Object(map) => {
137 let filtered: serde_json::Map<String, serde_json::Value> = map
138 .iter()
139 .filter(|(_, v)| !v.is_null())
140 .map(|(k, v)| (k.clone(), v.clone()))
141 .collect();
142 serde_json::Value::Object(filtered)
143 }
144 other => other.clone(),
145 }
146}
147
148fn build_filter(filter: &MemoryFilter) -> Option<String> {
149 let mut conditions = Vec::new();
150 if let Some(ref uid) = filter.user_id {
151 conditions.push(format!("user_id = '{uid}'"));
152 }
153 if let Some(ref aid) = filter.agent_id {
154 conditions.push(format!("agent_id = '{aid}'"));
155 }
156 if let Some(ref rid) = filter.run_id {
157 conditions.push(format!("run_id = '{rid}'"));
158 }
159
160 if let Some(ref meta) = filter.metadata
161 && let Some(obj) = meta.as_object()
162 {
163 for (key, cond) in obj {
164 if let Some(expr) = translate_metadata_condition(key, cond) {
165 conditions.push(expr);
166 }
167 }
168 }
169
170 if conditions.is_empty() {
171 None
172 } else {
173 Some(conditions.join(" AND "))
174 }
175}
176
177fn translate_metadata_condition(key: &str, condition: &serde_json::Value) -> Option<String> {
180 let field = format!("metadata.{key}");
181 match key {
182 "AND" => {
183 let parts: Vec<String> = condition
184 .as_array()?
185 .iter()
186 .filter_map(|f| {
187 let obj = f.as_object()?;
188 let sub: Vec<String> = obj
189 .iter()
190 .filter_map(|(k, v)| translate_metadata_condition(k, v))
191 .collect();
192 if sub.is_empty() {
193 None
194 } else {
195 Some(sub.join(" AND "))
196 }
197 })
198 .collect();
199 if parts.is_empty() {
200 None
201 } else {
202 Some(format!("({})", parts.join(" AND ")))
203 }
204 }
205 "OR" => {
206 let parts: Vec<String> = condition
207 .as_array()?
208 .iter()
209 .filter_map(|f| {
210 let obj = f.as_object()?;
211 let sub: Vec<String> = obj
212 .iter()
213 .filter_map(|(k, v)| translate_metadata_condition(k, v))
214 .collect();
215 if sub.is_empty() {
216 None
217 } else {
218 Some(sub.join(" AND "))
219 }
220 })
221 .collect();
222 if parts.is_empty() {
223 None
224 } else {
225 Some(format!("({})", parts.join(" OR ")))
226 }
227 }
228 "NOT" => {
229 let parts: Vec<String> = condition
230 .as_array()?
231 .iter()
232 .filter_map(|f| {
233 let obj = f.as_object()?;
234 let sub: Vec<String> = obj
235 .iter()
236 .filter_map(|(k, v)| translate_metadata_condition(k, v))
237 .collect();
238 if sub.is_empty() {
239 None
240 } else {
241 Some(format!("NOT ({})", sub.join(" AND ")))
242 }
243 })
244 .collect();
245 if parts.is_empty() {
246 None
247 } else {
248 Some(parts.join(" AND "))
249 }
250 }
251 _ => match condition {
252 serde_json::Value::Object(ops) => {
253 let parts: Vec<String> = ops
254 .iter()
255 .filter_map(|(op, val)| translate_operator(&field, op, val))
256 .collect();
257 if parts.is_empty() {
258 None
259 } else {
260 Some(parts.join(" AND "))
261 }
262 }
263 serde_json::Value::String(s) => Some(format!("{field} = '{s}'")),
264 serde_json::Value::Number(n) => Some(format!("{field} = {n}")),
265 serde_json::Value::Bool(b) => Some(format!("{field} = {b}")),
266 _ => None,
267 },
268 }
269}
270
271fn translate_operator(field: &str, op: &str, val: &serde_json::Value) -> Option<String> {
272 match op {
273 "eq" => Some(format_equality(field, val)),
274 "ne" => Some(format!("{field} != {}", format_value(val))),
275 "gt" => Some(format!("{field} > {}", format_value(val))),
276 "gte" => Some(format!("{field} >= {}", format_value(val))),
277 "lt" => Some(format!("{field} < {}", format_value(val))),
278 "lte" => Some(format!("{field} <= {}", format_value(val))),
279 "in" => {
280 let items: Vec<String> = val.as_array()?.iter().map(format_value).collect();
281 Some(format!("{field} IN ({})", items.join(", ")))
282 }
283 "nin" => {
284 let items: Vec<String> = val.as_array()?.iter().map(format_value).collect();
285 Some(format!("{field} NOT IN ({})", items.join(", ")))
286 }
287 "contains" => {
288 let s = val.as_str()?;
289 Some(format!("{field} GLOB '*{s}*'"))
290 }
291 "icontains" => {
292 let s = val.as_str()?;
293 Some(format!("{field} GLOB '*{s}*'"))
294 }
295 _ => None,
296 }
297}
298
299fn format_equality(field: &str, val: &serde_json::Value) -> String {
300 format!("{field} = {}", format_value(val))
301}
302
303fn format_value(val: &serde_json::Value) -> String {
304 match val {
305 serde_json::Value::String(s) => format!("'{s}'"),
306 serde_json::Value::Number(n) => n.to_string(),
307 serde_json::Value::Bool(b) => b.to_string(),
308 other => other.to_string(),
309 }
310}
311
312fn matches_filter_local(payload: &serde_json::Value, filter: &MemoryFilter) -> bool {
313 crate::filter::matches_filter(payload, filter)
314}
315
316#[async_trait]
317impl VectorIndex for UpstashVectorIndex {
318 async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
319 let entries = vec![UpsertEntry {
320 id: id.to_string(),
321 vector: vector.to_vec(),
322 metadata: strip_nulls(&payload),
323 }];
324 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
325 Ok(())
326 }
327
328 async fn search(
329 &self,
330 query: &[f32],
331 limit: usize,
332 filters: Option<&MemoryFilter>,
333 ) -> Result<Vec<VectorSearchResult>> {
334 let req = QueryRequest {
335 vector: query.to_vec(),
336 top_k: limit,
337 include_metadata: true,
338 include_vectors: false,
339 filter: filters.and_then(build_filter),
340 };
341 let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
342
343 Ok(resp
344 .result
345 .into_iter()
346 .filter_map(|entry| {
347 let id = Uuid::parse_str(&entry.id).ok()?;
348 Some(VectorSearchResult {
349 id,
350 score: entry.score,
351 payload: entry.metadata.unwrap_or(serde_json::Value::Null),
352 })
353 })
354 .collect())
355 }
356
357 async fn delete(&self, id: &Uuid) -> Result<()> {
358 let ids = vec![id.to_string()];
359 let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
360 Ok(())
361 }
362
363 async fn update(
364 &self,
365 id: &Uuid,
366 vector: Option<&[f32]>,
367 payload: Option<serde_json::Value>,
368 ) -> Result<()> {
369 let existing = self.get(id).await?;
371
372 let (vec_data, meta_data) = match existing {
373 Some((v, m)) => (v, m),
374 None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
375 };
376
377 let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
378 let final_meta = payload
379 .map(|p| strip_nulls(&p))
380 .unwrap_or_else(|| strip_nulls(&meta_data));
381
382 let entries = vec![UpsertEntry {
383 id: id.to_string(),
384 vector: final_vec,
385 metadata: final_meta,
386 }];
387 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
388 Ok(())
389 }
390
391 async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
392 let ids = vec![id.to_string()];
393
394 let resp: UpstashResponse<Vec<FetchResultEntry>> = self
395 .post(
396 "fetch",
397 &FetchRequest {
398 ids,
399 include_metadata: true,
400 include_vectors: true,
401 },
402 )
403 .await?;
404
405 Ok(resp.result.into_iter().next().map(|entry| {
406 (
407 entry.vector.unwrap_or_default(),
408 entry.metadata.unwrap_or(serde_json::Value::Null),
409 )
410 }))
411 }
412
413 async fn list(
414 &self,
415 filters: Option<&MemoryFilter>,
416 limit: Option<usize>,
417 ) -> Result<Vec<(Uuid, serde_json::Value)>> {
418 let mut all_results = Vec::new();
419 let mut cursor = "0".to_string();
420 let page_size = 100;
421
422 loop {
423 let req = RangeRequest {
424 cursor,
425 limit: page_size,
426 include_metadata: true,
427 include_vectors: false,
428 };
429 let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
430
431 for entry in resp.result.vectors {
432 if let Ok(id) = Uuid::parse_str(&entry.id) {
433 let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
434 let passes = filters
435 .map(|f| matches_filter_local(&metadata, f))
436 .unwrap_or(true);
437 if passes {
438 all_results.push((id, metadata));
439 }
440 }
441 }
442
443 if let Some(lim) = limit
444 && all_results.len() >= lim
445 {
446 all_results.truncate(lim);
447 break;
448 }
449
450 if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
451 break;
452 }
453 cursor = resp.result.next_cursor;
454 }
455
456 all_results.sort_by(|a, b| a.0.cmp(&b.0));
457 Ok(all_results)
458 }
459
460 async fn reset(&self) -> Result<()> {
461 let url = self.url("reset");
462 let resp = self
463 .client
464 .post(&url)
465 .header("Authorization", format!("Bearer {}", self.token))
466 .send()
467 .await?;
468
469 if !resp.status().is_success() {
470 let status = resp.status();
471 let text = resp.text().await.unwrap_or_default();
472 return Err(Mem7Error::VectorStore(format!(
473 "Upstash reset HTTP {status}: {text}"
474 )));
475 }
476 Ok(())
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use serde_json::json;
484
485 #[test]
486 fn test_build_filter_user_id_only() {
487 let f = MemoryFilter {
488 user_id: Some("alice".into()),
489 ..Default::default()
490 };
491 assert_eq!(build_filter(&f), Some("user_id = 'alice'".into()));
492 }
493
494 #[test]
495 fn test_build_filter_metadata_simple_eq() {
496 let f = MemoryFilter {
497 metadata: Some(json!({"status": "active"})),
498 ..Default::default()
499 };
500 assert_eq!(build_filter(&f), Some("metadata.status = 'active'".into()));
501 }
502
503 #[test]
504 fn test_build_filter_metadata_operators() {
505 let f = MemoryFilter {
506 metadata: Some(json!({"score": {"gt": 50}})),
507 ..Default::default()
508 };
509 assert_eq!(build_filter(&f), Some("metadata.score > 50".into()));
510
511 let f2 = MemoryFilter {
512 metadata: Some(json!({"score": {"lte": 100}})),
513 ..Default::default()
514 };
515 assert_eq!(build_filter(&f2), Some("metadata.score <= 100".into()));
516 }
517
518 #[test]
519 fn test_build_filter_metadata_in() {
520 let f = MemoryFilter {
521 metadata: Some(json!({"tag": {"in": ["rust", "python"]}})),
522 ..Default::default()
523 };
524 assert_eq!(
525 build_filter(&f),
526 Some("metadata.tag IN ('rust', 'python')".into())
527 );
528 }
529
530 #[test]
531 fn test_build_filter_metadata_contains() {
532 let f = MemoryFilter {
533 metadata: Some(json!({"desc": {"contains": "hello"}})),
534 ..Default::default()
535 };
536 assert_eq!(
537 build_filter(&f),
538 Some("metadata.desc GLOB '*hello*'".into())
539 );
540 }
541
542 #[test]
543 fn test_build_filter_and_combinator() {
544 let f = MemoryFilter {
545 metadata: Some(json!({"AND": [
546 {"status": "active"},
547 {"score": {"gt": 50}}
548 ]})),
549 ..Default::default()
550 };
551 let result = build_filter(&f).unwrap();
552 assert_eq!(
553 result,
554 "(metadata.status = 'active' AND metadata.score > 50)"
555 );
556 }
557
558 #[test]
559 fn test_build_filter_or_combinator() {
560 let f = MemoryFilter {
561 metadata: Some(json!({"OR": [
562 {"status": "active"},
563 {"status": "pending"}
564 ]})),
565 ..Default::default()
566 };
567 let result = build_filter(&f).unwrap();
568 assert_eq!(
569 result,
570 "(metadata.status = 'active' OR metadata.status = 'pending')"
571 );
572 }
573
574 #[test]
575 fn test_build_filter_combined_first_class_and_metadata() {
576 let f = MemoryFilter {
577 user_id: Some("alice".into()),
578 metadata: Some(json!({"status": "active"})),
579 ..Default::default()
580 };
581 let result = build_filter(&f).unwrap();
582 assert_eq!(result, "user_id = 'alice' AND metadata.status = 'active'");
583 }
584
585 #[test]
586 fn test_build_filter_no_conditions() {
587 let f = MemoryFilter::default();
588 assert_eq!(build_filter(&f), None);
589 }
590
591 #[test]
592 fn test_build_filter_ne_operator() {
593 let f = MemoryFilter {
594 metadata: Some(json!({"status": {"ne": "deleted"}})),
595 ..Default::default()
596 };
597 assert_eq!(
598 build_filter(&f),
599 Some("metadata.status != 'deleted'".into())
600 );
601 }
602
603 #[test]
604 fn test_build_filter_not_combinator() {
605 let f = MemoryFilter {
606 metadata: Some(json!({"NOT": [{"status": "deleted"}]})),
607 ..Default::default()
608 };
609 let result = build_filter(&f).unwrap();
610 assert_eq!(result, "NOT (metadata.status = 'deleted')");
611 }
612}