Skip to main content

kiromi_ai_memory/
query.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Public `Query` for `Memory::search`.
3//!
4//! Three modes:
5//! - `Query::semantic(text)` — top-K cosine over vector indices.
6//! - `Query::text(text)` — Tantivy parsed query over lexical indices.
7//! - `Query::hybrid(text)` — RRF fusion of the two; alpha-weighted.
8
9use crate::memory::{MemoryId, MemoryRef};
10use crate::partition::PartitionPath;
11
12/// Search mode.
13#[derive(Debug, Clone, PartialEq)]
14#[non_exhaustive]
15pub enum QueryMode {
16    /// Semantic only.
17    Semantic,
18    /// Lexical only.
19    Text,
20    /// Reciprocal-rank fusion of both. `alpha` weights the semantic side; the
21    /// text side gets `1 - alpha`. Default `0.6`.
22    Hybrid {
23        /// Semantic weight in `[0, 1]`.
24        alpha: f32,
25    },
26}
27
28/// A search query.
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub struct Query {
32    pub(crate) text: String,
33    pub(crate) mode: QueryMode,
34    pub(crate) within: Option<PartitionPath>,
35    pub(crate) precomputed_embedding: Option<Vec<f32>>,
36    /// Plan 10: opt-in hierarchical descent.
37    pub(crate) hierarchical: bool,
38    /// Plan 10: minimum cosine score (after dedup) for a child node to be
39    /// descended into during hierarchical search. `None` means no pruning.
40    pub(crate) prune_threshold: Option<f32>,
41    /// Plan 10: at each level, descend into top `k * descend_factor` children.
42    pub(crate) descend_factor: u32,
43}
44
45impl Query {
46    /// Pure semantic search.
47    #[must_use]
48    pub fn semantic(text: impl Into<String>) -> Self {
49        Query {
50            text: text.into(),
51            mode: QueryMode::Semantic,
52            within: None,
53            precomputed_embedding: None,
54            hierarchical: false,
55            prune_threshold: None,
56            descend_factor: 4,
57        }
58    }
59
60    /// Pure lexical search.
61    #[must_use]
62    pub fn text(text: impl Into<String>) -> Self {
63        Query {
64            text: text.into(),
65            mode: QueryMode::Text,
66            within: None,
67            precomputed_embedding: None,
68            hierarchical: false,
69            prune_threshold: None,
70            descend_factor: 4,
71        }
72    }
73
74    /// Hybrid (RRF) search; default `alpha = 0.6`.
75    #[must_use]
76    pub fn hybrid(text: impl Into<String>) -> Self {
77        Query {
78            text: text.into(),
79            mode: QueryMode::Hybrid { alpha: 0.6 },
80            within: None,
81            precomputed_embedding: None,
82            hierarchical: false,
83            prune_threshold: None,
84            descend_factor: 4,
85        }
86    }
87
88    /// Plan 10: opt into hierarchical descent. The search starts at the
89    /// configured scope (or the tenant root if unset) and descends through
90    /// internal nodes, scoring child summaries at each level and pruning
91    /// branches whose top scores fall below [`Query::prune_threshold`].
92    /// Leaves are searched only after the parent's index has identified
93    /// them as relevant. Defaults to flat (whole-tree) search when off.
94    #[must_use]
95    pub fn hierarchical(mut self) -> Self {
96        self.hierarchical = true;
97        self
98    }
99
100    /// Plan 10: prune children whose mid-level summary score is below `t`.
101    /// Default is no pruning — every child whose score made the top-K of the
102    /// parent's index is descended into. Clamped to `[-1, 1]` (cosine range).
103    #[must_use]
104    pub fn prune_threshold(mut self, t: f32) -> Self {
105        self.prune_threshold = Some(t.clamp(-1.0, 1.0));
106        self
107    }
108
109    /// Plan 10: at each internal level, descend into the top
110    /// `k * descend_factor` children. Higher = wider beam = slower but more
111    /// recall; lower = narrower beam = faster but more risk of pruning the
112    /// right answer. Default 4. Clamped to `>= 1`.
113    #[must_use]
114    pub fn descend_factor(mut self, n: u32) -> Self {
115        self.descend_factor = n.max(1);
116        self
117    }
118
119    /// Whether this query is in hierarchical mode.
120    #[must_use]
121    pub fn is_hierarchical(&self) -> bool {
122        self.hierarchical
123    }
124
125    /// Override hybrid alpha. No effect on non-hybrid queries.
126    #[must_use]
127    pub fn alpha(mut self, alpha: f32) -> Self {
128        if let QueryMode::Hybrid { alpha: a } = &mut self.mode {
129            *a = alpha.clamp(0.0, 1.0);
130        }
131        self
132    }
133
134    /// Restrict the search to a partition (and its descendants).
135    #[must_use]
136    pub fn within(mut self, path: PartitionPath) -> Self {
137        self.within = Some(path);
138        self
139    }
140
141    /// Caller hands the engine a pre-computed query vector. Used by
142    /// caller-owned-models pathways (Apple Foundation Models, OpenAI proxies,
143    /// Swift FFI consumers) where the model runs outside the library on the
144    /// query side.
145    ///
146    /// When set, the engine bypasses its `Embedder` (if any) for the query
147    /// step. The vector's length must match `schema_meta.embedder_dims`;
148    /// mismatch surfaces from the underlying vector index. Has no effect on
149    /// `QueryMode::Text` (lexical search ignores the vector).
150    ///
151    /// See spec § 12 (caller-owned models) and § 12.13.
152    #[must_use]
153    pub fn with_embedding(mut self, vector: Vec<f32>) -> Self {
154        self.precomputed_embedding = Some(vector);
155        self
156    }
157
158    /// Borrow the precomputed query vector, if any. `None` means the engine
159    /// must invoke its configured `Embedder` to derive the query vector.
160    #[must_use]
161    pub fn precomputed_embedding(&self) -> Option<&[f32]> {
162        self.precomputed_embedding.as_deref()
163    }
164
165    /// Borrow the query text.
166    #[must_use]
167    pub fn text_str(&self) -> &str {
168        &self.text
169    }
170
171    /// Borrow the mode.
172    #[must_use]
173    pub fn mode(&self) -> &QueryMode {
174        &self.mode
175    }
176
177    /// Borrow the partition restriction (if any).
178    #[must_use]
179    pub fn scope(&self) -> Option<&PartitionPath> {
180        self.within.as_ref()
181    }
182}
183
184/// One hit in a search result. Score is mode-specific:
185/// - semantic: cosine similarity in `[-1, 1]` (mock embedder is unit-norm so usually `[0, 1]`).
186/// - text: Tantivy BM25-ish raw score.
187/// - hybrid: RRF score (sum of `alpha / (60 + rank)`).
188#[derive(Debug, Clone, PartialEq, serde::Serialize)]
189#[non_exhaustive]
190pub struct SearchHit {
191    /// Memory.
192    pub r#ref: MemoryRef,
193    /// Score (higher = better).
194    pub score: f32,
195}
196
197impl SearchHit {
198    /// Construct.
199    #[must_use]
200    pub fn new(id: MemoryId, partition: PartitionPath, score: f32) -> Self {
201        SearchHit {
202            r#ref: MemoryRef { id, partition },
203            score,
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn hybrid_clamps_alpha() {
214        let q = Query::hybrid("x").alpha(2.0);
215        match q.mode() {
216            QueryMode::Hybrid { alpha } => assert!((alpha - 1.0).abs() < f32::EPSILON),
217            _ => panic!("expected hybrid"),
218        }
219        let q = Query::hybrid("x").alpha(-0.5);
220        match q.mode() {
221            QueryMode::Hybrid { alpha } => assert!(alpha.abs() < f32::EPSILON),
222            _ => panic!("expected hybrid"),
223        }
224    }
225
226    #[test]
227    fn alpha_on_non_hybrid_is_noop() {
228        let q = Query::semantic("x").alpha(0.1);
229        assert!(matches!(q.mode(), QueryMode::Semantic));
230    }
231}