foundry_compilers/
flatten.rs

1use crate::{
2    compilers::{Compiler, ParsedSource},
3    filter::MaybeSolData,
4    resolver::parse::SolData,
5    ArtifactOutput, CompilerSettings, Graph, Project, ProjectPathsConfig,
6};
7use foundry_compilers_artifacts::{
8    ast::{visitor::Visitor, *},
9    output_selection::OutputSelection,
10    solc::ExternalInlineAssemblyReference,
11    sources::{Source, Sources},
12    ContractDefinitionPart, SourceUnit, SourceUnitPart,
13};
14use foundry_compilers_core::{
15    error::{Result, SolcError},
16    utils,
17};
18use itertools::Itertools;
19use std::{
20    collections::{HashMap, HashSet},
21    hash::Hash,
22    path::{Path, PathBuf},
23};
24use visitor::Walk;
25
26/// Alternative of `SourceLocation` which includes path of the file.
27#[derive(Clone, Debug, PartialEq, Eq, Hash)]
28struct ItemLocation {
29    path: PathBuf,
30    start: usize,
31    end: usize,
32}
33
34impl ItemLocation {
35    fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option<Self> {
36        let start = src.start?;
37        let end = start + src.length?;
38
39        Some(Self { path, start, end })
40    }
41
42    fn length(&self) -> usize {
43        self.end - self.start
44    }
45}
46
47/// Visitor exploring AST and collecting all references to declarations via `Identifier` and
48/// `IdentifierPath` nodes.
49///
50/// It also collects `MemberAccess` parts. So, if we have `X.Y` expression, loc and AST ID will be
51/// saved for Y only.
52///
53/// That way, even if we have a long `MemberAccess` expression (a.b.c.d) then the first member (a)
54/// will be collected as either `Identifier` or `IdentifierPath`, and all subsequent parts (b, c, d)
55/// will be collected as `MemberAccess` parts.
56struct ReferencesCollector {
57    path: PathBuf,
58    references: HashMap<isize, HashSet<ItemLocation>>,
59}
60
61impl ReferencesCollector {
62    fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) {
63        if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) {
64            self.references.entry(id).or_default().insert(loc);
65        }
66    }
67}
68
69impl Visitor for ReferencesCollector {
70    fn visit_identifier(&mut self, identifier: &Identifier) {
71        if let Some(id) = identifier.referenced_declaration {
72            self.process_referenced_declaration(id, &identifier.src);
73        }
74    }
75
76    fn visit_identifier_path(&mut self, path: &IdentifierPath) {
77        self.process_referenced_declaration(path.referenced_declaration, &path.src);
78    }
79
80    fn visit_member_access(&mut self, access: &MemberAccess) {
81        if let Some(referenced_declaration) = access.referenced_declaration {
82            if let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length) {
83                let name_length = access.member_name.len();
84                // Accessed member name is in the last name.len() symbols of the expression.
85                let start = src_start + src_length - name_length;
86                let end = start + name_length;
87
88                self.references.entry(referenced_declaration).or_default().insert(ItemLocation {
89                    start,
90                    end,
91                    path: self.path.to_path_buf(),
92                });
93            }
94        }
95    }
96
97    fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) {
98        let mut src = reference.src.clone();
99
100        // If suffix is used in assembly reference (e.g. value.slot), it will be included into src.
101        // However, we are only interested in the referenced name, thus we strip .<suffix> part.
102        if let Some(suffix) = &reference.suffix {
103            if let Some(len) = src.length.as_mut() {
104                let suffix_len = suffix.to_string().len();
105                *len -= suffix_len + 1;
106            }
107        }
108
109        self.process_referenced_declaration(reference.declaration as isize, &src);
110    }
111}
112
113/// Updates to be applied to the sources.
114/// source_path -> (start, end, new_value)
115type Updates = HashMap<PathBuf, HashSet<(usize, usize, String)>>;
116
117pub struct FlatteningResult<'a> {
118    /// Updated source in the order they should be written to the output file.
119    sources: Vec<String>,
120    /// Pragmas that should be present in the target file.
121    pragmas: Vec<String>,
122    /// License identifier that should be present in the target file.
123    license: Option<&'a str>,
124}
125
126impl<'a> FlatteningResult<'a> {
127    fn new(
128        flattener: &Flattener,
129        mut updates: Updates,
130        pragmas: Vec<String>,
131        license: Option<&'a str>,
132    ) -> Self {
133        let mut sources = Vec::new();
134
135        for path in &flattener.ordered_sources {
136            let mut content = flattener.sources.get(path).unwrap().content.as_bytes().to_vec();
137            let mut offset: isize = 0;
138            if let Some(updates) = updates.remove(path) {
139                let mut updates = updates.iter().collect::<Vec<_>>();
140                updates.sort_by_key(|(start, _, _)| *start);
141                for (start, end, new_value) in updates {
142                    let start = (*start as isize + offset) as usize;
143                    let end = (*end as isize + offset) as usize;
144
145                    content.splice(start..end, new_value.bytes());
146                    offset += new_value.len() as isize - (end - start) as isize;
147                }
148            }
149            let content = format!(
150                "// {}\n{}",
151                path.strip_prefix(&flattener.project_root).unwrap_or(path).display(),
152                String::from_utf8(content).unwrap()
153            );
154            sources.push(content);
155        }
156
157        Self { sources, pragmas, license }
158    }
159
160    fn get_flattened_target(&self) -> String {
161        let mut result = String::new();
162
163        if let Some(license) = &self.license {
164            result.push_str(&format!("// {license}\n"));
165        }
166        for pragma in &self.pragmas {
167            result.push_str(&format!("{pragma}\n"));
168        }
169        for source in &self.sources {
170            result.push_str(&format!("\n\n{source}"));
171        }
172
173        format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())
174    }
175}
176
177#[derive(Debug, thiserror::Error)]
178pub enum FlattenerError {
179    #[error("Failed to compile {0}")]
180    Compilation(SolcError),
181    #[error(transparent)]
182    Other(SolcError),
183}
184
185impl<T: Into<SolcError>> From<T> for FlattenerError {
186    fn from(err: T) -> Self {
187        Self::Other(err.into())
188    }
189}
190
191/// Context for flattening. Stores all sources and ASTs that are in scope of the flattening target.
192pub struct Flattener {
193    /// Target file to flatten.
194    target: PathBuf,
195    /// Sources including only target and it dependencies (imports of any depth).
196    sources: Sources,
197    /// Vec of (path, ast) pairs.
198    asts: Vec<(PathBuf, SourceUnit)>,
199    /// Sources in the order they should be written to the output file.
200    ordered_sources: Vec<PathBuf>,
201    /// Project root directory.
202    project_root: PathBuf,
203}
204
205impl Flattener {
206    /// Compiles the target file and prepares AST and analysis data for flattening.
207    pub fn new<C: Compiler, T: ArtifactOutput<CompilerContract = C::CompilerContract>>(
208        mut project: Project<C, T>,
209        target: &Path,
210    ) -> std::result::Result<Self, FlattenerError>
211    where
212        C::ParsedSource: MaybeSolData,
213    {
214        // Configure project to compile the target file and only request AST for target file.
215        project.cached = false;
216        project.no_artifacts = true;
217        project.settings.update_output_selection(|selection| {
218            *selection = OutputSelection::ast_output_selection();
219        });
220
221        let output = project.compile_file(target).map_err(FlattenerError::Compilation)?;
222
223        if output.has_compiler_errors() {
224            return Err(FlattenerError::Compilation(SolcError::msg(&output)));
225        }
226
227        let output = output.compiler_output;
228
229        let sources = Source::read_all_files(vec![target.to_path_buf()])?;
230        let graph = Graph::<C::ParsedSource>::resolve_sources(&project.paths, sources)?;
231
232        let ordered_sources = collect_ordered_deps(&target.to_path_buf(), &project.paths, &graph)?;
233
234        #[cfg(windows)]
235        let ordered_sources = {
236            let mut sources = ordered_sources;
237            use path_slash::PathBufExt;
238            for p in &mut sources {
239                *p = PathBuf::from(p.to_slash_lossy().to_string());
240            }
241            sources
242        };
243
244        let sources = Source::read_all(&ordered_sources)?;
245
246        // Convert all ASTs from artifacts to strongly typed ASTs
247        let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new();
248        for (path, ast) in output.sources.0.iter().filter_map(|(path, files)| {
249            if let Some(ast) = files.first().and_then(|source| source.source_file.ast.as_ref()) {
250                if sources.contains_key(path) {
251                    return Some((path, ast));
252                }
253            }
254            None
255        }) {
256            asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?));
257        }
258
259        Ok(Self {
260            target: target.into(),
261            sources,
262            asts,
263            ordered_sources,
264            project_root: project.root().clone(),
265        })
266    }
267
268    /// Flattens target file and returns the result as a string
269    ///
270    /// Flattening process includes following steps:
271    /// 1. Find all file-level definitions and rename references to them via aliased or qualified
272    ///    imports.
273    /// 2. Find all duplicates among file-level definitions and rename them to avoid conflicts.
274    /// 3. Remove all imports.
275    /// 4. Remove all pragmas except for the ones in the target file.
276    /// 5. Remove all license identifiers except for the one in the target file.
277    pub fn flatten(&self) -> String {
278        let mut updates = Updates::new();
279
280        let top_level_names = self.rename_top_level_definitions(&mut updates);
281        self.rename_contract_level_types_references(&top_level_names, &mut updates);
282        self.remove_qualified_imports(&mut updates);
283        self.update_inheritdocs(&top_level_names, &mut updates);
284
285        self.remove_imports(&mut updates);
286        let target_pragmas = self.process_pragmas(&mut updates);
287        let target_license = self.process_licenses(&mut updates);
288
289        self.flatten_result(updates, target_pragmas, target_license).get_flattened_target()
290    }
291
292    fn flatten_result<'a>(
293        &'a self,
294        updates: Updates,
295        target_pragmas: Vec<String>,
296        target_license: Option<&'a str>,
297    ) -> FlatteningResult<'a> {
298        FlatteningResult::new(self, updates, target_pragmas, target_license)
299    }
300
301    /// Finds and goes over all references to file-level definitions and updates them to match
302    /// definition name. This is needed for two reasons:
303    /// 1. We want to rename all aliased or qualified imports.
304    /// 2. We want to find any duplicates and rename them to avoid conflicts.
305    ///
306    /// If we find more than 1 declaration with the same name, it's name is getting changed.
307    /// Two Counter contracts will be renamed to Counter_0 and Counter_1
308    ///
309    /// Returns mapping from top-level declaration id to its name (possibly updated)
310    fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap<usize, String> {
311        let top_level_definitions = self.collect_top_level_definitions();
312        let references = self.collect_references();
313
314        let mut top_level_names = HashMap::new();
315
316        for (name, ids) in top_level_definitions {
317            let mut definition_name = name.to_string();
318            let needs_rename = ids.len() > 1;
319
320            let mut ids = ids.clone().into_iter().collect::<Vec<_>>();
321            if needs_rename {
322                // `loc.path` is expected to be different for each id because there can't be 2
323                // top-level declarations with the same name in the same file.
324                //
325                // Sorting by index loc.path in sorted files to make the renaming process
326                // deterministic.
327                ids.sort_by_key(|(_, loc)| {
328                    self.ordered_sources.iter().position(|p| p == &loc.path).unwrap()
329                });
330            }
331            for (i, (id, loc)) in ids.iter().enumerate() {
332                if needs_rename {
333                    definition_name = format!("{name}_{i}");
334                }
335                updates.entry(loc.path.clone()).or_default().insert((
336                    loc.start,
337                    loc.end,
338                    definition_name.clone(),
339                ));
340                if let Some(references) = references.get(&(*id as isize)) {
341                    for loc in references {
342                        updates.entry(loc.path.clone()).or_default().insert((
343                            loc.start,
344                            loc.end,
345                            definition_name.clone(),
346                        ));
347                    }
348                }
349
350                top_level_names.insert(*id, definition_name.clone());
351            }
352        }
353        top_level_names
354    }
355
356    /// This is not very clean, but in most cases effective enough method to remove qualified
357    /// imports from sources.
358    ///
359    /// Every qualified import part is an `Identifier` with `referencedDeclaration` field matching
360    /// ID of one of the import directives.
361    ///
362    /// This approach works by firstly collecting all IDs of import directives, and then looks for
363    /// any references of them. Once the reference is found, it's full length is getting removed
364    /// from source + 1 character ('.')
365    ///
366    /// This should work correctly for vast majority of cases, however there are situations for
367    /// which such approach won't work, most of which are related to code being formatted in an
368    /// uncommon way.
369    fn remove_qualified_imports(&self, updates: &mut Updates) {
370        let imports_ids = self
371            .asts
372            .iter()
373            .flat_map(|(_, ast)| {
374                ast.nodes.iter().filter_map(|node| match node {
375                    SourceUnitPart::ImportDirective(directive) => Some(directive.id),
376                    _ => None,
377                })
378            })
379            .collect::<HashSet<_>>();
380
381        let references = self.collect_references();
382
383        for (id, locs) in references {
384            if !imports_ids.contains(&(id as usize)) {
385                continue;
386            }
387
388            for loc in locs {
389                updates.entry(loc.path).or_default().insert((
390                    loc.start,
391                    loc.end + 1,
392                    String::new(),
393                ));
394            }
395        }
396    }
397
398    /// Here we are going through all references to items defined in scope of contracts and updating
399    /// them to be using correct parent contract name.
400    ///
401    /// This will only operate on references from `IdentifierPath` nodes.
402    fn rename_contract_level_types_references(
403        &self,
404        top_level_names: &HashMap<usize, String>,
405        updates: &mut Updates,
406    ) {
407        let contract_level_definitions = self.collect_contract_level_definitions();
408
409        for (path, ast) in &self.asts {
410            for node in &ast.nodes {
411                let mut collector =
412                    ReferencesCollector { path: self.target.clone(), references: HashMap::new() };
413
414                node.walk(&mut collector);
415
416                let references = collector.references;
417
418                for (id, locs) in references {
419                    if let Some((name, contract_id)) =
420                        contract_level_definitions.get(&(id as usize))
421                    {
422                        for loc in &locs {
423                            // If child item is referenced directly by it's name it's either defined
424                            // in the same contract or in one of it's base contracts, so we don't
425                            // have to change anything.
426                            // Comparing lengths is enough because such items cannot be aliased.
427                            if loc.length() == name.len() {
428                                continue;
429                            }
430                            // If it was referenced somehow else, we rename it to `Parent.Child`
431                            // format.
432                            let parent_name = top_level_names.get(contract_id).unwrap();
433                            updates.entry(path.clone()).or_default().insert((
434                                loc.start,
435                                loc.end,
436                                format!("{parent_name}.{name}"),
437                            ));
438                        }
439                    }
440                }
441            }
442        }
443    }
444
445    /// Finds all @inheritdoc tags in natspec comments and tries replacing them.
446    ///
447    /// We will either replace contract name or remove @inheritdoc tag completely to avoid
448    /// generating invalid source code.
449    fn update_inheritdocs(&self, top_level_names: &HashMap<usize, String>, updates: &mut Updates) {
450        trace!("updating @inheritdoc tags");
451        for (path, ast) in &self.asts {
452            // Collect all exported symbols for this source unit
453            // @inheritdoc value is either one of those or qualified import path which we don't
454            // support
455            let exported_symbols = ast
456                .exported_symbols
457                .iter()
458                .filter_map(
459                    |(name, ids)| {
460                        if !ids.is_empty() {
461                            Some((name.as_str(), ids[0]))
462                        } else {
463                            None
464                        }
465                    },
466                )
467                .collect::<HashMap<_, _>>();
468
469            // Collect all docs in all contracts
470            let docs = ast
471                .nodes
472                .iter()
473                .filter_map(|node| match node {
474                    SourceUnitPart::ContractDefinition(d) => Some(d),
475                    _ => None,
476                })
477                .flat_map(|contract| {
478                    contract.nodes.iter().filter_map(|node| match node {
479                        ContractDefinitionPart::EventDefinition(event) => {
480                            event.documentation.as_ref()
481                        }
482                        ContractDefinitionPart::ErrorDefinition(error) => {
483                            error.documentation.as_ref()
484                        }
485                        ContractDefinitionPart::FunctionDefinition(func) => {
486                            func.documentation.as_ref()
487                        }
488                        ContractDefinitionPart::VariableDeclaration(var) => {
489                            var.documentation.as_ref()
490                        }
491                        _ => None,
492                    })
493                });
494
495            docs.for_each(|doc| {
496                let Documentation::Structured(doc) = doc else {
497                    return
498                };
499                let src_start = doc.src.start.unwrap();
500                let src_end = src_start + doc.src.length.unwrap();
501
502                // Documentation node has `text` field, however, it does not contain
503                // slashes and we can't use if to find positions.
504                let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end];
505                let tag_len = "@inheritdoc".len();
506
507                if let Some(tag_start) = content.find("@inheritdoc") {
508                    trace!("processing doc with content {:?}", content);
509                    if let Some(name_start) = content[tag_start + tag_len..]
510                        .find(|c| c != ' ')
511                        .map(|p| p + tag_start + tag_len)
512                    {
513                        let name_end = content[name_start..]
514                            .find([' ', '\n', '*', '/'])
515                            .map(|p| p + name_start)
516                            .unwrap_or(content.len());
517
518                        let name = &content[name_start..name_end];
519                        trace!("found name {name}");
520
521                        let mut new_name = None;
522
523                        if let Some(ast_id) = exported_symbols.get(name) {
524                            if let Some(name) = top_level_names.get(ast_id) {
525                                new_name = Some(name);
526                            } else {
527                                trace!(identifiers=?top_level_names, "ast id {ast_id} cannot be matched to top-level identifier");
528                            }
529                        }
530
531                        if let Some(new_name) = new_name {
532                            trace!("updating tag value with {new_name}");
533                            updates.entry(path.to_path_buf()).or_default().insert((
534                                src_start + name_start,
535                                src_start + name_end,
536                                new_name.to_string(),
537                            ));
538                        } else {
539                            trace!("name is unknown, removing @inheritdoc tag");
540                            updates.entry(path.to_path_buf()).or_default().insert((
541                                src_start + tag_start,
542                                src_start + name_end,
543                                String::new(),
544                            ));
545                        }
546                    }
547                }
548            });
549        }
550    }
551
552    /// Processes all ASTs and collects all top-level definitions in the form of
553    /// a mapping from name to (definition id, source location)
554    fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> {
555        self.asts
556            .iter()
557            .flat_map(|(path, ast)| {
558                ast.nodes
559                    .iter()
560                    .filter_map(|node| match node {
561                        SourceUnitPart::ContractDefinition(contract) => Some((
562                            &contract.name,
563                            contract.id,
564                            &contract.src,
565                            &contract.name_location,
566                        )),
567                        SourceUnitPart::EnumDefinition(enum_) => {
568                            Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location))
569                        }
570                        SourceUnitPart::StructDefinition(struct_) => {
571                            Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location))
572                        }
573                        SourceUnitPart::FunctionDefinition(func) => {
574                            Some((&func.name, func.id, &func.src, &func.name_location))
575                        }
576                        SourceUnitPart::VariableDeclaration(var) => {
577                            Some((&var.name, var.id, &var.src, &var.name_location))
578                        }
579                        SourceUnitPart::UserDefinedValueTypeDefinition(type_) => {
580                            Some((&type_.name, type_.id, &type_.src, &type_.name_location))
581                        }
582                        _ => None,
583                    })
584                    .map(|(name, id, src, maybe_name_src)| {
585                        let loc = match maybe_name_src {
586                            Some(src) => {
587                                ItemLocation::try_from_source_loc(src, path.clone()).unwrap()
588                            }
589                            None => {
590                                // Find location of name in source
591                                let content: &str = &self.sources.get(path).unwrap().content;
592                                let start = src.start.unwrap();
593                                let end = start + src.length.unwrap();
594
595                                let name_start = content[start..end].find(name).unwrap();
596                                let name_end = name_start + name.len();
597
598                                ItemLocation {
599                                    path: path.clone(),
600                                    start: start + name_start,
601                                    end: start + name_end,
602                                }
603                            }
604                        };
605
606                        (name, (id, loc))
607                    })
608            })
609            .fold(HashMap::new(), |mut acc, (name, (id, item_location))| {
610                acc.entry(name).or_default().insert((id, item_location));
611                acc
612            })
613    }
614
615    /// Collect all contract-level definitions in the form of a mapping from definition id to
616    /// (definition name, contract id)
617    fn collect_contract_level_definitions(&self) -> HashMap<usize, (&String, usize)> {
618        self.asts
619            .iter()
620            .flat_map(|(_, ast)| {
621                ast.nodes.iter().filter_map(|node| match node {
622                    SourceUnitPart::ContractDefinition(contract) => {
623                        Some((contract.id, &contract.nodes))
624                    }
625                    _ => None,
626                })
627            })
628            .flat_map(|(contract_id, nodes)| {
629                nodes.iter().filter_map(move |node| match node {
630                    ContractDefinitionPart::EnumDefinition(enum_) => {
631                        Some((enum_.id, (&enum_.name, contract_id)))
632                    }
633                    ContractDefinitionPart::ErrorDefinition(error) => {
634                        Some((error.id, (&error.name, contract_id)))
635                    }
636                    ContractDefinitionPart::EventDefinition(event) => {
637                        Some((event.id, (&event.name, contract_id)))
638                    }
639                    ContractDefinitionPart::StructDefinition(struct_) => {
640                        Some((struct_.id, (&struct_.name, contract_id)))
641                    }
642                    ContractDefinitionPart::FunctionDefinition(function) => {
643                        Some((function.id, (&function.name, contract_id)))
644                    }
645                    ContractDefinitionPart::VariableDeclaration(variable) => {
646                        Some((variable.id, (&variable.name, contract_id)))
647                    }
648                    ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => {
649                        Some((value_type.id, (&value_type.name, contract_id)))
650                    }
651                    _ => None,
652                })
653            })
654            .collect()
655    }
656
657    /// Collects all references to any declaration in the form of a mapping from declaration id to
658    /// set of source locations it appears in
659    fn collect_references(&self) -> HashMap<isize, HashSet<ItemLocation>> {
660        self.asts
661            .iter()
662            .flat_map(|(path, ast)| {
663                let mut collector =
664                    ReferencesCollector { path: path.clone(), references: HashMap::new() };
665                ast.walk(&mut collector);
666                collector.references
667            })
668            .fold(HashMap::new(), |mut acc, (id, locs)| {
669                acc.entry(id).or_default().extend(locs);
670                acc
671            })
672    }
673
674    /// Removes all imports from all sources.
675    fn remove_imports(&self, updates: &mut Updates) {
676        for loc in self.collect_imports() {
677            updates.entry(loc.path.clone()).or_default().insert((
678                loc.start,
679                loc.end,
680                String::new(),
681            ));
682        }
683    }
684
685    // Collects all imports locations.
686    fn collect_imports(&self) -> HashSet<ItemLocation> {
687        self.asts
688            .iter()
689            .flat_map(|(path, ast)| {
690                ast.nodes.iter().filter_map(|node| match node {
691                    SourceUnitPart::ImportDirective(import) => {
692                        ItemLocation::try_from_source_loc(&import.src, path.clone())
693                    }
694                    _ => None,
695                })
696            })
697            .collect()
698    }
699
700    /// Removes all pragma directives from all sources. Returns Vec with experimental and combined
701    /// version pragmas (if present).
702    fn process_pragmas(&self, updates: &mut Updates) -> Vec<String> {
703        let mut abicoder_v2 = None;
704
705        let pragmas = self.collect_pragmas();
706        let mut version_pragmas = Vec::new();
707
708        for loc in &pragmas {
709            let pragma_content = self.read_location(loc);
710            if pragma_content.contains("experimental") || pragma_content.contains("abicoder") {
711                if abicoder_v2.is_none() {
712                    abicoder_v2 = Some(self.read_location(loc).to_string());
713                }
714            } else if pragma_content.contains("solidity") {
715                version_pragmas.push(pragma_content);
716            }
717
718            updates.entry(loc.path.clone()).or_default().insert((
719                loc.start,
720                loc.end,
721                String::new(),
722            ));
723        }
724
725        let mut pragmas = Vec::new();
726
727        if let Some(version_pragma) = combine_version_pragmas(version_pragmas) {
728            pragmas.push(version_pragma);
729        }
730
731        if let Some(pragma) = abicoder_v2 {
732            pragmas.push(pragma);
733        }
734
735        pragmas
736    }
737
738    // Collects all pragma directives locations.
739    fn collect_pragmas(&self) -> HashSet<ItemLocation> {
740        self.asts
741            .iter()
742            .flat_map(|(path, ast)| {
743                ast.nodes.iter().filter_map(|node| match node {
744                    SourceUnitPart::PragmaDirective(import) => {
745                        ItemLocation::try_from_source_loc(&import.src, path.clone())
746                    }
747                    _ => None,
748                })
749            })
750            .collect()
751    }
752
753    /// Removes all license identifiers from all sources. Returns license identifier from target
754    /// file, if any.
755    fn process_licenses(&self, updates: &mut Updates) -> Option<&str> {
756        let mut target_license = None;
757
758        for loc in &self.collect_licenses() {
759            if loc.path == self.target {
760                let license_line = self.read_location(loc);
761                let license_start = license_line.find("SPDX-License-Identifier:").unwrap();
762                target_license = Some(license_line[license_start..].trim());
763            }
764            updates.entry(loc.path.clone()).or_default().insert((
765                loc.start,
766                loc.end,
767                String::new(),
768            ));
769        }
770
771        target_license
772    }
773
774    // Collects all SPDX-License-Identifier locations.
775    fn collect_licenses(&self) -> HashSet<ItemLocation> {
776        self.sources
777            .iter()
778            .flat_map(|(path, source)| {
779                let mut licenses = HashSet::new();
780                if let Some(license_start) = source.content.find("SPDX-License-Identifier:") {
781                    let start =
782                        source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0);
783                    let end = start
784                        + source.content[start..]
785                            .find('\n')
786                            .unwrap_or(source.content.len() - start);
787                    licenses.insert(ItemLocation { path: path.clone(), start, end });
788                }
789                licenses
790            })
791            .collect()
792    }
793
794    // Reads value from the given location of a source file.
795    fn read_location(&self, loc: &ItemLocation) -> &str {
796        let content: &str = &self.sources.get(&loc.path).unwrap().content;
797        &content[loc.start..loc.end]
798    }
799}
800
801/// Performs DFS to collect all dependencies of a target
802fn collect_deps<D: ParsedSource + MaybeSolData>(
803    path: &PathBuf,
804    paths: &ProjectPathsConfig<D::Language>,
805    graph: &Graph<D>,
806    deps: &mut HashSet<PathBuf>,
807) -> Result<()> {
808    if deps.insert(path.clone()) {
809        let target_dir = path.parent().ok_or_else(|| {
810            SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display()))
811        })?;
812
813        let node_id = graph
814            .files()
815            .get(path)
816            .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?;
817
818        if let Some(data) = graph.node(*node_id).data.sol_data() {
819            for import in &data.imports {
820                let path = paths.resolve_import(target_dir, import.data().path())?;
821                collect_deps(&path, paths, graph, deps)?;
822            }
823        }
824    }
825    Ok(())
826}
827
828/// We want to make order in which sources are written to resulted flattened file
829/// deterministic.
830///
831/// We can't just sort files alphabetically as it might break compilation, because Solidity
832/// does not allow base class definitions to appear after derived contract
833/// definitions.
834///
835/// Instead, we sort files by the number of their dependencies (imports of any depth) in ascending
836/// order. If files have the same number of dependencies, we sort them alphabetically.
837/// Target file is always placed last.
838pub fn collect_ordered_deps<D: ParsedSource + MaybeSolData>(
839    path: &PathBuf,
840    paths: &ProjectPathsConfig<D::Language>,
841    graph: &Graph<D>,
842) -> Result<Vec<PathBuf>> {
843    let mut deps = HashSet::new();
844    collect_deps(path, paths, graph, &mut deps)?;
845
846    // Remove path prior counting dependencies
847    // It will be added later to the end of resulted Vec
848    deps.remove(path);
849
850    let mut paths_with_deps_count = Vec::new();
851    for path in deps {
852        let mut path_deps = HashSet::new();
853        collect_deps(&path, paths, graph, &mut path_deps)?;
854        paths_with_deps_count.push((path_deps.len(), path));
855    }
856
857    paths_with_deps_count.sort_by(|(count_0, path_0), (count_1, path_1)| {
858        // Compare dependency counts
859        match count_0.cmp(count_1) {
860            o if !o.is_eq() => return o,
861            _ => {}
862        };
863
864        // Try comparing file names
865        if let Some((name_0, name_1)) = path_0.file_name().zip(path_1.file_name()) {
866            match name_0.cmp(name_1) {
867                o if !o.is_eq() => return o,
868                _ => {}
869            }
870        }
871
872        // If both filenames and dependecy counts are equal, fallback to comparing file paths
873        path_0.cmp(path_1)
874    });
875
876    let mut ordered_deps =
877        paths_with_deps_count.into_iter().map(|(_, path)| path).collect::<Vec<_>>();
878
879    ordered_deps.push(path.clone());
880
881    Ok(ordered_deps)
882}
883
884pub fn combine_version_pragmas(pragmas: Vec<&str>) -> Option<String> {
885    let mut versions = pragmas
886        .into_iter()
887        .filter_map(|p| {
888            SolData::parse_version_req(
889                p.replace("pragma", "").replace("solidity", "").replace(';', "").trim(),
890            )
891            .ok()
892        })
893        .flat_map(|req| req.comparators)
894        .collect::<HashSet<_>>()
895        .into_iter()
896        .map(|comp| comp.to_string())
897        .collect::<Vec<_>>();
898
899    versions.sort();
900
901    if !versions.is_empty() {
902        return Some(format!("pragma solidity {};", versions.iter().format(" ")));
903    }
904
905    None
906}