1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11pub struct UpstashVectorIndex {
13 client: Client,
14 base_url: String,
15 token: String,
16 namespace: String,
17}
18
19impl UpstashVectorIndex {
20 pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21 Self {
22 client: Client::new(),
23 base_url: base_url.trim_end_matches('/').to_string(),
24 token: token.to_string(),
25 namespace: namespace.to_string(),
26 }
27 }
28
29 fn url(&self, endpoint: &str) -> String {
30 if self.namespace.is_empty() {
31 format!("{}/{endpoint}", self.base_url)
32 } else {
33 format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34 }
35 }
36
37 async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38 &self,
39 endpoint: &str,
40 body: &T,
41 ) -> Result<R> {
42 let url = self.url(endpoint);
43 debug!(url = %url, "upstash request");
44
45 let resp = self
46 .client
47 .post(&url)
48 .header("Authorization", format!("Bearer {}", self.token))
49 .json(body)
50 .send()
51 .await?;
52
53 if !resp.status().is_success() {
54 let status = resp.status();
55 let text = resp.text().await.unwrap_or_default();
56 return Err(Mem7Error::VectorStore(format!(
57 "Upstash HTTP {status}: {text}"
58 )));
59 }
60
61 resp.json()
62 .await
63 .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64 }
65}
66
67#[derive(Serialize)]
70struct UpsertEntry {
71 id: String,
72 vector: Vec<f32>,
73 metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78 vector: Vec<f32>,
79 #[serde(rename = "topK")]
80 top_k: usize,
81 #[serde(rename = "includeMetadata")]
82 include_metadata: bool,
83 #[serde(rename = "includeVectors")]
84 include_vectors: bool,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct RangeRequest {
91 cursor: String,
92 limit: usize,
93 #[serde(rename = "includeMetadata")]
94 include_metadata: bool,
95 #[serde(rename = "includeVectors")]
96 include_vectors: bool,
97}
98
99#[derive(Deserialize)]
100struct UpstashResponse<T> {
101 result: T,
102}
103
104#[derive(Deserialize)]
105struct QueryResultEntry {
106 id: String,
107 score: f32,
108 metadata: Option<serde_json::Value>,
109}
110
111#[derive(Deserialize)]
112struct FetchResultEntry {
113 id: String,
114 vector: Option<Vec<f32>>,
115 metadata: Option<serde_json::Value>,
116}
117
118#[derive(Deserialize)]
119struct RangeResult {
120 #[serde(rename = "nextCursor")]
121 next_cursor: String,
122 vectors: Vec<FetchResultEntry>,
123}
124
125fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
126 match val {
127 serde_json::Value::Object(map) => {
128 let filtered: serde_json::Map<String, serde_json::Value> = map
129 .iter()
130 .filter(|(_, v)| !v.is_null())
131 .map(|(k, v)| (k.clone(), v.clone()))
132 .collect();
133 serde_json::Value::Object(filtered)
134 }
135 other => other.clone(),
136 }
137}
138
139fn build_filter(filter: &MemoryFilter) -> Option<String> {
140 let mut conditions = Vec::new();
141 if let Some(ref uid) = filter.user_id {
142 conditions.push(format!("user_id = '{uid}'"));
143 }
144 if let Some(ref aid) = filter.agent_id {
145 conditions.push(format!("agent_id = '{aid}'"));
146 }
147 if let Some(ref rid) = filter.run_id {
148 conditions.push(format!("run_id = '{rid}'"));
149 }
150 if conditions.is_empty() {
151 None
152 } else {
153 Some(conditions.join(" AND "))
154 }
155}
156
157fn matches_filter_local(metadata: &serde_json::Value, filter: &MemoryFilter) -> bool {
158 if let Some(ref uid) = filter.user_id
159 && metadata.get("user_id").and_then(|v| v.as_str()) != Some(uid.as_str())
160 {
161 return false;
162 }
163 if let Some(ref aid) = filter.agent_id
164 && metadata.get("agent_id").and_then(|v| v.as_str()) != Some(aid.as_str())
165 {
166 return false;
167 }
168 if let Some(ref rid) = filter.run_id
169 && metadata.get("run_id").and_then(|v| v.as_str()) != Some(rid.as_str())
170 {
171 return false;
172 }
173 true
174}
175
176#[async_trait]
177impl VectorIndex for UpstashVectorIndex {
178 async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
179 let entries = vec![UpsertEntry {
180 id: id.to_string(),
181 vector: vector.to_vec(),
182 metadata: strip_nulls(&payload),
183 }];
184 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
185 Ok(())
186 }
187
188 async fn search(
189 &self,
190 query: &[f32],
191 limit: usize,
192 filters: Option<&MemoryFilter>,
193 ) -> Result<Vec<VectorSearchResult>> {
194 let req = QueryRequest {
195 vector: query.to_vec(),
196 top_k: limit,
197 include_metadata: true,
198 include_vectors: false,
199 filter: filters.and_then(build_filter),
200 };
201 let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
202
203 Ok(resp
204 .result
205 .into_iter()
206 .filter_map(|entry| {
207 let id = Uuid::parse_str(&entry.id).ok()?;
208 Some(VectorSearchResult {
209 id,
210 score: entry.score,
211 payload: entry.metadata.unwrap_or(serde_json::Value::Null),
212 })
213 })
214 .collect())
215 }
216
217 async fn delete(&self, id: &Uuid) -> Result<()> {
218 let ids = vec![id.to_string()];
219 let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
220 Ok(())
221 }
222
223 async fn update(
224 &self,
225 id: &Uuid,
226 vector: Option<&[f32]>,
227 payload: Option<serde_json::Value>,
228 ) -> Result<()> {
229 let existing = self.get(id).await?;
231
232 let (vec_data, meta_data) = match existing {
233 Some((v, m)) => (v, m),
234 None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
235 };
236
237 let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
238 let final_meta = payload
239 .map(|p| strip_nulls(&p))
240 .unwrap_or_else(|| strip_nulls(&meta_data));
241
242 let entries = vec![UpsertEntry {
243 id: id.to_string(),
244 vector: final_vec,
245 metadata: final_meta,
246 }];
247 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
248 Ok(())
249 }
250
251 async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
252 let url = if self.namespace.is_empty() {
253 format!("{}/fetch/{id}", self.base_url)
254 } else {
255 format!("{}/fetch/{id}?ns={}", self.base_url, self.namespace)
256 };
257
258 let resp = self
259 .client
260 .get(&url)
261 .header("Authorization", format!("Bearer {}", self.token))
262 .send()
263 .await?;
264
265 if !resp.status().is_success() {
266 let status = resp.status();
267 let text = resp.text().await.unwrap_or_default();
268 return Err(Mem7Error::VectorStore(format!(
269 "Upstash fetch HTTP {status}: {text}"
270 )));
271 }
272
273 let data: UpstashResponse<Option<FetchResultEntry>> = resp
274 .json()
275 .await
276 .map_err(|e| Mem7Error::VectorStore(format!("Upstash fetch parse error: {e}")))?;
277
278 Ok(data.result.map(|entry| {
279 (
280 entry.vector.unwrap_or_default(),
281 entry.metadata.unwrap_or(serde_json::Value::Null),
282 )
283 }))
284 }
285
286 async fn list(
287 &self,
288 filters: Option<&MemoryFilter>,
289 limit: Option<usize>,
290 ) -> Result<Vec<(Uuid, serde_json::Value)>> {
291 let mut all_results = Vec::new();
292 let mut cursor = "0".to_string();
293 let page_size = 100;
294
295 loop {
296 let req = RangeRequest {
297 cursor,
298 limit: page_size,
299 include_metadata: true,
300 include_vectors: false,
301 };
302 let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
303
304 for entry in resp.result.vectors {
305 if let Ok(id) = Uuid::parse_str(&entry.id) {
306 let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
307 let passes = filters
308 .map(|f| matches_filter_local(&metadata, f))
309 .unwrap_or(true);
310 if passes {
311 all_results.push((id, metadata));
312 }
313 }
314 }
315
316 if let Some(lim) = limit
317 && all_results.len() >= lim
318 {
319 all_results.truncate(lim);
320 break;
321 }
322
323 if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
324 break;
325 }
326 cursor = resp.result.next_cursor;
327 }
328
329 all_results.sort_by(|a, b| a.0.cmp(&b.0));
330 Ok(all_results)
331 }
332
333 async fn reset(&self) -> Result<()> {
334 let url = self.url("reset");
335 let resp = self
336 .client
337 .post(&url)
338 .header("Authorization", format!("Bearer {}", self.token))
339 .send()
340 .await?;
341
342 if !resp.status().is_success() {
343 let status = resp.status();
344 let text = resp.text().await.unwrap_or_default();
345 return Err(Mem7Error::VectorStore(format!(
346 "Upstash reset HTTP {status}: {text}"
347 )));
348 }
349 Ok(())
350 }
351}