Skip to main content

mem7_store/
search.rs

1use mem7_core::{
2    GraphRelation, MemoryFilter, MemoryItem, SearchOptions, SearchResult, TaskType,
3    sort_by_score_desc,
4};
5use mem7_error::Result;
6use mem7_reranker::RerankDocument;
7use tracing::{debug, instrument, warn};
8use uuid::Uuid;
9
10use crate::constants::*;
11use crate::decay;
12use crate::engine::{MemoryEngine, graph_result_to_relation};
13use crate::payload::payload_to_memory_item;
14use crate::pipeline;
15use crate::rehearsal;
16use crate::require_scope;
17
18impl MemoryEngine {
19    /// Search memories by semantic similarity.
20    ///
21    /// `filters` is an optional JSON object evaluated against `payload.metadata`
22    /// using the filter DSL (simple equality, operators, AND/OR/NOT).
23    ///
24    /// When a reranker is configured and `rerank` is `true`, the engine
25    /// over-fetches candidates by `top_k_multiplier` and then reranks them
26    /// down to `limit`.
27    #[allow(clippy::too_many_arguments)]
28    #[instrument(skip(self, filters))]
29    pub async fn search(
30        &self,
31        query: &str,
32        user_id: Option<&str>,
33        agent_id: Option<&str>,
34        run_id: Option<&str>,
35        limit: usize,
36        filters: Option<&serde_json::Value>,
37        rerank: bool,
38        threshold: Option<f32>,
39        task_type: Option<&str>,
40    ) -> Result<SearchResult> {
41        let opts = SearchOptions {
42            user_id,
43            agent_id,
44            run_id,
45            limit,
46            filters,
47            rerank,
48            threshold,
49            task_type,
50        };
51        self.search_with_options(query, &opts).await
52    }
53
54    /// Search memories using structured options.
55    pub async fn search_with_options(
56        &self,
57        query: &str,
58        opts: &SearchOptions<'_>,
59    ) -> Result<SearchResult> {
60        require_scope("search", opts.user_id, opts.agent_id, opts.run_id)?;
61        let context_cfg = self.config.context.as_ref().filter(|c| c.enabled);
62
63        let classify_future = async {
64            if context_cfg.is_some() && opts.task_type.is_none() {
65                pipeline::classify_query(self.llm.as_ref(), query).await
66            } else {
67                opts.task_type
68                    .map(TaskType::from_str_lossy)
69                    .unwrap_or_default()
70            }
71        };
72
73        let query_owned = vec![query.to_string()];
74        let embed_future = self.embedder.embed(&query_owned);
75
76        let (vecs, task_type) = tokio::join!(embed_future, classify_future);
77        let vecs = vecs?;
78        let query_vec = vecs.into_iter().next().unwrap_or_default();
79
80        debug!(?task_type, "classified query");
81
82        let filter = MemoryFilter {
83            metadata: opts.filters.cloned(),
84            ..MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id)
85        };
86
87        let should_rerank = opts.rerank && self.reranker.is_some();
88        let fetch_limit = if should_rerank {
89            let multiplier = self
90                .config
91                .reranker
92                .as_ref()
93                .map(|r| r.top_k_multiplier)
94                .unwrap_or(DEFAULT_RERANK_MULTIPLIER);
95            opts.limit * multiplier
96        } else {
97            opts.limit
98        };
99
100        let graph_filter = MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
101
102        let vector_future = self
103            .vector_index
104            .search(&query_vec, fetch_limit, Some(&filter));
105
106        let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
107        let graph_future = async {
108            match &self.graph_pipeline {
109                Some(gp) => gp.search(query, &graph_filter, opts.limit, decay_cfg).await,
110                None => Ok(Vec::new()),
111            }
112        };
113
114        let (results, graph_results) = tokio::join!(vector_future, graph_future);
115        let results = results?;
116        let graph_results = graph_results.unwrap_or_else(|e| {
117            warn!(error = %e, "graph search failed");
118            Vec::new()
119        });
120
121        let memories: Vec<MemoryItem> = if let (true, Some(reranker)) =
122            (should_rerank && !results.is_empty(), self.reranker.as_ref())
123        {
124            let docs: Vec<RerankDocument> = results
125                .iter()
126                .filter_map(|r| {
127                    r.payload
128                        .get("text")
129                        .and_then(|v| v.as_str())
130                        .map(|text| RerankDocument {
131                            id: r.id,
132                            text: text.to_string(),
133                            score: r.score,
134                            payload: r.payload.clone(),
135                        })
136                })
137                .collect();
138
139            match reranker.rerank(query, &docs, opts.limit).await {
140                Ok(reranked) => {
141                    debug!(count = reranked.len(), "reranked results");
142                    reranked
143                        .into_iter()
144                        .map(|r| {
145                            let mut item =
146                                payload_to_memory_item(r.id, &r.payload, Some(r.rerank_score));
147                            item.score = Some(r.rerank_score);
148                            item
149                        })
150                        .collect()
151                }
152                Err(e) => {
153                    warn!(error = %e, "reranking failed, using original results");
154                    results
155                        .into_iter()
156                        .take(opts.limit)
157                        .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
158                        .collect()
159                }
160            }
161        } else {
162            results
163                .into_iter()
164                .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
165                .collect()
166        };
167
168        let mut memories = if let Some(decay_cfg) = decay_cfg {
169            let mut decayed: Vec<MemoryItem> = memories
170                .into_iter()
171                .map(|mut item| {
172                    let age = decay::age_from_memory_item(
173                        item.last_accessed_at.as_deref(),
174                        &item.updated_at,
175                        &item.created_at,
176                    );
177                    item.score = item
178                        .score
179                        .map(|s| decay::apply_decay(s, age, item.access_count, decay_cfg));
180                    item
181                })
182                .collect();
183            sort_by_score_desc(&mut decayed, |m| m.score.unwrap_or(0.0));
184            decayed
185        } else {
186            memories
187        };
188
189        if let Some(ctx_cfg) = context_cfg {
190            let tt = task_type.as_str();
191            for item in &mut memories {
192                let mt = item.memory_type.as_deref().unwrap_or("factual");
193                let coeff = ctx_cfg.weight_for(mt, tt) as f32;
194                item.score = item.score.map(|s| s * coeff);
195            }
196            sort_by_score_desc(&mut memories, |m| m.score.unwrap_or(0.0));
197            debug!(task_type = tt, "applied context-aware scoring");
198        }
199
200        if let Some(thresh) = opts.threshold {
201            memories.retain(|m| m.score.unwrap_or(0.0) >= thresh);
202        }
203
204        let mut relations: Vec<GraphRelation> =
205            graph_results.iter().map(graph_result_to_relation).collect();
206
207        if let Some(ctx_cfg) = context_cfg {
208            let tt = task_type.as_str();
209            let coeff = ctx_cfg.weight_for("factual", tt) as f32;
210            for rel in &mut relations {
211                rel.score = rel.score.map(|s| s * coeff);
212            }
213        }
214
215        if self.config.decay.as_ref().is_some_and(|d| d.enabled)
216            && (!memories.is_empty() || !relations.is_empty())
217        {
218            let mem_ids: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
219            let rel_triples: Vec<(String, String, String)> = relations
220                .iter()
221                .map(|r| {
222                    (
223                        r.source.clone(),
224                        r.relationship.clone(),
225                        r.destination.clone(),
226                    )
227                })
228                .collect();
229            let rehearsal_filter =
230                MemoryFilter::from_session(opts.user_id, opts.agent_id, opts.run_id);
231
232            rehearsal::spawn_rehearsal(
233                self.vector_index.clone(),
234                self.graph_pipeline.as_ref().map(|gp| gp.store().clone()),
235                mem_ids,
236                rel_triples,
237                rehearsal_filter,
238            );
239        }
240
241        Ok(SearchResult {
242            memories,
243            relations,
244        })
245    }
246}