Skip to main content

code_search/
lib.rs

1use colored::Colorize;
2use indicatif::{ProgressBar, ProgressFinish, ProgressStyle};
3use lang::{
4    CQuery, CSharpQuery, CppQuery, GoQuery, JavaQuery, JavascriptQuery, PythonQuery, RustQuery,
5    SymbolQuery,
6};
7use regex::Regex;
8use rustyline::{
9    hint::{Hint, Hinter},
10    Completer, Context, Helper, Highlighter, Validator,
11};
12use std::{
13    collections::HashSet,
14    ffi::OsStr,
15    fs::{self, read_dir, File},
16    io::{BufRead, BufReader},
17    path::{Path, PathBuf},
18    rc::Rc,
19};
20use tree_sitter::{Node, Parser, Query, QueryCursor};
21
22mod lang;
23
24#[derive(Completer, Helper, Highlighter, Validator)]
25pub struct CodeHinter {
26    pub hints: HashSet<CommandHint>,
27}
28
29#[derive(Hash, Debug, PartialEq, Eq)]
30pub struct CommandHint {
31    display: String,
32    complete_up_to: usize,
33}
34
35impl Hint for CommandHint {
36    fn completion(&self) -> Option<&str> {
37        if self.complete_up_to > 0 {
38            Some(&self.display[..self.to_owned().complete_up_to])
39        } else {
40            None
41        }
42    }
43
44    fn display(&self) -> &str {
45        &self.display
46    }
47}
48
49impl CommandHint {
50    fn new(text: &str, complete_up_to: &str) -> Self {
51        assert!(text.starts_with(complete_up_to));
52        Self {
53            display: text.into(),
54            complete_up_to: complete_up_to.len(),
55        }
56    }
57
58    fn suffix(&self, strip_chars: usize) -> Self {
59        Self {
60            display: self.display[strip_chars..].to_owned(),
61            complete_up_to: self.complete_up_to.saturating_sub(strip_chars),
62        }
63    }
64}
65
66impl Hinter for CodeHinter {
67    type Hint = CommandHint;
68
69    fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option<CommandHint> {
70        if line.is_empty() || pos < line.len() {
71            return None;
72        }
73
74        self.hints
75            .iter()
76            .filter_map(|hint| {
77                // expect hint after word complete, like redis cli, add condition:
78                // line.ends_with(" ")
79                if hint.display.starts_with(line) {
80                    Some(hint.suffix(pos))
81                } else {
82                    None
83                }
84            })
85            .next()
86    }
87}
88
89pub fn diy_hints() -> HashSet<CommandHint> {
90    let mut set = HashSet::new();
91    set.insert(CommandHint::new("help", "help"));
92    set.insert(CommandHint::new(
93        format!("outline {}", "代码文件路径".bright_black()).as_str(),
94        "outline ",
95    ));
96    set.insert(CommandHint::new(
97        format!("search {}", "path search_key".bright_black()).as_str(),
98        "search ",
99    ));
100    set.insert(CommandHint::new("quit()", "quit()"));
101    set
102}
103
104fn valid_language_file(extention: &str) -> bool {
105    let valid_extensions = vec![
106        "rs",
107        "js",
108        "ts",
109        "java",
110        "py",
111        "go",
112        "c",
113        "cpp",
114        "md",
115        "txt",
116        "html",
117        "css",
118        "cs",
119        "kt",
120        "swift",
121        "php",
122        "rb",
123        "sh",
124        "sql",
125        "vb",
126        "lua",
127        "hs",
128        "scala",
129        "erl",
130        "m",
131        "r",
132        "h",
133        "hpp",
134        "toml",
135        "yaml",
136        "yml",
137        "properties",
138    ];
139    return valid_extensions.contains(&extention);
140}
141/*
142* 递归目录
143*/
144pub fn recursion_dir(root_path: &Path, pathes: &mut Vec<PathBuf>, filter: &str) {
145    if root_path.is_dir() {
146        for entry in read_dir(root_path).expect("Error read Dir") {
147            let dir_entry = entry.expect("Error");
148            let path_buf = dir_entry.path();
149
150            recursion_dir(path_buf.as_path(), pathes, filter);
151        }
152    } else if root_path.is_file() {
153        if root_path.extension().is_some() {
154            let extension = root_path
155                .extension()
156                .unwrap_or(OsStr::new(""))
157                .to_str()
158                .unwrap();
159            if (filter.is_empty() || filter == extension) && valid_language_file(extension) {
160                pathes.push(root_path.to_path_buf());
161            }
162        }
163    }
164}
165
166pub fn get_symbol_query(extention: &str) -> Box<dyn SymbolQuery> {
167    match extention {
168        "rs" => Box::new(RustQuery),
169        "java" => Box::new(JavaQuery),
170        "py" => Box::new(PythonQuery),
171        "c" => Box::new(CQuery),
172        "cs" => Box::new(CSharpQuery),
173        "cpp" => Box::new(CppQuery),
174        "js" => Box::new(JavascriptQuery),
175        "go" => Box::new(GoQuery),
176        _ => Box::new(RustQuery),
177    }
178}
179/**
180* 获取源码中的所有符号
181*
182*
183*/
184pub fn get_all_symbols(
185    code: &String,
186    search_key: &str,
187    symbol_query: Box<dyn SymbolQuery>,
188) -> Vec<(usize, String)> {
189    let mut parser = Parser::new();
190    parser
191        .set_language(&symbol_query.get_lang())
192        .expect("Error load Rust grammer");
193    let tree = parser.parse(code.as_str(), None).unwrap();
194
195    let mut query_cursor = QueryCursor::new();
196    let mut filed_vec = vec![];
197    for sq in symbol_query.get_queries() {
198        let query = Query::new(
199            &symbol_query.get_lang(),
200            sq.replace(":?", search_key).as_str(),
201        )
202        .unwrap();
203        let captures = query_cursor.captures(&query, tree.root_node(), code.as_bytes());
204        for (m, capture_index) in captures {
205            let capture = m.captures[capture_index];
206            let node = capture.node;
207            let text = node.utf8_text(code.as_bytes()).unwrap();
208            filed_vec.push((node.start_position().row + 1, text.to_string()));
209        }
210    }
211    return filed_vec;
212}
213/**
214* 打印大纲
215*/
216pub fn print_outline(code: &String, symbol_query: Box<dyn SymbolQuery>) {
217    let mut parser = Parser::new();
218    parser
219        .set_language(&symbol_query.get_lang())
220        .expect("Error load Rust grammer");
221    let tree = parser.parse(code.as_str(), None).unwrap();
222    let root_node = tree.root_node();
223    recursion_outline(root_node, code, 0, &symbol_query);
224}
225
226pub fn recursion_outline(
227    node: Node,
228    code: &String,
229    indent: usize,
230    symbol_query: &Box<dyn SymbolQuery>,
231) {
232    let mut temp_indent = indent;
233    if symbol_query.is_key_node(&node) {
234        print!("{}", " ".repeat(indent));
235        let output = symbol_query.get_definition(code, &node);
236        println!("{}", output);
237        temp_indent += 2;
238    }
239
240    for child in node.children(&mut node.walk()) {
241        recursion_outline(child, code, temp_indent, symbol_query)
242    }
243}
244
245pub fn find_text_in_file(
246    filename: &str,
247    text: &str,
248    reg: Option<Rc<Regex>>,
249) -> Result<Vec<(usize, String)>, std::io::Error> {
250    let file = File::open(filename)?;
251    let reader = BufReader::new(file);
252    let mut found_lines = Vec::new();
253
254    for (line_number, line) in reader.lines().enumerate() {
255        let line = line.unwrap_or("".to_string());
256        if reg.is_some() {
257            let reg = reg.clone().unwrap();
258            if reg.captures(line.as_str()).is_some() {
259                found_lines.push((line_number + 1, line));
260            }
261        } else if line.contains(text) {
262            found_lines.push((line_number + 1, line));
263        }
264    }
265    Ok(found_lines)
266}
267
268pub fn get_absolute_path(path: &Path) -> String {
269    if path.exists() {
270        let absolute_path = fs::canonicalize(path).unwrap();
271        return absolute_path.to_str().unwrap().to_string();
272    } else {
273        return String::new();
274    }
275}
276
277#[derive(Debug, PartialEq, Eq)]
278pub struct CodeIndex {
279    pub path: String,
280    pub line: usize,
281    pub line_code: String,
282}
283
284/**
285* 构建索引
286*/
287pub fn build_index(project_path: &Path) -> Vec<CodeIndex> {
288    let mut index_list = vec![];
289    let mut pathes = vec![];
290    // 获取项目中的文件
291    recursion_dir(project_path, &mut pathes, "");
292    let files = pathes.len();
293    let pb = ProgressBar::new(files as u64);
294    pb.set_style(
295        ProgressStyle::with_template("{spinner:.green} {pos}/{len} [{elapsed_precise}] {msg}")
296            .unwrap(),
297    );
298    let mut progress = 1;
299    if pathes.len() > 0 {
300        for path in pathes {
301            let path_extension = path.extension().unwrap().to_str().unwrap();
302            let path_str = get_absolute_path(&path);
303
304            let code = fs::read_to_string(Path::new(path_str.as_str())).unwrap_or("".to_string());
305            let result = get_all_symbols(&code, ".*", get_symbol_query(path_extension));
306            if result.len() > 0 {
307                result
308                    .iter()
309                    .map(|item| CodeIndex {
310                        path: path_str.to_string(),
311                        line: item.0,
312                        line_code: item.1.clone(),
313                    })
314                    .for_each(|item| {
315                        pb.set_message(item.path.clone());
316                        pb.set_position(progress);
317                        index_list.push(item);
318                    });
319            }
320            progress += 1;
321        }
322    }
323    pb.with_finish(ProgressFinish::AndClear);
324    return index_list;
325}