use chrono::Utc;
use rusqlite::Connection;
use crate::error::Result;
use crate::storage::filter::{parse_filter, SqlBuilder};
use crate::storage::queries::{load_tags, memory_from_row};
use crate::types::{MatchInfo, Memory, MemoryScope, SearchStrategy};
#[derive(Debug)]
pub struct Bm25Result {
pub memory: Memory,
pub score: f32,
pub matched_terms: Vec<String>,
pub highlights: Vec<String>,
}
pub fn bm25_search(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
) -> Result<Vec<Bm25Result>> {
bm25_search_with_options(conn, query, limit, explain, None)
}
pub fn bm25_search_with_options(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
scope: Option<&MemoryScope>,
) -> Result<Vec<Bm25Result>> {
bm25_search_with_filter(conn, query, limit, explain, scope, None)
}
pub fn bm25_search_with_filter(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
scope: Option<&MemoryScope>,
filter: Option<&serde_json::Value>,
) -> Result<Vec<Bm25Result>> {
bm25_search_full(conn, query, limit, explain, scope, filter, false)
}
pub fn bm25_search_full(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
scope: Option<&MemoryScope>,
filter: Option<&serde_json::Value>,
include_transcripts: bool,
) -> Result<Vec<Bm25Result>> {
bm25_search_complete(
conn,
query,
limit,
explain,
scope,
filter,
include_transcripts,
false, None,
None,
None,
)
}
pub fn bm25_search_complete(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
scope: Option<&MemoryScope>,
filter: Option<&serde_json::Value>,
include_transcripts: bool,
include_archived: bool,
workspace: Option<&str>,
workspaces: Option<&[String]>,
tier: Option<&crate::types::MemoryTier>,
) -> Result<Vec<Bm25Result>> {
bm25_search_complete_with_scope_path(
conn,
query,
limit,
explain,
scope,
filter,
include_transcripts,
include_archived,
workspace,
workspaces,
tier,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn bm25_search_complete_with_scope_path(
conn: &Connection,
query: &str,
limit: i64,
explain: bool,
scope: Option<&MemoryScope>,
filter: Option<&serde_json::Value>,
include_transcripts: bool,
include_archived: bool,
workspace: Option<&str>,
workspaces: Option<&[String]>,
tier: Option<&crate::types::MemoryTier>,
scope_path: Option<&str>,
) -> Result<Vec<Bm25Result>> {
let escaped_query = escape_fts5_query(query);
let now = Utc::now().to_rfc3339();
let mut sql = String::from(
r#"
SELECT
m.id, m.content, m.memory_type, m.importance, m.access_count,
m.created_at, m.updated_at, m.last_accessed_at, m.owner_id,
m.visibility, m.version, m.has_embedding, m.metadata,
m.scope_type, m.scope_id, m.expires_at,
bm25(memories_fts) as score
FROM memories_fts fts
JOIN memories m ON fts.rowid = m.id
WHERE memories_fts MATCH ? AND m.valid_to IS NULL
AND (m.expires_at IS NULL OR m.expires_at > ?)
"#,
);
let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(escaped_query), Box::new(now)];
if !include_transcripts {
sql.push_str(" AND m.memory_type != 'transcript_chunk'");
}
if !include_archived {
sql.push_str(" AND (m.lifecycle_state IS NULL OR m.lifecycle_state != 'archived')");
}
if let Some(filter_json) = filter {
let filter_expr = parse_filter(filter_json)?;
let mut builder = SqlBuilder::new();
let filter_sql = builder.build_filter(&filter_expr)?;
sql.push_str(" AND ");
sql.push_str(&filter_sql);
for param in builder.take_params() {
params.push(param);
}
}
if let Some(scope) = scope {
sql.push_str(" AND m.scope_type = ?");
params.push(Box::new(scope.scope_type().to_string()));
if let Some(scope_id) = scope.scope_id() {
sql.push_str(" AND m.scope_id = ?");
params.push(Box::new(scope_id.to_string()));
} else {
sql.push_str(" AND m.scope_id IS NULL");
}
}
if let Some(ws) = workspace {
sql.push_str(" AND m.workspace = ?");
params.push(Box::new(ws.to_string()));
} else if let Some(ws_list) = workspaces {
if !ws_list.is_empty() {
let placeholders: Vec<&str> = ws_list.iter().map(|_| "?").collect();
sql.push_str(&format!(
" AND m.workspace IN ({})",
placeholders.join(", ")
));
for ws in ws_list {
params.push(Box::new(ws.clone()));
}
}
}
if let Some(t) = tier {
sql.push_str(&format!(" AND m.tier = '{}'", t.as_str()));
}
if let Some(sp) = scope_path {
let escaped = sp.replace('%', "\\%").replace('_', "\\_");
sql.push_str(" AND (m.scope_path = ? OR m.scope_path LIKE ? ESCAPE '\\')");
params.push(Box::new(sp.to_string()));
params.push(Box::new(format!("{}/", escaped) + "%"));
}
sql.push_str(" ORDER BY bm25(memories_fts) LIMIT ?");
params.push(Box::new(limit));
let mut stmt = conn.prepare(&sql)?;
let mut results = Vec::new();
let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), |row| {
let memory = memory_from_row(row)?;
let score: f32 = row.get("score")?;
Ok((memory, score))
})?;
for row in rows {
let (mut memory, score) = row?;
memory.tags = load_tags(conn, memory.id)?;
let normalized_score = 1.0 / (1.0 + score.abs());
let matched_terms = if explain {
extract_matched_terms(query, &memory.content)
} else {
vec![]
};
let highlights = if explain {
generate_highlights(query, &memory.content)
} else {
vec![]
};
results.push(Bm25Result {
memory,
score: normalized_score,
matched_terms,
highlights,
});
}
Ok(results)
}
pub fn phrase_search(conn: &Connection, phrase: &str, limit: i64) -> Result<Vec<Bm25Result>> {
let query = format!("\"{}\"", phrase.replace('"', ""));
bm25_search(conn, &query, limit, false)
}
pub fn proximity_search(
conn: &Connection,
terms: &[&str],
max_distance: i32,
limit: i64,
) -> Result<Vec<Bm25Result>> {
if terms.is_empty() {
return Ok(vec![]);
}
let escaped_terms: Vec<String> = terms.iter().map(|t| escape_fts5_term(t)).collect();
let query = format!("NEAR({}, {})", escaped_terms.join(" "), max_distance);
bm25_search(conn, &query, limit, false)
}
pub fn field_search(
conn: &Connection,
field: &str,
query: &str,
limit: i64,
) -> Result<Vec<Bm25Result>> {
let valid_fields = ["content", "tags", "metadata"];
if !valid_fields.contains(&field) {
return bm25_search(conn, query, limit, false);
}
let field_query = format!("{}: {}", field, escape_fts5_query(query));
bm25_search(conn, &field_query, limit, false)
}
fn escape_fts5_query(query: &str) -> String {
let trimmed = query.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() > 2 {
let inner = &trimmed[1..trimmed.len() - 1];
let escaped_inner = inner.replace('"', "\"\"");
return format!("\"{}\"", escaped_inner);
}
trimmed
.split_whitespace()
.filter(|t| !t.is_empty())
.map(escape_fts5_term)
.collect::<Vec<_>>()
.join(" ")
}
fn escape_fts5_term(term: &str) -> String {
if term.is_empty() {
return String::new();
}
let special = [
'"', '*', '(', ')', '{', '}', '[', ']', '^', '~', ':', '+', '-',
];
let needs_quotes = term
.chars()
.any(|c| special.contains(&c) || c.is_whitespace());
let is_operator = matches!(term, "AND" | "OR" | "NOT" | "NEAR");
if needs_quotes || is_operator {
let mut escaped = String::with_capacity(term.len() + 4);
escaped.push('"');
for c in term.chars() {
if c == '"' {
escaped.push_str("\"\""); } else {
escaped.push(c);
}
}
escaped.push('"');
escaped
} else {
term.to_string()
}
}
fn extract_matched_terms(query: &str, content: &str) -> Vec<String> {
let content_lower = content.to_lowercase();
query
.split_whitespace()
.filter(|term| {
let term_lower = term.to_lowercase();
let clean_term =
term_lower.trim_matches(|c| c == '"' || c == '*' || c == '+' || c == '-');
content_lower.contains(clean_term)
})
.map(String::from)
.collect()
}
fn generate_highlights(query: &str, content: &str) -> Vec<String> {
let content_lower = content.to_lowercase();
let terms: Vec<&str> = query
.split_whitespace()
.map(|t| t.trim_matches(|c| c == '"' || c == '*' || c == '+' || c == '-'))
.filter(|t| !t.is_empty())
.collect();
if terms.is_empty() {
return vec![];
}
for term in &terms {
let term_lower = term.to_lowercase();
if let Some(pos) = content_lower.find(&term_lower) {
let start = pos.saturating_sub(30);
let end = (pos + term.len() + 30).min(content.len());
let snippet_start = content[..start].rfind(' ').map(|p| p + 1).unwrap_or(start);
let snippet_end = content[end..].find(' ').map(|p| end + p).unwrap_or(end);
let mut snippet = String::new();
if snippet_start > 0 {
snippet.push_str("...");
}
snippet.push_str(content[snippet_start..snippet_end].trim());
if snippet_end < content.len() {
snippet.push_str("...");
}
return vec![snippet];
}
}
vec![]
}
impl Bm25Result {
pub fn to_match_info(&self) -> MatchInfo {
MatchInfo {
strategy: SearchStrategy::KeywordOnly,
matched_terms: self.matched_terms.clone(),
highlights: self.highlights.clone(),
semantic_score: None,
keyword_score: Some(self.score),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_fts5_term_simple() {
assert_eq!(escape_fts5_term("hello"), "hello");
assert_eq!(escape_fts5_term("world"), "world");
assert_eq!(escape_fts5_term("rust123"), "rust123");
}
#[test]
fn test_escape_fts5_term_with_spaces() {
assert_eq!(escape_fts5_term("hello world"), "\"hello world\"");
assert_eq!(escape_fts5_term(" spaces "), "\" spaces \"");
}
#[test]
fn test_escape_fts5_term_with_quotes() {
assert_eq!(escape_fts5_term("test\"quote"), "\"test\"\"quote\"");
assert_eq!(escape_fts5_term("\"quoted\""), "\"\"\"quoted\"\"\"");
}
#[test]
fn test_escape_fts5_term_special_chars() {
assert_eq!(escape_fts5_term("test*"), "\"test*\"");
assert_eq!(escape_fts5_term("(group)"), "\"(group)\"");
assert_eq!(escape_fts5_term("content:term"), "\"content:term\"");
assert_eq!(escape_fts5_term("term^2"), "\"term^2\"");
assert_eq!(escape_fts5_term("+required"), "\"+required\"");
assert_eq!(escape_fts5_term("-excluded"), "\"-excluded\"");
}
#[test]
fn test_escape_fts5_term_operators() {
assert_eq!(escape_fts5_term("AND"), "\"AND\"");
assert_eq!(escape_fts5_term("OR"), "\"OR\"");
assert_eq!(escape_fts5_term("NOT"), "\"NOT\"");
assert_eq!(escape_fts5_term("NEAR"), "\"NEAR\"");
assert_eq!(escape_fts5_term("and"), "and");
assert_eq!(escape_fts5_term("or"), "or");
}
#[test]
fn test_escape_fts5_term_empty() {
assert_eq!(escape_fts5_term(""), "");
}
#[test]
fn test_escape_fts5_query_simple() {
assert_eq!(escape_fts5_query("hello world"), "hello world");
assert_eq!(escape_fts5_query("single"), "single");
}
#[test]
fn test_escape_fts5_query_quoted_phrase() {
assert_eq!(escape_fts5_query("\"exact phrase\""), "\"exact phrase\"");
assert_eq!(
escape_fts5_query("\"phrase with \"quotes\"\""),
"\"phrase with \"\"quotes\"\"\""
);
}
#[test]
fn test_escape_fts5_query_whitespace() {
assert_eq!(escape_fts5_query(""), "");
assert_eq!(escape_fts5_query(" "), "");
assert_eq!(escape_fts5_query(" hello world "), "hello world");
}
#[test]
fn test_escape_fts5_query_injection_attempts() {
assert_eq!(
escape_fts5_query("hello OR (drop table)"),
"hello \"OR\" \"(drop\" \"table)\""
);
assert_eq!(
escape_fts5_query("content:malicious"),
"\"content:malicious\""
);
assert_eq!(
escape_fts5_query("NEAR(term1 term2, 5)"),
"\"NEAR(term1\" term2, \"5)\""
);
}
#[test]
fn test_escape_fts5_query_real_world() {
assert_eq!(escape_fts5_query("user@example.com"), "user@example.com");
assert_eq!(escape_fts5_query("file.rs"), "file.rs");
assert_eq!(escape_fts5_query("C++"), "\"C++\"");
assert_eq!(escape_fts5_query("node.js"), "node.js");
assert_eq!(escape_fts5_query("@username"), "@username");
assert_eq!(
escape_fts5_query("https://example.com"),
"\"https://example.com\""
);
}
#[test]
fn test_extract_matched_terms() {
let terms = extract_matched_terms("hello world", "Hello there, World!");
assert!(terms.contains(&"hello".to_string()));
assert!(terms.contains(&"world".to_string()));
}
#[test]
fn test_extract_matched_terms_partial() {
let terms = extract_matched_terms("rust programming", "Rust is a programming language");
assert!(terms.contains(&"rust".to_string()));
assert!(terms.contains(&"programming".to_string()));
}
#[test]
fn test_extract_matched_terms_no_match() {
let terms = extract_matched_terms("xyz abc", "Hello world");
assert!(terms.is_empty());
}
#[test]
fn test_generate_highlights() {
let highlights = generate_highlights("test", "This is a test string for testing");
assert!(!highlights.is_empty());
assert!(highlights[0].contains("test"));
}
#[test]
fn test_generate_highlights_no_match() {
let highlights = generate_highlights("xyz", "Hello world");
assert!(highlights.is_empty());
}
}