use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
pub const MAX_EMBEDDING_TOKENS: usize = 8192;
pub const MAX_CODE_PREVIEW_TOKENS: usize = 6000;
pub const CHUNK_OVERLAP_TOKENS: usize = 200;
#[derive(Debug, Clone)]
pub struct SemanticPattern {
pub name: &'static str,
pub pattern: &'static str,
}
pub static SEMANTIC_PATTERNS: &[SemanticPattern] = &[
SemanticPattern {
name: "crud",
pattern: r"\b(create|read|update|delete|insert|select|save|load|fetch|store|persist|get|set|add|remove)\b",
},
SemanticPattern {
name: "validation",
pattern: r"\b(valid|validate|check|verify|assert|ensure|sanitize|normalize|parse|format)\b",
},
SemanticPattern {
name: "transform",
pattern: r"\b(convert|transform|map|reduce|filter|sort|merge|split|join|serialize|deserialize)\b",
},
SemanticPattern {
name: "error_handling",
pattern: r"\b(try|catch|except|raise|throw|error|exception|fail|panic)\b",
},
SemanticPattern {
name: "async_ops",
pattern: r"\b(async|await|promise|future|callback|then|concurrent|parallel|thread)\b",
},
SemanticPattern {
name: "iteration",
pattern: r"\b(for|while|loop|iterate|each|map|reduce|filter)\b",
},
SemanticPattern {
name: "api_endpoint",
pattern: r"\b(route|endpoint|handler|controller|get|post|put|delete|patch|request|response)\b",
},
SemanticPattern {
name: "database",
pattern: r"\b(query|sql|select|insert|update|delete|table|schema|migration|model|entity)\b",
},
SemanticPattern {
name: "auth",
pattern: r"\b(auth|login|logout|session|token|jwt|oauth|permission|role|access)\b",
},
SemanticPattern {
name: "cache",
pattern: r"\b(cache|memoize|memo|store|redis|memcache|ttl|expire|invalidate)\b",
},
SemanticPattern {
name: "test",
pattern: r"\b(test|spec|mock|stub|assert|expect|should|describe|it)\b",
},
SemanticPattern {
name: "logging",
pattern: r"\b(log|logger|debug|info|warn|error|trace|print|console)\b",
},
SemanticPattern {
name: "config",
pattern: r"\b(config|setting|option|env|environment|parameter|argument)\b",
},
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum UnitKind {
Function,
Method,
Class,
Module,
Chunk,
}
impl UnitKind {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Function => "function",
Self::Method => "method",
Self::Class => "class",
Self::Module => "module",
Self::Chunk => "chunk",
}
}
}
impl std::fmt::Display for UnitKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CodeComplexity {
pub depth: usize,
pub branches: usize,
pub loops: usize,
}
impl CodeComplexity {
#[must_use]
pub fn empty() -> Self {
Self::default()
}
#[must_use]
pub fn is_complex(&self) -> bool {
self.depth > 3 || self.branches > 5 || self.loops > 2
}
#[must_use]
pub fn describe(&self) -> Option<String> {
let mut parts = Vec::new();
if self.depth > 3 {
parts.push("deep nesting");
}
if self.branches > 5 {
parts.push("many branches");
}
if self.loops > 2 {
parts.push("multiple loops");
}
if parts.is_empty() {
None
} else {
Some(parts.join(", "))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUnit {
pub id: String,
pub file: String,
pub name: String,
pub kind: UnitKind,
pub code: String,
pub signature: String,
pub docstring: Option<String>,
pub start_line: usize,
pub end_line: usize,
pub token_count: usize,
pub semantic_tags: Vec<String>,
pub parent: Option<String>,
pub language: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub calls: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub called_by: Vec<String>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub cfg_summary: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub dfg_summary: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub dependencies: String,
#[serde(default)]
pub complexity: CodeComplexity,
#[serde(default)]
pub chunk_index: usize,
#[serde(default = "default_chunk_total")]
pub chunk_total: usize,
}
fn default_chunk_total() -> usize {
1
}
impl EmbeddingUnit {
#[must_use]
pub fn new(
file: impl Into<String>,
name: impl Into<String>,
kind: UnitKind,
code: impl Into<String>,
start_line: usize,
language: impl Into<String>,
) -> Self {
let name = name.into();
let file = file.into();
let code = code.into();
Self {
id: format!("{}::{}", file, name),
file,
name,
kind,
code,
signature: String::new(),
docstring: None,
start_line,
end_line: start_line,
token_count: 0,
semantic_tags: Vec::new(),
parent: None,
language: language.into(),
calls: Vec::new(),
called_by: Vec::new(),
cfg_summary: String::new(),
dfg_summary: String::new(),
dependencies: String::new(),
complexity: CodeComplexity::default(),
chunk_index: 0,
chunk_total: 1,
}
}
#[must_use]
pub fn is_chunk(&self) -> bool {
self.chunk_total > 1
}
#[must_use]
pub fn needs_chunking(&self) -> bool {
self.token_count > MAX_EMBEDDING_TOKENS
}
#[must_use]
pub fn qualified_name(&self) -> String {
match &self.parent {
Some(parent) if self.kind == UnitKind::Method => {
format!("{}::{}.{}", self.file, parent, self.name)
}
_ => format!("{}::{}", self.file, self.name),
}
}
#[must_use]
pub fn to_map(&self) -> HashMap<String, serde_json::Value> {
serde_json::to_value(self)
.ok()
.and_then(|v| v.as_object().cloned())
.map(|m| m.into_iter().collect())
.unwrap_or_default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub unit: EmbeddingUnit,
pub score: f32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub highlights: Vec<String>,
}
impl SearchResult {
#[must_use]
pub fn new(unit: EmbeddingUnit, score: f32) -> Self {
Self {
unit,
score,
highlights: Vec::new(),
}
}
#[must_use]
pub fn with_highlights(unit: EmbeddingUnit, score: f32, highlights: Vec<String>) -> Self {
Self {
unit,
score,
highlights,
}
}
}
#[derive(Debug, Clone)]
pub struct ChunkInfo {
pub text: String,
pub start_char: usize,
pub end_char: usize,
}
impl ChunkInfo {
#[must_use]
pub fn new(text: String, start_char: usize, end_char: usize) -> Self {
Self {
text,
start_char,
end_char,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CodeLocation {
pub file: String,
pub name: String,
pub line: usize,
}
impl CodeLocation {
#[must_use]
pub fn new(file: impl Into<String>, name: impl Into<String>, line: usize) -> Self {
Self {
file: file.into(),
name: name.into(),
line,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ContentHashedIndex {
seen: HashMap<u64, CodeLocation>,
pub duplicates_found: usize,
pub unique_items: usize,
}
impl ContentHashedIndex {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn hash_content(content: &str) -> u64 {
let mut hasher = DefaultHasher::new();
let normalized: String = content
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join("\n");
normalized.hash(&mut hasher);
hasher.finish()
}
#[must_use]
pub fn check_duplicate(&self, content: &str) -> Option<&CodeLocation> {
let hash = Self::hash_content(content);
self.seen.get(&hash)
}
pub fn add(
&mut self,
content: &str,
file: &str,
function_name: &str,
line: usize,
) -> bool {
let hash = Self::hash_content(content);
if self.seen.contains_key(&hash) {
self.duplicates_found += 1;
false
} else {
self.seen.insert(
hash,
CodeLocation::new(file, function_name, line),
);
self.unique_items += 1;
true
}
}
#[must_use]
pub fn stats(&self) -> (usize, usize) {
(self.unique_items, self.duplicates_found)
}
#[must_use]
pub fn len(&self) -> usize {
self.seen.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.seen.is_empty()
}
pub fn clear(&mut self) {
self.seen.clear();
self.duplicates_found = 0;
self.unique_items = 0;
}
#[must_use]
pub fn dedup_ratio(&self) -> f64 {
let total = self.unique_items + self.duplicates_found;
if total == 0 {
0.0
} else {
self.duplicates_found as f64 / total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unit_kind_as_str() {
assert_eq!(UnitKind::Function.as_str(), "function");
assert_eq!(UnitKind::Method.as_str(), "method");
assert_eq!(UnitKind::Class.as_str(), "class");
assert_eq!(UnitKind::Module.as_str(), "module");
assert_eq!(UnitKind::Chunk.as_str(), "chunk");
}
#[test]
fn test_unit_kind_display() {
assert_eq!(format!("{}", UnitKind::Function), "function");
}
#[test]
fn test_code_complexity_describe() {
let simple = CodeComplexity {
depth: 2,
branches: 3,
loops: 1,
};
assert!(simple.describe().is_none());
let complex = CodeComplexity {
depth: 5,
branches: 10,
loops: 4,
};
let desc = complex.describe().unwrap();
assert!(desc.contains("deep nesting"));
assert!(desc.contains("many branches"));
assert!(desc.contains("multiple loops"));
}
#[test]
fn test_embedding_unit_new() {
let unit = EmbeddingUnit::new(
"src/main.py",
"process_data",
UnitKind::Function,
"def process_data(): pass",
10,
"python",
);
assert_eq!(unit.id, "src/main.py::process_data");
assert_eq!(unit.file, "src/main.py");
assert_eq!(unit.name, "process_data");
assert_eq!(unit.kind, UnitKind::Function);
assert_eq!(unit.start_line, 10);
assert_eq!(unit.language, "python");
assert!(!unit.is_chunk());
}
#[test]
fn test_embedding_unit_qualified_name() {
let mut unit = EmbeddingUnit::new(
"src/model.py",
"save",
UnitKind::Method,
"def save(self): pass",
20,
"python",
);
unit.parent = Some("User".to_string());
assert_eq!(unit.qualified_name(), "src/model.py::User.save");
}
#[test]
fn test_embedding_unit_is_chunk() {
let mut unit = EmbeddingUnit::new(
"src/large.py",
"big_function[1/3]",
UnitKind::Chunk,
"# chunk 1",
1,
"python",
);
unit.chunk_index = 0;
unit.chunk_total = 3;
assert!(unit.is_chunk());
}
#[test]
fn test_search_result() {
let unit = EmbeddingUnit::new(
"test.py",
"test_fn",
UnitKind::Function,
"def test_fn(): pass",
1,
"python",
);
let result = SearchResult::new(unit.clone(), 0.95);
assert_eq!(result.score, 0.95);
assert!(result.highlights.is_empty());
let result_with_highlights =
SearchResult::with_highlights(unit, 0.95, vec!["highlighted text".to_string()]);
assert_eq!(result_with_highlights.highlights.len(), 1);
}
#[test]
fn test_semantic_patterns_defined() {
assert!(!SEMANTIC_PATTERNS.is_empty());
let pattern_names: Vec<_> = SEMANTIC_PATTERNS.iter().map(|p| p.name).collect();
assert!(pattern_names.contains(&"crud"));
assert!(pattern_names.contains(&"validation"));
assert!(pattern_names.contains(&"error_handling"));
assert!(pattern_names.contains(&"async_ops"));
}
#[test]
fn test_constants() {
assert!(MAX_EMBEDDING_TOKENS > 0);
assert!(MAX_CODE_PREVIEW_TOKENS < MAX_EMBEDDING_TOKENS);
assert!(CHUNK_OVERLAP_TOKENS < MAX_CODE_PREVIEW_TOKENS);
}
#[test]
fn test_code_location_new() {
let loc = CodeLocation::new("src/main.py", "process", 42);
assert_eq!(loc.file, "src/main.py");
assert_eq!(loc.name, "process");
assert_eq!(loc.line, 42);
}
#[test]
fn test_content_hashed_index_new() {
let index = ContentHashedIndex::new();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
assert_eq!(index.unique_items, 0);
assert_eq!(index.duplicates_found, 0);
}
#[test]
fn test_content_hashed_index_add_unique() {
let mut index = ContentHashedIndex::new();
assert!(index.add("def foo(): pass", "src/a.py", "foo", 10));
assert_eq!(index.unique_items, 1);
assert_eq!(index.duplicates_found, 0);
assert_eq!(index.len(), 1);
assert!(index.add("def bar(): return 1", "src/b.py", "bar", 20));
assert_eq!(index.unique_items, 2);
assert_eq!(index.duplicates_found, 0);
assert_eq!(index.len(), 2);
}
#[test]
fn test_content_hashed_index_detect_duplicate() {
let mut index = ContentHashedIndex::new();
assert!(index.add("def foo(): pass", "src/a.py", "foo", 10));
assert!(!index.add("def foo(): pass", "src/b.py", "foo", 20));
assert_eq!(index.unique_items, 1);
assert_eq!(index.duplicates_found, 1);
}
#[test]
fn test_content_hashed_index_whitespace_normalization() {
let mut index = ContentHashedIndex::new();
let code1 = "def foo():\n return 1";
assert!(index.add(code1, "src/a.py", "foo", 10));
let code2 = " def foo():\n return 1 ";
assert!(!index.add(code2, "src/b.py", "foo", 20));
let code3 = "def foo():\n\n return 1\n\n";
assert!(!index.add(code3, "src/c.py", "foo", 30));
assert_eq!(index.unique_items, 1);
assert_eq!(index.duplicates_found, 2);
}
#[test]
fn test_content_hashed_index_check_duplicate() {
let mut index = ContentHashedIndex::new();
assert!(index.check_duplicate("def foo(): pass").is_none());
index.add("def foo(): pass", "src/a.py", "foo", 10);
let loc = index.check_duplicate("def foo(): pass").unwrap();
assert_eq!(loc.file, "src/a.py");
assert_eq!(loc.name, "foo");
assert_eq!(loc.line, 10);
}
#[test]
fn test_content_hashed_index_stats() {
let mut index = ContentHashedIndex::new();
index.add("code1", "f1.py", "fn1", 1);
index.add("code2", "f2.py", "fn2", 2);
index.add("code1", "f3.py", "fn1", 3); index.add("code3", "f4.py", "fn3", 4);
index.add("code2", "f5.py", "fn2", 5);
let (unique, dups) = index.stats();
assert_eq!(unique, 3);
assert_eq!(dups, 2);
}
#[test]
fn test_content_hashed_index_dedup_ratio() {
let mut index = ContentHashedIndex::new();
assert_eq!(index.dedup_ratio(), 0.0);
index.add("code1", "f1.py", "fn1", 1);
index.add("code2", "f2.py", "fn2", 2);
index.add("code1", "f3.py", "fn1", 3);
index.add("code3", "f4.py", "fn3", 4);
index.add("code2", "f5.py", "fn2", 5);
assert!((index.dedup_ratio() - 0.4).abs() < 0.001);
}
#[test]
fn test_content_hashed_index_clear() {
let mut index = ContentHashedIndex::new();
index.add("code1", "f1.py", "fn1", 1);
index.add("code1", "f2.py", "fn1", 2);
assert!(!index.is_empty());
assert_eq!(index.unique_items, 1);
assert_eq!(index.duplicates_found, 1);
index.clear();
assert!(index.is_empty());
assert_eq!(index.unique_items, 0);
assert_eq!(index.duplicates_found, 0);
assert!(index.add("code1", "f1.py", "fn1", 1));
}
}