rem_extract/extract/
extraction.rs

1use std::{
2    fs,
3    io::{
4        self,
5        ErrorKind
6    },
7    path::PathBuf
8};
9
10use ra_ap_ide_db::EditionedFileId;
11use ra_ap_project_model::{
12    CargoConfig,
13    ProjectWorkspace,
14    ProjectManifest,
15};
16
17use ra_ap_ide::{
18    Analysis, AssistConfig, AssistResolveStrategy, TextSize
19};
20
21use ra_ap_syntax::{
22    algo, ast::HasName, AstNode, SourceFile
23};
24
25use ra_ap_hir::Semantics;
26
27use ra_ap_ide_assists::Assist;
28
29use ra_ap_vfs::AbsPathBuf;
30
31use crate::{
32    error::ExtractionError,
33    extract::extraction_utils::{
34        apply_edits, apply_extract_function, check_braces, check_comment, convert_to_abs_path_buf, filter_extract_function_assist, fixup_controlflow, generate_frange, generate_frange_from_fileid, get_assists, get_cargo_config, get_cargo_toml, get_manifest_dir, load_project_manifest, load_project_workspace, load_workspace_data, rename_function, trim_range
35    },
36};
37
38use rem_interface::metrics as mx;
39
40#[derive(Debug, PartialEq, Clone)]
41pub struct ExtractionInput {
42    pub file_path: String,
43    pub new_fn_name: String,
44    pub start_idx: u32,
45    pub end_idx: u32,
46}
47
48impl ExtractionInput {
49    pub fn new(
50        file_path: &str,
51        new_fn_name: &str,
52        start_idx: u32,
53        end_idx: u32,
54    ) -> Self { ExtractionInput {
55            file_path: file_path.to_string(),
56            new_fn_name: new_fn_name.to_string(),
57            start_idx,
58            end_idx,
59        }
60    }
61
62    #[allow(dead_code)]
63    pub fn new_absolute(
64        file_path: &str,
65        new_fn_name: &str,
66        start_idx: u32,
67        end_idx: u32,
68    ) -> Self { ExtractionInput {
69            file_path: convert_to_abs_path_buf(file_path).unwrap().as_str().to_string(),
70            new_fn_name: new_fn_name.to_string(),
71            start_idx,
72            end_idx,
73        }
74    }
75}
76
77// ========================================
78// Checks for the validity of the input
79// ========================================
80
81// Check if the file exists and is readable
82fn check_file_exists(file_path: &str) -> Result<(), ExtractionError> {
83    if fs::metadata(file_path).is_err() {
84        return Err(ExtractionError::Io(io::Error::new(
85            ErrorKind::NotFound,
86            format!("File not found: {}", file_path),
87        )));
88    }
89    Ok(())
90}
91
92// Check if the idx pair is valid
93fn check_idx(input: &ExtractionInput) -> Result<(), ExtractionError> {
94    if input.start_idx == input.end_idx {
95        return Err(ExtractionError::SameIdx);
96    } else if input.start_idx > input.end_idx {
97        return Err(ExtractionError::InvalidIdxPair);
98    }
99    if input.start_idx == 0 {
100        return Err(ExtractionError::InvalidStartIdx);
101    }
102    if input.end_idx == 0 {
103        return Err(ExtractionError::InvalidEndIdx);
104    }
105    Ok(())
106}
107
108fn verify_input(input: &ExtractionInput) -> Result<(), ExtractionError> {
109    // Execute each input validation step one by one
110    check_file_exists(&input.file_path)?;
111    check_idx(input)?;
112
113    Ok(())
114}
115
116pub fn extract_method_file(input: ExtractionInput) -> Result<(String, String), ExtractionError> {
117    mx::mark("Extraction Start");
118
119    // Extract the struct information
120    let input_path: &str = &input.file_path;
121    let callee_name: &str = &input.new_fn_name;
122    let start_idx: u32 = input.start_idx;
123    let end_idx: u32 = input.end_idx;
124
125    let text: String = fs::read_to_string(&input.file_path).unwrap();
126
127    // Verify the input data
128    verify_input(&input)?;
129
130    mx::mark("Load the analysis");
131
132    let (analysis,file_id) = Analysis::from_single_file(text.clone());
133
134    mx::mark("Analysis Loaded");
135
136    let assist_config: AssistConfig = super::extraction_utils::generate_assist_config();
137    let diagnostics_config = super::extraction_utils::generate_diagnostics_config();
138    let resolve: AssistResolveStrategy = super::extraction_utils::generate_resolve_strategy();
139    let range: (u32, u32) = (start_idx, end_idx);
140
141    let frange = generate_frange_from_fileid(file_id, range);
142
143    mx::mark("Get the assists");
144
145    let assists: Vec<Assist> = analysis.assists_with_fixes(
146        &assist_config,
147        &diagnostics_config,
148        resolve,
149        frange
150    ).unwrap();
151
152    mx::mark("Filter for extract function assist");
153
154    let assist: Assist = filter_extract_function_assist( assists )?;
155
156    mx::mark("Apply extract function assist");
157
158    let src_change = assist.source_change
159        .as_ref()
160        .unwrap()
161        .clone();
162
163    let (text_edit, maybe_snippet_edit) =
164        src_change.get_source_and_snippet_edit(
165            file_id,
166        ).unwrap();
167
168    let edited_text: String = apply_edits(
169        text.clone(),
170        text_edit.clone(),
171        maybe_snippet_edit.clone(),
172    );
173
174    let renamed_text: String = rename_function(
175        edited_text,
176        "fun_name",
177        callee_name,
178    );
179
180    // Ensure that the output file imports std::ops::ControlFlow if it uses it
181    let fixed_cf_text: String = fixup_controlflow( renamed_text );
182
183    mx::mark("Extraction End");
184
185    let parent_method: String = parent_method_from_text(
186        text,
187        &range,
188    );
189
190    Ok( (fixed_cf_text, parent_method) )
191}
192
193// ========================================
194// Performs the method extraction
195// ========================================
196
197/// Function to extract the code segment based on cursor positions
198/// If successful, returns the `String` of the output code, followed by a
199/// `String` of the caller method
200pub fn extract_method(input: ExtractionInput) -> Result<(String, String), ExtractionError> {
201
202    mx::mark("Extraction Start");
203
204    // Extract the struct information
205    let input_path: &str = &input.file_path;
206    let callee_name: &str = &input.new_fn_name;
207    let start_idx: u32 = input.start_idx;
208    let end_idx: u32 = input.end_idx;
209
210    // Convert the input path to an `AbsPathBuf`
211    let input_abs_path: AbsPathBuf = convert_to_abs_path_buf(input_path).unwrap();
212
213    // Verify the input data
214    verify_input(&input)?;
215
216    let manifest_dir: PathBuf = get_manifest_dir(
217        &PathBuf::from(input_abs_path.as_str())
218    )?;
219    let cargo_toml: AbsPathBuf = get_cargo_toml( &manifest_dir );
220    // println!("Cargo.toml {:?}", cargo_toml);
221
222    mx::mark("Load the project workspace");
223
224    let project_manifest: ProjectManifest = load_project_manifest( &cargo_toml );
225    // println!("Project Manifest {:?}", project_manifest);
226
227    // MARKER: Load the cargo config
228    mx::mark("Load the cargo config");
229
230    let cargo_config: CargoConfig = get_cargo_config( &project_manifest );
231    // println!("Cargo Config {:?}", cargo_config);
232
233    // MARKER: Load the project workspace
234    mx::mark("Load the project workspace");
235
236    let workspace: ProjectWorkspace = load_project_workspace( &project_manifest, &cargo_config );
237    // println!("Project Workspace {:?}", workspace);
238
239    // MARKER: Load the analysis database and VFS
240    mx::mark("Load the analysis database and VFS");
241
242    let (db, vfs) = load_workspace_data(workspace, &cargo_config);
243
244    // Parse the cursor positions into the range
245    let range_: (u32, u32) = (
246        start_idx,
247        end_idx,
248    );
249
250    // MARKER: Database Loaded
251    mx::mark("Database Loaded");
252
253    // Before we go too far, lets do few more quick checks now that we have the
254    // analysis
255    // 1. Check if the function to extract is not just a comment
256    // 2. Check if the function to extract has matching braces
257    // 3. Convert the range to a trimmed range.
258    let sema: Semantics<'_, ra_ap_ide::RootDatabase> = Semantics::new( &db );
259    let frange_: ra_ap_hir::FileRangeWrapper<ra_ap_vfs::FileId> = generate_frange( &input_abs_path, &vfs, range_.clone() );
260    let edition: EditionedFileId = EditionedFileId::current_edition( frange_.file_id );
261    let source_file: SourceFile = sema.parse( edition );
262    let range: (u32, u32) = trim_range( &source_file, &range_ );
263    check_comment( &source_file, &range )?;
264    check_braces( &source_file, &range )?;
265
266    // MARKER: Run the analysis
267    mx::mark("Run the analysis");
268
269    // let analysis_host: AnalysisHost = AnalysisHost::with_database( db );
270    // let analysis: Analysis = run_analysis( analysis_host );
271
272    // MARKER: Get the assists and filter for extract function assist
273    mx::mark("Get the assists");
274
275    let assists: Vec<Assist> = get_assists( &db, &vfs, &input_abs_path, range );
276
277    // mx::mark("1");
278    // let assists_2: Vec<Assist> = get_assists(&analysis, &vfs, &input_abs_path, range);
279
280    mx::mark("Filter for extract function assist");
281
282    let assist: Assist = filter_extract_function_assist( assists )?;
283
284    mx::mark("Apply extract function assist");
285
286    let modified_code: String = apply_extract_function(
287        &assist,
288        &input_abs_path,
289        &vfs,
290        &callee_name,
291    )?;
292
293    mx::mark("Get parent method");
294
295    let parent_method: String = parent_method(
296        &source_file,
297        range,
298    )?;
299
300    // MARKER: Extraction End
301    mx::mark("Extraction End");
302
303    Ok( (modified_code, parent_method) )
304}
305
306/// Gets the caller method, based on the input code and the cursor positions
307/// If successful, returns the `String` of the caller method
308/// If unsuccessful, returns an `ExtractionError`
309pub fn parent_method(
310    source_file: &SourceFile,
311    range: (u32, u32),
312) -> Result<String, ExtractionError> {
313    let start: TextSize = TextSize::new(range.0);
314
315    // We want the last function that occurs before the start of the range
316    let node: Option<ra_ap_syntax::ast::Fn> = algo::find_node_at_offset::<ra_ap_syntax::ast::Fn>(
317        source_file.syntax(),
318        start,
319    );
320
321    let fn_name: String = match node {
322        Some(n) => n.name().map_or("".to_string(), |name| name.text().to_string()),
323        None => "".to_string(),
324    };
325
326    if fn_name.is_empty() {
327        return Err(ExtractionError::ParentMethodNotFound);
328    }
329
330    Ok( fn_name.trim().to_string() )
331}
332
333use proc_macro2::Span;
334use syn::spanned::Spanned;
335use syn::visit::Visit;
336
337/// Return the name of the function/method that contains the given [start, end)
338/// byte range in `text`. Returns empty string if none found.
339///
340/// NOTE: Requires `proc-macro2` with the "span-locations" feature enabled.
341pub fn parent_method_from_text(text: String, range: &(u32, u32)) -> String {
342    let Ok(file) = syn::parse_file(&text) else {
343        return String::new();
344    };
345
346    let line_starts = compute_line_starts(&text);
347    let selection = (range.0 as usize, range.1 as usize);
348
349    let mut visitor = FnCollector {
350        text: &text,
351        line_starts: &line_starts,
352        fns: Vec::new(),
353    };
354    visitor.visit_file(&file);
355
356    // Find the *innermost* function that contains the selection.
357    let mut best: Option<(&str, usize, usize)> = None;
358    for (name, start, end) in visitor.fns {
359        if contains((start, end), selection) {
360            match best {
361                None => best = Some((name, start, end)),
362                Some((_, b_start, b_end)) => {
363                    if (end - start) < (b_end - b_start) {
364                        best = Some((name, start, end));
365                    }
366                }
367            }
368        }
369    }
370
371    best.map(|(name, _, _)| name.to_string()).unwrap_or_default()
372}
373
374/// Collect function spans (name, start_byte, end_byte).
375struct FnCollector<'a> {
376    text: &'a str,
377    line_starts: &'a [usize],
378    fns: Vec<(&'a str, usize, usize)>,
379}
380
381impl<'a, 'ast> Visit<'ast> for FnCollector<'a> {
382    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
383        // Free function
384        let name = node.sig.ident.to_string();
385        let (start, end) = span_to_offsets(node.block.span(), self.line_starts, self.text);
386        self.fns.push((self.leak(name), start, end));
387        // Recurse into the function in case there are nested modules, etc.
388        syn::visit::visit_item_fn(self, node);
389    }
390
391    fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
392        for item in &node.items {
393            if let syn::ImplItem::Fn(m) = item {
394                let name = m.sig.ident.to_string();
395                let (start, end) = span_to_offsets(m.block.span(), self.line_starts, self.text);
396                self.fns.push((self.leak(name), start, end));
397            }
398        }
399        syn::visit::visit_item_impl(self, node);
400    }
401
402    fn visit_item_trait(&mut self, node: &'ast syn::ItemTrait) {
403        for item in &node.items {
404            if let syn::TraitItem::Fn(f) = item {
405                if let Some(block) = &f.default {
406                    let name = f.sig.ident.to_string();
407                    let (start, end) = span_to_offsets(block.span(), self.line_starts, self.text);
408                    self.fns.push((self.leak(name), start, end));
409                }
410            }
411        }
412        syn::visit::visit_item_trait(self, node);
413    }
414
415    fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
416        // For inline modules (`mod m { ... }`) the content is present; recurse.
417        if let Some((_brace, items)) = &node.content {
418            for it in items {
419                self.visit_item(it);
420            }
421        }
422        // For `mod m;` (file modules) we can't see into another file from this text.
423    }
424}
425
426impl<'a> FnCollector<'a> {
427    /// Leak a `String` into a `'static` str so we can store &str in self.fns without lifetimes hell.
428    /// This is fine for short-lived analysis in a tool; if you prefer, store `String` instead.
429    fn leak(&self, s: String) -> &'static str {
430        Box::leak(s.into_boxed_str())
431    }
432}
433
434/// Compute the starting byte offset of each line (1-based line numbers).
435fn compute_line_starts(text: &str) -> Vec<usize> {
436    let mut starts = vec![0]; // line 1 starts at 0
437    for (i, b) in text.bytes().enumerate() {
438        if b == b'\n' {
439            starts.push(i + 1);
440        }
441    }
442    starts
443}
444
445/// Convert a Span to byte start/end offsets within `text`.
446///
447/// This relies on proc_macro2's "span-locations" to get (line, column).
448fn span_to_offsets(span: Span, line_starts: &[usize], text: &str) -> (usize, usize) {
449    let start = span.start();
450    let end = span.end();
451
452    // Line numbers are 1-based; columns are (effectively) byte offsets within the line.
453    let start_off = lc_to_offset(start.line, start.column, line_starts, text);
454    let end_off = lc_to_offset(end.line, end.column, line_starts, text);
455
456    (start_off.min(text.len()), end_off.min(text.len()))
457}
458
459fn lc_to_offset(line: usize, column: usize, line_starts: &[usize], text: &str) -> usize {
460    if line == 0 || line > line_starts.len() {
461        return text.len();
462    }
463    let base = line_starts[line - 1];
464    base.saturating_add(column)
465}
466
467fn contains(outer: (usize, usize), inner: (usize, usize)) -> bool {
468    let (o_start, o_end) = outer;
469    let (i_start, i_end) = inner;
470    o_start <= i_start && i_end <= o_end && i_start <= i_end
471}