1use chrono::Utc;
4use mem_embed::Embedder;
5use mem_graph::GraphStore;
6use mem_types::*;
7use mem_vec::VecStore;
8use std::collections::HashMap;
9use uuid::Uuid;
10
11pub struct NaiveMemCube<G, V, E> {
13 pub graph: G,
14 pub vec_store: V,
15 pub embedder: E,
16 pub default_scope: String,
18}
19
20impl<G, V, E> NaiveMemCube<G, V, E>
21where
22 G: GraphStore + Send + Sync,
23 V: VecStore + Send + Sync,
24 E: Embedder + Send + Sync,
25{
26 pub fn new(graph: G, vec_store: V, embedder: E) -> Self {
27 Self {
28 graph,
29 vec_store,
30 embedder,
31 default_scope: "LongTermMemory".to_string(),
32 }
33 }
34
35 fn node_owner(metadata: &HashMap<String, serde_json::Value>) -> &str {
36 metadata
37 .get("user_name")
38 .and_then(|v| v.as_str())
39 .unwrap_or("")
40 }
41
42 fn parse_cursor(cursor: Option<&str>) -> Result<usize, MemCubeError> {
43 match cursor {
44 Some(c) => c
45 .parse::<usize>()
46 .map_err(|_| MemCubeError::BadRequest("invalid graph cursor".to_string())),
47 None => Ok(0),
48 }
49 }
50
51 fn normalize_scope(scope: &str) -> Option<&'static str> {
52 let normalized = scope
53 .trim()
54 .to_ascii_lowercase()
55 .replace([' ', '-', '_'], "");
56 match normalized.as_str() {
57 "workingmemory" | "working" | "shortterm" | "shorttermmemory" | "stm" | "recent" => {
58 Some(MemoryScope::WorkingMemory.as_str())
59 }
60 "usermemory" | "user" | "midterm" | "midtermmemory" | "profile" | "preference" => {
61 Some(MemoryScope::UserMemory.as_str())
62 }
63 "longtermmemory" | "longterm" | "ltm" => Some(MemoryScope::LongTermMemory.as_str()),
64 _ => None,
65 }
66 }
67
68 fn resolve_scope(req: &ApiAddRequest, default_scope: &str) -> String {
69 let from_info = req.info.as_ref().and_then(|info| {
70 info.get("scope")
71 .or_else(|| info.get("memory_scope"))
72 .and_then(|v| v.as_str())
73 });
74 from_info
75 .and_then(Self::normalize_scope)
76 .unwrap_or(default_scope)
77 .to_string()
78 }
79
80 fn bucket_name_for_scope(scope: &str) -> Option<&'static str> {
81 match scope {
82 "WorkingMemory" => Some("short_term"),
83 "UserMemory" => Some("mid_term"),
84 "LongTermMemory" => Some("long_term"),
85 _ => None,
86 }
87 }
88}
89
90#[async_trait::async_trait]
91impl<G, V, E> MemCube for NaiveMemCube<G, V, E>
92where
93 G: GraphStore + Send + Sync,
94 V: VecStore + Send + Sync,
95 E: Embedder + Send + Sync,
96{
97 async fn add_memories(&self, req: &ApiAddRequest) -> Result<MemoryResponse, MemCubeError> {
98 let content = req.content_to_store().ok_or_else(|| {
99 MemCubeError::Other("no messages or memory_content in request".to_string())
100 })?;
101 let cube_ids = req.writable_cube_ids();
102 let user_name = cube_ids.first().map(String::as_str).unwrap_or(&req.user_id);
103 let scope = Self::resolve_scope(req, &self.default_scope);
104
105 let id = Uuid::new_v4().to_string();
106 let embedding = self.embedder.embed(&content).await?;
107 let mut metadata = HashMap::new();
108 metadata.insert(
109 "scope".to_string(),
110 serde_json::Value::String(scope.clone()),
111 );
112 metadata.insert(
113 "created_at".to_string(),
114 serde_json::Value::String(Utc::now().to_rfc3339()),
115 );
116 if let Some(ref session_id) = req.session_id {
117 metadata.insert(
118 "session_id".to_string(),
119 serde_json::Value::String(session_id.clone()),
120 );
121 }
122 if let Some(ref task_id) = req.task_id {
123 metadata.insert(
124 "task_id".to_string(),
125 serde_json::Value::String(task_id.clone()),
126 );
127 }
128 if let Some(ref custom_tags) = req.custom_tags {
129 metadata.insert("custom_tags".to_string(), serde_json::json!(custom_tags));
130 }
131 if let Some(ref chat_history) = req.chat_history {
132 metadata.insert("chat_history".to_string(), serde_json::json!(chat_history));
133 }
134 if let Some(ref info) = req.info {
135 for (k, v) in info {
136 metadata.insert(k.clone(), v.clone());
137 }
138 metadata.insert(
139 "scope".to_string(),
140 serde_json::Value::String(scope.clone()),
141 );
142 }
143
144 let node = MemoryNode {
145 id: id.clone(),
146 memory: content.clone(),
147 metadata: metadata.clone(),
148 embedding: Some(embedding.clone()),
149 };
150 self.graph
151 .add_nodes_batch(&[node], Some(user_name))
152 .await
153 .map_err(MemCubeError::Graph)?;
154
155 if let Some(relations) = req.relations.as_ref() {
156 if !relations.is_empty() {
157 let mut edges = Vec::new();
158 for rel in relations {
159 let mut base_metadata = rel.metadata.clone();
160 base_metadata.insert(
161 "created_at".to_string(),
162 serde_json::Value::String(Utc::now().to_rfc3339()),
163 );
164 match rel.direction {
165 GraphDirection::Outbound => {
166 edges.push(MemoryEdge {
167 id: Uuid::new_v4().to_string(),
168 from: id.clone(),
169 to: rel.memory_id.clone(),
170 relation: rel.relation.clone(),
171 metadata: base_metadata.clone(),
172 });
173 }
174 GraphDirection::Inbound => {
175 edges.push(MemoryEdge {
176 id: Uuid::new_v4().to_string(),
177 from: rel.memory_id.clone(),
178 to: id.clone(),
179 relation: rel.relation.clone(),
180 metadata: base_metadata.clone(),
181 });
182 }
183 GraphDirection::Both => {
184 edges.push(MemoryEdge {
185 id: Uuid::new_v4().to_string(),
186 from: id.clone(),
187 to: rel.memory_id.clone(),
188 relation: rel.relation.clone(),
189 metadata: base_metadata.clone(),
190 });
191 edges.push(MemoryEdge {
192 id: Uuid::new_v4().to_string(),
193 from: rel.memory_id.clone(),
194 to: id.clone(),
195 relation: rel.relation.clone(),
196 metadata: base_metadata.clone(),
197 });
198 }
199 }
200 }
201 if let Err(e) = self.graph.add_edges_batch(&edges, Some(user_name)).await {
202 let _ = self.graph.delete_node(&id, Some(user_name)).await;
204 return Err(MemCubeError::Graph(e));
205 }
206 }
207 }
208
209 let payload = {
210 let mut p = HashMap::new();
211 p.insert(
212 "mem_cube_id".to_string(),
213 serde_json::Value::String(user_name.to_string()),
214 );
215 p.insert(
216 "memory_type".to_string(),
217 serde_json::Value::String("text_mem".to_string()),
218 );
219 p.insert("scope".to_string(), serde_json::Value::String(scope));
220 p
221 };
222 let item = VecStoreItem {
223 id: id.clone(),
224 vector: embedding,
225 payload,
226 };
227 if let Err(e) = self.vec_store.add(&[item], None).await {
228 let _ = self.graph.delete_node(&id, Some(user_name)).await;
230 return Err(MemCubeError::Vec(e));
231 }
232
233 let data = vec![serde_json::json!({ "id": id, "memory": content })];
234 Ok(MemoryResponse {
235 code: 200,
236 message: "Memory added successfully".to_string(),
237 data: Some(data),
238 })
239 }
240
241 async fn search_memories(
242 &self,
243 req: &ApiSearchRequest,
244 ) -> Result<SearchResponse, MemCubeError> {
245 let cube_ids = req.readable_cube_ids();
246 let user_name = cube_ids.first().map(String::as_str).unwrap_or(&req.user_id);
247
248 let query_vector = self.embedder.embed(&req.query).await?;
249 let top_k = req.top_k as usize;
250
251 let mut filter = req.filter.clone().unwrap_or_default();
252 filter.insert(
254 "mem_cube_id".to_string(),
255 serde_json::Value::String(user_name.to_string()),
256 );
257
258 let mut hits = self
259 .vec_store
260 .search(&query_vector, top_k, Some(&filter), None)
261 .await
262 .map_err(MemCubeError::Vec)?;
263 if req.relativity > 0.0 {
264 hits.retain(|h| h.score >= req.relativity);
265 }
266
267 let ids: Vec<String> = hits.iter().map(|h| h.id.clone()).collect();
268 if ids.is_empty() {
269 return Ok(SearchResponse {
270 code: 200,
271 message: "Search completed successfully".to_string(),
272 data: Some(SearchResponseData {
273 text_mem: vec![MemoryBucket {
274 name: Some("all".to_string()),
275 memories: vec![],
276 total_nodes: Some(0),
277 }],
278 pref_mem: vec![],
279 }),
280 });
281 }
282
283 let nodes = self
284 .graph
285 .get_nodes(&ids, false)
286 .await
287 .map_err(MemCubeError::Graph)?;
288
289 let memories: Vec<MemoryItem> = nodes
290 .into_iter()
291 .filter(|n| {
292 n.metadata
293 .get("state")
294 .and_then(|v| v.as_str())
295 .unwrap_or("active")
296 != "tombstone"
297 })
298 .map(|n| {
299 let mut meta = n.metadata.clone();
300 if let Some(score) = hits.iter().find(|h| h.id == n.id).map(|h| h.score) {
301 meta.insert(
302 "relativity".to_string(),
303 serde_json::Value::Number(
304 serde_json::Number::from_f64(score)
305 .unwrap_or(serde_json::Number::from(0)),
306 ),
307 );
308 }
309 MemoryItem {
310 id: n.id,
311 memory: n.memory,
312 metadata: meta,
313 }
314 })
315 .collect();
316
317 let all_bucket = MemoryBucket {
318 name: Some("all".to_string()),
319 total_nodes: Some(memories.len()),
320 memories: memories.clone(),
321 };
322 let mut text_mem = vec![all_bucket];
323 for scope in [
324 MemoryScope::WorkingMemory.as_str(),
325 MemoryScope::UserMemory.as_str(),
326 MemoryScope::LongTermMemory.as_str(),
327 ] {
328 let scoped: Vec<MemoryItem> = memories
329 .iter()
330 .filter(|m| {
331 m.metadata
332 .get("scope")
333 .and_then(|v| v.as_str())
334 .map(|s| s == scope)
335 .unwrap_or(false)
336 })
337 .cloned()
338 .collect();
339 if scoped.is_empty() {
340 continue;
341 }
342 text_mem.push(MemoryBucket {
343 name: Self::bucket_name_for_scope(scope).map(str::to_string),
344 total_nodes: Some(scoped.len()),
345 memories: scoped,
346 });
347 }
348 Ok(SearchResponse {
349 code: 200,
350 message: "Search completed successfully".to_string(),
351 data: Some(SearchResponseData {
352 text_mem,
353 pref_mem: vec![],
354 }),
355 })
356 }
357
358 async fn update_memory(
359 &self,
360 req: &UpdateMemoryRequest,
361 ) -> Result<UpdateMemoryResponse, MemCubeError> {
362 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
363 let id = &req.memory_id;
364
365 let existing = self
366 .graph
367 .get_node(id, false)
368 .await
369 .map_err(MemCubeError::Graph)?;
370 let node =
371 existing.ok_or_else(|| MemCubeError::NotFound(format!("memory not found: {}", id)))?;
372 let node_owner = Self::node_owner(&node.metadata);
373 if node_owner != user_name {
374 return Err(MemCubeError::NotFound(format!("memory not found: {}", id)));
375 }
376 let mut payload_scope = node
377 .metadata
378 .get("scope")
379 .and_then(|v| v.as_str())
380 .unwrap_or(MemoryScope::LongTermMemory.as_str())
381 .to_string();
382
383 let mut fields = HashMap::new();
384 if let Some(ref memory) = req.memory {
385 fields.insert(
386 "memory".to_string(),
387 serde_json::Value::String(memory.clone()),
388 );
389 }
390 let mut scope_changed = false;
391 if let Some(ref meta) = req.metadata {
392 for (k, v) in meta {
393 if k == "scope" {
394 if let Some(raw_scope) = v.as_str() {
395 if let Some(normalized_scope) = Self::normalize_scope(raw_scope) {
396 payload_scope = normalized_scope.to_string();
397 fields.insert(
398 "scope".to_string(),
399 serde_json::Value::String(payload_scope.clone()),
400 );
401 scope_changed = true;
402 } else {
403 return Err(MemCubeError::BadRequest(format!(
404 "invalid scope value: {}",
405 raw_scope
406 )));
407 }
408 } else {
409 return Err(MemCubeError::BadRequest(
410 "scope must be a string".to_string(),
411 ));
412 }
413 } else {
414 fields.insert(k.clone(), v.clone());
415 }
416 }
417 }
418 fields.insert(
419 "updated_at".to_string(),
420 serde_json::Value::String(Utc::now().to_rfc3339()),
421 );
422
423 if fields.len() > 1 || req.memory.is_some() {
424 self.graph
425 .update_node(id, &fields, Some(user_name))
426 .await
427 .map_err(MemCubeError::Graph)?;
428 }
429
430 if req.memory.is_some() || scope_changed {
431 let embedding = if let Some(ref new_memory) = req.memory {
432 self.embedder.embed(new_memory).await?
433 } else {
434 let ids = vec![id.to_string()];
435 let mut existing_items = self
436 .vec_store
437 .get_by_ids(&ids, None)
438 .await
439 .map_err(MemCubeError::Vec)?;
440 if let Some(existing_item) = existing_items.pop() {
441 existing_item.vector
442 } else {
443 self.embedder.embed(&node.memory).await?
444 }
445 };
446 let payload = {
447 let mut p = HashMap::new();
448 p.insert(
449 "mem_cube_id".to_string(),
450 serde_json::Value::String(user_name.to_string()),
451 );
452 p.insert(
453 "memory_type".to_string(),
454 serde_json::Value::String("text_mem".to_string()),
455 );
456 p.insert(
457 "scope".to_string(),
458 serde_json::Value::String(payload_scope),
459 );
460 p
461 };
462 let item = VecStoreItem {
463 id: id.to_string(),
464 vector: embedding,
465 payload,
466 };
467 self.vec_store
468 .upsert(&[item], None)
469 .await
470 .map_err(MemCubeError::Vec)?;
471 }
472
473 let data = vec![serde_json::json!({ "id": id, "updated": true })];
474 Ok(UpdateMemoryResponse {
475 code: 200,
476 message: "Memory updated successfully".to_string(),
477 data: Some(data),
478 })
479 }
480
481 async fn forget_memory(
482 &self,
483 req: &ForgetMemoryRequest,
484 ) -> Result<ForgetMemoryResponse, MemCubeError> {
485 let id = &req.memory_id;
486 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
487
488 let existing = self
489 .graph
490 .get_node(id, false)
491 .await
492 .map_err(MemCubeError::Graph)?;
493 let node =
494 existing.ok_or_else(|| MemCubeError::NotFound(format!("memory not found: {}", id)))?;
495 let node_owner = Self::node_owner(&node.metadata);
496 if node_owner != user_name {
497 return Err(MemCubeError::NotFound(format!("memory not found: {}", id)));
498 }
499
500 if req.soft {
501 let mut fields = HashMap::new();
502 fields.insert(
503 "state".to_string(),
504 serde_json::Value::String("tombstone".to_string()),
505 );
506 fields.insert(
507 "updated_at".to_string(),
508 serde_json::Value::String(Utc::now().to_rfc3339()),
509 );
510 self.graph
511 .update_node(id, &fields, Some(user_name))
512 .await
513 .map_err(MemCubeError::Graph)?;
514 self.vec_store
515 .delete(&[id.to_string()], None)
516 .await
517 .map_err(MemCubeError::Vec)?;
518 } else {
519 self.graph
520 .delete_node(id, Some(user_name))
521 .await
522 .map_err(MemCubeError::Graph)?;
523 self.vec_store
524 .delete(&[id.to_string()], None)
525 .await
526 .map_err(MemCubeError::Vec)?;
527 }
528 let data = vec![serde_json::json!({ "id": id, "forgotten": true })];
529 Ok(ForgetMemoryResponse {
530 code: 200,
531 message: "Memory forgotten successfully".to_string(),
532 data: Some(data),
533 })
534 }
535
536 async fn get_memory(&self, req: &GetMemoryRequest) -> Result<GetMemoryResponse, MemCubeError> {
537 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
538 let node_opt = self
539 .graph
540 .get_node(&req.memory_id, false)
541 .await
542 .map_err(MemCubeError::Graph)?;
543 let node = match node_opt {
544 Some(n) => n,
545 None => {
546 return Ok(GetMemoryResponse {
547 code: 404,
548 message: "Memory not found".to_string(),
549 data: None,
550 });
551 }
552 };
553 let node_user = Self::node_owner(&node.metadata);
554 if node_user != user_name {
555 return Ok(GetMemoryResponse {
556 code: 404,
557 message: "Memory not found".to_string(),
558 data: None,
559 });
560 }
561 let state = node
562 .metadata
563 .get("state")
564 .and_then(|v| v.as_str())
565 .unwrap_or("active");
566 if state == "tombstone" && !req.include_deleted {
567 return Ok(GetMemoryResponse {
568 code: 404,
569 message: "Memory not found".to_string(),
570 data: None,
571 });
572 }
573 let item = MemoryItem {
574 id: node.id,
575 memory: node.memory,
576 metadata: node.metadata,
577 };
578 Ok(GetMemoryResponse {
579 code: 200,
580 message: "Success".to_string(),
581 data: Some(item),
582 })
583 }
584
585 async fn graph_neighbors(
586 &self,
587 req: &GraphNeighborsRequest,
588 ) -> Result<GraphNeighborsResponse, MemCubeError> {
589 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
590 let offset = Self::parse_cursor(req.cursor.as_deref())?;
591 let source = self
592 .graph
593 .get_node(&req.memory_id, false)
594 .await
595 .map_err(MemCubeError::Graph)?
596 .ok_or_else(|| {
597 MemCubeError::NotFound(format!("memory not found: {}", req.memory_id))
598 })?;
599 if Self::node_owner(&source.metadata) != user_name {
600 return Err(MemCubeError::NotFound(format!(
601 "memory not found: {}",
602 req.memory_id
603 )));
604 }
605 let source_state = source
606 .metadata
607 .get("state")
608 .and_then(|v| v.as_str())
609 .unwrap_or("active");
610 if source_state == "tombstone" && !req.include_deleted {
611 return Err(MemCubeError::NotFound(format!(
612 "memory not found: {}",
613 req.memory_id
614 )));
615 }
616
617 let neighbors = self
618 .graph
619 .get_neighbors(
620 &req.memory_id,
621 req.relation.as_deref(),
622 req.direction,
623 usize::MAX,
624 req.include_embedding,
625 Some(user_name),
626 )
627 .await
628 .map_err(|e| {
629 let msg = e.to_string();
630 if msg.contains("not found") || msg.contains("access denied") {
631 MemCubeError::NotFound(format!("memory not found: {}", req.memory_id))
632 } else {
633 MemCubeError::Graph(e)
634 }
635 })?;
636
637 let all_items: Vec<GraphNeighborItem> = neighbors
638 .into_iter()
639 .filter(|n| {
640 if req.include_deleted {
641 return true;
642 }
643 n.node
644 .metadata
645 .get("state")
646 .and_then(|v| v.as_str())
647 .unwrap_or("active")
648 != "tombstone"
649 })
650 .map(|n| GraphNeighborItem {
651 edge: n.edge,
652 memory: MemoryItem {
653 id: n.node.id,
654 memory: n.node.memory,
655 metadata: n.node.metadata,
656 },
657 })
658 .collect();
659
660 let limit = req.limit as usize;
661 let items: Vec<GraphNeighborItem> =
662 all_items.iter().skip(offset).take(limit).cloned().collect();
663 let next_cursor = if offset + items.len() < all_items.len() {
664 Some((offset + items.len()).to_string())
665 } else {
666 None
667 };
668
669 Ok(GraphNeighborsResponse {
670 code: 200,
671 message: "Success".to_string(),
672 data: Some(GraphNeighborsData { items, next_cursor }),
673 })
674 }
675
676 async fn graph_path(&self, req: &GraphPathRequest) -> Result<GraphPathResponse, MemCubeError> {
677 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
678 let path = self
679 .graph
680 .shortest_path(
681 &req.source_memory_id,
682 &req.target_memory_id,
683 req.relation.as_deref(),
684 req.direction,
685 req.max_depth as usize,
686 req.include_deleted,
687 Some(user_name),
688 )
689 .await
690 .map_err(|e| {
691 let msg = e.to_string();
692 if msg.contains("not found") || msg.contains("access denied") {
693 MemCubeError::NotFound(format!(
694 "memory not found: {} or {}",
695 req.source_memory_id, req.target_memory_id
696 ))
697 } else {
698 MemCubeError::Graph(e)
699 }
700 })?
701 .ok_or_else(|| {
702 MemCubeError::NotFound(format!(
703 "path not found: {} -> {}",
704 req.source_memory_id, req.target_memory_id
705 ))
706 })?;
707
708 let nodes = self
709 .graph
710 .get_nodes(&path.node_ids, false)
711 .await
712 .map_err(MemCubeError::Graph)?;
713 let items: Vec<MemoryItem> = nodes
714 .into_iter()
715 .map(|n| MemoryItem {
716 id: n.id,
717 memory: n.memory,
718 metadata: n.metadata,
719 })
720 .collect();
721
722 Ok(GraphPathResponse {
723 code: 200,
724 message: "Success".to_string(),
725 data: Some(GraphPathData {
726 hops: path.edges.len() as u32,
727 nodes: items,
728 edges: path.edges,
729 }),
730 })
731 }
732
733 async fn graph_paths(
734 &self,
735 req: &GraphPathsRequest,
736 ) -> Result<GraphPathsResponse, MemCubeError> {
737 if req.top_k_paths == 0 {
738 return Err(MemCubeError::BadRequest(
739 "top_k_paths must be greater than 0".to_string(),
740 ));
741 }
742 let user_name = req.mem_cube_id.as_deref().unwrap_or(req.user_id.as_str());
743 let paths = self
744 .graph
745 .find_paths(
746 &req.source_memory_id,
747 &req.target_memory_id,
748 req.relation.as_deref(),
749 req.direction,
750 req.max_depth as usize,
751 req.top_k_paths as usize,
752 req.include_deleted,
753 Some(user_name),
754 )
755 .await
756 .map_err(|e| {
757 let msg = e.to_string();
758 if msg.contains("not found") || msg.contains("access denied") {
759 MemCubeError::NotFound(format!(
760 "memory not found: {} or {}",
761 req.source_memory_id, req.target_memory_id
762 ))
763 } else {
764 MemCubeError::Graph(e)
765 }
766 })?;
767 if paths.is_empty() {
768 return Err(MemCubeError::NotFound(format!(
769 "path not found: {} -> {}",
770 req.source_memory_id, req.target_memory_id
771 )));
772 }
773
774 let mut out = Vec::with_capacity(paths.len());
775 for path in paths {
776 let nodes = self
777 .graph
778 .get_nodes(&path.node_ids, false)
779 .await
780 .map_err(MemCubeError::Graph)?;
781 let items: Vec<MemoryItem> = nodes
782 .into_iter()
783 .map(|n| MemoryItem {
784 id: n.id,
785 memory: n.memory,
786 metadata: n.metadata,
787 })
788 .collect();
789 out.push(GraphPathData {
790 hops: path.edges.len() as u32,
791 nodes: items,
792 edges: path.edges,
793 });
794 }
795
796 Ok(GraphPathsResponse {
797 code: 200,
798 message: "Success".to_string(),
799 data: Some(out),
800 })
801 }
802}