Skip to main content

mem7_vector/
upstash.rs

1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11/// Upstash Vector REST API client.
12pub struct UpstashVectorIndex {
13    client: Client,
14    base_url: String,
15    token: String,
16    namespace: String,
17}
18
19impl UpstashVectorIndex {
20    pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21        Self {
22            client: Client::new(),
23            base_url: base_url.trim_end_matches('/').to_string(),
24            token: token.to_string(),
25            namespace: namespace.to_string(),
26        }
27    }
28
29    fn url(&self, endpoint: &str) -> String {
30        if self.namespace.is_empty() {
31            format!("{}/{endpoint}", self.base_url)
32        } else {
33            format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34        }
35    }
36
37    async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38        &self,
39        endpoint: &str,
40        body: &T,
41    ) -> Result<R> {
42        let url = self.url(endpoint);
43        debug!(url = %url, "upstash request");
44
45        let resp = self
46            .client
47            .post(&url)
48            .header("Authorization", format!("Bearer {}", self.token))
49            .json(body)
50            .send()
51            .await?;
52
53        if !resp.status().is_success() {
54            let status = resp.status();
55            let text = resp.text().await.unwrap_or_default();
56            return Err(Mem7Error::VectorStore(format!(
57                "Upstash HTTP {status}: {text}"
58            )));
59        }
60
61        resp.json()
62            .await
63            .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64    }
65}
66
67// --- Upstash REST API types ---
68
69#[derive(Serialize)]
70struct UpsertEntry {
71    id: String,
72    vector: Vec<f32>,
73    metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78    vector: Vec<f32>,
79    #[serde(rename = "topK")]
80    top_k: usize,
81    #[serde(rename = "includeMetadata")]
82    include_metadata: bool,
83    #[serde(rename = "includeVectors")]
84    include_vectors: bool,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct FetchRequest {
91    ids: Vec<String>,
92    #[serde(rename = "includeMetadata")]
93    include_metadata: bool,
94    #[serde(rename = "includeVectors")]
95    include_vectors: bool,
96}
97
98#[derive(Serialize)]
99struct RangeRequest {
100    cursor: String,
101    limit: usize,
102    #[serde(rename = "includeMetadata")]
103    include_metadata: bool,
104    #[serde(rename = "includeVectors")]
105    include_vectors: bool,
106}
107
108#[derive(Deserialize)]
109struct UpstashResponse<T> {
110    result: T,
111}
112
113#[derive(Deserialize)]
114struct QueryResultEntry {
115    id: String,
116    score: f32,
117    metadata: Option<serde_json::Value>,
118}
119
120#[derive(Deserialize)]
121struct FetchResultEntry {
122    id: String,
123    vector: Option<Vec<f32>>,
124    metadata: Option<serde_json::Value>,
125}
126
127#[derive(Deserialize)]
128struct RangeResult {
129    #[serde(rename = "nextCursor")]
130    next_cursor: String,
131    vectors: Vec<FetchResultEntry>,
132}
133
134fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
135    match val {
136        serde_json::Value::Object(map) => {
137            let filtered: serde_json::Map<String, serde_json::Value> = map
138                .iter()
139                .filter(|(_, v)| !v.is_null())
140                .map(|(k, v)| (k.clone(), v.clone()))
141                .collect();
142            serde_json::Value::Object(filtered)
143        }
144        other => other.clone(),
145    }
146}
147
148fn build_filter(filter: &MemoryFilter) -> Option<String> {
149    let mut conditions = Vec::new();
150    if let Some(ref uid) = filter.user_id {
151        conditions.push(format!("user_id = '{uid}'"));
152    }
153    if let Some(ref aid) = filter.agent_id {
154        conditions.push(format!("agent_id = '{aid}'"));
155    }
156    if let Some(ref rid) = filter.run_id {
157        conditions.push(format!("run_id = '{rid}'"));
158    }
159
160    if let Some(ref meta) = filter.metadata
161        && let Some(obj) = meta.as_object()
162    {
163        for (key, cond) in obj {
164            if let Some(expr) = translate_metadata_condition(key, cond) {
165                conditions.push(expr);
166            }
167        }
168    }
169
170    if conditions.is_empty() {
171        None
172    } else {
173        Some(conditions.join(" AND "))
174    }
175}
176
177/// Translate a single metadata filter entry to an Upstash filter expression.
178/// Uses `metadata.{key}` path prefix for nested access.
179fn translate_metadata_condition(key: &str, condition: &serde_json::Value) -> Option<String> {
180    let field = format!("metadata.{key}");
181    match key {
182        "AND" => {
183            let parts: Vec<String> = condition
184                .as_array()?
185                .iter()
186                .filter_map(|f| {
187                    let obj = f.as_object()?;
188                    let sub: Vec<String> = obj
189                        .iter()
190                        .filter_map(|(k, v)| translate_metadata_condition(k, v))
191                        .collect();
192                    if sub.is_empty() {
193                        None
194                    } else {
195                        Some(sub.join(" AND "))
196                    }
197                })
198                .collect();
199            if parts.is_empty() {
200                None
201            } else {
202                Some(format!("({})", parts.join(" AND ")))
203            }
204        }
205        "OR" => {
206            let parts: Vec<String> = condition
207                .as_array()?
208                .iter()
209                .filter_map(|f| {
210                    let obj = f.as_object()?;
211                    let sub: Vec<String> = obj
212                        .iter()
213                        .filter_map(|(k, v)| translate_metadata_condition(k, v))
214                        .collect();
215                    if sub.is_empty() {
216                        None
217                    } else {
218                        Some(sub.join(" AND "))
219                    }
220                })
221                .collect();
222            if parts.is_empty() {
223                None
224            } else {
225                Some(format!("({})", parts.join(" OR ")))
226            }
227        }
228        "NOT" => {
229            let parts: Vec<String> = condition
230                .as_array()?
231                .iter()
232                .filter_map(|f| {
233                    let obj = f.as_object()?;
234                    let sub: Vec<String> = obj
235                        .iter()
236                        .filter_map(|(k, v)| translate_metadata_condition(k, v))
237                        .collect();
238                    if sub.is_empty() {
239                        None
240                    } else {
241                        Some(format!("NOT ({})", sub.join(" AND ")))
242                    }
243                })
244                .collect();
245            if parts.is_empty() {
246                None
247            } else {
248                Some(parts.join(" AND "))
249            }
250        }
251        _ => match condition {
252            serde_json::Value::Object(ops) => {
253                let parts: Vec<String> = ops
254                    .iter()
255                    .filter_map(|(op, val)| translate_operator(&field, op, val))
256                    .collect();
257                if parts.is_empty() {
258                    None
259                } else {
260                    Some(parts.join(" AND "))
261                }
262            }
263            serde_json::Value::String(s) => Some(format!("{field} = '{s}'")),
264            serde_json::Value::Number(n) => Some(format!("{field} = {n}")),
265            serde_json::Value::Bool(b) => Some(format!("{field} = {b}")),
266            _ => None,
267        },
268    }
269}
270
271fn translate_operator(field: &str, op: &str, val: &serde_json::Value) -> Option<String> {
272    match op {
273        "eq" => Some(format_equality(field, val)),
274        "ne" => Some(format!("{field} != {}", format_value(val))),
275        "gt" => Some(format!("{field} > {}", format_value(val))),
276        "gte" => Some(format!("{field} >= {}", format_value(val))),
277        "lt" => Some(format!("{field} < {}", format_value(val))),
278        "lte" => Some(format!("{field} <= {}", format_value(val))),
279        "in" => {
280            let items: Vec<String> = val.as_array()?.iter().map(format_value).collect();
281            Some(format!("{field} IN ({})", items.join(", ")))
282        }
283        "nin" => {
284            let items: Vec<String> = val.as_array()?.iter().map(format_value).collect();
285            Some(format!("{field} NOT IN ({})", items.join(", ")))
286        }
287        "contains" => {
288            let s = val.as_str()?;
289            Some(format!("{field} GLOB '*{s}*'"))
290        }
291        "icontains" => {
292            let s = val.as_str()?;
293            Some(format!("{field} GLOB '*{s}*'"))
294        }
295        _ => None,
296    }
297}
298
299fn format_equality(field: &str, val: &serde_json::Value) -> String {
300    format!("{field} = {}", format_value(val))
301}
302
303fn format_value(val: &serde_json::Value) -> String {
304    match val {
305        serde_json::Value::String(s) => format!("'{s}'"),
306        serde_json::Value::Number(n) => n.to_string(),
307        serde_json::Value::Bool(b) => b.to_string(),
308        other => other.to_string(),
309    }
310}
311
312fn matches_filter_local(payload: &serde_json::Value, filter: &MemoryFilter) -> bool {
313    crate::filter::matches_filter(payload, filter)
314}
315
316#[async_trait]
317impl VectorIndex for UpstashVectorIndex {
318    async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
319        let entries = vec![UpsertEntry {
320            id: id.to_string(),
321            vector: vector.to_vec(),
322            metadata: strip_nulls(&payload),
323        }];
324        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
325        Ok(())
326    }
327
328    async fn search(
329        &self,
330        query: &[f32],
331        limit: usize,
332        filters: Option<&MemoryFilter>,
333    ) -> Result<Vec<VectorSearchResult>> {
334        let req = QueryRequest {
335            vector: query.to_vec(),
336            top_k: limit,
337            include_metadata: true,
338            include_vectors: false,
339            filter: filters.and_then(build_filter),
340        };
341        let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
342
343        Ok(resp
344            .result
345            .into_iter()
346            .filter_map(|entry| {
347                let id = Uuid::parse_str(&entry.id).ok()?;
348                Some(VectorSearchResult {
349                    id,
350                    score: entry.score,
351                    payload: entry.metadata.unwrap_or(serde_json::Value::Null),
352                })
353            })
354            .collect())
355    }
356
357    async fn delete(&self, id: &Uuid) -> Result<()> {
358        let ids = vec![id.to_string()];
359        let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
360        Ok(())
361    }
362
363    async fn update(
364        &self,
365        id: &Uuid,
366        vector: Option<&[f32]>,
367        payload: Option<serde_json::Value>,
368    ) -> Result<()> {
369        // Upstash upsert replaces the entry; we need the full vector + metadata.
370        let existing = self.get(id).await?;
371
372        let (vec_data, meta_data) = match existing {
373            Some((v, m)) => (v, m),
374            None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
375        };
376
377        let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
378        let final_meta = payload
379            .map(|p| strip_nulls(&p))
380            .unwrap_or_else(|| strip_nulls(&meta_data));
381
382        let entries = vec![UpsertEntry {
383            id: id.to_string(),
384            vector: final_vec,
385            metadata: final_meta,
386        }];
387        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
388        Ok(())
389    }
390
391    async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
392        let ids = vec![id.to_string()];
393
394        let resp: UpstashResponse<Vec<FetchResultEntry>> = self
395            .post(
396                "fetch",
397                &FetchRequest {
398                    ids,
399                    include_metadata: true,
400                    include_vectors: true,
401                },
402            )
403            .await?;
404
405        Ok(resp.result.into_iter().next().map(|entry| {
406            (
407                entry.vector.unwrap_or_default(),
408                entry.metadata.unwrap_or(serde_json::Value::Null),
409            )
410        }))
411    }
412
413    async fn list(
414        &self,
415        filters: Option<&MemoryFilter>,
416        limit: Option<usize>,
417    ) -> Result<Vec<(Uuid, serde_json::Value)>> {
418        let mut all_results = Vec::new();
419        let mut cursor = "0".to_string();
420        let page_size = 100;
421
422        loop {
423            let req = RangeRequest {
424                cursor,
425                limit: page_size,
426                include_metadata: true,
427                include_vectors: false,
428            };
429            let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
430
431            for entry in resp.result.vectors {
432                if let Ok(id) = Uuid::parse_str(&entry.id) {
433                    let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
434                    let passes = filters
435                        .map(|f| matches_filter_local(&metadata, f))
436                        .unwrap_or(true);
437                    if passes {
438                        all_results.push((id, metadata));
439                    }
440                }
441            }
442
443            if let Some(lim) = limit
444                && all_results.len() >= lim
445            {
446                all_results.truncate(lim);
447                break;
448            }
449
450            if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
451                break;
452            }
453            cursor = resp.result.next_cursor;
454        }
455
456        all_results.sort_by(|a, b| a.0.cmp(&b.0));
457        Ok(all_results)
458    }
459
460    async fn reset(&self) -> Result<()> {
461        let url = self.url("reset");
462        let resp = self
463            .client
464            .post(&url)
465            .header("Authorization", format!("Bearer {}", self.token))
466            .send()
467            .await?;
468
469        if !resp.status().is_success() {
470            let status = resp.status();
471            let text = resp.text().await.unwrap_or_default();
472            return Err(Mem7Error::VectorStore(format!(
473                "Upstash reset HTTP {status}: {text}"
474            )));
475        }
476        Ok(())
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use serde_json::json;
484
485    #[test]
486    fn test_build_filter_user_id_only() {
487        let f = MemoryFilter {
488            user_id: Some("alice".into()),
489            ..Default::default()
490        };
491        assert_eq!(build_filter(&f), Some("user_id = 'alice'".into()));
492    }
493
494    #[test]
495    fn test_build_filter_metadata_simple_eq() {
496        let f = MemoryFilter {
497            metadata: Some(json!({"status": "active"})),
498            ..Default::default()
499        };
500        assert_eq!(build_filter(&f), Some("metadata.status = 'active'".into()));
501    }
502
503    #[test]
504    fn test_build_filter_metadata_operators() {
505        let f = MemoryFilter {
506            metadata: Some(json!({"score": {"gt": 50}})),
507            ..Default::default()
508        };
509        assert_eq!(build_filter(&f), Some("metadata.score > 50".into()));
510
511        let f2 = MemoryFilter {
512            metadata: Some(json!({"score": {"lte": 100}})),
513            ..Default::default()
514        };
515        assert_eq!(build_filter(&f2), Some("metadata.score <= 100".into()));
516    }
517
518    #[test]
519    fn test_build_filter_metadata_in() {
520        let f = MemoryFilter {
521            metadata: Some(json!({"tag": {"in": ["rust", "python"]}})),
522            ..Default::default()
523        };
524        assert_eq!(
525            build_filter(&f),
526            Some("metadata.tag IN ('rust', 'python')".into())
527        );
528    }
529
530    #[test]
531    fn test_build_filter_metadata_contains() {
532        let f = MemoryFilter {
533            metadata: Some(json!({"desc": {"contains": "hello"}})),
534            ..Default::default()
535        };
536        assert_eq!(
537            build_filter(&f),
538            Some("metadata.desc GLOB '*hello*'".into())
539        );
540    }
541
542    #[test]
543    fn test_build_filter_and_combinator() {
544        let f = MemoryFilter {
545            metadata: Some(json!({"AND": [
546                {"status": "active"},
547                {"score": {"gt": 50}}
548            ]})),
549            ..Default::default()
550        };
551        let result = build_filter(&f).unwrap();
552        assert_eq!(
553            result,
554            "(metadata.status = 'active' AND metadata.score > 50)"
555        );
556    }
557
558    #[test]
559    fn test_build_filter_or_combinator() {
560        let f = MemoryFilter {
561            metadata: Some(json!({"OR": [
562                {"status": "active"},
563                {"status": "pending"}
564            ]})),
565            ..Default::default()
566        };
567        let result = build_filter(&f).unwrap();
568        assert_eq!(
569            result,
570            "(metadata.status = 'active' OR metadata.status = 'pending')"
571        );
572    }
573
574    #[test]
575    fn test_build_filter_combined_first_class_and_metadata() {
576        let f = MemoryFilter {
577            user_id: Some("alice".into()),
578            metadata: Some(json!({"status": "active"})),
579            ..Default::default()
580        };
581        let result = build_filter(&f).unwrap();
582        assert_eq!(result, "user_id = 'alice' AND metadata.status = 'active'");
583    }
584
585    #[test]
586    fn test_build_filter_no_conditions() {
587        let f = MemoryFilter::default();
588        assert_eq!(build_filter(&f), None);
589    }
590
591    #[test]
592    fn test_build_filter_ne_operator() {
593        let f = MemoryFilter {
594            metadata: Some(json!({"status": {"ne": "deleted"}})),
595            ..Default::default()
596        };
597        assert_eq!(
598            build_filter(&f),
599            Some("metadata.status != 'deleted'".into())
600        );
601    }
602
603    #[test]
604    fn test_build_filter_not_combinator() {
605        let f = MemoryFilter {
606            metadata: Some(json!({"NOT": [{"status": "deleted"}]})),
607            ..Default::default()
608        };
609        let result = build_filter(&f).unwrap();
610        assert_eq!(result, "NOT (metadata.status = 'deleted')");
611    }
612}