Skip to main content

coding_agent_search/search/
vector_index.rs

1//! Vector index facade for cass.
2//!
3//! cass uses the frankensearch FSVI vector index format and search primitives
4//! (via the `frankensearch` crate). The older CVVI format has been retired.
5//!
6//! This module keeps cass-specific helpers (paths, role codes) in one place.
7
8use std::collections::{HashMap, HashSet};
9use std::path::{Path, PathBuf};
10
11use anyhow::{Result, anyhow};
12use frankensqlite::Connection as FrankenConnection;
13use frankensqlite::compat::{ConnectionExt, RowExt};
14use half::f16;
15
16pub use frankensearch::index::{Quantization, SearchParams, VectorIndex, VectorIndexWriter};
17
18use crate::search::query::SearchFilters;
19use crate::sources::provenance::{LOCAL_SOURCE_ID, SourceFilter, SourceKind};
20use crate::storage::sqlite::FrankenStorage;
21
22/// Directory under the cass data dir where vector artifacts are stored.
23pub const VECTOR_INDEX_DIR: &str = "vector_index";
24
25// Message role codes stored in doc_id metadata and used for filtering.
26pub const ROLE_USER: u8 = 0;
27pub const ROLE_ASSISTANT: u8 = 1;
28pub const ROLE_SYSTEM: u8 = 2;
29pub const ROLE_TOOL: u8 = 3;
30
31/// Map a role string (from SQLite / connectors) to a compact u8 code.
32#[must_use]
33pub fn role_code_from_str(role: &str) -> Option<u8> {
34    match role {
35        "user" => Some(ROLE_USER),
36        // cass historically used both "agent" and "assistant" for model responses.
37        "assistant" | "agent" => Some(ROLE_ASSISTANT),
38        "system" => Some(ROLE_SYSTEM),
39        "tool" => Some(ROLE_TOOL),
40        _ => None,
41    }
42}
43
44/// Parse a list of role strings into a set of role codes.
45///
46/// # Errors
47///
48/// Returns an error if any role string is unknown.
49pub fn parse_role_codes<I, S>(roles: I) -> Result<HashSet<u8>>
50where
51    I: IntoIterator<Item = S>,
52    S: AsRef<str>,
53{
54    let mut out = HashSet::new();
55    for role in roles {
56        let role_str = role.as_ref();
57        let code =
58            role_code_from_str(role_str).ok_or_else(|| anyhow!("unknown role: {role_str}"))?;
59        out.insert(code);
60    }
61    Ok(out)
62}
63
64/// Path to the primary FSVI vector index for a given embedder.
65#[must_use]
66pub fn vector_index_path(data_dir: &Path, embedder_id: &str) -> PathBuf {
67    data_dir
68        .join(VECTOR_INDEX_DIR)
69        .join(format!("index-{embedder_id}.fsvi"))
70}
71
72/// Semantic doc_id fields encoded into FSVI records.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub struct SemanticDocId {
75    pub message_id: u64,
76    pub chunk_idx: u8,
77    pub agent_id: u32,
78    pub workspace_id: u32,
79    pub source_id: u32,
80    pub role: u8,
81    pub created_at_ms: i64,
82    pub content_hash: Option<[u8; 32]>,
83}
84
85impl SemanticDocId {
86    /// Encode this semantic vector record doc_id into the string form stored in FSVI.
87    ///
88    /// Hot-path encoder: runs once per embedded message during indexing and
89    /// for every search hit that goes through semantic lookup. Build the
90    /// output in a single pre-sized `String` with `itoa::Buffer` for the
91    /// integer fields instead of `format!`, which walks the formatter-trait
92    /// machinery per arg and grows its internal buffer on demand.
93    #[must_use]
94    pub fn to_doc_id_string(&self) -> String {
95        // Capacity estimate: "m|" (2) + seven integer fields up to 20 chars
96        // + six '|' separators + optional 64-hex hash + one '|' if present.
97        // Slight over-allocation is fine and avoids any realloc.
98        let capacity = 2 + (7 * 20) + 6 + if self.content_hash.is_some() { 65 } else { 0 };
99        let mut out = String::with_capacity(capacity);
100        let mut buf = itoa::Buffer::new();
101        out.push_str("m|");
102        out.push_str(buf.format(self.message_id));
103        out.push('|');
104        out.push_str(buf.format(self.chunk_idx));
105        out.push('|');
106        out.push_str(buf.format(self.agent_id));
107        out.push('|');
108        out.push_str(buf.format(self.workspace_id));
109        out.push('|');
110        out.push_str(buf.format(self.source_id));
111        out.push('|');
112        out.push_str(buf.format(self.role));
113        out.push('|');
114        out.push_str(buf.format(self.created_at_ms));
115        if let Some(hash) = self.content_hash {
116            out.push('|');
117            // Stack-buffered hex encode: avoids the 64-byte heap alloc that
118            // `hex::encode(hash)` performs internally. Hex output is pure
119            // ASCII so str::from_utf8 can't fail on the filled slice.
120            let mut hex_buf = [0u8; 64];
121            hex::encode_to_slice(hash, &mut hex_buf)
122                .expect("32 bytes encode to exactly 64 hex chars");
123            out.push_str(std::str::from_utf8(&hex_buf).expect("hex output is always valid ASCII"));
124        }
125        out
126    }
127}
128
129/// Parse a cass semantic doc_id string.
130///
131/// Accepts doc_ids with trailing segments (future expansion) and an optional
132/// 64-hex content hash suffix.
133#[must_use]
134pub fn parse_semantic_doc_id(doc_id: &str) -> Option<SemanticDocId> {
135    // Fast reject: every cass semantic doc_id starts with "m|". `strip_prefix`
136    // avoids the full iterator setup + first `.next()` comparison when the
137    // discriminator doesn't match. `splitn(8, '|')` caps the field scan at
138    // exactly the 7 required fields + a single tail holding the optional
139    // content hash (which itself never contains '|').
140    let rest = doc_id.strip_prefix("m|")?;
141    let mut parts = rest.splitn(8, '|');
142    let parsed = SemanticDocId {
143        message_id: parts.next()?.parse().ok()?,
144        chunk_idx: parts.next()?.parse().ok()?,
145        agent_id: parts.next()?.parse().ok()?,
146        workspace_id: parts.next()?.parse().ok()?,
147        source_id: parts.next()?.parse().ok()?,
148        role: parts.next()?.parse().ok()?,
149        created_at_ms: parts.next()?.parse().ok()?,
150        content_hash: parts.next().and_then(|hash_hex| {
151            if hash_hex.len() != 64 {
152                return None;
153            }
154            let mut hash = [0u8; 32];
155            hex::decode_to_slice(hash_hex, &mut hash).ok()?;
156            Some(hash)
157        }),
158    };
159
160    Some(parsed)
161}
162
163/// Lean filter-only view of a parsed semantic doc_id.
164///
165/// Drops the content_hash (which requires hex::decode_to_slice on 64 bytes)
166/// plus the unused message_id and chunk_idx. Used by
167/// `SemanticFilter::matches`, which runs once per HNSW-visited node during
168/// ANN traversal — often thousands of times per query — and never reads the
169/// content_hash or message identifiers.
170#[derive(Debug, Clone, Copy)]
171pub(crate) struct SemanticDocIdFilterView {
172    pub agent_id: u32,
173    pub workspace_id: u32,
174    pub source_id: u32,
175    pub role: u8,
176    pub created_at_ms: i64,
177}
178
179/// Parse only the filter-relevant fields of a cass semantic doc_id string.
180///
181/// ~5x cheaper than `parse_semantic_doc_id` when the content_hash is present,
182/// because it skips the 64-byte hex decode that dominates the full-parse cost.
183#[must_use]
184pub(crate) fn parse_semantic_doc_id_filter_view(doc_id: &str) -> Option<SemanticDocIdFilterView> {
185    let rest = doc_id.strip_prefix("m|")?;
186    let mut parts = rest.splitn(8, '|');
187    // message_id + chunk_idx: we only need to advance the iterator past them.
188    parts.next()?;
189    parts.next()?;
190    let agent_id: u32 = parts.next()?.parse().ok()?;
191    let workspace_id: u32 = parts.next()?.parse().ok()?;
192    let source_id: u32 = parts.next()?.parse().ok()?;
193    let role: u8 = parts.next()?.parse().ok()?;
194    let created_at_ms: i64 = parts.next()?.parse().ok()?;
195    Some(SemanticDocIdFilterView {
196        agent_id,
197        workspace_id,
198        source_id,
199        role,
200        created_at_ms,
201    })
202}
203
204fn map_filter_set(keys: &HashSet<String>, map: &HashMap<String, u32>) -> Option<HashSet<u32>> {
205    if keys.is_empty() {
206        return None;
207    }
208    let mut set = HashSet::new();
209    for key in keys {
210        if let Some(id) = map.get(key) {
211            set.insert(*id);
212        }
213    }
214    Some(set)
215}
216
217fn source_id_hash(source_id: &str) -> u32 {
218    let mut hasher = crc32fast::Hasher::new();
219    hasher.update(source_id.as_bytes());
220    hasher.finalize()
221}
222
223/// Semantic filter constraints expressed in numeric IDs for fast evaluation.
224#[derive(Debug, Clone, Default)]
225pub struct SemanticFilter {
226    pub agents: Option<HashSet<u32>>,
227    pub workspaces: Option<HashSet<u32>>,
228    pub sources: Option<HashSet<u32>>,
229    pub roles: Option<HashSet<u8>>,
230    pub created_from: Option<i64>,
231    pub created_to: Option<i64>,
232}
233
234impl SemanticFilter {
235    pub fn from_search_filters(filters: &SearchFilters, maps: &SemanticFilterMaps) -> Result<Self> {
236        let agents = map_filter_set(&filters.agents, &maps.agent_slug_to_id);
237        let workspaces = map_filter_set(&filters.workspaces, &maps.workspace_path_to_id);
238        let sources = maps.sources_from_filter(&filters.source_filter)?;
239
240        Ok(Self {
241            agents,
242            workspaces,
243            sources,
244            roles: None,
245            created_from: filters.created_from,
246            created_to: filters.created_to,
247        })
248    }
249
250    #[must_use]
251    pub fn is_unrestricted(&self) -> bool {
252        self.agents.is_none()
253            && self.workspaces.is_none()
254            && self.sources.is_none()
255            && self.roles.is_none()
256            && self.created_from.is_none()
257            && self.created_to.is_none()
258    }
259
260    #[must_use]
261    pub fn with_roles(mut self, roles: Option<HashSet<u8>>) -> Self {
262        self.roles = roles;
263        self
264    }
265}
266
267/// Lookup maps for converting human filters (agent slug, workspace path, source id)
268/// into compact numeric IDs embedded into semantic doc_id strings.
269#[derive(Debug, Clone)]
270pub struct SemanticFilterMaps {
271    agent_slug_to_id: HashMap<String, u32>,
272    workspace_path_to_id: HashMap<String, u32>,
273    source_id_to_id: HashMap<String, u32>,
274    remote_source_ids: HashSet<u32>,
275}
276
277impl SemanticFilterMaps {
278    pub fn from_storage(storage: &FrankenStorage) -> Result<Self> {
279        Self::from_connection(storage.raw())
280    }
281
282    pub fn from_connection(conn: &FrankenConnection) -> Result<Self> {
283        let mut agent_slug_to_id = HashMap::new();
284        let agent_rows = conn.query_map_collect(
285            "SELECT id, slug FROM agents",
286            &[],
287            |row: &frankensqlite::Row| {
288                let id: i64 = row.get_typed(0)?;
289                let slug: String = row.get_typed(1)?;
290                Ok((id, slug))
291            },
292        )?;
293        for (id, slug) in agent_rows {
294            let id_u32 = u32::try_from(id).map_err(|_| anyhow!("agent id out of range"))?;
295            agent_slug_to_id.insert(slug, id_u32);
296        }
297
298        let mut workspace_path_to_id = HashMap::new();
299        let workspace_rows = conn.query_map_collect(
300            "SELECT id, path FROM workspaces",
301            &[],
302            |row: &frankensqlite::Row| {
303                let id: i64 = row.get_typed(0)?;
304                let path: String = row.get_typed(1)?;
305                Ok((id, path))
306            },
307        )?;
308        for (id, path) in workspace_rows {
309            let id_u32 = u32::try_from(id).map_err(|_| anyhow!("workspace id out of range"))?;
310            workspace_path_to_id.insert(path, id_u32);
311        }
312
313        let mut source_id_to_id = HashMap::new();
314        let mut remote_source_ids = HashSet::new();
315        let source_rows = conn.query_map_collect(
316            "SELECT id, kind FROM sources",
317            &[],
318            |row: &frankensqlite::Row| {
319                let id: String = row.get_typed(0)?;
320                let kind: String = row.get_typed(1)?;
321                Ok((id, kind))
322            },
323        )?;
324        for (id, kind) in source_rows {
325            let id_u32 = source_id_hash(&id);
326            if SourceKind::parse(&kind).is_none_or(|k| k.is_remote()) {
327                remote_source_ids.insert(id_u32);
328            }
329            source_id_to_id.insert(id, id_u32);
330        }
331
332        Ok(Self {
333            agent_slug_to_id,
334            workspace_path_to_id,
335            source_id_to_id,
336            remote_source_ids,
337        })
338    }
339
340    #[cfg(test)]
341    pub(crate) fn for_tests(
342        agent_slug_to_id: HashMap<String, u32>,
343        workspace_path_to_id: HashMap<String, u32>,
344        source_id_to_id: HashMap<String, u32>,
345        remote_source_ids: HashSet<u32>,
346    ) -> Self {
347        Self {
348            agent_slug_to_id,
349            workspace_path_to_id,
350            source_id_to_id,
351            remote_source_ids,
352        }
353    }
354
355    fn sources_from_filter(&self, filter: &SourceFilter) -> Result<Option<HashSet<u32>>> {
356        let result = match filter {
357            SourceFilter::All => None,
358            SourceFilter::Local => Some(HashSet::from([self.source_id(LOCAL_SOURCE_ID)])),
359            SourceFilter::Remote => Some(self.remote_source_ids.clone()),
360            SourceFilter::SourceId(id) => Some(HashSet::from([self.source_id(id)])),
361        };
362        Ok(result)
363    }
364
365    fn source_id(&self, source_id: &str) -> u32 {
366        self.source_id_to_id
367            .get(source_id)
368            .copied()
369            .unwrap_or_else(|| source_id_hash(source_id))
370    }
371}
372
373/// Collapsed semantic search hit (best chunk per message).
374#[derive(Debug, Clone)]
375pub struct VectorSearchResult {
376    pub message_id: u64,
377    pub chunk_idx: u8,
378    pub score: f32,
379}
380
381impl frankensearch::core::filter::SearchFilter for SemanticFilter {
382    fn matches(&self, doc_id: &str, _metadata: Option<&serde_json::Value>) -> bool {
383        // Use the filter-view parse: skips the expensive 64-byte hex decode
384        // of content_hash that the full parse runs on every call.
385        let Some(parsed) = parse_semantic_doc_id_filter_view(doc_id) else {
386            return false;
387        };
388
389        if let Some(agents) = &self.agents
390            && !agents.contains(&parsed.agent_id)
391        {
392            return false;
393        }
394        if let Some(workspaces) = &self.workspaces
395            && !workspaces.contains(&parsed.workspace_id)
396        {
397            return false;
398        }
399        if let Some(sources) = &self.sources
400            && !sources.contains(&parsed.source_id)
401        {
402            return false;
403        }
404        if let Some(roles) = &self.roles
405            && !roles.contains(&parsed.role)
406        {
407            return false;
408        }
409        if let Some(from) = self.created_from
410            && parsed.created_at_ms < from
411        {
412            return false;
413        }
414        if let Some(to) = self.created_to
415            && parsed.created_at_ms > to
416        {
417            return false;
418        }
419
420        true
421    }
422
423    fn matches_doc_id_hash(
424        &self,
425        _doc_id_hash: u64,
426        _metadata: Option<&serde_json::Value>,
427    ) -> Option<bool> {
428        None
429    }
430
431    fn name(&self) -> &str {
432        "cass_semantic_filter"
433    }
434}
435
436/// Scalar dot product benchmark helper.
437#[must_use]
438pub fn dot_product_scalar_bench(a: &[f32], b: &[f32]) -> f32 {
439    a.iter().zip(b).map(|(x, y)| x * y).sum()
440}
441
442/// SIMD dot product benchmark helper (uses frankensearch's portable SIMD).
443#[must_use]
444pub fn dot_product_simd_bench(a: &[f32], b: &[f32]) -> f32 {
445    frankensearch::index::dot_product_f32_f32(a, b).expect("dot product inputs must match length")
446}
447
448/// Scalar dot product benchmark helper for f16 stored vectors vs f32 query.
449#[must_use]
450pub fn dot_product_f16_scalar_bench(stored: &[f16], query: &[f32]) -> f32 {
451    stored.iter().zip(query).map(|(x, y)| x.to_f32() * y).sum()
452}
453
454/// SIMD dot product benchmark helper for f16 stored vectors vs f32 query.
455#[must_use]
456pub fn dot_product_f16_simd_bench(stored: &[f16], query: &[f32]) -> f32 {
457    frankensearch::index::dot_product_f16_f32(stored, query)
458        .expect("dot product inputs must match length")
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn role_code_from_str_accepts_known_roles() {
467        let cases = [
468            ("user", Some(ROLE_USER)),
469            ("assistant", Some(ROLE_ASSISTANT)),
470            ("agent", Some(ROLE_ASSISTANT)),
471            ("system", Some(ROLE_SYSTEM)),
472            ("tool", Some(ROLE_TOOL)),
473            ("unknown", None),
474        ];
475
476        for (role, expected_code) in cases {
477            assert_eq!(role_code_from_str(role), expected_code, "{role}");
478        }
479    }
480
481    #[test]
482    fn parse_role_codes_rejects_unknown_roles() {
483        let err = parse_role_codes(["user", "bogus"]).unwrap_err();
484        assert!(err.to_string().contains("unknown role"));
485    }
486
487    #[test]
488    fn vector_index_path_points_to_fsvi() {
489        let dir = Path::new("/tmp/cass");
490        let p = vector_index_path(dir, "fnv1a-384");
491        assert!(p.ends_with("vector_index/index-fnv1a-384.fsvi"));
492    }
493
494    #[test]
495    fn semantic_doc_id_roundtrip_with_hash() {
496        let hash = [0u8; 32];
497        let doc_id = SemanticDocId {
498            message_id: 42,
499            chunk_idx: 2,
500            agent_id: 3,
501            workspace_id: 7,
502            source_id: 11,
503            role: 1,
504            created_at_ms: 1_700_000_000_000,
505            content_hash: Some(hash),
506        }
507        .to_doc_id_string();
508        let parsed = parse_semantic_doc_id(&doc_id).expect("parse");
509        assert_eq!(parsed.message_id, 42);
510        assert_eq!(parsed.chunk_idx, 2);
511        assert_eq!(parsed.agent_id, 3);
512        assert_eq!(parsed.workspace_id, 7);
513        assert_eq!(parsed.source_id, 11);
514        assert_eq!(parsed.role, 1);
515        assert_eq!(parsed.created_at_ms, 1_700_000_000_000);
516        assert_eq!(parsed.content_hash, Some(hash));
517    }
518
519    #[test]
520    fn semantic_doc_id_roundtrip_without_hash() {
521        let doc_id = SemanticDocId {
522            message_id: 42,
523            chunk_idx: 2,
524            agent_id: 3,
525            workspace_id: 7,
526            source_id: 11,
527            role: 1,
528            created_at_ms: 1_700_000_000_000,
529            content_hash: None,
530        }
531        .to_doc_id_string();
532        let parsed = parse_semantic_doc_id(&doc_id).expect("parse");
533        assert_eq!(parsed.message_id, 42);
534        assert_eq!(parsed.chunk_idx, 2);
535        assert_eq!(parsed.agent_id, 3);
536        assert_eq!(parsed.workspace_id, 7);
537        assert_eq!(parsed.source_id, 11);
538        assert_eq!(parsed.role, 1);
539        assert_eq!(parsed.created_at_ms, 1_700_000_000_000);
540        assert_eq!(parsed.content_hash, None);
541    }
542}