Skip to main content

mem7_vector/
upstash.rs

1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11/// Upstash Vector REST API client.
12pub struct UpstashVectorIndex {
13    client: Client,
14    base_url: String,
15    token: String,
16    namespace: String,
17}
18
19impl UpstashVectorIndex {
20    pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21        Self {
22            client: Client::new(),
23            base_url: base_url.trim_end_matches('/').to_string(),
24            token: token.to_string(),
25            namespace: namespace.to_string(),
26        }
27    }
28
29    fn url(&self, endpoint: &str) -> String {
30        if self.namespace.is_empty() {
31            format!("{}/{endpoint}", self.base_url)
32        } else {
33            format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34        }
35    }
36
37    async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38        &self,
39        endpoint: &str,
40        body: &T,
41    ) -> Result<R> {
42        let url = self.url(endpoint);
43        debug!(url = %url, "upstash request");
44
45        let resp = self
46            .client
47            .post(&url)
48            .header("Authorization", format!("Bearer {}", self.token))
49            .json(body)
50            .send()
51            .await?;
52
53        if !resp.status().is_success() {
54            let status = resp.status();
55            let text = resp.text().await.unwrap_or_default();
56            return Err(Mem7Error::VectorStore(format!(
57                "Upstash HTTP {status}: {text}"
58            )));
59        }
60
61        resp.json()
62            .await
63            .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64    }
65}
66
67// --- Upstash REST API types ---
68
69#[derive(Serialize)]
70struct UpsertEntry {
71    id: String,
72    vector: Vec<f32>,
73    metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78    vector: Vec<f32>,
79    #[serde(rename = "topK")]
80    top_k: usize,
81    #[serde(rename = "includeMetadata")]
82    include_metadata: bool,
83    #[serde(rename = "includeVectors")]
84    include_vectors: bool,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct RangeRequest {
91    cursor: String,
92    limit: usize,
93    #[serde(rename = "includeMetadata")]
94    include_metadata: bool,
95    #[serde(rename = "includeVectors")]
96    include_vectors: bool,
97}
98
99#[derive(Deserialize)]
100struct UpstashResponse<T> {
101    result: T,
102}
103
104#[derive(Deserialize)]
105struct QueryResultEntry {
106    id: String,
107    score: f32,
108    metadata: Option<serde_json::Value>,
109}
110
111#[derive(Deserialize)]
112struct FetchResultEntry {
113    id: String,
114    vector: Option<Vec<f32>>,
115    metadata: Option<serde_json::Value>,
116}
117
118#[derive(Deserialize)]
119struct RangeResult {
120    #[serde(rename = "nextCursor")]
121    next_cursor: String,
122    vectors: Vec<FetchResultEntry>,
123}
124
125fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
126    match val {
127        serde_json::Value::Object(map) => {
128            let filtered: serde_json::Map<String, serde_json::Value> = map
129                .iter()
130                .filter(|(_, v)| !v.is_null())
131                .map(|(k, v)| (k.clone(), v.clone()))
132                .collect();
133            serde_json::Value::Object(filtered)
134        }
135        other => other.clone(),
136    }
137}
138
139fn build_filter(filter: &MemoryFilter) -> Option<String> {
140    let mut conditions = Vec::new();
141    if let Some(ref uid) = filter.user_id {
142        conditions.push(format!("user_id = '{uid}'"));
143    }
144    if let Some(ref aid) = filter.agent_id {
145        conditions.push(format!("agent_id = '{aid}'"));
146    }
147    if let Some(ref rid) = filter.run_id {
148        conditions.push(format!("run_id = '{rid}'"));
149    }
150    if conditions.is_empty() {
151        None
152    } else {
153        Some(conditions.join(" AND "))
154    }
155}
156
157fn matches_filter_local(metadata: &serde_json::Value, filter: &MemoryFilter) -> bool {
158    if let Some(ref uid) = filter.user_id
159        && metadata.get("user_id").and_then(|v| v.as_str()) != Some(uid.as_str())
160    {
161        return false;
162    }
163    if let Some(ref aid) = filter.agent_id
164        && metadata.get("agent_id").and_then(|v| v.as_str()) != Some(aid.as_str())
165    {
166        return false;
167    }
168    if let Some(ref rid) = filter.run_id
169        && metadata.get("run_id").and_then(|v| v.as_str()) != Some(rid.as_str())
170    {
171        return false;
172    }
173    true
174}
175
176#[async_trait]
177impl VectorIndex for UpstashVectorIndex {
178    async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
179        let entries = vec![UpsertEntry {
180            id: id.to_string(),
181            vector: vector.to_vec(),
182            metadata: strip_nulls(&payload),
183        }];
184        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
185        Ok(())
186    }
187
188    async fn search(
189        &self,
190        query: &[f32],
191        limit: usize,
192        filters: Option<&MemoryFilter>,
193    ) -> Result<Vec<VectorSearchResult>> {
194        let req = QueryRequest {
195            vector: query.to_vec(),
196            top_k: limit,
197            include_metadata: true,
198            include_vectors: false,
199            filter: filters.and_then(build_filter),
200        };
201        let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
202
203        Ok(resp
204            .result
205            .into_iter()
206            .filter_map(|entry| {
207                let id = Uuid::parse_str(&entry.id).ok()?;
208                Some(VectorSearchResult {
209                    id,
210                    score: entry.score,
211                    payload: entry.metadata.unwrap_or(serde_json::Value::Null),
212                })
213            })
214            .collect())
215    }
216
217    async fn delete(&self, id: &Uuid) -> Result<()> {
218        let ids = vec![id.to_string()];
219        let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
220        Ok(())
221    }
222
223    async fn update(
224        &self,
225        id: &Uuid,
226        vector: Option<&[f32]>,
227        payload: Option<serde_json::Value>,
228    ) -> Result<()> {
229        // Upstash upsert replaces the entry; we need the full vector + metadata.
230        let existing = self.get(id).await?;
231
232        let (vec_data, meta_data) = match existing {
233            Some((v, m)) => (v, m),
234            None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
235        };
236
237        let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
238        let final_meta = payload
239            .map(|p| strip_nulls(&p))
240            .unwrap_or_else(|| strip_nulls(&meta_data));
241
242        let entries = vec![UpsertEntry {
243            id: id.to_string(),
244            vector: final_vec,
245            metadata: final_meta,
246        }];
247        let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
248        Ok(())
249    }
250
251    async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
252        let url = if self.namespace.is_empty() {
253            format!("{}/fetch/{id}", self.base_url)
254        } else {
255            format!("{}/fetch/{id}?ns={}", self.base_url, self.namespace)
256        };
257
258        let resp = self
259            .client
260            .get(&url)
261            .header("Authorization", format!("Bearer {}", self.token))
262            .send()
263            .await?;
264
265        if !resp.status().is_success() {
266            let status = resp.status();
267            let text = resp.text().await.unwrap_or_default();
268            return Err(Mem7Error::VectorStore(format!(
269                "Upstash fetch HTTP {status}: {text}"
270            )));
271        }
272
273        let data: UpstashResponse<Option<FetchResultEntry>> = resp
274            .json()
275            .await
276            .map_err(|e| Mem7Error::VectorStore(format!("Upstash fetch parse error: {e}")))?;
277
278        Ok(data.result.map(|entry| {
279            (
280                entry.vector.unwrap_or_default(),
281                entry.metadata.unwrap_or(serde_json::Value::Null),
282            )
283        }))
284    }
285
286    async fn list(
287        &self,
288        filters: Option<&MemoryFilter>,
289        limit: Option<usize>,
290    ) -> Result<Vec<(Uuid, serde_json::Value)>> {
291        let mut all_results = Vec::new();
292        let mut cursor = "0".to_string();
293        let page_size = 100;
294
295        loop {
296            let req = RangeRequest {
297                cursor,
298                limit: page_size,
299                include_metadata: true,
300                include_vectors: false,
301            };
302            let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
303
304            for entry in resp.result.vectors {
305                if let Ok(id) = Uuid::parse_str(&entry.id) {
306                    let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
307                    let passes = filters
308                        .map(|f| matches_filter_local(&metadata, f))
309                        .unwrap_or(true);
310                    if passes {
311                        all_results.push((id, metadata));
312                    }
313                }
314            }
315
316            if let Some(lim) = limit
317                && all_results.len() >= lim
318            {
319                all_results.truncate(lim);
320                break;
321            }
322
323            if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
324                break;
325            }
326            cursor = resp.result.next_cursor;
327        }
328
329        all_results.sort_by(|a, b| a.0.cmp(&b.0));
330        Ok(all_results)
331    }
332
333    async fn reset(&self) -> Result<()> {
334        let url = self.url("reset");
335        let resp = self
336            .client
337            .post(&url)
338            .header("Authorization", format!("Bearer {}", self.token))
339            .send()
340            .await?;
341
342        if !resp.status().is_success() {
343            let status = resp.status();
344            let text = resp.text().await.unwrap_or_default();
345            return Err(Mem7Error::VectorStore(format!(
346                "Upstash reset HTTP {status}: {text}"
347            )));
348        }
349        Ok(())
350    }
351}