1use mem7_core::{
2 GraphRelation, MemoryFilter, MemoryItem, SearchOptions, SearchResult, TaskType,
3 sort_by_score_desc,
4};
5use mem7_error::Result;
6use mem7_reranker::RerankDocument;
7use tracing::{debug, instrument, warn};
8use uuid::Uuid;
9
10use crate::constants::*;
11use crate::decay;
12use crate::engine::{MemoryEngine, graph_result_to_relation};
13use crate::payload::payload_to_memory_item;
14use crate::pipeline;
15use crate::rehearsal;
16use crate::require_scope;
17
18impl MemoryEngine {
19 #[allow(clippy::too_many_arguments)]
28 #[instrument(skip(self, filters))]
29 pub async fn search(
30 &self,
31 query: &str,
32 user_id: Option<&str>,
33 agent_id: Option<&str>,
34 run_id: Option<&str>,
35 limit: usize,
36 filters: Option<&serde_json::Value>,
37 rerank: bool,
38 threshold: Option<f32>,
39 task_type: Option<&str>,
40 ) -> Result<SearchResult> {
41 let opts = SearchOptions {
42 user_id,
43 agent_id,
44 run_id,
45 limit,
46 filters,
47 rerank,
48 threshold,
49 task_type,
50 };
51 self.search_with_options(query, &opts).await
52 }
53
54 pub async fn search_with_options(
56 &self,
57 query: &str,
58 opts: &SearchOptions<'_>,
59 ) -> Result<SearchResult> {
60 require_scope("search", opts.user_id, opts.agent_id, opts.run_id)?;
61 let context_cfg = self.config.context.as_ref().filter(|c| c.enabled);
62
63 let classify_future = async {
64 if context_cfg.is_some() && opts.task_type.is_none() {
65 pipeline::classify_query(self.llm.as_ref(), query).await
66 } else {
67 opts.task_type
68 .map(TaskType::from_str_lossy)
69 .unwrap_or_default()
70 }
71 };
72
73 let query_owned = vec![query.to_string()];
74 let embed_future = self.embedder.embed(&query_owned);
75
76 let (vecs, task_type) = tokio::join!(embed_future, classify_future);
77 let vecs = vecs?;
78 let query_vec = vecs.into_iter().next().unwrap_or_default();
79
80 debug!(?task_type, "classified query");
81
82 let filter = MemoryFilter {
83 metadata: opts.filters.cloned(),
84 ..MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id)
85 };
86
87 let should_rerank = opts.rerank && self.reranker.is_some();
88 let fetch_limit = if should_rerank {
89 let multiplier = self
90 .config
91 .reranker
92 .as_ref()
93 .map(|r| r.top_k_multiplier)
94 .unwrap_or(DEFAULT_RERANK_MULTIPLIER);
95 opts.limit * multiplier
96 } else {
97 opts.limit
98 };
99
100 let graph_filter = MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
101
102 let vector_future = self
103 .vector_index
104 .search(&query_vec, fetch_limit, Some(&filter));
105
106 let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
107 let graph_future = async {
108 match &self.graph_pipeline {
109 Some(gp) => gp.search(query, &graph_filter, opts.limit, decay_cfg).await,
110 None => Ok(Vec::new()),
111 }
112 };
113
114 let (results, graph_results) = tokio::join!(vector_future, graph_future);
115 let results = results?;
116 let graph_results = graph_results.unwrap_or_else(|e| {
117 warn!(error = %e, "graph search failed");
118 Vec::new()
119 });
120
121 let memories: Vec<MemoryItem> = if let (true, Some(reranker)) =
122 (should_rerank && !results.is_empty(), self.reranker.as_ref())
123 {
124 let docs: Vec<RerankDocument> = results
125 .iter()
126 .filter_map(|r| {
127 r.payload
128 .get("text")
129 .and_then(|v| v.as_str())
130 .map(|text| RerankDocument {
131 id: r.id,
132 text: text.to_string(),
133 score: r.score,
134 payload: r.payload.clone(),
135 })
136 })
137 .collect();
138
139 match reranker.rerank(query, &docs, opts.limit).await {
140 Ok(reranked) => {
141 debug!(count = reranked.len(), "reranked results");
142 reranked
143 .into_iter()
144 .map(|r| {
145 let mut item =
146 payload_to_memory_item(r.id, &r.payload, Some(r.rerank_score));
147 item.score = Some(r.rerank_score);
148 item
149 })
150 .collect()
151 }
152 Err(e) => {
153 warn!(error = %e, "reranking failed, using original results");
154 results
155 .into_iter()
156 .take(opts.limit)
157 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
158 .collect()
159 }
160 }
161 } else {
162 results
163 .into_iter()
164 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
165 .collect()
166 };
167
168 let mut memories = if let Some(decay_cfg) = decay_cfg {
169 let mut decayed: Vec<MemoryItem> = memories
170 .into_iter()
171 .map(|mut item| {
172 let age = decay::age_from_memory_item(
173 item.last_accessed_at.as_deref(),
174 &item.updated_at,
175 &item.created_at,
176 );
177 item.score = item
178 .score
179 .map(|s| decay::apply_decay(s, age, item.access_count, decay_cfg));
180 item
181 })
182 .collect();
183 sort_by_score_desc(&mut decayed, |m| m.score.unwrap_or(0.0));
184 decayed
185 } else {
186 memories
187 };
188
189 if let Some(ctx_cfg) = context_cfg {
190 let tt = task_type.as_str();
191 for item in &mut memories {
192 let mt = item.memory_type.as_deref().unwrap_or("factual");
193 let coeff = ctx_cfg.weight_for(mt, tt) as f32;
194 item.score = item.score.map(|s| s * coeff);
195 }
196 sort_by_score_desc(&mut memories, |m| m.score.unwrap_or(0.0));
197 debug!(task_type = tt, "applied context-aware scoring");
198 }
199
200 if let Some(thresh) = opts.threshold {
201 memories.retain(|m| m.score.unwrap_or(0.0) >= thresh);
202 }
203
204 let mut relations: Vec<GraphRelation> =
205 graph_results.iter().map(graph_result_to_relation).collect();
206
207 if let Some(ctx_cfg) = context_cfg {
208 let tt = task_type.as_str();
209 let coeff = ctx_cfg.weight_for("factual", tt) as f32;
210 for rel in &mut relations {
211 rel.score = rel.score.map(|s| s * coeff);
212 }
213 }
214
215 if self.config.decay.as_ref().is_some_and(|d| d.enabled)
216 && (!memories.is_empty() || !relations.is_empty())
217 {
218 let mem_ids: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
219 let rel_triples: Vec<(String, String, String)> = relations
220 .iter()
221 .map(|r| {
222 (
223 r.source.clone(),
224 r.relationship.clone(),
225 r.destination.clone(),
226 )
227 })
228 .collect();
229 let rehearsal_filter =
230 MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
231
232 rehearsal::spawn_rehearsal(
233 self.vector_index.clone(),
234 self.graph_pipeline.as_ref().map(|gp| gp.store().clone()),
235 mem_ids,
236 rel_triples,
237 rehearsal_filter,
238 );
239 }
240
241 Ok(SearchResult {
242 memories,
243 relations,
244 })
245 }
246}