Skip to main content

cqs/store/
helpers.rs

1//! Store helper types and embedding conversion functions
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5use thiserror::Error;
6
7use crate::embedder::Embedding;
8use crate::parser::{ChunkType, Language};
9
10/// Schema version for database migrations
11///
12/// Increment this when changing the database schema. Store::open() checks this
13/// against the stored version and returns StoreError::SchemaMismatch if different.
14///
15/// History:
16/// - v10: Current (sentiment in embeddings, call graph, notes)
17pub const CURRENT_SCHEMA_VERSION: i32 = 10;
18pub const MODEL_NAME: &str = "intfloat/e5-base-v2";
19/// Expected embedding dimensions — derived from crate::EMBEDDING_DIM
20pub const EXPECTED_DIMENSIONS: u32 = crate::EMBEDDING_DIM as u32;
21
22#[derive(Error, Debug)]
23pub enum StoreError {
24    #[error("Database error: {0}")]
25    Database(#[from] sqlx::Error),
26    #[error("IO error: {0}")]
27    Io(#[from] std::io::Error),
28    #[error("System time error: file mtime before Unix epoch")]
29    SystemTime,
30    #[error("Runtime error: {0}")]
31    Runtime(String),
32    #[error("Schema version mismatch in {0}: index is v{1}, cqs expects v{2}. Run 'cqs index --force' to rebuild.")]
33    SchemaMismatch(String, i32, i32),
34    #[error("Index created by newer cqs version (schema v{0}). Please upgrade cqs.")]
35    SchemaNewerThanCq(i32),
36    #[error("No migration path from schema v{0} to v{1}. Run 'cqs index --force' to rebuild.")]
37    MigrationNotSupported(i32, i32),
38    #[error(
39        "Model mismatch: index uses '{0}', current is '{1}'. Run 'cqs index --force' to re-embed."
40    )]
41    ModelMismatch(String, String),
42    #[error(
43        "Dimension mismatch: index has {0}-dim embeddings, current model expects {1}. Run 'cqs index --force' to rebuild."
44    )]
45    DimensionMismatch(u32, u32),
46}
47
48/// Raw row from chunks table (crate-internal, used by search module)
49#[derive(Clone)]
50pub(crate) struct ChunkRow {
51    pub id: String,
52    pub origin: String,
53    pub language: String,
54    pub chunk_type: String,
55    pub name: String,
56    pub signature: String,
57    pub content: String,
58    pub doc: Option<String>,
59    pub line_start: u32,
60    pub line_end: u32,
61    pub parent_id: Option<String>,
62}
63
64impl ChunkRow {
65    /// Construct from a SQLite row containing columns:
66    /// id, origin, language, chunk_type, name, signature, content, doc, line_start, line_end, parent_id
67    pub(crate) fn from_row(row: &sqlx::sqlite::SqliteRow) -> Self {
68        use sqlx::Row;
69        ChunkRow {
70            id: row.get("id"),
71            origin: row.get("origin"),
72            language: row.get("language"),
73            chunk_type: row.get("chunk_type"),
74            name: row.get("name"),
75            signature: row.get("signature"),
76            content: row.get("content"),
77            doc: row.get("doc"),
78            line_start: clamp_line_number(row.get::<i64, _>("line_start")),
79            line_end: clamp_line_number(row.get::<i64, _>("line_end")),
80            parent_id: row.get("parent_id"),
81        }
82    }
83}
84
85/// Chunk metadata returned from search results
86///
87/// Contains all chunk information except the embedding vector.
88#[derive(Debug, Clone)]
89pub struct ChunkSummary {
90    /// Unique identifier
91    pub id: String,
92    /// Source file path (typically absolute, as stored during indexing)
93    pub file: PathBuf,
94    /// Programming language
95    pub language: Language,
96    /// Type of code element
97    pub chunk_type: ChunkType,
98    /// Name of the function/class/etc.
99    pub name: String,
100    /// Function signature or declaration
101    pub signature: String,
102    /// Full source code
103    pub content: String,
104    /// Documentation comment if present
105    pub doc: Option<String>,
106    /// Starting line number (1-indexed)
107    pub line_start: u32,
108    /// Ending line number (1-indexed)
109    pub line_end: u32,
110}
111
112impl From<ChunkRow> for ChunkSummary {
113    fn from(row: ChunkRow) -> Self {
114        let language = row.language.parse().unwrap_or_else(|_| {
115            tracing::warn!(
116                chunk_id = %row.id,
117                stored_value = %row.language,
118                "Failed to parse language from database, defaulting to Rust"
119            );
120            Language::Rust
121        });
122        let chunk_type = row.chunk_type.parse().unwrap_or_else(|_| {
123            tracing::warn!(
124                chunk_id = %row.id,
125                stored_value = %row.chunk_type,
126                "Failed to parse chunk_type from database, defaulting to Function"
127            );
128            ChunkType::Function
129        });
130        ChunkSummary {
131            id: row.id,
132            file: PathBuf::from(row.origin),
133            language,
134            chunk_type,
135            name: row.name,
136            signature: row.signature,
137            content: row.content,
138            doc: row.doc,
139            line_start: row.line_start,
140            line_end: row.line_end,
141        }
142    }
143}
144
145/// A search result with similarity score
146#[derive(Debug)]
147pub struct SearchResult {
148    /// The matching chunk
149    pub chunk: ChunkSummary,
150    /// Similarity score (0.0 to 1.0, higher is better)
151    pub score: f32,
152}
153
154/// Caller information from the full call graph
155///
156/// Unlike ChunkSummary, this doesn't require a chunk to exist -
157/// it captures callers from large functions that exceed chunk size limits.
158#[derive(Debug, Clone)]
159pub struct CallerInfo {
160    /// Function name
161    pub name: String,
162    /// Source file path
163    pub file: PathBuf,
164    /// Line where function starts
165    pub line: u32,
166}
167
168/// Caller with call-site context for impact analysis
169///
170/// Enriches CallerInfo with the specific line where the call occurs,
171/// enabling snippet extraction without reading the source file.
172#[derive(Debug, Clone)]
173pub struct CallerWithContext {
174    /// Function name of the caller
175    pub name: String,
176    /// Source file path
177    pub file: PathBuf,
178    /// Line where the calling function starts
179    pub line: u32,
180    /// Line where the call to the target occurs
181    pub call_line: u32,
182}
183
184/// In-memory call graph for BFS traversal
185///
186/// Built from a single scan of the `function_calls` table.
187/// Both forward and reverse adjacency lists are included
188/// to support trace (forward BFS) and impact/test-map (reverse BFS).
189pub struct CallGraph {
190    /// Forward edges: caller_name -> Vec<callee_name>
191    pub forward: HashMap<String, Vec<String>>,
192    /// Reverse edges: callee_name -> Vec<caller_name>
193    pub reverse: HashMap<String, Vec<String>>,
194}
195
196/// Chunk identity for diff comparison
197///
198/// Minimal metadata needed to identify and match chunks across stores.
199/// Does not include content or embeddings.
200#[derive(Debug, Clone)]
201pub struct ChunkIdentity {
202    /// Unique chunk identifier
203    pub id: String,
204    /// Source file path
205    pub origin: String,
206    /// Function/class/etc. name
207    pub name: String,
208    /// Type of code element (e.g., "function", "class")
209    pub chunk_type: String,
210    /// Starting line number (1-indexed)
211    pub line_start: u32,
212    /// Programming language
213    pub language: String,
214    /// Parent chunk ID (for windowed chunks)
215    pub parent_id: Option<String>,
216    /// Window index within parent (for long functions split into windows)
217    pub window_idx: Option<u32>,
218}
219
220/// Note statistics (total count and categorized counts)
221#[derive(Debug, Clone)]
222pub struct NoteStats {
223    /// Total number of notes
224    pub total: u64,
225    /// Notes with negative sentiment (warnings)
226    pub warnings: u64,
227    /// Notes with positive sentiment (patterns)
228    pub patterns: u64,
229}
230
231/// Note metadata returned from search results
232#[derive(Debug, Clone)]
233pub struct NoteSummary {
234    /// Unique identifier
235    pub id: String,
236    /// Note content
237    pub text: String,
238    /// Sentiment: -1.0 to +1.0
239    pub sentiment: f32,
240    /// Mentioned code paths/functions
241    pub mentions: Vec<String>,
242}
243
244/// A note search result with similarity score
245#[derive(Debug)]
246pub struct NoteSearchResult {
247    /// The matching note
248    pub note: NoteSummary,
249    /// Similarity score (0.0 to 1.0)
250    pub score: f32,
251}
252
253/// Unified search result (code chunk or note)
254///
255/// Search can return both code chunks and notes. This enum allows
256/// handling them uniformly while preserving type-specific data.
257#[derive(Debug)]
258pub enum UnifiedResult {
259    /// A code chunk search result
260    Code(SearchResult),
261    /// A note search result
262    Note(NoteSearchResult),
263}
264
265impl UnifiedResult {
266    /// Get the similarity score
267    pub fn score(&self) -> f32 {
268        match self {
269            UnifiedResult::Code(r) => r.score,
270            UnifiedResult::Note(r) => r.score,
271        }
272    }
273}
274
275/// Filter and scoring options for search
276///
277/// All fields are optional. Unset filters match all chunks.
278/// Use `validate()` to check constraints before searching.
279pub struct SearchFilter {
280    /// Filter by programming language(s)
281    pub languages: Option<Vec<Language>>,
282    /// Filter by chunk type(s) (function, method, class, struct, enum, trait, interface, constant)
283    pub chunk_types: Option<Vec<ChunkType>>,
284    /// Filter by file path glob pattern (e.g., `src/**/*.rs`)
285    pub path_pattern: Option<String>,
286    /// Weight for name matching in hybrid search (0.0-1.0)
287    ///
288    /// 0.0 = pure embedding similarity (default)
289    /// 1.0 = pure name matching
290    /// 0.2 = recommended for balanced results
291    pub name_boost: f32,
292    /// Query text for name matching (required if name_boost > 0 or enable_rrf)
293    pub query_text: String,
294    /// Enable RRF (Reciprocal Rank Fusion) hybrid search
295    ///
296    /// When enabled, combines semantic search results with FTS5 keyword search
297    /// using the formula: score = Σ 1/(k + rank), where k=60.
298    /// This typically improves recall for identifier-heavy queries.
299    pub enable_rrf: bool,
300    /// Weight multiplier for note scores in unified search (0.0-1.0)
301    ///
302    /// 1.0 = notes scored equally with code (default)
303    /// 0.5 = notes scored at half weight
304    /// 0.0 = notes excluded from results
305    pub note_weight: f32,
306    /// When true, return only notes (skip code search entirely)
307    pub note_only: bool,
308}
309
310impl Default for SearchFilter {
311    fn default() -> Self {
312        Self {
313            languages: None,
314            chunk_types: None,
315            path_pattern: None,
316            name_boost: 0.0,
317            query_text: String::new(),
318            enable_rrf: false,
319            note_weight: 1.0, // Notes weighted equally by default
320            note_only: false,
321        }
322    }
323}
324
325impl SearchFilter {
326    /// Create a new SearchFilter with default values.
327    ///
328    /// Use struct literal syntax to customize:
329    /// ```ignore
330    /// let filter = SearchFilter {
331    ///     languages: Some(vec![Language::Rust]),
332    ///     path_pattern: Some("src/**/*.rs".to_string()),
333    ///     query_text: "retry logic".to_string(),
334    ///     ..SearchFilter::new()
335    /// };
336    /// ```
337    pub fn new() -> Self {
338        Self::default()
339    }
340
341    /// Set the query text (required for name_boost > 0 or enable_rrf).
342    pub fn with_query(mut self, query: impl Into<String>) -> Self {
343        self.query_text = query.into();
344        self
345    }
346
347    /// Validate filter constraints
348    ///
349    /// Returns Ok(()) if valid, or Err with description of what's wrong.
350    pub fn validate(&self) -> Result<(), &'static str> {
351        // name_boost must be in [0.0, 1.0] (NaN-safe: NaN is not contained in any range)
352        if !(0.0..=1.0).contains(&self.name_boost) {
353            return Err("name_boost must be between 0.0 and 1.0");
354        }
355
356        // note_weight must be in [0.0, 1.0] (NaN-safe)
357        if !(0.0..=1.0).contains(&self.note_weight) {
358            return Err("note_weight must be between 0.0 and 1.0");
359        }
360
361        // note_only with note_weight=0 is contradictory
362        if self.note_only && self.note_weight == 0.0 {
363            return Err("note_only=true with note_weight=0.0 is contradictory");
364        }
365
366        // query_text required when name_boost > 0 or enable_rrf
367        if (self.name_boost > 0.0 || self.enable_rrf) && self.query_text.is_empty() {
368            return Err("query_text required when name_boost > 0 or enable_rrf is true");
369        }
370
371        // path_pattern must be valid glob syntax if provided
372        if let Some(ref pattern) = self.path_pattern {
373            if pattern.len() > 500 {
374                return Err("path_pattern too long (max 500 chars)");
375            }
376            // Reject control characters (except tab/newline which glob might handle)
377            if pattern
378                .chars()
379                .any(|c| c.is_control() && c != '\t' && c != '\n')
380            {
381                return Err("path_pattern contains invalid control characters");
382            }
383            // Limit brace nesting depth to prevent exponential expansion
384            // e.g., "{a,{b,{c,{d,{e,...}}}}}" can cause O(2^n) expansion
385            const MAX_BRACE_DEPTH: usize = 10;
386            let mut depth = 0usize;
387            for c in pattern.chars() {
388                match c {
389                    '{' => {
390                        depth += 1;
391                        if depth > MAX_BRACE_DEPTH {
392                            return Err("path_pattern has too many nested braces (max 10 levels)");
393                        }
394                    }
395                    '}' => depth = depth.saturating_sub(1),
396                    _ => {}
397                }
398            }
399            if globset::Glob::new(pattern).is_err() {
400                return Err("path_pattern is not a valid glob pattern");
401            }
402        }
403
404        Ok(())
405    }
406}
407
408/// Model metadata for index initialization
409pub struct ModelInfo {
410    pub name: String,
411    pub dimensions: u32,
412    pub version: String,
413}
414
415impl Default for ModelInfo {
416    fn default() -> Self {
417        ModelInfo {
418            name: MODEL_NAME.to_string(),
419            dimensions: 769,          // 768 from model + 1 sentiment
420            version: "2".to_string(), // E5-base-v2
421        }
422    }
423}
424
425/// Index statistics
426///
427/// Provides overview information about the indexed codebase.
428/// Retrieved via `Store::stats()`.
429#[derive(Debug)]
430pub struct IndexStats {
431    /// Total number of code chunks indexed
432    pub total_chunks: u64,
433    /// Number of unique source files
434    pub total_files: u64,
435    /// Chunk count grouped by programming language
436    pub chunks_by_language: HashMap<Language, u64>,
437    /// Chunk count grouped by element type (function, class, etc.)
438    pub chunks_by_type: HashMap<ChunkType, u64>,
439    /// Database file size in bytes
440    pub index_size_bytes: u64,
441    /// ISO 8601 timestamp when index was created
442    pub created_at: String,
443    /// ISO 8601 timestamp of last update
444    pub updated_at: String,
445    /// Embedding model used (e.g., "intfloat/e5-base-v2")
446    pub model_name: String,
447    /// Database schema version
448    pub schema_version: i32,
449}
450
451// ============ Line Number Conversion ============
452
453/// Clamp i64 to valid u32 line number range (1-indexed)
454///
455/// SQLite returns i64, but line numbers are u32 and 1-indexed.
456/// This safely clamps to avoid truncation issues on extreme values,
457/// with minimum 1 since line 0 is invalid in 1-indexed systems.
458#[inline]
459pub fn clamp_line_number(n: i64) -> u32 {
460    n.clamp(1, u32::MAX as i64) as u32
461}
462
463// ============ Embedding Serialization ============
464
465/// Convert embedding to bytes for storage.
466///
467/// # Panics
468/// Panics if embedding is not exactly 769 dimensions (768 model + 1 sentiment).
469/// This is intentional - storing wrong-sized embeddings corrupts the index.
470pub fn embedding_to_bytes(embedding: &Embedding) -> Vec<u8> {
471    assert_eq!(
472        embedding.len(),
473        EXPECTED_DIMENSIONS as usize,
474        "Embedding dimension mismatch: expected {}, got {}. This indicates a bug in the embedder.",
475        EXPECTED_DIMENSIONS,
476        embedding.len()
477    );
478    embedding
479        .as_slice()
480        .iter()
481        .flat_map(|f| f.to_le_bytes())
482        .collect()
483}
484
485/// Zero-copy view of embedding bytes as f32 slice (for hot paths)
486///
487/// Returns None if byte length doesn't match expected embedding size.
488/// Uses trace level logging to avoid impacting search performance.
489pub fn embedding_slice(bytes: &[u8]) -> Option<&[f32]> {
490    const EXPECTED_BYTES: usize = crate::EMBEDDING_DIM * 4;
491    if bytes.len() != EXPECTED_BYTES {
492        tracing::trace!(
493            expected = EXPECTED_BYTES,
494            actual = bytes.len(),
495            "Embedding byte length mismatch, skipping"
496        );
497        return None;
498    }
499    Some(bytemuck::cast_slice(bytes))
500}
501
502/// Convert embedding bytes to owned Vec (when ownership needed)
503///
504/// Returns None if byte length doesn't match expected embedding size (769 * 4 bytes).
505/// This prevents silently using corrupted/truncated embeddings.
506/// Uses trace level logging consistent with embedding_slice() since both are called on hot paths.
507pub fn bytes_to_embedding(bytes: &[u8]) -> Option<Vec<f32>> {
508    const EXPECTED_BYTES: usize = crate::EMBEDDING_DIM * 4;
509    if bytes.len() != EXPECTED_BYTES {
510        tracing::trace!(
511            expected = EXPECTED_BYTES,
512            actual = bytes.len(),
513            "Embedding byte length mismatch, skipping"
514        );
515        return None;
516    }
517    Some(bytemuck::cast_slice::<u8, f32>(bytes).to_vec())
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    // ===== SearchFilter validation tests =====
525
526    #[test]
527    fn test_search_filter_valid_default() {
528        let filter = SearchFilter::default();
529        assert!(filter.validate().is_ok());
530    }
531
532    #[test]
533    fn test_search_filter_valid_with_name_boost() {
534        let filter = SearchFilter {
535            name_boost: 0.2,
536            query_text: "test".to_string(),
537            ..Default::default()
538        };
539        assert!(filter.validate().is_ok());
540    }
541
542    #[test]
543    fn test_search_filter_valid_with_rrf() {
544        let filter = SearchFilter {
545            enable_rrf: true,
546            query_text: "test".to_string(),
547            ..Default::default()
548        };
549        assert!(filter.validate().is_ok());
550    }
551
552    #[test]
553    fn test_search_filter_invalid_name_boost_negative() {
554        let filter = SearchFilter {
555            name_boost: -0.1,
556            ..Default::default()
557        };
558        assert!(filter.validate().is_err());
559        assert!(filter.validate().unwrap_err().contains("name_boost"));
560    }
561
562    #[test]
563    fn test_search_filter_invalid_name_boost_too_high() {
564        let filter = SearchFilter {
565            name_boost: 1.5,
566            query_text: "test".to_string(),
567            ..Default::default()
568        };
569        assert!(filter.validate().is_err());
570    }
571
572    #[test]
573    fn test_search_filter_invalid_missing_query_text() {
574        let filter = SearchFilter {
575            name_boost: 0.5,
576            query_text: String::new(),
577            ..Default::default()
578        };
579        assert!(filter.validate().is_err());
580        assert!(filter.validate().unwrap_err().contains("query_text"));
581    }
582
583    #[test]
584    fn test_search_filter_invalid_rrf_missing_query() {
585        let filter = SearchFilter {
586            enable_rrf: true,
587            query_text: String::new(),
588            ..Default::default()
589        };
590        assert!(filter.validate().is_err());
591    }
592
593    #[test]
594    fn test_search_filter_valid_path_pattern() {
595        let filter = SearchFilter {
596            path_pattern: Some("src/**/*.rs".to_string()),
597            ..Default::default()
598        };
599        assert!(filter.validate().is_ok());
600    }
601
602    #[test]
603    fn test_search_filter_invalid_path_pattern_syntax() {
604        let filter = SearchFilter {
605            path_pattern: Some("[invalid".to_string()),
606            ..Default::default()
607        };
608        assert!(filter.validate().is_err());
609        assert!(filter.validate().unwrap_err().contains("glob"));
610    }
611
612    #[test]
613    fn test_search_filter_path_pattern_too_long() {
614        let filter = SearchFilter {
615            path_pattern: Some("a".repeat(501)),
616            ..Default::default()
617        };
618        assert!(filter.validate().is_err());
619        assert!(filter.validate().unwrap_err().contains("too long"));
620    }
621
622    // ===== clamp_line_number tests =====
623
624    #[test]
625    fn test_clamp_line_number_normal() {
626        assert_eq!(clamp_line_number(1), 1);
627        assert_eq!(clamp_line_number(100), 100);
628    }
629
630    #[test]
631    fn test_clamp_line_number_negative() {
632        // Line numbers are 1-indexed, so negative/zero clamps to 1
633        assert_eq!(clamp_line_number(-1), 1);
634        assert_eq!(clamp_line_number(-1000), 1);
635        assert_eq!(clamp_line_number(0), 1);
636    }
637
638    #[test]
639    fn test_clamp_line_number_overflow() {
640        assert_eq!(clamp_line_number(i64::MAX), u32::MAX);
641        assert_eq!(clamp_line_number(u32::MAX as i64 + 1), u32::MAX);
642    }
643}