Skip to main content

mem7_vector/
upstash.rs

1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11/// Upstash Vector REST API client.
12pub struct UpstashVectorIndex {
13    client: Client,
14    base_url: String,
15    token: String,
16    namespace: String,
17}
18
19impl UpstashVectorIndex {
20    pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21        Self {
22            client: Client::new(),
23            base_url: base_url.trim_end_matches('/').to_string(),
24            token: token.to_string(),
25            namespace: namespace.to_string(),
26        }
27    }
28
29    fn url(&self, endpoint: &str) -> String {
30        if self.namespace.is_empty() {
31            format!("{}/{endpoint}", self.base_url)
32        } else {
33            format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34        }
35    }
36
37    async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38        &self,
39        endpoint: &str,
40        body: &T,
41    ) -> Result<R> {
42        let url = self.url(endpoint);
43        debug!(url = %url, "upstash request");
44
45        let resp = self
46            .client
47            .post(&url)
48            .header("Authorization", format!("Bearer {}", self.token))
49            .json(body)
50            .send()
51            .await?;
52
53        if !resp.status().is_success() {
54            let status = resp.status();
55            let text = resp.text().await.unwrap_or_default();
56            return Err(Mem7Error::VectorStore(format!(
57                "Upstash HTTP {status}: {text}"
58            )));
59        }
60
61        resp.json()
62            .await
63            .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64    }
65}
66
67// --- Upstash REST API types ---
68
69#[derive(Serialize)]
70struct UpsertEntry {
71    id: String,
72    vector: Vec<f32>,
73    metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78    vector: Vec<f32>,
79    #[serde(rename = "topK")]
80    top_k: usize,
81    #[serde(rename = "includeMetadata")]
82    include_metadata: bool,
83    #[serde(rename = "includeVectors")]
84    include_vectors: bool,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct FetchRequest {
91    ids: Vec<String>,
92    #[serde(rename = "includeMetadata")]
93    include_metadata: bool,
94    #[serde(rename = "includeVectors")]
95    include_vectors: bool,
96}
97
98#[derive(Serialize)]
99struct RangeRequest {
100    cursor: String,
101    limit: usize,
102    #[serde(rename = "includeMetadata")]
103    include_metadata: bool,
104    #[serde(rename = "includeVectors")]
105    include_vectors: bool,
106}
107
108#[derive(Deserialize)]
109struct UpstashResponse<T> {
110    result: T,
111}
112
113#[derive(Deserialize)]
114struct QueryResultEntry {
115    id: String,
116    score: f32,
117    metadata: Option<serde_json::Value>,
118}
119
120#[derive(Deserialize)]
121struct FetchResultEntry {
122    id: String,
123    vector: Option<Vec<f32>>,
124    metadata: Option<serde_json::Value>,
125}
126
127#[derive(Deserialize)]
128struct RangeResult {
129    #[serde(rename = "nextCursor")]
130    next_cursor: String,
131    vectors: Vec<FetchResultEntry>,
132}
133
134fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
135    match val {
136        serde_json::Value::Object(map) => {
137            let filtered: serde_json::Map<String, serde_json::Value> = map
138                .iter()
139                .filter(|(_, v)| !v.is_null())
140                .map(|(k, v)| (k.clone(), v.clone()))
141                .collect();
142            serde_json::Value::Object(filtered)
143        }
144        other => other.clone(),
145    }
146}
147
148fn build_filter(filter: &MemoryFilter) -> Option<String> {
149    let mut conditions = Vec::new();
150    if let Some(ref uid) = filter.user_id {
151        conditions.push(format!("user_id = '{uid}'"));
152    }
153    if let Some(ref aid) = filter.agent_id {
154        conditions.push(format!("agent_id = '{aid}'"));
155    }
156    if let Some(ref rid) = filter.run_id {
157        conditions.push(format!("run_id = '{rid}'"));
158    }
159    if conditions.is_empty() {
160        None
161    } else {
162        Some(conditions.join(" AND "))
163    }
164}
165
166fn matches_filter_local(metadata: &serde_json::Value, filter: &MemoryFilter) -> bool {
167    if let Some(ref uid) = filter.user_id
168        && metadata.get("user_id").and_then(|v| v.as_str()) != Some(uid.as_str())
169    {
170        return false;
171    }
172    if let Some(ref aid) = filter.agent_id
173        && metadata.get("agent_id").and_then(|v| v.as_str()) != Some(aid.as_str())
174    {
175        return false;
176    }
177    if let Some(ref rid) = filter.run_id
178        && metadata.get("run_id").and_then(|v| v.as_str()) != Some(rid.as_str())
179    {
180        return false;
181    }
182    true
183}
184
185#[async_trait]
186impl VectorIndex for UpstashVectorIndex {
187    async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
188        let entries = vec![UpsertEntry {
189            id: id.to_string(),
190            vector: vector.to_vec(),
191            metadata: strip_nulls(&payload),
192        }];
193        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
194        Ok(())
195    }
196
197    async fn search(
198        &self,
199        query: &[f32],
200        limit: usize,
201        filters: Option<&MemoryFilter>,
202    ) -> Result<Vec<VectorSearchResult>> {
203        let req = QueryRequest {
204            vector: query.to_vec(),
205            top_k: limit,
206            include_metadata: true,
207            include_vectors: false,
208            filter: filters.and_then(build_filter),
209        };
210        let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
211
212        Ok(resp
213            .result
214            .into_iter()
215            .filter_map(|entry| {
216                let id = Uuid::parse_str(&entry.id).ok()?;
217                Some(VectorSearchResult {
218                    id,
219                    score: entry.score,
220                    payload: entry.metadata.unwrap_or(serde_json::Value::Null),
221                })
222            })
223            .collect())
224    }
225
226    async fn delete(&self, id: &Uuid) -> Result<()> {
227        let ids = vec![id.to_string()];
228        let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
229        Ok(())
230    }
231
232    async fn update(
233        &self,
234        id: &Uuid,
235        vector: Option<&[f32]>,
236        payload: Option<serde_json::Value>,
237    ) -> Result<()> {
238        // Upstash upsert replaces the entry; we need the full vector + metadata.
239        let existing = self.get(id).await?;
240
241        let (vec_data, meta_data) = match existing {
242            Some((v, m)) => (v, m),
243            None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
244        };
245
246        let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
247        let final_meta = payload
248            .map(|p| strip_nulls(&p))
249            .unwrap_or_else(|| strip_nulls(&meta_data));
250
251        let entries = vec![UpsertEntry {
252            id: id.to_string(),
253            vector: final_vec,
254            metadata: final_meta,
255        }];
256        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
257        Ok(())
258    }
259
260    async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
261        let ids = vec![id.to_string()];
262
263        let resp: UpstashResponse<Vec<FetchResultEntry>> = self
264            .post(
265                "fetch",
266                &FetchRequest {
267                    ids,
268                    include_metadata: true,
269                    include_vectors: true,
270                },
271            )
272            .await?;
273
274        Ok(resp.result.into_iter().next().map(|entry| {
275            (
276                entry.vector.unwrap_or_default(),
277                entry.metadata.unwrap_or(serde_json::Value::Null),
278            )
279        }))
280    }
281
282    async fn list(
283        &self,
284        filters: Option<&MemoryFilter>,
285        limit: Option<usize>,
286    ) -> Result<Vec<(Uuid, serde_json::Value)>> {
287        let mut all_results = Vec::new();
288        let mut cursor = "0".to_string();
289        let page_size = 100;
290
291        loop {
292            let req = RangeRequest {
293                cursor,
294                limit: page_size,
295                include_metadata: true,
296                include_vectors: false,
297            };
298            let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
299
300            for entry in resp.result.vectors {
301                if let Ok(id) = Uuid::parse_str(&entry.id) {
302                    let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
303                    let passes = filters
304                        .map(|f| matches_filter_local(&metadata, f))
305                        .unwrap_or(true);
306                    if passes {
307                        all_results.push((id, metadata));
308                    }
309                }
310            }
311
312            if let Some(lim) = limit
313                && all_results.len() >= lim
314            {
315                all_results.truncate(lim);
316                break;
317            }
318
319            if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
320                break;
321            }
322            cursor = resp.result.next_cursor;
323        }
324
325        all_results.sort_by(|a, b| a.0.cmp(&b.0));
326        Ok(all_results)
327    }
328
329    async fn reset(&self) -> Result<()> {
330        let url = self.url("reset");
331        let resp = self
332            .client
333            .post(&url)
334            .header("Authorization", format!("Bearer {}", self.token))
335            .send()
336            .await?;
337
338        if !resp.status().is_success() {
339            let status = resp.status();
340            let text = resp.text().await.unwrap_or_default();
341            return Err(Mem7Error::VectorStore(format!(
342                "Upstash reset HTTP {status}: {text}"
343            )));
344        }
345        Ok(())
346    }
347}