use neco_textview::{LineIndex, TextRange};
use regex::Regex;
use std::fmt;
pub struct SearchQuery {
pub pattern: String,
pub is_regex: bool,
pub case_sensitive: bool,
pub whole_word: bool,
}
#[derive(Debug, Clone)]
pub struct SearchMatch {
range: TextRange,
line: u32,
column: u32,
}
impl SearchMatch {
pub fn range(&self) -> &TextRange {
&self.range
}
pub fn line(&self) -> u32 {
self.line
}
pub fn column(&self) -> u32 {
self.column
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SearchError {
InvalidRegex(String),
}
impl fmt::Display for SearchError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidRegex(msg) => write!(f, "invalid regex: {msg}"),
}
}
}
impl std::error::Error for SearchError {}
fn build_regex(query: &SearchQuery) -> Result<Regex, SearchError> {
let mut pat = if query.is_regex {
query.pattern.clone()
} else {
regex::escape(&query.pattern)
};
if query.whole_word {
pat = format!(r"\b(?:{pat})\b");
}
if !query.case_sensitive {
pat = format!("(?i){pat}");
}
Regex::new(&pat).map_err(|e| SearchError::InvalidRegex(e.to_string()))
}
fn match_from_offsets(
text: &str,
line_index: &LineIndex,
start: usize,
end: usize,
) -> Result<SearchMatch, SearchError> {
let range = TextRange::new(start, end).expect("match offsets must satisfy start <= end");
let pos = line_index
.offset_to_position(text, start)
.expect("regex match start must be a valid offset");
Ok(SearchMatch {
range,
line: pos.line(),
column: pos.column(),
})
}
pub fn find_all(
text: &str,
line_index: &LineIndex,
query: &SearchQuery,
) -> Result<Vec<SearchMatch>, SearchError> {
let re = build_regex(query)?;
let mut results = Vec::new();
for m in re.find_iter(text) {
results.push(match_from_offsets(text, line_index, m.start(), m.end())?);
}
Ok(results)
}
pub fn find_next(
text: &str,
line_index: &LineIndex,
query: &SearchQuery,
from_offset: usize,
) -> Result<Option<SearchMatch>, SearchError> {
let re = build_regex(query)?;
if from_offset > text.len() {
return Ok(None);
}
match re.find_at(text, from_offset) {
Some(m) => Ok(Some(match_from_offsets(
text,
line_index,
m.start(),
m.end(),
)?)),
None => Ok(None),
}
}
pub fn find_previous(
text: &str,
line_index: &LineIndex,
query: &SearchQuery,
to_offset: usize,
) -> Result<Option<SearchMatch>, SearchError> {
let all = find_all(text, line_index, query)?;
Ok(all.into_iter().rfind(|m| m.range().end() <= to_offset))
}
pub fn replace_all(
text: &str,
query: &SearchQuery,
replacement: &str,
) -> Result<(String, usize), SearchError> {
let re = build_regex(query)?;
let count = re.find_iter(text).count();
let new_text = re.replace_all(text, replacement).into_owned();
Ok((new_text, count))
}
pub fn replace_all_ranges(
text: &str,
query: &SearchQuery,
replacement: &str,
) -> Result<Vec<(TextRange, String)>, SearchError> {
let re = build_regex(query)?;
let mut replacements = Vec::new();
for caps in re.captures_iter(text) {
let m = caps
.get(0)
.expect("regex captures must include the full match");
let mut expanded = String::new();
caps.expand(replacement, &mut expanded);
let range =
TextRange::new(m.start(), m.end()).expect("match offsets must satisfy start <= end");
replacements.push((range, expanded));
}
Ok(replacements)
}
pub fn replace_next(
text: &str,
line_index: &LineIndex,
query: &SearchQuery,
replacement: &str,
from_offset: usize,
) -> Result<Option<(String, SearchMatch)>, SearchError> {
let re = build_regex(query)?;
if from_offset > text.len() {
return Ok(None);
}
match re.captures_at(text, from_offset) {
Some(caps) => {
let m = caps
.get(0)
.expect("regex captures must include the full match");
let sm = match_from_offsets(text, line_index, m.start(), m.end())?;
let mut expanded = String::new();
caps.expand(replacement, &mut expanded);
let mut new_text = String::with_capacity(text.len());
new_text.push_str(&text[..m.start()]);
new_text.push_str(&expanded);
new_text.push_str(&text[m.end()..]);
Ok(Some((new_text, sm)))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn plain(pattern: &str) -> SearchQuery {
SearchQuery {
pattern: pattern.to_string(),
is_regex: false,
case_sensitive: true,
whole_word: false,
}
}
#[test]
fn build_regex_invalid_returns_error() {
let q = SearchQuery {
pattern: "[invalid".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let err = build_regex(&q).unwrap_err();
assert!(matches!(err, SearchError::InvalidRegex(_)));
}
#[test]
fn build_regex_plain_escapes_special_chars() {
let q = SearchQuery {
pattern: "a.b".to_string(),
is_regex: false,
case_sensitive: true,
whole_word: false,
};
let re = build_regex(&q).expect("should compile");
assert!(re.is_match("a.b"));
assert!(!re.is_match("axb"));
}
#[test]
fn build_regex_case_insensitive() {
let q = SearchQuery {
pattern: "hello".to_string(),
is_regex: false,
case_sensitive: false,
whole_word: false,
};
let re = build_regex(&q).expect("should compile");
assert!(re.is_match("HELLO"));
assert!(re.is_match("Hello"));
}
#[test]
fn build_regex_whole_word() {
let q = SearchQuery {
pattern: "foo".to_string(),
is_regex: false,
case_sensitive: true,
whole_word: true,
};
let re = build_regex(&q).expect("should compile");
assert!(re.is_match("foo"));
assert!(!re.is_match("foobar"));
assert!(!re.is_match("barfoo"));
}
#[test]
fn find_all_basic_match() {
let text = "hello world";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("world")).unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].range().start(), 6);
assert_eq!(matches[0].range().end(), 11);
assert_eq!(matches[0].line(), 0);
assert_eq!(matches[0].column(), 6);
}
#[test]
fn find_all_multiple_matches() {
let text = "abcabc";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("abc")).unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].range().start(), 0);
assert_eq!(matches[1].range().start(), 3);
}
#[test]
fn find_all_no_match() {
let text = "hello";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("xyz")).unwrap();
assert!(matches.is_empty());
}
#[test]
fn find_all_case_insensitive() {
let text = "Hello HELLO hello";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "hello".to_string(),
is_regex: false,
case_sensitive: false,
whole_word: false,
};
let matches = find_all(text, &li, &q).unwrap();
assert_eq!(matches.len(), 3);
}
#[test]
fn find_all_whole_word() {
let text = "foo foobar barfoo foo";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "foo".to_string(),
is_regex: false,
case_sensitive: true,
whole_word: true,
};
let matches = find_all(text, &li, &q).unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].range().start(), 0);
assert_eq!(matches[1].range().start(), 18);
}
#[test]
fn find_all_regex_with_groups() {
let text = "2024-01-15 and 2025-12-31";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: r"\d{4}-\d{2}-\d{2}".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let matches = find_all(text, &li, &q).unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].range().start(), 0);
assert_eq!(matches[0].range().end(), 10);
assert_eq!(matches[1].range().start(), 15);
}
#[test]
fn find_all_multiline() {
let text = "line1\nfoo\nline3\nfoo";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("foo")).unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].line(), 1);
assert_eq!(matches[0].column(), 0);
assert_eq!(matches[1].line(), 3);
assert_eq!(matches[1].column(), 0);
}
#[test]
fn find_all_empty_text() {
let text = "";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("x")).unwrap();
assert!(matches.is_empty());
}
#[test]
fn find_all_empty_pattern() {
let text = "abc";
let li = LineIndex::new(text);
let matches = find_all(text, &li, &plain("")).unwrap();
assert_eq!(matches.len(), 4); }
#[test]
fn find_next_from_zero() {
let text = "abc def abc";
let li = LineIndex::new(text);
let m = find_next(text, &li, &plain("abc"), 0).unwrap().unwrap();
assert_eq!(m.range().start(), 0);
}
#[test]
fn find_next_from_middle() {
let text = "abc def abc";
let li = LineIndex::new(text);
let m = find_next(text, &li, &plain("abc"), 1).unwrap().unwrap();
assert_eq!(m.range().start(), 8);
}
#[test]
fn find_next_past_last_match() {
let text = "abc def abc";
let li = LineIndex::new(text);
let m = find_next(text, &li, &plain("abc"), 9).unwrap();
assert!(m.is_none());
}
#[test]
fn find_next_from_beyond_text() {
let text = "abc";
let li = LineIndex::new(text);
let m = find_next(text, &li, &plain("abc"), 100).unwrap();
assert!(m.is_none());
}
#[test]
fn find_next_whole_word_mid_word_offset() {
let text = "foobar baz bar";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "bar".to_string(),
is_regex: false,
case_sensitive: true,
whole_word: true,
};
let m = find_next(text, &li, &q, 3).unwrap().unwrap();
assert_eq!(m.range().start(), 11);
}
#[test]
fn find_previous_returns_last_match_ending_at_offset() {
let text = "abc abc abc";
let li = LineIndex::new(text);
let m = find_previous(text, &li, &plain("abc"), 7).unwrap().unwrap();
assert_eq!(m.range().start(), 4);
assert_eq!(m.range().end(), 7);
}
#[test]
fn find_previous_returns_none_before_first_match() {
let text = "abc abc";
let li = LineIndex::new(text);
let result = find_previous(text, &li, &plain("abc"), 2).unwrap();
assert!(result.is_none());
}
#[test]
fn replace_all_basic() {
let text = "hello world";
let (new_text, count) = replace_all(text, &plain("world"), "rust").unwrap();
assert_eq!(new_text, "hello rust");
assert_eq!(count, 1);
}
#[test]
fn replace_all_multiple() {
let text = "aaa";
let (new_text, count) = replace_all(text, &plain("a"), "bb").unwrap();
assert_eq!(new_text, "bbbbbb");
assert_eq!(count, 3);
}
#[test]
fn replace_all_no_match() {
let text = "hello";
let (new_text, count) = replace_all(text, &plain("xyz"), "abc").unwrap();
assert_eq!(new_text, "hello");
assert_eq!(count, 0);
}
#[test]
fn replace_all_empty_text() {
let text = "";
let (new_text, count) = replace_all(text, &plain("x"), "y").unwrap();
assert_eq!(new_text, "");
assert_eq!(count, 0);
}
#[test]
fn replace_all_regex_backreference() {
let text = "foo123bar456";
let q = SearchQuery {
pattern: r"(\d+)".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let (new_text, count) = replace_all(text, &q, "[$1]").unwrap();
assert_eq!(new_text, "foo[123]bar[456]");
assert_eq!(count, 2);
}
#[test]
fn replace_all_ranges_expands_regex_backreference() {
let text = "foo123bar456";
let q = SearchQuery {
pattern: r"(\d+)".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let replacements = replace_all_ranges(text, &q, "[$1]").unwrap();
assert_eq!(replacements.len(), 2);
assert_eq!(replacements[0].0.start(), 3);
assert_eq!(replacements[0].0.end(), 6);
assert_eq!(replacements[0].1, "[123]");
assert_eq!(replacements[1].1, "[456]");
}
#[test]
fn replace_next_basic() {
let text = "abc def abc";
let li = LineIndex::new(text);
let (new_text, m) = replace_next(text, &li, &plain("abc"), "XYZ", 0)
.unwrap()
.unwrap();
assert_eq!(new_text, "XYZ def abc");
assert_eq!(m.range().start(), 0);
assert_eq!(m.range().end(), 3);
}
#[test]
fn replace_next_from_offset() {
let text = "abc def abc";
let li = LineIndex::new(text);
let (new_text, m) = replace_next(text, &li, &plain("abc"), "XYZ", 1)
.unwrap()
.unwrap();
assert_eq!(new_text, "abc def XYZ");
assert_eq!(m.range().start(), 8);
}
#[test]
fn replace_next_regex_backreference() {
let text = "foo123bar456";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: r"(\d+)".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let (new_text, m) = replace_next(text, &li, &q, "[$1]", 0).unwrap().unwrap();
assert_eq!(new_text, "foo[123]bar456");
assert_eq!(m.range().start(), 3);
}
#[test]
fn replace_next_no_match() {
let text = "hello";
let li = LineIndex::new(text);
let result = replace_next(text, &li, &plain("xyz"), "abc", 0).unwrap();
assert!(result.is_none());
}
#[test]
fn replace_next_from_beyond_text() {
let text = "abc";
let li = LineIndex::new(text);
let result = replace_next(text, &li, &plain("abc"), "x", 100).unwrap();
assert!(result.is_none());
}
#[test]
fn search_error_display() {
let err = SearchError::InvalidRegex("bad pattern".to_string());
let s = err.to_string();
assert!(s.contains("bad pattern"));
}
#[test]
fn find_all_invalid_regex() {
let text = "hello";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "[".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let err = find_all(text, &li, &q).unwrap_err();
assert!(matches!(err, SearchError::InvalidRegex(_)));
}
#[test]
fn find_next_invalid_regex() {
let text = "hello";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "[".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let err = find_next(text, &li, &q, 0).unwrap_err();
assert!(matches!(err, SearchError::InvalidRegex(_)));
}
#[test]
fn replace_all_invalid_regex() {
let q = SearchQuery {
pattern: "[".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let err = replace_all("x", &q, "y").unwrap_err();
assert!(matches!(err, SearchError::InvalidRegex(_)));
}
#[test]
fn replace_next_invalid_regex() {
let text = "hello";
let li = LineIndex::new(text);
let q = SearchQuery {
pattern: "[".to_string(),
is_regex: true,
case_sensitive: true,
whole_word: false,
};
let err = replace_next(text, &li, &q, "y", 0).unwrap_err();
assert!(matches!(err, SearchError::InvalidRegex(_)));
}
}