1use std::collections::{HashMap, HashSet};
9use std::path::{Path, PathBuf};
10
11use anyhow::{Result, anyhow};
12use frankensqlite::Connection as FrankenConnection;
13use frankensqlite::compat::{ConnectionExt, RowExt};
14use half::f16;
15
16pub use frankensearch::index::{Quantization, SearchParams, VectorIndex, VectorIndexWriter};
17
18use crate::search::query::SearchFilters;
19use crate::sources::provenance::{LOCAL_SOURCE_ID, SourceFilter, SourceKind};
20use crate::storage::sqlite::FrankenStorage;
21
22pub const VECTOR_INDEX_DIR: &str = "vector_index";
24
25pub const ROLE_USER: u8 = 0;
27pub const ROLE_ASSISTANT: u8 = 1;
28pub const ROLE_SYSTEM: u8 = 2;
29pub const ROLE_TOOL: u8 = 3;
30
31#[must_use]
33pub fn role_code_from_str(role: &str) -> Option<u8> {
34 match role {
35 "user" => Some(ROLE_USER),
36 "assistant" | "agent" => Some(ROLE_ASSISTANT),
38 "system" => Some(ROLE_SYSTEM),
39 "tool" => Some(ROLE_TOOL),
40 _ => None,
41 }
42}
43
44pub fn parse_role_codes<I, S>(roles: I) -> Result<HashSet<u8>>
50where
51 I: IntoIterator<Item = S>,
52 S: AsRef<str>,
53{
54 let mut out = HashSet::new();
55 for role in roles {
56 let role_str = role.as_ref();
57 let code =
58 role_code_from_str(role_str).ok_or_else(|| anyhow!("unknown role: {role_str}"))?;
59 out.insert(code);
60 }
61 Ok(out)
62}
63
64#[must_use]
66pub fn vector_index_path(data_dir: &Path, embedder_id: &str) -> PathBuf {
67 data_dir
68 .join(VECTOR_INDEX_DIR)
69 .join(format!("index-{embedder_id}.fsvi"))
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub struct SemanticDocId {
75 pub message_id: u64,
76 pub chunk_idx: u8,
77 pub agent_id: u32,
78 pub workspace_id: u32,
79 pub source_id: u32,
80 pub role: u8,
81 pub created_at_ms: i64,
82 pub content_hash: Option<[u8; 32]>,
83}
84
85impl SemanticDocId {
86 #[must_use]
94 pub fn to_doc_id_string(&self) -> String {
95 let capacity = 2 + (7 * 20) + 6 + if self.content_hash.is_some() { 65 } else { 0 };
99 let mut out = String::with_capacity(capacity);
100 let mut buf = itoa::Buffer::new();
101 out.push_str("m|");
102 out.push_str(buf.format(self.message_id));
103 out.push('|');
104 out.push_str(buf.format(self.chunk_idx));
105 out.push('|');
106 out.push_str(buf.format(self.agent_id));
107 out.push('|');
108 out.push_str(buf.format(self.workspace_id));
109 out.push('|');
110 out.push_str(buf.format(self.source_id));
111 out.push('|');
112 out.push_str(buf.format(self.role));
113 out.push('|');
114 out.push_str(buf.format(self.created_at_ms));
115 if let Some(hash) = self.content_hash {
116 out.push('|');
117 let mut hex_buf = [0u8; 64];
121 hex::encode_to_slice(hash, &mut hex_buf)
122 .expect("32 bytes encode to exactly 64 hex chars");
123 out.push_str(std::str::from_utf8(&hex_buf).expect("hex output is always valid ASCII"));
124 }
125 out
126 }
127}
128
129#[must_use]
134pub fn parse_semantic_doc_id(doc_id: &str) -> Option<SemanticDocId> {
135 let rest = doc_id.strip_prefix("m|")?;
141 let mut parts = rest.splitn(8, '|');
142 let parsed = SemanticDocId {
143 message_id: parts.next()?.parse().ok()?,
144 chunk_idx: parts.next()?.parse().ok()?,
145 agent_id: parts.next()?.parse().ok()?,
146 workspace_id: parts.next()?.parse().ok()?,
147 source_id: parts.next()?.parse().ok()?,
148 role: parts.next()?.parse().ok()?,
149 created_at_ms: parts.next()?.parse().ok()?,
150 content_hash: parts.next().and_then(|hash_hex| {
151 if hash_hex.len() != 64 {
152 return None;
153 }
154 let mut hash = [0u8; 32];
155 hex::decode_to_slice(hash_hex, &mut hash).ok()?;
156 Some(hash)
157 }),
158 };
159
160 Some(parsed)
161}
162
163#[derive(Debug, Clone, Copy)]
171pub(crate) struct SemanticDocIdFilterView {
172 pub agent_id: u32,
173 pub workspace_id: u32,
174 pub source_id: u32,
175 pub role: u8,
176 pub created_at_ms: i64,
177}
178
179#[must_use]
184pub(crate) fn parse_semantic_doc_id_filter_view(doc_id: &str) -> Option<SemanticDocIdFilterView> {
185 let rest = doc_id.strip_prefix("m|")?;
186 let mut parts = rest.splitn(8, '|');
187 parts.next()?;
189 parts.next()?;
190 let agent_id: u32 = parts.next()?.parse().ok()?;
191 let workspace_id: u32 = parts.next()?.parse().ok()?;
192 let source_id: u32 = parts.next()?.parse().ok()?;
193 let role: u8 = parts.next()?.parse().ok()?;
194 let created_at_ms: i64 = parts.next()?.parse().ok()?;
195 Some(SemanticDocIdFilterView {
196 agent_id,
197 workspace_id,
198 source_id,
199 role,
200 created_at_ms,
201 })
202}
203
204fn map_filter_set(keys: &HashSet<String>, map: &HashMap<String, u32>) -> Option<HashSet<u32>> {
205 if keys.is_empty() {
206 return None;
207 }
208 let mut set = HashSet::new();
209 for key in keys {
210 if let Some(id) = map.get(key) {
211 set.insert(*id);
212 }
213 }
214 Some(set)
215}
216
217fn source_id_hash(source_id: &str) -> u32 {
218 let mut hasher = crc32fast::Hasher::new();
219 hasher.update(source_id.as_bytes());
220 hasher.finalize()
221}
222
223#[derive(Debug, Clone, Default)]
225pub struct SemanticFilter {
226 pub agents: Option<HashSet<u32>>,
227 pub workspaces: Option<HashSet<u32>>,
228 pub sources: Option<HashSet<u32>>,
229 pub roles: Option<HashSet<u8>>,
230 pub created_from: Option<i64>,
231 pub created_to: Option<i64>,
232}
233
234impl SemanticFilter {
235 pub fn from_search_filters(filters: &SearchFilters, maps: &SemanticFilterMaps) -> Result<Self> {
236 let agents = map_filter_set(&filters.agents, &maps.agent_slug_to_id);
237 let workspaces = map_filter_set(&filters.workspaces, &maps.workspace_path_to_id);
238 let sources = maps.sources_from_filter(&filters.source_filter)?;
239
240 Ok(Self {
241 agents,
242 workspaces,
243 sources,
244 roles: None,
245 created_from: filters.created_from,
246 created_to: filters.created_to,
247 })
248 }
249
250 #[must_use]
251 pub fn is_unrestricted(&self) -> bool {
252 self.agents.is_none()
253 && self.workspaces.is_none()
254 && self.sources.is_none()
255 && self.roles.is_none()
256 && self.created_from.is_none()
257 && self.created_to.is_none()
258 }
259
260 #[must_use]
261 pub fn with_roles(mut self, roles: Option<HashSet<u8>>) -> Self {
262 self.roles = roles;
263 self
264 }
265}
266
267#[derive(Debug, Clone)]
270pub struct SemanticFilterMaps {
271 agent_slug_to_id: HashMap<String, u32>,
272 workspace_path_to_id: HashMap<String, u32>,
273 source_id_to_id: HashMap<String, u32>,
274 remote_source_ids: HashSet<u32>,
275}
276
277impl SemanticFilterMaps {
278 pub fn from_storage(storage: &FrankenStorage) -> Result<Self> {
279 Self::from_connection(storage.raw())
280 }
281
282 pub fn from_connection(conn: &FrankenConnection) -> Result<Self> {
283 let mut agent_slug_to_id = HashMap::new();
284 let agent_rows = conn.query_map_collect(
285 "SELECT id, slug FROM agents",
286 &[],
287 |row: &frankensqlite::Row| {
288 let id: i64 = row.get_typed(0)?;
289 let slug: String = row.get_typed(1)?;
290 Ok((id, slug))
291 },
292 )?;
293 for (id, slug) in agent_rows {
294 let id_u32 = u32::try_from(id).map_err(|_| anyhow!("agent id out of range"))?;
295 agent_slug_to_id.insert(slug, id_u32);
296 }
297
298 let mut workspace_path_to_id = HashMap::new();
299 let workspace_rows = conn.query_map_collect(
300 "SELECT id, path FROM workspaces",
301 &[],
302 |row: &frankensqlite::Row| {
303 let id: i64 = row.get_typed(0)?;
304 let path: String = row.get_typed(1)?;
305 Ok((id, path))
306 },
307 )?;
308 for (id, path) in workspace_rows {
309 let id_u32 = u32::try_from(id).map_err(|_| anyhow!("workspace id out of range"))?;
310 workspace_path_to_id.insert(path, id_u32);
311 }
312
313 let mut source_id_to_id = HashMap::new();
314 let mut remote_source_ids = HashSet::new();
315 let source_rows = conn.query_map_collect(
316 "SELECT id, kind FROM sources",
317 &[],
318 |row: &frankensqlite::Row| {
319 let id: String = row.get_typed(0)?;
320 let kind: String = row.get_typed(1)?;
321 Ok((id, kind))
322 },
323 )?;
324 for (id, kind) in source_rows {
325 let id_u32 = source_id_hash(&id);
326 if SourceKind::parse(&kind).is_none_or(|k| k.is_remote()) {
327 remote_source_ids.insert(id_u32);
328 }
329 source_id_to_id.insert(id, id_u32);
330 }
331
332 Ok(Self {
333 agent_slug_to_id,
334 workspace_path_to_id,
335 source_id_to_id,
336 remote_source_ids,
337 })
338 }
339
340 #[cfg(test)]
341 pub(crate) fn for_tests(
342 agent_slug_to_id: HashMap<String, u32>,
343 workspace_path_to_id: HashMap<String, u32>,
344 source_id_to_id: HashMap<String, u32>,
345 remote_source_ids: HashSet<u32>,
346 ) -> Self {
347 Self {
348 agent_slug_to_id,
349 workspace_path_to_id,
350 source_id_to_id,
351 remote_source_ids,
352 }
353 }
354
355 fn sources_from_filter(&self, filter: &SourceFilter) -> Result<Option<HashSet<u32>>> {
356 let result = match filter {
357 SourceFilter::All => None,
358 SourceFilter::Local => Some(HashSet::from([self.source_id(LOCAL_SOURCE_ID)])),
359 SourceFilter::Remote => Some(self.remote_source_ids.clone()),
360 SourceFilter::SourceId(id) => Some(HashSet::from([self.source_id(id)])),
361 };
362 Ok(result)
363 }
364
365 fn source_id(&self, source_id: &str) -> u32 {
366 self.source_id_to_id
367 .get(source_id)
368 .copied()
369 .unwrap_or_else(|| source_id_hash(source_id))
370 }
371}
372
373#[derive(Debug, Clone)]
375pub struct VectorSearchResult {
376 pub message_id: u64,
377 pub chunk_idx: u8,
378 pub score: f32,
379}
380
381impl frankensearch::core::filter::SearchFilter for SemanticFilter {
382 fn matches(&self, doc_id: &str, _metadata: Option<&serde_json::Value>) -> bool {
383 let Some(parsed) = parse_semantic_doc_id_filter_view(doc_id) else {
386 return false;
387 };
388
389 if let Some(agents) = &self.agents
390 && !agents.contains(&parsed.agent_id)
391 {
392 return false;
393 }
394 if let Some(workspaces) = &self.workspaces
395 && !workspaces.contains(&parsed.workspace_id)
396 {
397 return false;
398 }
399 if let Some(sources) = &self.sources
400 && !sources.contains(&parsed.source_id)
401 {
402 return false;
403 }
404 if let Some(roles) = &self.roles
405 && !roles.contains(&parsed.role)
406 {
407 return false;
408 }
409 if let Some(from) = self.created_from
410 && parsed.created_at_ms < from
411 {
412 return false;
413 }
414 if let Some(to) = self.created_to
415 && parsed.created_at_ms > to
416 {
417 return false;
418 }
419
420 true
421 }
422
423 fn matches_doc_id_hash(
424 &self,
425 _doc_id_hash: u64,
426 _metadata: Option<&serde_json::Value>,
427 ) -> Option<bool> {
428 None
429 }
430
431 fn name(&self) -> &str {
432 "cass_semantic_filter"
433 }
434}
435
436#[must_use]
438pub fn dot_product_scalar_bench(a: &[f32], b: &[f32]) -> f32 {
439 a.iter().zip(b).map(|(x, y)| x * y).sum()
440}
441
442#[must_use]
444pub fn dot_product_simd_bench(a: &[f32], b: &[f32]) -> f32 {
445 frankensearch::index::dot_product_f32_f32(a, b).expect("dot product inputs must match length")
446}
447
448#[must_use]
450pub fn dot_product_f16_scalar_bench(stored: &[f16], query: &[f32]) -> f32 {
451 stored.iter().zip(query).map(|(x, y)| x.to_f32() * y).sum()
452}
453
454#[must_use]
456pub fn dot_product_f16_simd_bench(stored: &[f16], query: &[f32]) -> f32 {
457 frankensearch::index::dot_product_f16_f32(stored, query)
458 .expect("dot product inputs must match length")
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn role_code_from_str_accepts_known_roles() {
467 let cases = [
468 ("user", Some(ROLE_USER)),
469 ("assistant", Some(ROLE_ASSISTANT)),
470 ("agent", Some(ROLE_ASSISTANT)),
471 ("system", Some(ROLE_SYSTEM)),
472 ("tool", Some(ROLE_TOOL)),
473 ("unknown", None),
474 ];
475
476 for (role, expected_code) in cases {
477 assert_eq!(role_code_from_str(role), expected_code, "{role}");
478 }
479 }
480
481 #[test]
482 fn parse_role_codes_rejects_unknown_roles() {
483 let err = parse_role_codes(["user", "bogus"]).unwrap_err();
484 assert!(err.to_string().contains("unknown role"));
485 }
486
487 #[test]
488 fn vector_index_path_points_to_fsvi() {
489 let dir = Path::new("/tmp/cass");
490 let p = vector_index_path(dir, "fnv1a-384");
491 assert!(p.ends_with("vector_index/index-fnv1a-384.fsvi"));
492 }
493
494 #[test]
495 fn semantic_doc_id_roundtrip_with_hash() {
496 let hash = [0u8; 32];
497 let doc_id = SemanticDocId {
498 message_id: 42,
499 chunk_idx: 2,
500 agent_id: 3,
501 workspace_id: 7,
502 source_id: 11,
503 role: 1,
504 created_at_ms: 1_700_000_000_000,
505 content_hash: Some(hash),
506 }
507 .to_doc_id_string();
508 let parsed = parse_semantic_doc_id(&doc_id).expect("parse");
509 assert_eq!(parsed.message_id, 42);
510 assert_eq!(parsed.chunk_idx, 2);
511 assert_eq!(parsed.agent_id, 3);
512 assert_eq!(parsed.workspace_id, 7);
513 assert_eq!(parsed.source_id, 11);
514 assert_eq!(parsed.role, 1);
515 assert_eq!(parsed.created_at_ms, 1_700_000_000_000);
516 assert_eq!(parsed.content_hash, Some(hash));
517 }
518
519 #[test]
520 fn semantic_doc_id_roundtrip_without_hash() {
521 let doc_id = SemanticDocId {
522 message_id: 42,
523 chunk_idx: 2,
524 agent_id: 3,
525 workspace_id: 7,
526 source_id: 11,
527 role: 1,
528 created_at_ms: 1_700_000_000_000,
529 content_hash: None,
530 }
531 .to_doc_id_string();
532 let parsed = parse_semantic_doc_id(&doc_id).expect("parse");
533 assert_eq!(parsed.message_id, 42);
534 assert_eq!(parsed.chunk_idx, 2);
535 assert_eq!(parsed.agent_id, 3);
536 assert_eq!(parsed.workspace_id, 7);
537 assert_eq!(parsed.source_id, 11);
538 assert_eq!(parsed.role, 1);
539 assert_eq!(parsed.created_at_ms, 1_700_000_000_000);
540 assert_eq!(parsed.content_hash, None);
541 }
542}