1use mem7_core::{
2 GraphRelation, MemoryFilter, MemoryItem, SearchOptions, SearchResult, TaskType,
3 sort_by_score_desc,
4};
5use mem7_error::Result;
6use mem7_reranker::RerankDocument;
7use tracing::{debug, instrument, warn};
8use uuid::Uuid;
9
10use crate::constants::*;
11use crate::decay;
12use crate::engine::{MemoryEngine, graph_result_to_relation};
13use crate::payload::payload_to_memory_item;
14use crate::pipeline;
15use crate::rehearsal;
16
17impl MemoryEngine {
18 #[allow(clippy::too_many_arguments)]
27 #[instrument(skip(self, filters))]
28 pub async fn search(
29 &self,
30 query: &str,
31 user_id: Option<&str>,
32 agent_id: Option<&str>,
33 run_id: Option<&str>,
34 limit: usize,
35 filters: Option<&serde_json::Value>,
36 rerank: bool,
37 threshold: Option<f32>,
38 task_type: Option<&str>,
39 ) -> Result<SearchResult> {
40 let opts = SearchOptions {
41 user_id,
42 agent_id,
43 run_id,
44 limit,
45 filters,
46 rerank,
47 threshold,
48 task_type,
49 };
50 self.search_with_options(query, &opts).await
51 }
52
53 pub async fn search_with_options(
55 &self,
56 query: &str,
57 opts: &SearchOptions<'_>,
58 ) -> Result<SearchResult> {
59 let context_cfg = self.config.context.as_ref().filter(|c| c.enabled);
60
61 let classify_future = async {
62 if context_cfg.is_some() && opts.task_type.is_none() {
63 pipeline::classify_query(self.llm.as_ref(), query).await
64 } else {
65 opts.task_type
66 .map(TaskType::from_str_lossy)
67 .unwrap_or_default()
68 }
69 };
70
71 let query_owned = vec![query.to_string()];
72 let embed_future = self.embedder.embed(&query_owned);
73
74 let (vecs, task_type) = tokio::join!(embed_future, classify_future);
75 let vecs = vecs?;
76 let query_vec = vecs.into_iter().next().unwrap_or_default();
77
78 debug!(?task_type, "classified query");
79
80 let filter = MemoryFilter {
81 metadata: opts.filters.cloned(),
82 ..MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id)
83 };
84
85 let should_rerank = opts.rerank && self.reranker.is_some();
86 let fetch_limit = if should_rerank {
87 let multiplier = self
88 .config
89 .reranker
90 .as_ref()
91 .map(|r| r.top_k_multiplier)
92 .unwrap_or(DEFAULT_RERANK_MULTIPLIER);
93 opts.limit * multiplier
94 } else {
95 opts.limit
96 };
97
98 let graph_filter = MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
99
100 let vector_future = self
101 .vector_index
102 .search(&query_vec, fetch_limit, Some(&filter));
103
104 let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
105 let graph_future = async {
106 match &self.graph_pipeline {
107 Some(gp) => gp.search(query, &graph_filter, opts.limit, decay_cfg).await,
108 None => Ok(Vec::new()),
109 }
110 };
111
112 let (results, graph_results) = tokio::join!(vector_future, graph_future);
113 let results = results?;
114 let graph_results = graph_results.unwrap_or_else(|e| {
115 warn!(error = %e, "graph search failed");
116 Vec::new()
117 });
118
119 let memories: Vec<MemoryItem> = if let (true, Some(reranker)) =
120 (should_rerank && !results.is_empty(), self.reranker.as_ref())
121 {
122 let docs: Vec<RerankDocument> = results
123 .iter()
124 .filter_map(|r| {
125 r.payload
126 .get("text")
127 .and_then(|v| v.as_str())
128 .map(|text| RerankDocument {
129 id: r.id,
130 text: text.to_string(),
131 score: r.score,
132 payload: r.payload.clone(),
133 })
134 })
135 .collect();
136
137 match reranker.rerank(query, &docs, opts.limit).await {
138 Ok(reranked) => {
139 debug!(count = reranked.len(), "reranked results");
140 reranked
141 .into_iter()
142 .map(|r| {
143 let mut item =
144 payload_to_memory_item(r.id, &r.payload, Some(r.rerank_score));
145 item.score = Some(r.rerank_score);
146 item
147 })
148 .collect()
149 }
150 Err(e) => {
151 warn!(error = %e, "reranking failed, using original results");
152 results
153 .into_iter()
154 .take(opts.limit)
155 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
156 .collect()
157 }
158 }
159 } else {
160 results
161 .into_iter()
162 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
163 .collect()
164 };
165
166 let mut memories = if let Some(decay_cfg) = decay_cfg {
167 let mut decayed: Vec<MemoryItem> = memories
168 .into_iter()
169 .map(|mut item| {
170 let age = decay::age_from_memory_item(
171 item.last_accessed_at.as_deref(),
172 &item.updated_at,
173 &item.created_at,
174 );
175 item.score = item
176 .score
177 .map(|s| decay::apply_decay(s, age, item.access_count, decay_cfg));
178 item
179 })
180 .collect();
181 sort_by_score_desc(&mut decayed, |m| m.score.unwrap_or(0.0));
182 decayed
183 } else {
184 memories
185 };
186
187 if let Some(ctx_cfg) = context_cfg {
188 let tt = task_type.as_str();
189 for item in &mut memories {
190 let mt = item.memory_type.as_deref().unwrap_or("factual");
191 let coeff = ctx_cfg.weight_for(mt, tt) as f32;
192 item.score = item.score.map(|s| s * coeff);
193 }
194 sort_by_score_desc(&mut memories, |m| m.score.unwrap_or(0.0));
195 debug!(task_type = tt, "applied context-aware scoring");
196 }
197
198 if let Some(thresh) = opts.threshold {
199 memories.retain(|m| m.score.unwrap_or(0.0) >= thresh);
200 }
201
202 let mut relations: Vec<GraphRelation> =
203 graph_results.iter().map(graph_result_to_relation).collect();
204
205 if let Some(ctx_cfg) = context_cfg {
206 let tt = task_type.as_str();
207 let coeff = ctx_cfg.weight_for("factual", tt) as f32;
208 for rel in &mut relations {
209 rel.score = rel.score.map(|s| s * coeff);
210 }
211 }
212
213 if self.config.decay.as_ref().is_some_and(|d| d.enabled)
214 && (!memories.is_empty() || !relations.is_empty())
215 {
216 let mem_ids: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
217 let rel_triples: Vec<(String, String, String)> = relations
218 .iter()
219 .map(|r| {
220 (
221 r.source.clone(),
222 r.relationship.clone(),
223 r.destination.clone(),
224 )
225 })
226 .collect();
227 let rehearsal_filter =
228 MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
229
230 rehearsal::spawn_rehearsal(
231 self.vector_index.clone(),
232 self.graph_pipeline.as_ref().map(|gp| gp.store().clone()),
233 mem_ids,
234 rel_triples,
235 rehearsal_filter,
236 );
237 }
238
239 Ok(SearchResult {
240 memories,
241 relations,
242 })
243 }
244}