ast_grep_napi/
find_files.rs

1use ast_grep_config::RuleCore;
2use ast_grep_core::pinned::{NodeData, PinnedNodeData};
3use ast_grep_core::{AstGrep, NodeMatch};
4use ignore::{WalkBuilder, WalkParallel, WalkState};
5use napi::anyhow::{anyhow, Context, Result as Ret};
6use napi::bindgen_prelude::*;
7use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode};
8use napi::{JsNumber, Task};
9use napi_derive::napi;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU32, Ordering};
12
13use crate::doc::{JsDoc, NapiConfig};
14use crate::napi_lang::{build_files, LangOption, NapiLang};
15use crate::sg_node::{SgNode, SgRoot};
16
17pub struct ParseAsync {
18  pub src: String,
19  pub lang: NapiLang,
20}
21
22impl Task for ParseAsync {
23  type Output = SgRoot;
24  type JsValue = SgRoot;
25
26  fn compute(&mut self) -> Result<Self::Output> {
27    let src = std::mem::take(&mut self.src);
28    let doc = JsDoc::new(src, self.lang);
29    Ok(SgRoot(AstGrep::doc(doc), "anonymous".into()))
30  }
31  fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
32    Ok(output)
33  }
34}
35
36type Entry = std::result::Result<ignore::DirEntry, ignore::Error>;
37
38pub struct IterateFiles<D> {
39  walk: WalkParallel,
40  lang_option: LangOption,
41  tsfn: D,
42  producer: fn(&D, Entry, &LangOption) -> Ret<bool>,
43}
44
45impl<T: 'static + Send + Sync> Task for IterateFiles<T> {
46  type Output = u32;
47  type JsValue = JsNumber;
48
49  fn compute(&mut self) -> Result<Self::Output> {
50    let tsfn = &self.tsfn;
51    let file_count = AtomicU32::new(0);
52    let producer = self.producer;
53    let walker = std::mem::replace(&mut self.walk, WalkBuilder::new(".").build_parallel());
54    walker.run(|| {
55      let file_count = &file_count;
56      let lang_option = &self.lang_option;
57      Box::new(move |entry| match producer(tsfn, entry, lang_option) {
58        Ok(succeed) => {
59          if succeed {
60            // file is sent to JS thread, increment file count
61            file_count.fetch_add(1, Ordering::AcqRel);
62          }
63          WalkState::Continue
64        }
65        Err(_) => WalkState::Skip,
66      })
67    });
68    Ok(file_count.load(Ordering::Acquire))
69  }
70  fn resolve(&mut self, env: Env, output: Self::Output) -> Result<Self::JsValue> {
71    env.create_uint32(output)
72  }
73}
74
75// See https://github.com/ast-grep/ast-grep/issues/206
76// NodeJS has a 1000 file limitation on sync iteration count.
77// https://github.com/nodejs/node/blob/8ba54e50496a6a5c21d93133df60a9f7cb6c46ce/src/node_api.cc#L336
78const THREAD_FUNC_QUEUE_SIZE: usize = 1000;
79
80type ParseFiles = IterateFiles<ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>>;
81
82#[napi(object)]
83pub struct FileOption {
84  pub paths: Vec<String>,
85  pub language_globs: HashMap<String, Vec<String>>,
86}
87
88#[napi]
89pub fn parse_files(
90  paths: Either<Vec<String>, FileOption>,
91  callback: JsFunction,
92) -> Result<AsyncTask<ParseFiles>> {
93  let tsfn: ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled> =
94    callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| Ok(vec![ctx.value]))?;
95  let (paths, globs) = match paths {
96    Either::A(v) => (v, HashMap::new()),
97    Either::B(FileOption {
98      paths,
99      language_globs,
100    }) => (paths, NapiLang::lang_globs(language_globs)),
101  };
102  let walk = build_files(paths, &globs)?;
103  Ok(AsyncTask::new(ParseFiles {
104    walk,
105    tsfn,
106    lang_option: LangOption::infer(&globs),
107    producer: call_sg_root,
108  }))
109}
110
111// returns if the entry is a file and sent to JavaScript queue
112fn call_sg_root(
113  tsfn: &ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>,
114  entry: std::result::Result<ignore::DirEntry, ignore::Error>,
115  lang_option: &LangOption,
116) -> Ret<bool> {
117  let entry = entry?;
118  if !entry
119    .file_type()
120    .context("could not use stdin as file")?
121    .is_file()
122  {
123    return Ok(false);
124  }
125  let (root, path) = get_root(entry, lang_option)?;
126  let sg = SgRoot(root, path);
127  tsfn.call(Ok(sg), ThreadsafeFunctionCallMode::Blocking);
128  Ok(true)
129}
130
131fn get_root(entry: ignore::DirEntry, lang_option: &LangOption) -> Ret<(AstGrep<JsDoc>, String)> {
132  let path = entry.into_path();
133  let file_content = std::fs::read_to_string(&path)?;
134  let lang = lang_option
135    .get_lang(&path)
136    .context(anyhow!("file not recognized"))?;
137  let doc = JsDoc::new(file_content, lang);
138  Ok((AstGrep::doc(doc), path.to_string_lossy().into()))
139}
140
141pub type FindInFiles = IterateFiles<(
142  ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
143  RuleCore<NapiLang>,
144)>;
145
146pub struct PinnedNodes(
147  PinnedNodeData<JsDoc, Vec<NodeMatch<'static, JsDoc>>>,
148  String,
149);
150unsafe impl Send for PinnedNodes {}
151unsafe impl Sync for PinnedNodes {}
152
153#[napi(object)]
154pub struct FindConfig {
155  /// specify the file paths to recursively find files
156  pub paths: Vec<String>,
157  /// a Rule object to find what nodes will match
158  pub matcher: NapiConfig,
159  /// An list of pattern globs to treat of certain files in the specified language.
160  /// eg. ['*.vue', '*.svelte'] for html.findFiles, or ['*.ts'] for tsx.findFiles.
161  /// It is slightly different from https://ast-grep.github.io/reference/sgconfig.html#languageglobs
162  pub language_globs: Option<Vec<String>>,
163}
164
165pub fn find_in_files_impl(
166  lang: NapiLang,
167  config: FindConfig,
168  callback: JsFunction,
169) -> Result<AsyncTask<FindInFiles>> {
170  let tsfn = callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| {
171    from_pinned_data(ctx.value, ctx.env)
172  })?;
173  let FindConfig {
174    paths,
175    matcher,
176    language_globs,
177  } = config;
178  let rule = matcher.parse_with(lang)?;
179  let walk = lang.find_files(paths, language_globs)?;
180  Ok(AsyncTask::new(FindInFiles {
181    walk,
182    tsfn: (tsfn, rule),
183    lang_option: LangOption::Specified(lang),
184    producer: call_sg_node,
185  }))
186}
187
188// TODO: optimize
189fn from_pinned_data(pinned: PinnedNodes, env: napi::Env) -> Result<Vec<Vec<SgNode>>> {
190  let (root, nodes) = pinned.0.into_raw();
191  let sg_root = SgRoot(AstGrep { inner: root }, pinned.1);
192  let reference = SgRoot::into_reference(sg_root, env)?;
193  let mut v = vec![];
194  for mut node in nodes {
195    let root_ref = reference.clone(env)?;
196    let sg_node = SgNode {
197      inner: root_ref.share_with(env, |root| {
198        let r = &root.0.inner;
199        node.visit_nodes(|n| unsafe { r.readopt(n) });
200        Ok(node)
201      })?,
202    };
203    v.push(sg_node);
204  }
205  Ok(vec![v])
206}
207
208fn call_sg_node(
209  (tsfn, rule): &(
210    ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
211    RuleCore<NapiLang>,
212  ),
213  entry: std::result::Result<ignore::DirEntry, ignore::Error>,
214  lang_option: &LangOption,
215) -> Ret<bool> {
216  let entry = entry?;
217  if !entry
218    .file_type()
219    .context("could not use stdin as file")?
220    .is_file()
221  {
222    return Ok(false);
223  }
224  let (root, path) = get_root(entry, lang_option)?;
225  let mut pinned = PinnedNodeData::new(root.inner, |r| r.root().find_all(rule).collect());
226  let hits: &Vec<_> = pinned.get_data();
227  if hits.is_empty() {
228    return Ok(false);
229  }
230  let pinned = PinnedNodes(pinned, path);
231  tsfn.call(Ok(pinned), ThreadsafeFunctionCallMode::Blocking);
232  Ok(true)
233}