1use async_trait::async_trait;
2use mem7_core::MemoryFilter;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7use uuid::Uuid;
8
9use crate::{VectorIndex, VectorSearchResult};
10
11pub struct UpstashVectorIndex {
13 client: Client,
14 base_url: String,
15 token: String,
16 namespace: String,
17}
18
19impl UpstashVectorIndex {
20 pub fn new(base_url: &str, token: &str, namespace: &str) -> Self {
21 Self {
22 client: Client::new(),
23 base_url: base_url.trim_end_matches('/').to_string(),
24 token: token.to_string(),
25 namespace: namespace.to_string(),
26 }
27 }
28
29 fn url(&self, endpoint: &str) -> String {
30 if self.namespace.is_empty() {
31 format!("{}/{endpoint}", self.base_url)
32 } else {
33 format!("{}/{endpoint}/{}", self.base_url, self.namespace)
34 }
35 }
36
37 async fn post<T: Serialize, R: for<'de> Deserialize<'de>>(
38 &self,
39 endpoint: &str,
40 body: &T,
41 ) -> Result<R> {
42 let url = self.url(endpoint);
43 debug!(url = %url, "upstash request");
44
45 let resp = self
46 .client
47 .post(&url)
48 .header("Authorization", format!("Bearer {}", self.token))
49 .json(body)
50 .send()
51 .await?;
52
53 if !resp.status().is_success() {
54 let status = resp.status();
55 let text = resp.text().await.unwrap_or_default();
56 return Err(Mem7Error::VectorStore(format!(
57 "Upstash HTTP {status}: {text}"
58 )));
59 }
60
61 resp.json()
62 .await
63 .map_err(|e| Mem7Error::VectorStore(format!("Upstash response parse error: {e}")))
64 }
65}
66
67#[derive(Serialize)]
70struct UpsertEntry {
71 id: String,
72 vector: Vec<f32>,
73 metadata: serde_json::Value,
74}
75
76#[derive(Serialize)]
77struct QueryRequest {
78 vector: Vec<f32>,
79 #[serde(rename = "topK")]
80 top_k: usize,
81 #[serde(rename = "includeMetadata")]
82 include_metadata: bool,
83 #[serde(rename = "includeVectors")]
84 include_vectors: bool,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 filter: Option<String>,
87}
88
89#[derive(Serialize)]
90struct FetchRequest {
91 ids: Vec<String>,
92 #[serde(rename = "includeMetadata")]
93 include_metadata: bool,
94 #[serde(rename = "includeVectors")]
95 include_vectors: bool,
96}
97
98#[derive(Serialize)]
99struct RangeRequest {
100 cursor: String,
101 limit: usize,
102 #[serde(rename = "includeMetadata")]
103 include_metadata: bool,
104 #[serde(rename = "includeVectors")]
105 include_vectors: bool,
106}
107
108#[derive(Deserialize)]
109struct UpstashResponse<T> {
110 result: T,
111}
112
113#[derive(Deserialize)]
114struct QueryResultEntry {
115 id: String,
116 score: f32,
117 metadata: Option<serde_json::Value>,
118}
119
120#[derive(Deserialize)]
121struct FetchResultEntry {
122 id: String,
123 vector: Option<Vec<f32>>,
124 metadata: Option<serde_json::Value>,
125}
126
127#[derive(Deserialize)]
128struct RangeResult {
129 #[serde(rename = "nextCursor")]
130 next_cursor: String,
131 vectors: Vec<FetchResultEntry>,
132}
133
134fn strip_nulls(val: &serde_json::Value) -> serde_json::Value {
135 match val {
136 serde_json::Value::Object(map) => {
137 let filtered: serde_json::Map<String, serde_json::Value> = map
138 .iter()
139 .filter(|(_, v)| !v.is_null())
140 .map(|(k, v)| (k.clone(), v.clone()))
141 .collect();
142 serde_json::Value::Object(filtered)
143 }
144 other => other.clone(),
145 }
146}
147
148fn build_filter(filter: &MemoryFilter) -> Option<String> {
149 let mut conditions = Vec::new();
150 if let Some(ref uid) = filter.user_id {
151 conditions.push(format!("user_id = '{uid}'"));
152 }
153 if let Some(ref aid) = filter.agent_id {
154 conditions.push(format!("agent_id = '{aid}'"));
155 }
156 if let Some(ref rid) = filter.run_id {
157 conditions.push(format!("run_id = '{rid}'"));
158 }
159 if conditions.is_empty() {
160 None
161 } else {
162 Some(conditions.join(" AND "))
163 }
164}
165
166fn matches_filter_local(metadata: &serde_json::Value, filter: &MemoryFilter) -> bool {
167 if let Some(ref uid) = filter.user_id
168 && metadata.get("user_id").and_then(|v| v.as_str()) != Some(uid.as_str())
169 {
170 return false;
171 }
172 if let Some(ref aid) = filter.agent_id
173 && metadata.get("agent_id").and_then(|v| v.as_str()) != Some(aid.as_str())
174 {
175 return false;
176 }
177 if let Some(ref rid) = filter.run_id
178 && metadata.get("run_id").and_then(|v| v.as_str()) != Some(rid.as_str())
179 {
180 return false;
181 }
182 true
183}
184
185#[async_trait]
186impl VectorIndex for UpstashVectorIndex {
187 async fn insert(&self, id: Uuid, vector: &[f32], payload: serde_json::Value) -> Result<()> {
188 let entries = vec![UpsertEntry {
189 id: id.to_string(),
190 vector: vector.to_vec(),
191 metadata: strip_nulls(&payload),
192 }];
193 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
194 Ok(())
195 }
196
197 async fn search(
198 &self,
199 query: &[f32],
200 limit: usize,
201 filters: Option<&MemoryFilter>,
202 ) -> Result<Vec<VectorSearchResult>> {
203 let req = QueryRequest {
204 vector: query.to_vec(),
205 top_k: limit,
206 include_metadata: true,
207 include_vectors: false,
208 filter: filters.and_then(build_filter),
209 };
210 let resp: UpstashResponse<Vec<QueryResultEntry>> = self.post("query", &req).await?;
211
212 Ok(resp
213 .result
214 .into_iter()
215 .filter_map(|entry| {
216 let id = Uuid::parse_str(&entry.id).ok()?;
217 Some(VectorSearchResult {
218 id,
219 score: entry.score,
220 payload: entry.metadata.unwrap_or(serde_json::Value::Null),
221 })
222 })
223 .collect())
224 }
225
226 async fn delete(&self, id: &Uuid) -> Result<()> {
227 let ids = vec![id.to_string()];
228 let _: UpstashResponse<serde_json::Value> = self.post("delete", &ids).await?;
229 Ok(())
230 }
231
232 async fn update(
233 &self,
234 id: &Uuid,
235 vector: Option<&[f32]>,
236 payload: Option<serde_json::Value>,
237 ) -> Result<()> {
238 let existing = self.get(id).await?;
240
241 let (vec_data, meta_data) = match existing {
242 Some((v, m)) => (v, m),
243 None => return Err(Mem7Error::NotFound(format!("vector entry {id}"))),
244 };
245
246 let final_vec = vector.map(|v| v.to_vec()).unwrap_or(vec_data);
247 let final_meta = payload
248 .map(|p| strip_nulls(&p))
249 .unwrap_or_else(|| strip_nulls(&meta_data));
250
251 let entries = vec![UpsertEntry {
252 id: id.to_string(),
253 vector: final_vec,
254 metadata: final_meta,
255 }];
256 let _: UpstashResponse<String> = self.post("upsert", &entries).await?;
257 Ok(())
258 }
259
260 async fn get(&self, id: &Uuid) -> Result<Option<(Vec<f32>, serde_json::Value)>> {
261 let ids = vec![id.to_string()];
262
263 let resp: UpstashResponse<Vec<FetchResultEntry>> = self
264 .post(
265 "fetch",
266 &FetchRequest {
267 ids,
268 include_metadata: true,
269 include_vectors: true,
270 },
271 )
272 .await?;
273
274 Ok(resp.result.into_iter().next().map(|entry| {
275 (
276 entry.vector.unwrap_or_default(),
277 entry.metadata.unwrap_or(serde_json::Value::Null),
278 )
279 }))
280 }
281
282 async fn list(
283 &self,
284 filters: Option<&MemoryFilter>,
285 limit: Option<usize>,
286 ) -> Result<Vec<(Uuid, serde_json::Value)>> {
287 let mut all_results = Vec::new();
288 let mut cursor = "0".to_string();
289 let page_size = 100;
290
291 loop {
292 let req = RangeRequest {
293 cursor,
294 limit: page_size,
295 include_metadata: true,
296 include_vectors: false,
297 };
298 let resp: UpstashResponse<RangeResult> = self.post("range", &req).await?;
299
300 for entry in resp.result.vectors {
301 if let Ok(id) = Uuid::parse_str(&entry.id) {
302 let metadata = entry.metadata.unwrap_or(serde_json::Value::Null);
303 let passes = filters
304 .map(|f| matches_filter_local(&metadata, f))
305 .unwrap_or(true);
306 if passes {
307 all_results.push((id, metadata));
308 }
309 }
310 }
311
312 if let Some(lim) = limit
313 && all_results.len() >= lim
314 {
315 all_results.truncate(lim);
316 break;
317 }
318
319 if resp.result.next_cursor.is_empty() || resp.result.next_cursor == "0" {
320 break;
321 }
322 cursor = resp.result.next_cursor;
323 }
324
325 all_results.sort_by(|a, b| a.0.cmp(&b.0));
326 Ok(all_results)
327 }
328
329 async fn reset(&self) -> Result<()> {
330 let url = self.url("reset");
331 let resp = self
332 .client
333 .post(&url)
334 .header("Authorization", format!("Bearer {}", self.token))
335 .send()
336 .await?;
337
338 if !resp.status().is_success() {
339 let status = resp.status();
340 let text = resp.text().await.unwrap_or_default();
341 return Err(Mem7Error::VectorStore(format!(
342 "Upstash reset HTTP {status}: {text}"
343 )));
344 }
345 Ok(())
346 }
347}