Skip to main content

mem7_store/
search.rs

1use mem7_core::{
2    GraphRelation, MemoryFilter, MemoryItem, SearchOptions, SearchResult, TaskType,
3    sort_by_score_desc,
4};
5use mem7_error::Result;
6use mem7_reranker::RerankDocument;
7use tracing::{debug, instrument, warn};
8use uuid::Uuid;
9
10use crate::constants::*;
11use crate::decay;
12use crate::engine::{MemoryEngine, graph_result_to_relation};
13use crate::payload::payload_to_memory_item;
14use crate::pipeline;
15use crate::rehearsal;
16
17impl MemoryEngine {
18    /// Search memories by semantic similarity.
19    ///
20    /// `filters` is an optional JSON object evaluated against `payload.metadata`
21    /// using the filter DSL (simple equality, operators, AND/OR/NOT).
22    ///
23    /// When a reranker is configured and `rerank` is `true`, the engine
24    /// over-fetches candidates by `top_k_multiplier` and then reranks them
25    /// down to `limit`.
26    #[allow(clippy::too_many_arguments)]
27    #[instrument(skip(self, filters))]
28    pub async fn search(
29        &self,
30        query: &str,
31        user_id: Option<&str>,
32        agent_id: Option<&str>,
33        run_id: Option<&str>,
34        limit: usize,
35        filters: Option<&serde_json::Value>,
36        rerank: bool,
37        threshold: Option<f32>,
38        task_type: Option<&str>,
39    ) -> Result<SearchResult> {
40        let opts = SearchOptions {
41            user_id,
42            agent_id,
43            run_id,
44            limit,
45            filters,
46            rerank,
47            threshold,
48            task_type,
49        };
50        self.search_with_options(query, &opts).await
51    }
52
53    /// Search memories using structured options.
54    pub async fn search_with_options(
55        &self,
56        query: &str,
57        opts: &SearchOptions<'_>,
58    ) -> Result<SearchResult> {
59        let context_cfg = self.config.context.as_ref().filter(|c| c.enabled);
60
61        let classify_future = async {
62            if context_cfg.is_some() && opts.task_type.is_none() {
63                pipeline::classify_query(self.llm.as_ref(), query).await
64            } else {
65                opts.task_type
66                    .map(TaskType::from_str_lossy)
67                    .unwrap_or_default()
68            }
69        };
70
71        let query_owned = vec![query.to_string()];
72        let embed_future = self.embedder.embed(&query_owned);
73
74        let (vecs, task_type) = tokio::join!(embed_future, classify_future);
75        let vecs = vecs?;
76        let query_vec = vecs.into_iter().next().unwrap_or_default();
77
78        debug!(?task_type, "classified query");
79
80        let filter = MemoryFilter {
81            metadata: opts.filters.cloned(),
82            ..MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id)
83        };
84
85        let should_rerank = opts.rerank && self.reranker.is_some();
86        let fetch_limit = if should_rerank {
87            let multiplier = self
88                .config
89                .reranker
90                .as_ref()
91                .map(|r| r.top_k_multiplier)
92                .unwrap_or(DEFAULT_RERANK_MULTIPLIER);
93            opts.limit * multiplier
94        } else {
95            opts.limit
96        };
97
98        let graph_filter = MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
99
100        let vector_future = self
101            .vector_index
102            .search(&query_vec, fetch_limit, Some(&filter));
103
104        let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
105        let graph_future = async {
106            match &self.graph_pipeline {
107                Some(gp) => gp.search(query, &graph_filter, opts.limit, decay_cfg).await,
108                None => Ok(Vec::new()),
109            }
110        };
111
112        let (results, graph_results) = tokio::join!(vector_future, graph_future);
113        let results = results?;
114        let graph_results = graph_results.unwrap_or_else(|e| {
115            warn!(error = %e, "graph search failed");
116            Vec::new()
117        });
118
119        let memories: Vec<MemoryItem> = if let (true, Some(reranker)) =
120            (should_rerank && !results.is_empty(), self.reranker.as_ref())
121        {
122            let docs: Vec<RerankDocument> = results
123                .iter()
124                .filter_map(|r| {
125                    r.payload
126                        .get("text")
127                        .and_then(|v| v.as_str())
128                        .map(|text| RerankDocument {
129                            id: r.id,
130                            text: text.to_string(),
131                            score: r.score,
132                            payload: r.payload.clone(),
133                        })
134                })
135                .collect();
136
137            match reranker.rerank(query, &docs, opts.limit).await {
138                Ok(reranked) => {
139                    debug!(count = reranked.len(), "reranked results");
140                    reranked
141                        .into_iter()
142                        .map(|r| {
143                            let mut item =
144                                payload_to_memory_item(r.id, &r.payload, Some(r.rerank_score));
145                            item.score = Some(r.rerank_score);
146                            item
147                        })
148                        .collect()
149                }
150                Err(e) => {
151                    warn!(error = %e, "reranking failed, using original results");
152                    results
153                        .into_iter()
154                        .take(opts.limit)
155                        .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
156                        .collect()
157                }
158            }
159        } else {
160            results
161                .into_iter()
162                .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
163                .collect()
164        };
165
166        let mut memories = if let Some(decay_cfg) = decay_cfg {
167            let mut decayed: Vec<MemoryItem> = memories
168                .into_iter()
169                .map(|mut item| {
170                    let age = decay::age_from_memory_item(
171                        item.last_accessed_at.as_deref(),
172                        &item.updated_at,
173                        &item.created_at,
174                    );
175                    item.score = item
176                        .score
177                        .map(|s| decay::apply_decay(s, age, item.access_count, decay_cfg));
178                    item
179                })
180                .collect();
181            sort_by_score_desc(&mut decayed, |m| m.score.unwrap_or(0.0));
182            decayed
183        } else {
184            memories
185        };
186
187        if let Some(ctx_cfg) = context_cfg {
188            let tt = task_type.as_str();
189            for item in &mut memories {
190                let mt = item.memory_type.as_deref().unwrap_or("factual");
191                let coeff = ctx_cfg.weight_for(mt, tt) as f32;
192                item.score = item.score.map(|s| s * coeff);
193            }
194            sort_by_score_desc(&mut memories, |m| m.score.unwrap_or(0.0));
195            debug!(task_type = tt, "applied context-aware scoring");
196        }
197
198        if let Some(thresh) = opts.threshold {
199            memories.retain(|m| m.score.unwrap_or(0.0) >= thresh);
200        }
201
202        let mut relations: Vec<GraphRelation> =
203            graph_results.iter().map(graph_result_to_relation).collect();
204
205        if let Some(ctx_cfg) = context_cfg {
206            let tt = task_type.as_str();
207            let coeff = ctx_cfg.weight_for("factual", tt) as f32;
208            for rel in &mut relations {
209                rel.score = rel.score.map(|s| s * coeff);
210            }
211        }
212
213        if self.config.decay.as_ref().is_some_and(|d| d.enabled)
214            && (!memories.is_empty() || !relations.is_empty())
215        {
216            let mem_ids: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
217            let rel_triples: Vec<(String, String, String)> = relations
218                .iter()
219                .map(|r| {
220                    (
221                        r.source.clone(),
222                        r.relationship.clone(),
223                        r.destination.clone(),
224                    )
225                })
226                .collect();
227            let rehearsal_filter =
228                MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
229
230            rehearsal::spawn_rehearsal(
231                self.vector_index.clone(),
232                self.graph_pipeline.as_ref().map(|gp| gp.store().clone()),
233                mem_ids,
234                rel_triples,
235                rehearsal_filter,
236            );
237        }
238
239        Ok(SearchResult {
240            memories,
241            relations,
242        })
243    }
244}