1use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24 tokio::spawn(async move {
25 let result = tokio::task::spawn(fut).await;
26 if let Err(e) = result {
27 tracing::error!("Background task '{}' panicked: {:?}", name, e);
28 }
29 });
30}
31
32pub fn get_tools() -> Vec<Tool> {
34 let store_schema = {
35 #[allow(unused_mut)]
36 let mut props = json!({
37 "content": {
38 "type": "string",
39 "description": "The content to store"
40 },
41 "scope": {
42 "type": "string",
43 "enum": ["project", "global"],
44 "default": "project",
45 "description": "Where to store: 'project' (current project) or 'global' (all projects)"
46 }
47 });
48
49 #[cfg(feature = "bench")]
50 {
51 props.as_object_mut().unwrap().insert(
52 "created_at".to_string(),
53 json!({
54 "type": "number",
55 "description": "Override creation timestamp (Unix seconds). Benchmark builds only."
56 }),
57 );
58 }
59
60 json!({
61 "type": "object",
62 "properties": props,
63 "required": ["content"]
64 })
65 };
66
67 vec![
68 Tool {
69 name: "store".to_string(),
70 description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
71 input_schema: store_schema,
72 },
73 Tool {
74 name: "recall".to_string(),
75 description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
76 input_schema: json!({
77 "type": "object",
78 "properties": {
79 "query": {
80 "type": "string",
81 "description": "What to search for (semantic search)"
82 },
83 "limit": {
84 "type": "number",
85 "default": 5,
86 "description": "Maximum number of results"
87 }
88 },
89 "required": ["query"]
90 }),
91 },
92 Tool {
93 name: "list".to_string(),
94 description: "List stored items.".to_string(),
95 input_schema: json!({
96 "type": "object",
97 "properties": {
98 "limit": {
99 "type": "number",
100 "default": 10,
101 "description": "Maximum number of results"
102 },
103 "scope": {
104 "type": "string",
105 "enum": ["project", "global", "all"],
106 "default": "project",
107 "description": "Which items to list: 'project', 'global', or 'all'"
108 }
109 }
110 }),
111 },
112 Tool {
113 name: "forget".to_string(),
114 description: "Delete a stored item by its ID.".to_string(),
115 input_schema: json!({
116 "type": "object",
117 "properties": {
118 "id": {
119 "type": "string",
120 "description": "The item ID to delete"
121 }
122 },
123 "required": ["id"]
124 }),
125 },
126 ]
127}
128
129#[derive(Debug, Deserialize)]
132pub struct StoreParams {
133 pub content: String,
134 #[serde(default)]
135 pub scope: Option<String>,
136 #[cfg(feature = "bench")]
138 #[serde(default)]
139 pub created_at: Option<i64>,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct RecallParams {
144 pub query: String,
145 #[serde(default)]
146 pub limit: Option<usize>,
147}
148
149#[derive(Debug, Deserialize)]
150pub struct ListParams {
151 #[serde(default)]
152 pub limit: Option<usize>,
153 #[serde(default)]
154 pub scope: Option<String>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ForgetParams {
159 pub id: String,
160}
161
162pub struct RecallConfig {
167 pub enable_graph_backfill: bool,
168 pub enable_graph_expansion: bool,
169 pub enable_co_access: bool,
170 pub enable_decay_scoring: bool,
171 pub enable_background_tasks: bool,
172}
173
174impl Default for RecallConfig {
175 fn default() -> Self {
176 Self {
177 enable_graph_backfill: true,
178 enable_graph_expansion: true,
179 enable_co_access: true,
180 enable_decay_scoring: true,
181 enable_background_tasks: true,
182 }
183 }
184}
185
186pub struct RecallResult {
188 pub results: Vec<crate::item::SearchResult>,
189 pub graph_expanded: Vec<Value>,
190 pub suggested: Vec<Value>,
191 pub raw_similarities: std::collections::HashMap<String, f32>,
193}
194
195pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
198 let config = RetryConfig::default();
199 let args_for_retry = args.clone();
200
201 let result = with_retry(&config, || {
202 let ctx_ref = ctx;
203 let name_ref = name;
204 let args_clone = args_for_retry.clone();
205
206 async move {
207 let mut db = Database::open_with_embedder(
209 &ctx_ref.db_path,
210 ctx_ref.project_id.clone(),
211 ctx_ref.embedder.clone(),
212 )
213 .await
214 .map_err(|e| sanitize_err("Failed to open database", e))?;
215
216 let tracker = AccessTracker::open(&ctx_ref.access_db_path)
218 .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
219
220 let graph = GraphStore::open(&ctx_ref.access_db_path)
222 .map_err(|e| sanitize_err("Failed to open graph store", e))?;
223
224 let result = match name_ref {
225 "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
226 "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
227 "list" => execute_list(&mut db, args_clone).await,
228 "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
229 _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
230 };
231
232 if result.is_error.unwrap_or(false)
233 && let Some(content) = result.content.first()
234 && is_retryable_error(&content.text)
235 {
236 return Err(content.text.clone());
237 }
238
239 Ok(result)
240 }
241 })
242 .await;
243
244 match result {
245 Ok(call_result) => call_result,
246 Err(e) => {
247 tracing::error!("Operation failed after retries: {}", e);
248 CallToolResult::error("Operation failed after retries")
249 }
250 }
251}
252
253fn is_retryable_error(error_msg: &str) -> bool {
254 let retryable_patterns = [
255 "connection",
256 "timeout",
257 "temporarily unavailable",
258 "resource busy",
259 "lock",
260 "I/O error",
261 "Failed to open",
262 "Failed to connect",
263 ];
264
265 let lower = error_msg.to_lowercase();
266 retryable_patterns
267 .iter()
268 .any(|p| lower.contains(&p.to_lowercase()))
269}
270
271async fn execute_store(
274 db: &mut Database,
275 _tracker: &AccessTracker,
276 graph: &GraphStore,
277 ctx: &ServerContext,
278 args: Option<Value>,
279) -> CallToolResult {
280 let params: StoreParams = match args {
281 Some(v) => match serde_json::from_value(v) {
282 Ok(p) => p,
283 Err(e) => {
284 tracing::debug!("Parameter validation failed: {}", e);
285 return CallToolResult::error("Invalid parameters");
286 }
287 },
288 None => return CallToolResult::error("Missing parameters"),
289 };
290
291 const MAX_CONTENT_BYTES: usize = 1_000_000;
295 if params.content.len() > MAX_CONTENT_BYTES {
296 return CallToolResult::error(format!(
297 "Content too large: {} bytes (max {} bytes)",
298 params.content.len(),
299 MAX_CONTENT_BYTES
300 ));
301 }
302
303 let scope = params
305 .scope
306 .as_deref()
307 .map(|s| s.parse::<StoreScope>())
308 .transpose();
309
310 let scope = match scope {
311 Ok(s) => s.unwrap_or(StoreScope::Project),
312 Err(e) => return CallToolResult::error(e),
313 };
314
315 let mut item = Item::new(¶ms.content);
317
318 #[cfg(feature = "bench")]
320 if let Some(ts) = params.created_at {
321 if let Some(dt) = chrono::DateTime::from_timestamp(ts, 0) {
322 item = item.with_created_at(dt);
323 }
324 }
325
326 if scope == StoreScope::Project
328 && let Some(project_id) = db.project_id()
329 {
330 item = item.with_project_id(project_id);
331 }
332
333 match db.store_item(item).await {
334 Ok(store_result) => {
335 let new_id = store_result.id.clone();
336
337 let now = chrono::Utc::now().timestamp();
339 let project_id = db.project_id().map(|s| s.to_string());
340 if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
341 tracing::warn!("graph add_node failed: {}", e);
342 }
343
344 if !store_result.potential_conflicts.is_empty()
346 && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
347 {
348 for conflict in &store_result.potential_conflicts {
349 if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
350 {
351 tracing::warn!("enqueue consolidation failed: {}", e);
352 }
353 }
354 }
355
356 let mut result = json!({
357 "success": true,
358 "id": new_id,
359 "message": format!("Stored in {} scope", scope)
360 });
361
362 if !store_result.potential_conflicts.is_empty() {
363 let conflicts: Vec<Value> = store_result
364 .potential_conflicts
365 .iter()
366 .map(|c| {
367 json!({
368 "id": c.id,
369 "content": c.content,
370 "similarity": format!("{:.2}", c.similarity)
371 })
372 })
373 .collect();
374 result["potential_conflicts"] = json!(conflicts);
375 }
376
377 CallToolResult::success(
378 serde_json::to_string_pretty(&result)
379 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
380 )
381 }
382 Err(e) => sanitized_error("Failed to store item", e),
383 }
384}
385
386pub async fn recall_pipeline(
391 db: &mut Database,
392 tracker: &AccessTracker,
393 graph: &GraphStore,
394 query: &str,
395 limit: usize,
396 filters: ItemFilters,
397 config: &RecallConfig,
398) -> std::result::Result<RecallResult, String> {
399 let mut results = db
400 .search_items(query, limit, filters)
401 .await
402 .map_err(|e| format!("Search failed: {}", e))?;
403
404 if results.is_empty() {
405 return Ok(RecallResult {
406 results: Vec::new(),
407 graph_expanded: Vec::new(),
408 suggested: Vec::new(),
409 raw_similarities: std::collections::HashMap::new(),
410 });
411 }
412
413 if config.enable_graph_backfill {
415 for result in &results {
416 if let Err(e) = graph.ensure_node_exists(
417 &result.id,
418 result.project_id.as_deref(),
419 result.created_at.timestamp(),
420 ) {
421 tracing::warn!("ensure_node_exists failed: {}", e);
422 }
423 }
424 }
425
426 let mut raw_similarities: std::collections::HashMap<String, f32> =
428 std::collections::HashMap::new();
429 if config.enable_decay_scoring {
430 let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
431 let decay_data = tracker.get_decay_data(&item_ids).unwrap_or_default();
432 let edge_counts = graph.get_edge_counts(&item_ids).unwrap_or_default();
433 let now = chrono::Utc::now().timestamp();
434
435 for result in &mut results {
436 raw_similarities.insert(result.id.clone(), result.similarity);
437
438 let created_at = result.created_at.timestamp();
439 let (access_count, last_accessed, validation_count) = match decay_data.get(&result.id) {
440 Some(data) => (
441 data.access_count,
442 Some(data.last_accessed_at),
443 data.validation_count,
444 ),
445 None => (0, None, 0),
446 };
447
448 let base_score = score_with_decay(
449 result.similarity,
450 now,
451 created_at,
452 access_count,
453 last_accessed,
454 );
455
456 let edge_count = edge_counts.get(&result.id).copied().unwrap_or(0);
457 let trust_bonus =
458 1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
459
460 result.similarity = (base_score * trust_bonus).min(1.0);
461 }
462
463 results.sort_by(|a, b| {
464 b.similarity
465 .partial_cmp(&a.similarity)
466 .unwrap_or(std::cmp::Ordering::Equal)
467 });
468 }
469
470 for result in &results {
472 let created_at = result.created_at.timestamp();
473 if let Err(e) = tracker.record_access(&result.id, created_at) {
474 tracing::warn!("record_access failed: {}", e);
475 }
476 }
477
478 let existing_ids: std::collections::HashSet<String> =
480 results.iter().map(|r| r.id.clone()).collect();
481
482 let mut graph_expanded = Vec::new();
483 if config.enable_graph_expansion {
484 let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
485 if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
486 let neighbor_info: Vec<(String, String)> = neighbors
488 .into_iter()
489 .filter(|(id, _, _)| !existing_ids.contains(id))
490 .map(|(id, rel_type, _)| (id, rel_type))
491 .collect();
492
493 let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
494 if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
495 let item_map: std::collections::HashMap<&str, &Item> =
496 items.iter().map(|item| (item.id.as_str(), item)).collect();
497
498 for (neighbor_id, rel_type) in &neighbor_info {
499 if let Some(item) = item_map.get(neighbor_id.as_str()) {
500 let sr = crate::item::SearchResult::from_item(item, 0.05);
501 let mut entry = json!({
502 "id": sr.id,
503 "similarity": "graph",
504 "created": sr.created_at.to_rfc3339(),
505 "graph_expanded": true,
506 "rel_type": rel_type,
507 });
508 let same_project = match (db.project_id(), item.project_id.as_deref()) {
510 (Some(current), Some(item_pid)) => current == item_pid,
511 (_, None) => true,
512 _ => false,
513 };
514 if same_project {
515 entry["content"] = json!(sr.content);
516 } else {
517 entry["cross_project"] = json!(true);
518 }
519 graph_expanded.push(entry);
520 }
521 }
522 }
523 }
524 }
525
526 let mut suggested = Vec::new();
528 if config.enable_co_access {
529 let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
530 if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
531 let co_info: Vec<(String, i64)> = co_accessed
532 .into_iter()
533 .filter(|(id, _)| !existing_ids.contains(id))
534 .collect();
535
536 let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
537 if let Ok(items) = db.get_items_batch(&co_ids).await {
538 let item_map: std::collections::HashMap<&str, &Item> =
539 items.iter().map(|item| (item.id.as_str(), item)).collect();
540
541 for (co_id, co_count) in &co_info {
542 if let Some(item) = item_map.get(co_id.as_str()) {
543 let same_project = match (db.project_id(), item.project_id.as_deref()) {
544 (Some(current), Some(item_pid)) => current == item_pid,
545 (_, None) => true,
546 _ => false,
547 };
548 let mut entry = json!({
549 "id": item.id,
550 "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
551 });
552 if same_project {
553 entry["content"] = json!(truncate(&item.content, 100));
554 } else {
555 entry["cross_project"] = json!(true);
556 }
557 suggested.push(entry);
558 }
559 }
560 }
561 }
562 }
563
564 Ok(RecallResult {
565 results,
566 graph_expanded,
567 suggested,
568 raw_similarities,
569 })
570}
571
572async fn execute_recall(
573 db: &mut Database,
574 tracker: &AccessTracker,
575 graph: &GraphStore,
576 ctx: &ServerContext,
577 args: Option<Value>,
578) -> CallToolResult {
579 let params: RecallParams = match args {
580 Some(v) => match serde_json::from_value(v) {
581 Ok(p) => p,
582 Err(e) => {
583 tracing::debug!("Parameter validation failed: {}", e);
584 return CallToolResult::error("Invalid parameters");
585 }
586 },
587 None => return CallToolResult::error("Missing parameters"),
588 };
589
590 const MAX_QUERY_BYTES: usize = 10_000;
594 if params.query.len() > MAX_QUERY_BYTES {
595 return CallToolResult::error(format!(
596 "Query too large: {} bytes (max {} bytes)",
597 params.query.len(),
598 MAX_QUERY_BYTES
599 ));
600 }
601
602 let limit = params.limit.unwrap_or(5).min(100);
603
604 let filters = ItemFilters::new();
605
606 let config = RecallConfig::default();
607
608 let recall_result =
609 match recall_pipeline(db, tracker, graph, ¶ms.query, limit, filters, &config).await {
610 Ok(r) => r,
611 Err(e) => {
612 tracing::error!("Recall failed: {}", e);
613 return CallToolResult::error("Search failed");
614 }
615 };
616
617 if recall_result.results.is_empty() {
618 return CallToolResult::success("No items found matching your query.");
619 }
620
621 let results = &recall_result.results;
622
623 let all_result_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
625 let neighbors_map = graph
626 .get_neighbors_mapped(&all_result_ids, 0.5)
627 .unwrap_or_default();
628
629 let formatted: Vec<Value> = results
630 .iter()
631 .map(|r| {
632 let mut obj = json!({
633 "id": r.id,
634 "content": r.content,
635 "similarity": format!("{:.2}", r.similarity),
636 "created": r.created_at.to_rfc3339(),
637 });
638
639 if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
641 && (raw_sim - r.similarity).abs() > 0.001
642 {
643 obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
644 }
645
646 if let Some(ref excerpt) = r.relevant_excerpt {
647 obj["relevant_excerpt"] = json!(excerpt);
648 }
649
650 if let Some(ref current_pid) = ctx.project_id
652 && let Some(ref item_pid) = r.project_id
653 && item_pid != current_pid
654 {
655 obj["cross_project"] = json!(true);
656 }
657
658 if let Some(related) = neighbors_map.get(&r.id)
660 && !related.is_empty()
661 {
662 obj["related_ids"] = json!(related);
663 }
664
665 obj
666 })
667 .collect();
668
669 let mut result_json = json!({
670 "count": results.len(),
671 "results": formatted
672 });
673
674 if !recall_result.graph_expanded.is_empty() {
675 result_json["graph_expanded"] = json!(recall_result.graph_expanded);
676 }
677
678 if !recall_result.suggested.is_empty() {
679 result_json["suggested"] = json!(recall_result.suggested);
680 }
681
682 spawn_consolidation(
684 Arc::new(ctx.db_path.clone()),
685 Arc::new(ctx.access_db_path.clone()),
686 ctx.project_id.clone(),
687 ctx.embedder.clone(),
688 ctx.consolidation_semaphore.clone(),
689 );
690
691 let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
693 let access_db_path = ctx.access_db_path.clone();
694 spawn_logged("co_access", async move {
695 if let Ok(g) = GraphStore::open(&access_db_path) {
696 if let Err(e) = g.record_co_access(&result_ids) {
697 tracing::warn!("record_co_access failed: {}", e);
698 }
699 } else {
700 tracing::warn!("co_access: failed to open graph store");
701 }
702 });
703
704 let run_count = ctx
706 .recall_count
707 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
708 if run_count % 10 == 9 {
709 let access_db_path = ctx.access_db_path.clone();
711 spawn_logged("clustering", async move {
712 if let Ok(g) = GraphStore::open(&access_db_path)
713 && let Ok(clusters) = g.detect_clusters()
714 {
715 for (a, b, c) in &clusters {
716 let label = format!("cluster-{}", &a[..8.min(a.len())]);
717 if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
718 tracing::warn!("cluster add_related_edge failed: {}", e);
719 }
720 if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
721 tracing::warn!("cluster add_related_edge failed: {}", e);
722 }
723 if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
724 tracing::warn!("cluster add_related_edge failed: {}", e);
725 }
726 }
727 if !clusters.is_empty() {
728 tracing::info!("Detected {} clusters", clusters.len());
729 }
730 }
731 });
732
733 let access_db_path2 = ctx.access_db_path.clone();
735 spawn_logged("consolidation_cleanup", async move {
736 if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
737 && let Err(e) = q.cleanup_processed()
738 {
739 tracing::warn!("consolidation queue cleanup failed: {}", e);
740 }
741 });
742 }
743
744 CallToolResult::success(
745 serde_json::to_string_pretty(&result_json)
746 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
747 )
748}
749
750async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
751 let params: ListParams =
752 args.and_then(|v| serde_json::from_value(v).ok())
753 .unwrap_or(ListParams {
754 limit: None,
755 scope: None,
756 });
757
758 let limit = params.limit.unwrap_or(10).min(100);
759
760 let filters = ItemFilters::new();
761
762 let scope = params
763 .scope
764 .as_deref()
765 .map(|s| s.parse::<ListScope>())
766 .transpose();
767
768 let scope = match scope {
769 Ok(s) => s.unwrap_or(ListScope::Project),
770 Err(e) => return CallToolResult::error(e),
771 };
772
773 match db.list_items(filters, Some(limit), scope).await {
774 Ok(items) => {
775 if items.is_empty() {
776 CallToolResult::success("No items stored yet.")
777 } else {
778 let formatted: Vec<Value> = items
779 .iter()
780 .map(|item| {
781 let content_preview = truncate(&item.content, 100);
782 let mut obj = json!({
783 "id": item.id,
784 "content": content_preview,
785 "created": item.created_at.to_rfc3339(),
786 });
787
788 if item.is_chunked {
789 obj["chunked"] = json!(true);
790 }
791
792 obj
793 })
794 .collect();
795
796 let result = json!({
797 "count": items.len(),
798 "items": formatted
799 });
800
801 CallToolResult::success(
802 serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
803 format!("{{\"error\": \"serialization failed: {}\"}}", e)
804 }),
805 )
806 }
807 }
808 Err(e) => sanitized_error("Failed to list items", e),
809 }
810}
811
812async fn execute_forget(
813 db: &mut Database,
814 graph: &GraphStore,
815 ctx: &ServerContext,
816 args: Option<Value>,
817) -> CallToolResult {
818 let params: ForgetParams = match args {
819 Some(v) => match serde_json::from_value(v) {
820 Ok(p) => p,
821 Err(e) => {
822 tracing::debug!("Parameter validation failed: {}", e);
823 return CallToolResult::error("Invalid parameters");
824 }
825 },
826 None => return CallToolResult::error("Missing parameters"),
827 };
828
829 if let Some(ref current_pid) = ctx.project_id {
831 match db.get_item(¶ms.id).await {
832 Ok(Some(item)) => {
833 if let Some(ref item_pid) = item.project_id
834 && item_pid != current_pid
835 {
836 return CallToolResult::error(format!(
837 "Cannot delete item {} from a different project",
838 params.id
839 ));
840 }
841 }
842 Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
843 Err(e) => {
844 return sanitized_error("Failed to look up item", e);
845 }
846 }
847 }
848
849 match db.delete_item(¶ms.id).await {
850 Ok(true) => {
851 if let Err(e) = graph.remove_node(¶ms.id) {
853 tracing::warn!("remove_node failed: {}", e);
854 }
855
856 let result = json!({
857 "success": true,
858 "message": format!("Deleted item: {}", params.id)
859 });
860 CallToolResult::success(
861 serde_json::to_string_pretty(&result)
862 .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
863 )
864 }
865 Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
866 Err(e) => sanitized_error("Failed to delete item", e),
867 }
868}
869
870fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
875 tracing::error!("{}: {}", context, err);
876 CallToolResult::error(context.to_string())
877}
878
879fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
881 tracing::error!("{}: {}", context, err);
882 context.to_string()
883}
884
885fn truncate(s: &str, max_len: usize) -> String {
886 if s.chars().count() <= max_len {
887 s.to_string()
888 } else if max_len <= 3 {
889 s.chars().take(max_len).collect()
891 } else {
892 let cut = s
893 .char_indices()
894 .nth(max_len - 3)
895 .map(|(i, _)| i)
896 .unwrap_or(s.len());
897 format!("{}...", &s[..cut])
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::*;
904
905 #[test]
906 fn test_truncate_small_max_len() {
907 assert_eq!(truncate("hello", 0), "");
909 assert_eq!(truncate("hello", 1), "h");
910 assert_eq!(truncate("hello", 2), "he");
911 assert_eq!(truncate("hello", 3), "hel");
912 assert_eq!(truncate("hi", 3), "hi"); assert_eq!(truncate("hello", 5), "hello");
914 assert_eq!(truncate("hello!", 5), "he...");
915 }
916
917 #[test]
918 fn test_truncate_unicode() {
919 assert_eq!(truncate("héllo wörld", 5), "hé...");
920 assert_eq!(truncate("日本語テスト", 4), "日...");
921 }
922
923 use std::path::PathBuf;
926 use std::sync::Mutex;
927 use tokio::sync::Semaphore;
928
929 async fn setup_test_context() -> (ServerContext, tempfile::TempDir) {
931 let tmp = tempfile::TempDir::new().unwrap();
932 let db_path = tmp.path().join("data");
933 let access_db_path = tmp.path().join("access.db");
934
935 let embedder = Arc::new(crate::Embedder::new().unwrap());
936 let project_id = Some("test-project-00000001".to_string());
937
938 let ctx = ServerContext {
939 db_path,
940 access_db_path,
941 project_id,
942 embedder,
943 cwd: PathBuf::from("."),
944 consolidation_semaphore: Arc::new(Semaphore::new(1)),
945 recall_count: std::sync::atomic::AtomicU64::new(0),
946 rate_limit: Mutex::new(super::super::server::RateLimitState {
947 window_start_ms: 0,
948 count: 0,
949 }),
950 };
951
952 (ctx, tmp)
953 }
954
955 #[tokio::test]
956 #[ignore] async fn test_store_and_recall_roundtrip() {
958 let (ctx, _tmp) = setup_test_context().await;
959
960 let store_result = execute_tool(
962 &ctx,
963 "store",
964 Some(json!({ "content": "Rust is a systems programming language" })),
965 )
966 .await;
967 assert!(
968 store_result.is_error.is_none(),
969 "Store should succeed: {:?}",
970 store_result.content
971 );
972
973 let recall_result = execute_tool(
975 &ctx,
976 "recall",
977 Some(json!({ "query": "systems programming language" })),
978 )
979 .await;
980 assert!(recall_result.is_error.is_none(), "Recall should succeed");
981
982 let text = &recall_result.content[0].text;
983 assert!(
984 text.contains("Rust is a systems programming language"),
985 "Recall should return stored content, got: {}",
986 text
987 );
988 }
989
990 #[tokio::test]
991 #[ignore] async fn test_store_and_list() {
993 let (ctx, _tmp) = setup_test_context().await;
994
995 execute_tool(
997 &ctx,
998 "store",
999 Some(json!({ "content": "First item for listing" })),
1000 )
1001 .await;
1002 execute_tool(
1003 &ctx,
1004 "store",
1005 Some(json!({ "content": "Second item for listing" })),
1006 )
1007 .await;
1008
1009 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1011 assert!(list_result.is_error.is_none(), "List should succeed");
1012
1013 let text = &list_result.content[0].text;
1014 let parsed: Value = serde_json::from_str(text).unwrap();
1015 assert_eq!(parsed["count"], 2, "Should list 2 items");
1016 }
1017
1018 #[tokio::test]
1019 #[ignore] async fn test_store_conflict_detection() {
1021 let (ctx, _tmp) = setup_test_context().await;
1022
1023 execute_tool(
1025 &ctx,
1026 "store",
1027 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1028 )
1029 .await;
1030
1031 let result = execute_tool(
1033 &ctx,
1034 "store",
1035 Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1036 )
1037 .await;
1038 assert!(result.is_error.is_none(), "Store should succeed");
1039
1040 let text = &result.content[0].text;
1041 let parsed: Value = serde_json::from_str(text).unwrap();
1042 assert!(
1043 parsed.get("potential_conflicts").is_some(),
1044 "Should detect conflict for near-duplicate content, got: {}",
1045 text
1046 );
1047 }
1048
1049 #[tokio::test]
1050 #[ignore] async fn test_forget_removes_item() {
1052 let (ctx, _tmp) = setup_test_context().await;
1053
1054 let store_result = execute_tool(
1056 &ctx,
1057 "store",
1058 Some(json!({ "content": "Item to be forgotten" })),
1059 )
1060 .await;
1061 let text = &store_result.content[0].text;
1062 let parsed: Value = serde_json::from_str(text).unwrap();
1063 let item_id = parsed["id"].as_str().unwrap().to_string();
1064
1065 let forget_result = execute_tool(&ctx, "forget", Some(json!({ "id": item_id }))).await;
1067 assert!(forget_result.is_error.is_none(), "Forget should succeed");
1068
1069 let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1071 let text = &list_result.content[0].text;
1072 assert!(
1073 text.contains("No items stored yet"),
1074 "Should have no items after forget, got: {}",
1075 text
1076 );
1077 }
1078
1079 #[tokio::test]
1080 #[ignore] async fn test_recall_empty_db() {
1082 let (ctx, _tmp) = setup_test_context().await;
1083
1084 let result = execute_tool(&ctx, "recall", Some(json!({ "query": "anything" }))).await;
1085 assert!(
1086 result.is_error.is_none(),
1087 "Recall on empty DB should not error"
1088 );
1089
1090 let text = &result.content[0].text;
1091 assert!(
1092 text.contains("No items found"),
1093 "Should indicate no items found, got: {}",
1094 text
1095 );
1096 }
1097
1098 #[tokio::test]
1099 #[ignore] async fn test_store_rejects_oversized_content() {
1101 let (ctx, _tmp) = setup_test_context().await;
1102
1103 let large_content = "x".repeat(1_100_000); let result = execute_tool(&ctx, "store", Some(json!({ "content": large_content }))).await;
1105 assert!(
1106 result.is_error == Some(true),
1107 "Should reject oversized content"
1108 );
1109
1110 let text = &result.content[0].text;
1111 assert!(
1112 text.contains("too large"),
1113 "Error should mention size, got: {}",
1114 text
1115 );
1116 }
1117
1118 #[tokio::test]
1119 #[ignore] async fn test_recall_rejects_oversized_query() {
1121 let (ctx, _tmp) = setup_test_context().await;
1122
1123 let large_query = "x".repeat(11_000); let result = execute_tool(&ctx, "recall", Some(json!({ "query": large_query }))).await;
1125 assert!(
1126 result.is_error == Some(true),
1127 "Should reject oversized query"
1128 );
1129
1130 let text = &result.content[0].text;
1131 assert!(
1132 text.contains("too large"),
1133 "Error should mention size, got: {}",
1134 text
1135 );
1136 }
1137
1138 #[tokio::test]
1139 #[ignore] async fn test_store_missing_params() {
1141 let (ctx, _tmp) = setup_test_context().await;
1142
1143 let result = execute_tool(&ctx, "store", None).await;
1145 assert!(result.is_error == Some(true), "Should error with no params");
1146
1147 let result = execute_tool(&ctx, "store", Some(json!({}))).await;
1149 assert!(
1150 result.is_error == Some(true),
1151 "Should error with missing content"
1152 );
1153 }
1154}