use std::collections::HashSet;
use rusqlite::{Connection, OptionalExtension};
use crate::query::{memory, symbol};
use crate::search::lexical;
pub const MAX_CONTEXT_CHARS: usize = 1500;
const MAX_SYMBOLS: u32 = 3;
const MAX_MEMORIES: u32 = 4;
const MAX_LEXICAL_HITS: u32 = 3;
const LEXICAL_RELATIVE_FLOOR: f64 = 0.6;
const MAX_MEMORY_BODY_CHARS: usize = 240;
pub fn normalize_pattern(pattern: &str) -> String {
let chars_vec: Vec<char> = pattern.chars().collect();
let n = chars_vec.len();
let mut out = String::with_capacity(n);
let mut i = 0;
while i < n {
let ch = chars_vec[i];
match ch {
'\\' if i + 1 < n => {
let next = chars_vec[i + 1];
if next == '.' {
let prev_word = out
.chars()
.rev()
.find(|c| *c != ' ')
.map(|c| c.is_ascii_alphanumeric() || c == '_')
.unwrap_or(false);
let next_word = chars_vec
.get(i + 2)
.map(|c| c.is_ascii_alphanumeric() || *c == '_')
.unwrap_or(false);
if prev_word && next_word {
out.push('.');
} else {
out.push(' ');
}
i += 2;
} else {
out.push(' ');
i += 2;
}
},
'.' => {
let prev_word = out
.chars()
.rev()
.find(|c| *c != ' ')
.map(|c| c.is_ascii_alphanumeric() || c == '_')
.unwrap_or(false);
let next_word = chars_vec
.get(i + 1)
.map(|c| c.is_ascii_alphanumeric() || *c == '_')
.unwrap_or(false);
if prev_word && next_word {
out.push('.');
} else {
out.push(' ');
}
i += 1;
},
'^' | '$' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' => {
out.push(' ');
i += 1;
},
_ => {
out.push(ch);
i += 1;
},
}
}
out.split_whitespace().collect::<Vec<_>>().join(" ")
}
pub fn identifier_candidate(normalized: &str) -> Option<&str> {
if normalized.len() < 3 || normalized.contains(' ') {
return None;
}
let mut chars = normalized.chars();
let first = chars.next()?;
if !(first.is_ascii_alphabetic() || first == '_') {
return None;
}
chars.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | ':' | '.')).then_some(normalized)
}
const DEFINITION_KEYWORDS: &[&str] = &[
"fn", "pub", "mut", "let", "const", "static", "struct", "enum", "trait", "impl", "type", "mod",
"use", "async", "await", "return", "class", "def", "func", "function", "interface", "export",
"import", "var", "val", "public", "private", "protected", "final", "override", "suspend",
"void", "extern", "unsafe", "where", "dyn",
];
pub fn extract_symbol_identifier(normalized: &str) -> Option<&str> {
if let Some(ident) = identifier_candidate(normalized) {
return Some(ident);
}
let mut candidate: Option<&str> = None;
for token in normalized.split(' ') {
if DEFINITION_KEYWORDS.contains(&token) {
continue;
}
if identifier_candidate(token).is_some() {
if candidate.is_some() {
return None; }
candidate = Some(token);
} else {
return None; }
}
candidate
}
#[derive(Debug, Default, Clone)]
pub struct DedupeFilter {
pub memory_ids: HashSet<String>,
pub symbol_keys: HashSet<String>,
}
#[derive(Debug)]
pub struct GrepAugment {
pub context: String,
pub memory_ids: Vec<String>,
pub symbol_keys: Vec<String>,
}
pub fn compose(
conn: &Connection,
raw_pattern: &str,
search_path: Option<&str>,
dedupe: &DedupeFilter,
) -> anyhow::Result<Option<GrepAugment>> {
let normalized = normalize_pattern(raw_pattern);
if normalized.is_empty() {
return Ok(None);
}
let mut memories = Vec::new();
let mut symbol_items: Vec<SymbolItem> = Vec::new();
let mut symbol_lane_had_hits = false;
let mut seen_memory_ids: HashSet<String> = HashSet::new();
if let Some(ident) = extract_symbol_identifier(&normalized) {
let bare = ident.rsplit([':', '.']).next().unwrap_or(ident);
for hit in symbol::lookup(conn, bare, None, MAX_SYMBOLS)? {
symbol_lane_had_hits = true;
let key = format!("{}:{}", hit.path, hit.qualified_name);
if dedupe.symbol_keys.contains(&key) {
continue;
}
let (callers, callees) = edge_counts(conn, &hit)?;
let start_line = line_for_symbol(conn, &hit)?;
let line_suffix = match start_line {
Some(l) => format!("{}:{}", hit.path, l),
None => hit.path.clone(),
};
let rendered = format!(
"- `{}` ({}) — {} — {} callers / {} callees{}",
hit.qualified_name,
hit.kind,
line_suffix,
callers,
callees,
hit.signature.as_deref().map(|s| format!(" — `{s}`")).unwrap_or_default(),
);
for m in memory::memories_for_symbol(conn, &hit, MAX_MEMORIES)? {
if seen_memory_ids.insert(m.memory_id.clone()) {
memories.push(m);
}
}
symbol_items.push(SymbolItem { rendered, key });
}
}
for m in memory::memory_search(conn, &normalized, MAX_MEMORIES)? {
if seen_memory_ids.insert(m.memory_id.clone()) {
memories.push(m);
}
}
if let Some(path) = search_path {
for m in memory::memories_for_path(conn, path, MAX_MEMORIES)? {
if seen_memory_ids.insert(m.memory_id.clone()) {
memories.push(m);
}
}
}
memories.retain(|m| !dedupe.memory_ids.contains(&m.memory_id));
let lexical_lines = if !symbol_lane_had_hits {
let hits = lexical::search_lexical_only(conn, &normalized, MAX_LEXICAL_HITS, false)?;
let best = hits.iter().map(|hit| hit.score).fold(0.0_f64, f64::max);
let floor = best * LEXICAL_RELATIVE_FLOOR;
hits.into_iter()
.filter(|hit| hit.score >= floor)
.map(|hit| {
format!("- {}:{}-{} — {}", hit.path, hit.start_line, hit.end_line, hit.summary)
})
.collect::<Vec<_>>()
} else {
Vec::new()
};
if memories.is_empty() && symbol_items.is_empty() && lexical_lines.is_empty() {
return Ok(None);
}
Ok(Some(render(memories, symbol_items, lexical_lines)))
}
struct SymbolItem {
rendered: String,
key: String,
}
struct RenderItem {
line: String,
memory_id: Option<String>,
symbol_key: Option<String>,
}
struct Section {
header: String,
items: Vec<RenderItem>,
footer: Option<String>,
}
fn clamp_body(body: &str) -> String {
let collapsed: String = body.split_whitespace().collect::<Vec<_>>().join(" ");
if collapsed.chars().count() <= MAX_MEMORY_BODY_CHARS {
collapsed
} else {
let truncated: String = collapsed.chars().take(MAX_MEMORY_BODY_CHARS).collect();
format!("{truncated}…")
}
}
fn render(
memories: Vec<memory::RepoMemory>,
symbol_items: Vec<SymbolItem>,
lexical_lines: Vec<String>,
) -> GrepAugment {
let mut sections: Vec<Section> = Vec::new();
if !memories.is_empty() {
let items = memories
.into_iter()
.map(|m| RenderItem {
line: format!(
"- [{} | {}] {} — {} (rag-rat: memory_search)",
m.kind,
m.status,
m.title,
clamp_body(&m.body),
),
memory_id: Some(m.memory_id),
symbol_key: None,
})
.collect();
sections.push(Section {
header: "**Repo memories bound to this code:**".to_string(),
items,
footer: None,
});
}
if !symbol_items.is_empty() {
let items = symbol_items
.into_iter()
.map(|s| RenderItem { line: s.rendered, memory_id: None, symbol_key: Some(s.key) })
.collect();
sections.push(Section {
header: "**Known symbols matching this pattern:**".to_string(),
items,
footer: Some("(rag-rat: impact_surface <name> before editing)".to_string()),
});
}
if !lexical_lines.is_empty() {
let items = lexical_lines
.into_iter()
.map(|line| RenderItem { line, memory_id: None, symbol_key: None })
.collect();
sections.push(Section {
header: "**Indexed hits (rag-rat semantic_search has more):**".to_string(),
items,
footer: None,
});
}
let mut context = String::from("rag-rat index context for this search:\n");
let mut memory_ids: Vec<String> = Vec::new();
let mut symbol_keys: Vec<String> = Vec::new();
'section: for section in sections {
let mut section_committed = false;
for item in section.items {
let needed = if section_committed {
item.line.len() + 1
} else {
section.header.len() + 1 + item.line.len() + 1
};
if context.len() + needed > MAX_CONTEXT_CHARS {
break 'section;
}
if !section_committed {
context.push_str(§ion.header);
context.push('\n');
section_committed = true;
}
context.push_str(&item.line);
context.push('\n');
if let Some(mid) = item.memory_id {
memory_ids.push(mid);
}
if let Some(key) = item.symbol_key {
symbol_keys.push(key);
}
}
if section_committed
&& let Some(footer) = section.footer
&& context.len() + footer.len() < MAX_CONTEXT_CHARS
{
context.push_str(&footer);
context.push('\n');
}
}
GrepAugment { context: context.trim_end().to_string(), memory_ids, symbol_keys }
}
fn edge_counts(conn: &Connection, hit: &symbol::SymbolHit) -> anyhow::Result<(i64, i64)> {
let callers: i64 = conn.query_row(
"SELECT COUNT(*) FROM edges WHERE to_symbol_id = ?1 OR target_qualified_name = ?2",
rusqlite::params![hit.symbol_id, hit.qualified_name],
|row| row.get(0),
)?;
let callees: i64 = conn.query_row(
"SELECT COUNT(*) FROM edges WHERE from_symbol_id = ?1",
[hit.symbol_id],
|row| row.get(0),
)?;
Ok((callers, callees))
}
fn line_for_symbol(conn: &Connection, hit: &symbol::SymbolHit) -> anyhow::Result<Option<i64>> {
conn.query_row(
"SELECT start_line FROM chunks
WHERE file_id = ?1 AND start_byte <= ?2 AND end_byte >= ?2
ORDER BY (end_byte - start_byte) ASC LIMIT 1",
rusqlite::params![hit.file_id, hit.start_byte],
|row| row.get(0),
)
.optional()
.map_err(Into::into)
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rusqlite::Connection;
use super::*;
use crate::index::schema;
use crate::query::memory::{self, RepoMemoryBindTarget, RepoMemoryCreate};
fn seeded_conn() -> Connection {
let conn = Connection::open_in_memory().unwrap();
schema::apply(&conn).unwrap();
conn.execute(
"INSERT INTO files(path, language, kind, sha256, modified_at_ms, indexed_at_ms)
VALUES ('src/watch.rs', 'rust', 'source', 'abc', 0, 0)",
[],
)
.unwrap();
conn.execute(
"INSERT INTO symbols(file_id, language, name, qualified_name, kind, start_byte,
end_byte, signature, docs)
VALUES (1, 'rust', 'watcher_main', 'watch::watcher_main', 'function', 0, 100,
'fn watcher_main(config: Config)', NULL)",
[],
)
.unwrap();
conn.execute(
"INSERT INTO chunks(file_id, chunk_kind, symbol_path, start_byte, end_byte,
start_line, end_line, text, text_hash)
VALUES (1, 'symbol', 'watch::watcher_main', 0, 100, 1, 20,
'fn watcher_main() { /* election retry loop */ }', 'h1')",
[],
)
.unwrap();
let chunk_id = conn.last_insert_rowid();
conn.execute(
"INSERT INTO edges(source_file_id, from_symbol_id, to_symbol_id, to_name,
target_qualified_name, edge_kind, confidence)
VALUES (1, NULL, 1, 'watcher_main', 'watch::watcher_main', 'calls_name', 'exact')",
[],
)
.unwrap();
conn.execute(
"INSERT INTO edges(source_file_id, from_symbol_id, to_symbol_id, to_name,
target_qualified_name, edge_kind, confidence)
VALUES (1, 1, NULL, 'maintenance_pass', NULL, 'calls_name', 'name_only')",
[],
)
.unwrap();
memory::create_memory(
&conn,
RepoMemoryCreate {
kind: "Invariant".to_string(),
title: "One watcher per worktree".to_string(),
body: "The election lock guarantees a single watcher; never bind without it."
.to_string(),
confidence: "high".to_string(),
created_by: Some("test".to_string()),
source: None,
tags: vec![],
bind: RepoMemoryBindTarget {
symbol_id: Some(1),
logical_symbol_id: None,
chunk_id: None,
edge_id: None,
path: None,
start_line: None,
end_line: None,
commit_hash: None,
github_owner: None,
github_repo: None,
github_number: None,
start_logical_symbol_id: None,
end_logical_symbol_id: None,
edge_sequence_hash: None,
path_summary: None,
},
},
)
.unwrap();
conn.execute(
"INSERT INTO chunk_fts(rowid, text)
VALUES (?1, 'fn watcher_main() { /* election retry loop */ }')",
[chunk_id],
)
.unwrap();
conn
}
#[test]
fn compose_identifier_pattern_yields_symbol_and_memory() {
let conn = seeded_conn();
let out = compose(&conn, r"watcher_main\b", None, &DedupeFilter::default())
.unwrap()
.expect("payload expected");
assert!(out.context.contains("src/watch.rs"), "symbol location present");
assert!(out.context.contains("One watcher per worktree"), "memory title present");
let memory_pos = out.context.find("One watcher per worktree").unwrap();
let symbol_pos = out.context.find("src/watch.rs").unwrap();
assert!(memory_pos < symbol_pos, "memories render before symbols");
assert_eq!(out.memory_ids.len(), 1);
assert_eq!(out.symbol_keys.len(), 1);
assert!(out.context.len() <= MAX_CONTEXT_CHARS);
}
#[test]
fn compose_respects_dedupe_filter_and_returns_none_when_everything_filtered() {
let conn = seeded_conn();
let first = compose(&conn, "watcher_main", None, &DedupeFilter::default())
.unwrap()
.expect("first payload");
let filter = DedupeFilter {
memory_ids: first.memory_ids.iter().cloned().collect::<HashSet<_>>(),
symbol_keys: first.symbol_keys.iter().cloned().collect::<HashSet<_>>(),
};
assert!(compose(&conn, "watcher_main", None, &filter).unwrap().is_none());
}
#[test]
fn extract_symbol_identifier_handles_definition_patterns() {
assert_eq!(extract_symbol_identifier("watcher_main"), Some("watcher_main"));
assert_eq!(extract_symbol_identifier("fn watcher_main"), Some("watcher_main"));
assert_eq!(extract_symbol_identifier("pub struct SymbolIndex"), Some("SymbolIndex"));
assert_eq!(extract_symbol_identifier("pub async fn resolve_all_edges"), Some("resolve_all_edges"));
assert_eq!(extract_symbol_identifier("election retry loop"), None);
assert_eq!(extract_symbol_identifier("foo == bar"), None);
}
#[test]
fn compose_definition_pattern_routes_to_symbol_lane_not_lexical() {
let conn = seeded_conn();
let out = compose(&conn, r"fn watcher_main", None, &DedupeFilter::default())
.unwrap()
.expect("payload expected");
assert!(out.context.contains("watch::watcher_main"), "symbol lane fired");
assert!(out.context.contains("One watcher per worktree"), "bound memory surfaced");
assert!(
!out.context.contains("Indexed hits"),
"lexical lane must be suppressed when the symbol lane has hits: {}",
out.context
);
assert!(!out.symbol_keys.is_empty());
}
#[test]
fn compose_non_identifier_pattern_uses_lexical_lane() {
let conn = seeded_conn();
let out = compose(&conn, "election retry loop", None, &DedupeFilter::default())
.unwrap()
.expect("lexical payload");
assert!(out.context.contains("src/watch.rs"));
}
#[test]
fn compose_unknown_pattern_yields_none() {
let conn = seeded_conn();
assert!(
compose(&conn, "zzqqyyxx_nothing", None, &DedupeFilter::default()).unwrap().is_none()
);
}
#[test]
fn normalize_strips_regex_metacharacters_and_anchors() {
assert_eq!(normalize_pattern(r"^fn\s+watcher_main\b"), "fn watcher_main");
assert_eq!(
normalize_pattern(r"Watcher::spawn(_with_fleet)?"),
"Watcher::spawn _with_fleet"
);
assert_eq!(normalize_pattern("plain words"), "plain words");
assert_eq!(normalize_pattern(r".*[]()|+?^$\\"), "");
}
#[test]
fn normalize_preserves_dot_between_word_chars() {
assert_eq!(normalize_pattern("foo.bar"), "foo.bar");
assert_eq!(normalize_pattern(r"foo\.bar"), "foo.bar");
assert_eq!(normalize_pattern(".foo"), "foo");
assert_eq!(normalize_pattern("foo."), "foo");
assert_eq!(normalize_pattern("foo. bar"), "foo bar");
}
#[test]
fn identifier_candidate_accepts_identifier_shapes_only() {
assert_eq!(identifier_candidate("watcher_main"), Some("watcher_main"));
assert_eq!(identifier_candidate("Watcher::spawn"), Some("Watcher::spawn"));
assert_eq!(identifier_candidate("foo.bar"), Some("foo.bar"));
assert_eq!(identifier_candidate("fn watcher_main"), None); assert_eq!(identifier_candidate("ab"), None); assert_eq!(identifier_candidate("1abc"), None); assert_eq!(identifier_candidate(""), None);
}
#[test]
fn normalize_and_identifier_candidate_compose_for_dot_qualified() {
let norm = normalize_pattern("foo.bar");
assert_eq!(norm, "foo.bar");
assert_eq!(identifier_candidate(&norm), Some("foo.bar"));
let norm2 = normalize_pattern(r"foo\.bar");
assert_eq!(norm2, "foo.bar");
assert_eq!(identifier_candidate(&norm2), Some("foo.bar"));
}
#[test]
fn render_truncation_respects_cap_no_dangling_headers_ids_match() {
let conn = seeded_conn();
let long_body: String =
(0u32..300).map(|i| format!("word{i:04}")).collect::<Vec<_>>().join(" ");
assert!(long_body.len() > MAX_MEMORY_BODY_CHARS, "body must survive clamp");
assert!(long_body.len() < 4000, "must not exceed validation cap");
let titles = [
"Truncation memory one — extra padding words fill the title field here ok",
"Truncation memory two — extra padding words fill the title field here ok",
"Truncation memory three — extra padding words fill the title field here",
"Truncation memory four — extra padding words fill the title field here ok",
];
let mut created_ids: Vec<String> = Vec::new();
for title in &titles {
let result = memory::create_memory(
&conn,
RepoMemoryCreate {
kind: "Invariant".to_string(),
title: title.to_string(),
body: long_body.clone(),
confidence: "high".to_string(),
created_by: Some("test".to_string()),
source: None,
tags: vec![],
bind: RepoMemoryBindTarget {
symbol_id: Some(1),
logical_symbol_id: None,
chunk_id: None,
edge_id: None,
path: None,
start_line: None,
end_line: None,
commit_hash: None,
github_owner: None,
github_repo: None,
github_number: None,
start_logical_symbol_id: None,
end_logical_symbol_id: None,
edge_sequence_hash: None,
path_summary: None,
},
},
)
.unwrap();
created_ids.push(result.memory.memory_id);
}
assert_eq!(created_ids.len(), 4, "all four memories must be created");
let per_mem_line_min: usize = "- [Invariant | active] ".len() + titles[0].len() + " — ".len() + MAX_MEMORY_BODY_CHARS + 1; let preamble_len = "rag-rat index context for this search:\n".len();
let mem_header_len = "**Repo memories bound to this code:**\n".len();
let symbol_section_min: usize = "**Known symbols matching this pattern:**\n".len() + 80; let candidate_total =
preamble_len + mem_header_len + 4 * per_mem_line_min + symbol_section_min;
assert!(
candidate_total > MAX_CONTEXT_CHARS,
"candidate_total={candidate_total} must exceed MAX_CONTEXT_CHARS={MAX_CONTEXT_CHARS} \
for truncation to trigger",
);
let out = compose(&conn, "watcher_main", None, &DedupeFilter::default())
.unwrap()
.expect("payload expected");
assert!(
out.context.len() <= MAX_CONTEXT_CHARS,
"context.len()={} > MAX_CONTEXT_CHARS={}",
out.context.len(),
MAX_CONTEXT_CHARS,
);
let section_headers = [
"**Repo memories bound to this code:**",
"**Known symbols matching this pattern:**",
"**Indexed hits (rag-rat semantic_search has more):**",
];
let lines: Vec<&str> = out.context.lines().collect();
for (idx, line) in lines.iter().enumerate() {
let is_header = section_headers.iter().any(|h| line.trim() == *h);
if is_header {
assert!(
idx + 1 < lines.len(),
"section header '{line}' is the last line — dangling header",
);
}
}
for (title, id) in titles.iter().zip(created_ids.iter()) {
let in_context = out.context.contains(*title);
let id_present = out.memory_ids.contains(id);
assert_eq!(
in_context, id_present,
"mismatch for '{title}': in_context={in_context}, id_present={id_present}",
);
}
let sym_in_context = out.context.contains("watch::watcher_main");
let sym_keys_non_empty = !out.symbol_keys.is_empty();
assert_eq!(
sym_in_context, sym_keys_non_empty,
"symbol context/key mismatch: sym_in_context={sym_in_context}, \
sym_keys_non_empty={sym_keys_non_empty}",
);
let all_titles_present = titles.iter().all(|t| out.context.contains(*t));
let symbol_present = out.context.contains("watch::watcher_main");
assert!(
!all_titles_present || !symbol_present,
"no truncation detected: all memory titles and the symbol section all fit within \
MAX_CONTEXT_CHARS — increase body/title size so the cap is actually exercised",
);
}
#[test]
fn clamp_body_truncates_long_bodies_and_collapses_whitespace() {
let short = "hello world";
assert_eq!(clamp_body(short), "hello world");
let multiline = "line one\nline two\n indented";
assert_eq!(clamp_body(multiline), "line one line two indented");
let long = "x".repeat(300);
let clamped = clamp_body(&long);
assert!(clamped.ends_with('…'), "truncated body must end with ellipsis");
let without_ellipsis: String = clamped.chars().take(MAX_MEMORY_BODY_CHARS).collect();
assert_eq!(without_ellipsis.len(), MAX_MEMORY_BODY_CHARS);
}
}