1mod types;
2mod utils;
3
4use rayon::prelude::*;
5
6use std::{borrow::Cow, env, error::Error, sync::Arc};
7
8use colored::Colorize;
9pub use types::{Args, Config, FileResult, Pattern};
10pub use utils::{print_each_result, print_results, process_file};
11use walkdir::{DirEntry, WalkDir};
12
13use crate::utils::print_single_result;
14
15pub fn count_lines_with_matches(matches: &[(usize, String)]) -> usize {
16 return matches.len();
17}
18
19pub trait Matcher {
20 fn matches_query(&self, slice: &[u8]) -> bool;
21}
22
23impl Matcher for Pattern {
24 fn matches_query(&self, slice: &[u8]) -> bool {
25 match self {
26 Pattern::Regex(re) => re.is_match(slice),
27 Pattern::Literal { pattern, .. } => pattern.is_match(slice),
28
29 Pattern::MultipleLiteral { pattern, .. } => pattern.is_match(slice),
30 }
31 }
32}
33
34pub fn highlight_match(line: &[u8], pat: &Pattern) -> String {
35 let mut highlighted_string = String::from("");
36 let mut last = 0;
37
38 match pat {
39 Pattern::Literal { pattern, .. } | Pattern::MultipleLiteral { pattern, .. } => {
40 let matches: Vec<(usize, usize)> = pattern
41 .find_iter(line)
42 .map(|m| (m.start(), m.end()))
43 .collect();
44
45 for (start, end) in matches {
46 highlighted_string.push_str(&String::from_utf8_lossy(&line[last..start]));
47
48 highlighted_string.push_str(
49 &String::from_utf8_lossy(&line[start..end])
50 .red()
51 .underline()
52 .bold()
53 .to_string(),
54 );
55
56 last = end;
57 }
58 }
59 Pattern::Regex(re) => {
60 let matches: Vec<(usize, usize)> =
61 re.find_iter(line).map(|x| (x.start(), x.end())).collect();
62
63 for (start, end) in matches {
64 highlighted_string.push_str(&String::from_utf8_lossy(&line[last..start]));
65 highlighted_string.push_str(
66 &String::from_utf8_lossy(&line[start..end])
67 .red()
68 .underline()
69 .bold()
70 .to_string(),
71 );
72 last = end;
73 }
74 }
75 }
76 if last < line.len() {
77 highlighted_string.push_str(&String::from_utf8_lossy(&line[last..]));
78 }
79 highlighted_string
80}
81
82pub fn process_lines<'a>(
83 query: &Pattern,
84 contents: &'a [u8],
85 invert: bool,
86 highlight: bool,
87) -> Vec<(usize, Cow<'a, str>)> {
88 contents
89 .split(|&b| b == b'\n')
90 .enumerate()
91 .filter_map(|(i, line)| {
92 let matched = query.matches_query(line);
93 if matched ^ invert {
94 if highlight {
95 Some((i + 1, Cow::Owned(highlight_match(line, query))))
96 } else {
97 Some((
98 i + 1,
99 Cow::Owned(String::from_utf8_lossy(line).into_owned()),
100 ))
101 }
102 } else {
103 None
104 }
105 })
106 .collect()
107}
108
109pub fn run(config: Config) -> Result<(), Box<dyn Error>> {
110 let current = env::current_dir()?;
111
112 let config = Arc::new(config);
113
114 if config.recursive {
115 let entries: Vec<DirEntry> = WalkDir::new(¤t)
116 .into_iter()
117 .filter_map(|e: Result<DirEntry, walkdir::Error>| e.ok())
118 .filter(|e| e.file_type().is_file())
119 .collect();
120
121 let results: Vec<FileResult> = entries
122 .par_iter()
123 .filter_map(|e| process_file(e.clone(), Arc::clone(&config)).ok())
124 .collect();
125
126 print_results(results, config)?;
127 } else {
128 let entry = match WalkDir::new(&config.file_path)
129 .max_depth(1)
130 .into_iter()
131 .next()
132 {
133 Some(Ok(e)) => e,
134 Some(Err(e)) => return Err(Box::new(e)),
135 None => return Err("Entry was not found in current directory".into()),
136 };
137
138 let config = Arc::clone(&config);
139
140 let result = process_file(entry, Arc::clone(&config))?;
141
142 print_single_result(result, Arc::clone(&config))?;
143 }
144
145 Ok(())
146}
147
148#[cfg(test)]
149mod tests {
150
151 use super::*;
152 use aho_corasick::AhoCorasick;
153
154 #[test]
155 fn literal_match() {
156 use crate::{Matcher, Pattern};
157 use aho_corasick::AhoCorasick;
158
159 let ac = AhoCorasick::new(&["foo"]).unwrap();
160 let pattern = Pattern::Literal {
161 pattern: ac,
162 case_insensitive: false,
163 };
164 assert!(pattern.matches_query("foo".as_bytes()));
165 assert!(!pattern.matches_query("Foo".as_bytes()));
166 }
167
168 #[test]
169 fn multiple_literal_match() {
170 use crate::{Matcher, Pattern};
171 let ac = AhoCorasick::new(&["foo", "bar"]).unwrap();
172 let pattern = Pattern::MultipleLiteral {
173 pattern: ac,
174 case_insensitive: false,
175 };
176 assert!(pattern.matches_query("foo".as_bytes()));
177 assert!(pattern.matches_query("bar".as_bytes()));
178 assert!(!pattern.matches_query("baz".as_bytes()));
179 }
180 #[test]
181 fn highlight_literal() {
182 use crate::{Pattern, highlight_match};
183 use aho_corasick::AhoCorasick;
184 use colored::Colorize;
185
186 let ac = AhoCorasick::new(&["foo"]).unwrap();
187 let pattern = Pattern::Literal {
188 pattern: ac,
189 case_insensitive: false,
190 };
191 let result = highlight_match("foo bar".as_bytes(), &pattern);
192 let expected = "foo".red().underline().bold().to_string() + " bar";
193 assert_eq!(result, expected);
194 }
195
196 #[test]
197 fn process_lines_basic() {
198 use crate::{Pattern, process_lines};
199 use aho_corasick::AhoCorasick;
200 use std::borrow::Cow;
201
202 let ac = AhoCorasick::new(&["foo"]).unwrap();
203 let pattern = Pattern::Literal {
204 pattern: ac,
205 case_insensitive: false,
206 };
207 let text = "foo\nbar\nfoo bar";
208 let result: Vec<(usize, Cow<str>)> = process_lines(&pattern, text.as_bytes(), false, false);
209
210 assert_eq!(result.len(), 2);
211 assert_eq!(result[0].0, 1);
212 assert_eq!(result[1].0, 3);
213 }
214
215 #[test]
216 fn invert_lines() {
217 use crate::{Pattern, process_lines};
218 use aho_corasick::AhoCorasick;
219
220 let ac = AhoCorasick::new(&["foo"]).unwrap();
221 let pattern = Pattern::Literal {
222 pattern: ac,
223 case_insensitive: false,
224 };
225 let text = "foo\nbar\nbaz";
226 let result = process_lines(&pattern, text.as_bytes(), true, false);
227
228 assert_eq!(result.len(), 2);
229 assert_eq!(result[0].1, "bar");
230 assert_eq!(result[1].1, "baz");
231 }
232
233 #[test]
234 fn ignore_case_literal() {
235 use crate::{Matcher, Pattern};
236 use aho_corasick::AhoCorasickBuilder;
237
238 let ac = AhoCorasickBuilder::new()
239 .ascii_case_insensitive(true)
240 .build(&["foo"])
241 .unwrap();
242 let pattern = Pattern::Literal {
243 pattern: ac,
244 case_insensitive: true,
245 };
246 assert!(pattern.matches_query("FOO".as_bytes()));
247 }
248}