1use ast_grep_config::RuleCore;
2use ast_grep_core::pinned::{NodeData, PinnedNodeData};
3use ast_grep_core::{AstGrep, NodeMatch};
4use ignore::{WalkBuilder, WalkParallel, WalkState};
5use napi::anyhow::{anyhow, Context, Result as Ret};
6use napi::bindgen_prelude::*;
7use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode};
8use napi::{JsNumber, Task};
9use napi_derive::napi;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU32, Ordering};
12
13use crate::doc::{JsDoc, NapiConfig};
14use crate::napi_lang::{build_files, LangOption, NapiLang};
15use crate::sg_node::{SgNode, SgRoot};
16
17pub struct ParseAsync {
18 pub src: String,
19 pub lang: NapiLang,
20}
21
22impl Task for ParseAsync {
23 type Output = SgRoot;
24 type JsValue = SgRoot;
25
26 fn compute(&mut self) -> Result<Self::Output> {
27 let src = std::mem::take(&mut self.src);
28 let doc = JsDoc::new(src, self.lang);
29 Ok(SgRoot(AstGrep::doc(doc), "anonymous".into()))
30 }
31 fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
32 Ok(output)
33 }
34}
35
36type Entry = std::result::Result<ignore::DirEntry, ignore::Error>;
37
38pub struct IterateFiles<D> {
39 walk: WalkParallel,
40 lang_option: LangOption,
41 tsfn: D,
42 producer: fn(&D, Entry, &LangOption) -> Ret<bool>,
43}
44
45impl<T: 'static + Send + Sync> Task for IterateFiles<T> {
46 type Output = u32;
47 type JsValue = JsNumber;
48
49 fn compute(&mut self) -> Result<Self::Output> {
50 let tsfn = &self.tsfn;
51 let file_count = AtomicU32::new(0);
52 let producer = self.producer;
53 let walker = std::mem::replace(&mut self.walk, WalkBuilder::new(".").build_parallel());
54 walker.run(|| {
55 let file_count = &file_count;
56 let lang_option = &self.lang_option;
57 Box::new(move |entry| match producer(tsfn, entry, lang_option) {
58 Ok(succeed) => {
59 if succeed {
60 file_count.fetch_add(1, Ordering::AcqRel);
62 }
63 WalkState::Continue
64 }
65 Err(_) => WalkState::Skip,
66 })
67 });
68 Ok(file_count.load(Ordering::Acquire))
69 }
70 fn resolve(&mut self, env: Env, output: Self::Output) -> Result<Self::JsValue> {
71 env.create_uint32(output)
72 }
73}
74
75const THREAD_FUNC_QUEUE_SIZE: usize = 1000;
79
80type ParseFiles = IterateFiles<ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>>;
81
82#[napi(object)]
83pub struct FileOption {
84 pub paths: Vec<String>,
85 pub language_globs: HashMap<String, Vec<String>>,
86}
87
88#[napi]
89pub fn parse_files(
90 paths: Either<Vec<String>, FileOption>,
91 callback: JsFunction,
92) -> Result<AsyncTask<ParseFiles>> {
93 let tsfn: ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled> =
94 callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| Ok(vec![ctx.value]))?;
95 let (paths, globs) = match paths {
96 Either::A(v) => (v, HashMap::new()),
97 Either::B(FileOption {
98 paths,
99 language_globs,
100 }) => (paths, NapiLang::lang_globs(language_globs)),
101 };
102 let walk = build_files(paths, &globs)?;
103 Ok(AsyncTask::new(ParseFiles {
104 walk,
105 tsfn,
106 lang_option: LangOption::infer(&globs),
107 producer: call_sg_root,
108 }))
109}
110
111fn call_sg_root(
113 tsfn: &ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>,
114 entry: std::result::Result<ignore::DirEntry, ignore::Error>,
115 lang_option: &LangOption,
116) -> Ret<bool> {
117 let entry = entry?;
118 if !entry
119 .file_type()
120 .context("could not use stdin as file")?
121 .is_file()
122 {
123 return Ok(false);
124 }
125 let (root, path) = get_root(entry, lang_option)?;
126 let sg = SgRoot(root, path);
127 tsfn.call(Ok(sg), ThreadsafeFunctionCallMode::Blocking);
128 Ok(true)
129}
130
131fn get_root(entry: ignore::DirEntry, lang_option: &LangOption) -> Ret<(AstGrep<JsDoc>, String)> {
132 let path = entry.into_path();
133 let file_content = std::fs::read_to_string(&path)?;
134 let lang = lang_option
135 .get_lang(&path)
136 .context(anyhow!("file not recognized"))?;
137 let doc = JsDoc::new(file_content, lang);
138 Ok((AstGrep::doc(doc), path.to_string_lossy().into()))
139}
140
141pub type FindInFiles = IterateFiles<(
142 ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
143 RuleCore<NapiLang>,
144)>;
145
146pub struct PinnedNodes(
147 PinnedNodeData<JsDoc, Vec<NodeMatch<'static, JsDoc>>>,
148 String,
149);
150unsafe impl Send for PinnedNodes {}
151unsafe impl Sync for PinnedNodes {}
152
153#[napi(object)]
154pub struct FindConfig {
155 pub paths: Vec<String>,
157 pub matcher: NapiConfig,
159 pub language_globs: Option<Vec<String>>,
163}
164
165pub fn find_in_files_impl(
166 lang: NapiLang,
167 config: FindConfig,
168 callback: JsFunction,
169) -> Result<AsyncTask<FindInFiles>> {
170 let tsfn = callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| {
171 from_pinned_data(ctx.value, ctx.env)
172 })?;
173 let FindConfig {
174 paths,
175 matcher,
176 language_globs,
177 } = config;
178 let rule = matcher.parse_with(lang)?;
179 let walk = lang.find_files(paths, language_globs)?;
180 Ok(AsyncTask::new(FindInFiles {
181 walk,
182 tsfn: (tsfn, rule),
183 lang_option: LangOption::Specified(lang),
184 producer: call_sg_node,
185 }))
186}
187
188fn from_pinned_data(pinned: PinnedNodes, env: napi::Env) -> Result<Vec<Vec<SgNode>>> {
190 let (root, nodes) = pinned.0.into_raw();
191 let sg_root = SgRoot(AstGrep { inner: root }, pinned.1);
192 let reference = SgRoot::into_reference(sg_root, env)?;
193 let mut v = vec![];
194 for mut node in nodes {
195 let root_ref = reference.clone(env)?;
196 let sg_node = SgNode {
197 inner: root_ref.share_with(env, |root| {
198 let r = &root.0.inner;
199 node.visit_nodes(|n| unsafe { r.readopt(n) });
200 Ok(node)
201 })?,
202 };
203 v.push(sg_node);
204 }
205 Ok(vec![v])
206}
207
208fn call_sg_node(
209 (tsfn, rule): &(
210 ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
211 RuleCore<NapiLang>,
212 ),
213 entry: std::result::Result<ignore::DirEntry, ignore::Error>,
214 lang_option: &LangOption,
215) -> Ret<bool> {
216 let entry = entry?;
217 if !entry
218 .file_type()
219 .context("could not use stdin as file")?
220 .is_file()
221 {
222 return Ok(false);
223 }
224 let (root, path) = get_root(entry, lang_option)?;
225 let mut pinned = PinnedNodeData::new(root.inner, |r| r.root().find_all(rule).collect());
226 let hits: &Vec<_> = pinned.get_data();
227 if hits.is_empty() {
228 return Ok(false);
229 }
230 let pinned = PinnedNodes(pinned, path);
231 tsfn.call(Ok(pinned), ThreadsafeFunctionCallMode::Blocking);
232 Ok(true)
233}