foundry_compilers/
flatten.rs

1use crate::{
2    apply_updates,
3    compilers::{Compiler, ParsedSource},
4    filter::MaybeSolData,
5    resolver::parse::SolData,
6    ArtifactOutput, CompilerSettings, Graph, Project, ProjectPathsConfig, SourceParser, Updates,
7};
8use foundry_compilers_artifacts::{
9    ast::{visitor::Visitor, *},
10    output_selection::OutputSelection,
11    solc::ExternalInlineAssemblyReference,
12    sources::{Source, Sources},
13    ContractDefinitionPart, SourceUnit, SourceUnitPart,
14};
15use foundry_compilers_core::{
16    error::{Result, SolcError},
17    utils,
18};
19use itertools::Itertools;
20use std::{
21    collections::{BTreeSet, HashMap, HashSet},
22    hash::Hash,
23    path::{Path, PathBuf},
24    sync::Arc,
25};
26use visitor::Walk;
27
28/// Alternative of `SourceLocation` which includes path of the file.
29#[derive(Clone, Debug, PartialEq, Eq, Hash)]
30struct ItemLocation {
31    path: PathBuf,
32    start: usize,
33    end: usize,
34}
35
36impl ItemLocation {
37    fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option<Self> {
38        let start = src.start?;
39        let end = start + src.length?;
40
41        Some(Self { path, start, end })
42    }
43
44    fn length(&self) -> usize {
45        self.end - self.start
46    }
47}
48
49/// Visitor exploring AST and collecting all references to declarations via `Identifier` and
50/// `IdentifierPath` nodes.
51///
52/// It also collects `MemberAccess` parts. So, if we have `X.Y` expression, loc and AST ID will be
53/// saved for Y only.
54///
55/// That way, even if we have a long `MemberAccess` expression (a.b.c.d) then the first member (a)
56/// will be collected as either `Identifier` or `IdentifierPath`, and all subsequent parts (b, c, d)
57/// will be collected as `MemberAccess` parts.
58struct ReferencesCollector {
59    path: PathBuf,
60    references: HashMap<isize, HashSet<ItemLocation>>,
61}
62
63impl ReferencesCollector {
64    fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) {
65        if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) {
66            self.references.entry(id).or_default().insert(loc);
67        }
68    }
69}
70
71impl Visitor for ReferencesCollector {
72    fn visit_identifier(&mut self, identifier: &Identifier) {
73        if let Some(id) = identifier.referenced_declaration {
74            self.process_referenced_declaration(id, &identifier.src);
75        }
76    }
77
78    fn visit_identifier_path(&mut self, path: &IdentifierPath) {
79        self.process_referenced_declaration(path.referenced_declaration, &path.src);
80    }
81
82    fn visit_member_access(&mut self, access: &MemberAccess) {
83        if let Some(referenced_declaration) = access.referenced_declaration {
84            if let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length) {
85                let name_length = access.member_name.len();
86                // Accessed member name is in the last name.len() symbols of the expression.
87                let start = src_start + src_length - name_length;
88                let end = start + name_length;
89
90                self.references.entry(referenced_declaration).or_default().insert(ItemLocation {
91                    start,
92                    end,
93                    path: self.path.to_path_buf(),
94                });
95            }
96        }
97    }
98
99    fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) {
100        let mut src = reference.src;
101
102        // If suffix is used in assembly reference (e.g. value.slot), it will be included into src.
103        // However, we are only interested in the referenced name, thus we strip .<suffix> part.
104        if let Some(suffix) = &reference.suffix {
105            if let Some(len) = src.length.as_mut() {
106                let suffix_len = suffix.to_string().len();
107                *len -= suffix_len + 1;
108            }
109        }
110
111        self.process_referenced_declaration(reference.declaration as isize, &src);
112    }
113}
114
115pub struct FlatteningResult {
116    /// Updated source in the order they should be written to the output file.
117    sources: Vec<String>,
118    /// Pragmas that should be present in the target file.
119    pragmas: Vec<String>,
120    /// License identifier that should be present in the target file.
121    license: Option<String>,
122}
123
124impl FlatteningResult {
125    fn new(
126        mut flattener: Flattener,
127        updates: Updates,
128        pragmas: Vec<String>,
129        license: Option<String>,
130    ) -> Self {
131        apply_updates(&mut flattener.sources, updates);
132
133        let sources = flattener
134            .ordered_sources
135            .iter()
136            .map(|path| flattener.sources.remove(path).unwrap().content)
137            .map(Arc::unwrap_or_clone)
138            .collect();
139
140        Self { sources, pragmas, license }
141    }
142
143    fn get_flattened_target(&self) -> String {
144        let mut result = String::new();
145
146        if let Some(license) = &self.license {
147            result.push_str(&format!("// {license}\n"));
148        }
149        for pragma in &self.pragmas {
150            result.push_str(&format!("{pragma}\n"));
151        }
152        for source in &self.sources {
153            result.push_str(&format!("\n\n{source}"));
154        }
155
156        format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())
157    }
158}
159
160#[derive(Debug, thiserror::Error)]
161pub enum FlattenerError {
162    #[error("Failed to compile {0}")]
163    Compilation(SolcError),
164    #[error(transparent)]
165    Other(SolcError),
166}
167
168impl<T: Into<SolcError>> From<T> for FlattenerError {
169    fn from(err: T) -> Self {
170        Self::Other(err.into())
171    }
172}
173
174/// Context for flattening. Stores all sources and ASTs that are in scope of the flattening target.
175pub struct Flattener {
176    /// Target file to flatten.
177    target: PathBuf,
178    /// Sources including only target and it dependencies (imports of any depth).
179    sources: Sources,
180    /// Vec of (path, ast) pairs.
181    asts: Vec<(PathBuf, SourceUnit)>,
182    /// Sources in the order they should be written to the output file.
183    ordered_sources: Vec<PathBuf>,
184    /// Project root directory.
185    project_root: PathBuf,
186}
187
188impl Flattener {
189    /// Compiles the target file and prepares AST and analysis data for flattening.
190    pub fn new<C: Compiler, T: ArtifactOutput<CompilerContract = C::CompilerContract>>(
191        mut project: Project<C, T>,
192        target: &Path,
193    ) -> std::result::Result<Self, FlattenerError>
194    where
195        C::Parser: SourceParser<ParsedSource: MaybeSolData>,
196    {
197        // Configure project to compile the target file and only request AST for target file.
198        project.cached = false;
199        project.no_artifacts = true;
200        project.settings.update_output_selection(|selection| {
201            *selection = OutputSelection::ast_output_selection();
202        });
203
204        let output = project.compile_file(target).map_err(FlattenerError::Compilation)?;
205
206        if output.has_compiler_errors() {
207            return Err(FlattenerError::Compilation(SolcError::msg(&output)));
208        }
209
210        let output = output.compiler_output;
211
212        let sources = Source::read_all([target.to_path_buf()])?;
213        let graph = Graph::<C::Parser>::resolve_sources(&project.paths, sources)?;
214
215        let ordered_sources = collect_ordered_deps(target, &project.paths, &graph)?;
216
217        #[cfg(windows)]
218        let ordered_sources = {
219            let mut sources = ordered_sources;
220            use path_slash::PathBufExt;
221            for p in &mut sources {
222                *p = PathBuf::from(p.to_slash_lossy().to_string());
223            }
224            sources
225        };
226
227        let sources = Source::read_all(&ordered_sources)?;
228
229        // Convert all ASTs from artifacts to strongly typed ASTs
230        let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new();
231        for (path, ast) in output.sources.0.iter().filter_map(|(path, files)| {
232            if let Some(ast) = files.first().and_then(|source| source.source_file.ast.as_ref()) {
233                if sources.contains_key(path) {
234                    return Some((path, ast));
235                }
236            }
237            None
238        }) {
239            asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?));
240        }
241
242        Ok(Self {
243            target: target.into(),
244            sources,
245            asts,
246            ordered_sources,
247            project_root: project.root().to_path_buf(),
248        })
249    }
250
251    /// Flattens target file and returns the result as a string
252    ///
253    /// Flattening process includes following steps:
254    /// 1. Find all file-level definitions and rename references to them via aliased or qualified
255    ///    imports.
256    /// 2. Find all duplicates among file-level definitions and rename them to avoid conflicts.
257    /// 3. Remove all imports.
258    /// 4. Remove all pragmas except for the ones in the target file.
259    /// 5. Remove all license identifiers except for the one in the target file.
260    pub fn flatten(self) -> String {
261        let mut updates = Updates::new();
262
263        self.append_filenames(&mut updates);
264        let top_level_names = self.rename_top_level_definitions(&mut updates);
265        self.rename_contract_level_types_references(&top_level_names, &mut updates);
266        self.remove_qualified_imports(&mut updates);
267        self.update_inheritdocs(&top_level_names, &mut updates);
268
269        self.remove_imports(&mut updates);
270        let target_pragmas = self.process_pragmas(&mut updates);
271        let target_license = self.process_licenses(&mut updates);
272
273        self.flatten_result(updates, target_pragmas, target_license).get_flattened_target()
274    }
275
276    fn flatten_result(
277        self,
278        updates: Updates,
279        target_pragmas: Vec<String>,
280        target_license: Option<String>,
281    ) -> FlatteningResult {
282        FlatteningResult::new(self, updates, target_pragmas, target_license)
283    }
284
285    /// Appends a comment with the file name to the beginning of each source.
286    fn append_filenames(&self, updates: &mut Updates) {
287        for path in &self.ordered_sources {
288            updates.entry(path.clone()).or_default().insert((
289                0,
290                0,
291                format!("// {}\n", path.strip_prefix(&self.project_root).unwrap_or(path).display()),
292            ));
293        }
294    }
295
296    /// Finds and goes over all references to file-level definitions and updates them to match
297    /// definition name. This is needed for two reasons:
298    /// 1. We want to rename all aliased or qualified imports.
299    /// 2. We want to find any duplicates and rename them to avoid conflicts.
300    ///
301    /// If we find more than 1 declaration with the same name, it's name is getting changed.
302    /// Two Counter contracts will be renamed to Counter_0 and Counter_1
303    ///
304    /// Returns mapping from top-level declaration id to its name (possibly updated)
305    fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap<usize, String> {
306        let top_level_definitions = self.collect_top_level_definitions();
307        let references = self.collect_references();
308
309        let mut top_level_names = HashMap::new();
310
311        for (name, ids) in top_level_definitions {
312            let mut definition_name = name.to_string();
313            let needs_rename = ids.len() > 1;
314
315            let mut ids = ids.clone().into_iter().collect::<Vec<_>>();
316            if needs_rename {
317                // `loc.path` is expected to be different for each id because there can't be 2
318                // top-level declarations with the same name in the same file.
319                //
320                // Sorting by index loc.path and loc.start in sorted files to make the renaming
321                // process deterministic.
322                ids.sort_by_key(|(_, loc)| {
323                    (self.ordered_sources.iter().position(|p| p == &loc.path).unwrap(), loc.start)
324                });
325            }
326            for (i, (id, loc)) in ids.iter().enumerate() {
327                if needs_rename {
328                    definition_name = format!("{name}_{i}");
329                }
330                updates.entry(loc.path.clone()).or_default().insert((
331                    loc.start,
332                    loc.end,
333                    definition_name.clone(),
334                ));
335                if let Some(references) = references.get(&(*id as isize)) {
336                    for loc in references {
337                        updates.entry(loc.path.clone()).or_default().insert((
338                            loc.start,
339                            loc.end,
340                            definition_name.clone(),
341                        ));
342                    }
343                }
344
345                top_level_names.insert(*id, definition_name.clone());
346            }
347        }
348        top_level_names
349    }
350
351    /// This is not very clean, but in most cases effective enough method to remove qualified
352    /// imports from sources.
353    ///
354    /// Every qualified import part is an `Identifier` with `referencedDeclaration` field matching
355    /// ID of one of the import directives.
356    ///
357    /// This approach works by firstly collecting all IDs of import directives, and then looks for
358    /// any references of them. Once the reference is found, it's full length is getting removed
359    /// from source + 1 character ('.')
360    ///
361    /// This should work correctly for vast majority of cases, however there are situations for
362    /// which such approach won't work, most of which are related to code being formatted in an
363    /// uncommon way.
364    fn remove_qualified_imports(&self, updates: &mut Updates) {
365        let imports_ids = self
366            .asts
367            .iter()
368            .flat_map(|(_, ast)| {
369                ast.nodes.iter().filter_map(|node| match node {
370                    SourceUnitPart::ImportDirective(directive) => Some(directive.id),
371                    _ => None,
372                })
373            })
374            .collect::<HashSet<_>>();
375
376        let references = self.collect_references();
377
378        for (id, locs) in references {
379            if !imports_ids.contains(&(id as usize)) {
380                continue;
381            }
382
383            for loc in locs {
384                updates.entry(loc.path).or_default().insert((
385                    loc.start,
386                    loc.end + 1,
387                    String::new(),
388                ));
389            }
390        }
391    }
392
393    /// Here we are going through all references to items defined in scope of contracts and updating
394    /// them to be using correct parent contract name.
395    ///
396    /// This will only operate on references from `IdentifierPath` nodes.
397    fn rename_contract_level_types_references(
398        &self,
399        top_level_names: &HashMap<usize, String>,
400        updates: &mut Updates,
401    ) {
402        let contract_level_definitions = self.collect_contract_level_definitions();
403
404        for (path, ast) in &self.asts {
405            for node in &ast.nodes {
406                let mut collector =
407                    ReferencesCollector { path: self.target.clone(), references: HashMap::new() };
408
409                node.walk(&mut collector);
410
411                let references = collector.references;
412
413                for (id, locs) in references {
414                    if let Some((name, contract_id)) =
415                        contract_level_definitions.get(&(id as usize))
416                    {
417                        for loc in &locs {
418                            // If child item is referenced directly by it's name it's either defined
419                            // in the same contract or in one of it's base contracts, so we don't
420                            // have to change anything.
421                            // Comparing lengths is enough because such items cannot be aliased.
422                            if loc.length() == name.len() {
423                                continue;
424                            }
425                            // If it was referenced somehow else, we rename it to `Parent.Child`
426                            // format.
427                            let parent_name = top_level_names.get(contract_id).unwrap();
428                            updates.entry(path.clone()).or_default().insert((
429                                loc.start,
430                                loc.end,
431                                format!("{parent_name}.{name}"),
432                            ));
433                        }
434                    }
435                }
436            }
437        }
438    }
439
440    /// Finds all @inheritdoc tags in natspec comments and tries replacing them.
441    ///
442    /// We will either replace contract name or remove @inheritdoc tag completely to avoid
443    /// generating invalid source code.
444    fn update_inheritdocs(&self, top_level_names: &HashMap<usize, String>, updates: &mut Updates) {
445        trace!("updating @inheritdoc tags");
446        for (path, ast) in &self.asts {
447            // Collect all exported symbols for this source unit
448            // @inheritdoc value is either one of those or qualified import path which we don't
449            // support
450            let exported_symbols = ast
451                .exported_symbols
452                .iter()
453                .filter_map(
454                    |(name, ids)| {
455                        if !ids.is_empty() {
456                            Some((name.as_str(), ids[0]))
457                        } else {
458                            None
459                        }
460                    },
461                )
462                .collect::<HashMap<_, _>>();
463
464            // Collect all docs in all contracts
465            let docs = ast
466                .nodes
467                .iter()
468                .filter_map(|node| match node {
469                    SourceUnitPart::ContractDefinition(d) => Some(d),
470                    _ => None,
471                })
472                .flat_map(|contract| {
473                    contract.nodes.iter().filter_map(|node| match node {
474                        ContractDefinitionPart::EventDefinition(event) => {
475                            event.documentation.as_ref()
476                        }
477                        ContractDefinitionPart::ErrorDefinition(error) => {
478                            error.documentation.as_ref()
479                        }
480                        ContractDefinitionPart::FunctionDefinition(func) => {
481                            func.documentation.as_ref()
482                        }
483                        ContractDefinitionPart::VariableDeclaration(var) => {
484                            var.documentation.as_ref()
485                        }
486                        _ => None,
487                    })
488                });
489
490            docs.for_each(|doc| {
491                let Documentation::Structured(doc) = doc else {
492                    return
493                };
494                let src_start = doc.src.start.unwrap();
495                let src_end = src_start + doc.src.length.unwrap();
496
497                // Documentation node has `text` field, however, it does not contain
498                // slashes and we can't use if to find positions.
499                let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end];
500                let tag_len = "@inheritdoc".len();
501
502                if let Some(tag_start) = content.find("@inheritdoc") {
503                    trace!("processing doc with content {:?}", content);
504                    if let Some(name_start) = content[tag_start + tag_len..]
505                        .find(|c| c != ' ')
506                        .map(|p| p + tag_start + tag_len)
507                    {
508                        let name_end = content[name_start..]
509                            .find([' ', '\n', '*', '/'])
510                            .map(|p| p + name_start)
511                            .unwrap_or(content.len());
512
513                        let name = &content[name_start..name_end];
514                        trace!("found name {name}");
515
516                        let mut new_name = None;
517
518                        if let Some(ast_id) = exported_symbols.get(name) {
519                            if let Some(name) = top_level_names.get(ast_id) {
520                                new_name = Some(name);
521                            } else {
522                                trace!(identifiers=?top_level_names, "ast id {ast_id} cannot be matched to top-level identifier");
523                            }
524                        }
525
526                        if let Some(new_name) = new_name {
527                            trace!("updating tag value with {new_name}");
528                            updates.entry(path.to_path_buf()).or_default().insert((
529                                src_start + name_start,
530                                src_start + name_end,
531                                new_name.to_string(),
532                            ));
533                        } else {
534                            trace!("name is unknown, removing @inheritdoc tag");
535                            updates.entry(path.to_path_buf()).or_default().insert((
536                                src_start + tag_start,
537                                src_start + name_end,
538                                String::new(),
539                            ));
540                        }
541                    }
542                }
543            });
544        }
545    }
546
547    /// Processes all ASTs and collects all top-level definitions in the form of
548    /// a mapping from name to (definition id, source location)
549    fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> {
550        self.asts
551            .iter()
552            .flat_map(|(path, ast)| {
553                ast.nodes
554                    .iter()
555                    .filter_map(|node| match node {
556                        SourceUnitPart::ContractDefinition(contract) => Some((
557                            &contract.name,
558                            contract.id,
559                            &contract.src,
560                            &contract.name_location,
561                        )),
562                        SourceUnitPart::EnumDefinition(enum_) => {
563                            Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location))
564                        }
565                        SourceUnitPart::StructDefinition(struct_) => {
566                            Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location))
567                        }
568                        SourceUnitPart::FunctionDefinition(func) => {
569                            Some((&func.name, func.id, &func.src, &func.name_location))
570                        }
571                        SourceUnitPart::VariableDeclaration(var) => {
572                            Some((&var.name, var.id, &var.src, &var.name_location))
573                        }
574                        SourceUnitPart::UserDefinedValueTypeDefinition(type_) => {
575                            Some((&type_.name, type_.id, &type_.src, &type_.name_location))
576                        }
577                        _ => None,
578                    })
579                    .map(|(name, id, src, maybe_name_src)| {
580                        let loc = match maybe_name_src {
581                            Some(src) => {
582                                ItemLocation::try_from_source_loc(src, path.clone()).unwrap()
583                            }
584                            None => {
585                                // Find location of name in source
586                                let content: &str = &self.sources.get(path).unwrap().content;
587                                let start = src.start.unwrap();
588                                let end = start + src.length.unwrap();
589
590                                let name_start = content[start..end].find(name).unwrap();
591                                let name_end = name_start + name.len();
592
593                                ItemLocation {
594                                    path: path.clone(),
595                                    start: start + name_start,
596                                    end: start + name_end,
597                                }
598                            }
599                        };
600
601                        (name, (id, loc))
602                    })
603            })
604            .fold(HashMap::new(), |mut acc, (name, (id, item_location))| {
605                acc.entry(name).or_default().insert((id, item_location));
606                acc
607            })
608    }
609
610    /// Collect all contract-level definitions in the form of a mapping from definition id to
611    /// (definition name, contract id)
612    fn collect_contract_level_definitions(&self) -> HashMap<usize, (&String, usize)> {
613        self.asts
614            .iter()
615            .flat_map(|(_, ast)| {
616                ast.nodes.iter().filter_map(|node| match node {
617                    SourceUnitPart::ContractDefinition(contract) => {
618                        Some((contract.id, &contract.nodes))
619                    }
620                    _ => None,
621                })
622            })
623            .flat_map(|(contract_id, nodes)| {
624                nodes.iter().filter_map(move |node| match node {
625                    ContractDefinitionPart::EnumDefinition(enum_) => {
626                        Some((enum_.id, (&enum_.name, contract_id)))
627                    }
628                    ContractDefinitionPart::ErrorDefinition(error) => {
629                        Some((error.id, (&error.name, contract_id)))
630                    }
631                    ContractDefinitionPart::EventDefinition(event) => {
632                        Some((event.id, (&event.name, contract_id)))
633                    }
634                    ContractDefinitionPart::StructDefinition(struct_) => {
635                        Some((struct_.id, (&struct_.name, contract_id)))
636                    }
637                    ContractDefinitionPart::FunctionDefinition(function) => {
638                        Some((function.id, (&function.name, contract_id)))
639                    }
640                    ContractDefinitionPart::VariableDeclaration(variable) => {
641                        Some((variable.id, (&variable.name, contract_id)))
642                    }
643                    ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => {
644                        Some((value_type.id, (&value_type.name, contract_id)))
645                    }
646                    _ => None,
647                })
648            })
649            .collect()
650    }
651
652    /// Collects all references to any declaration in the form of a mapping from declaration id to
653    /// set of source locations it appears in
654    fn collect_references(&self) -> HashMap<isize, HashSet<ItemLocation>> {
655        self.asts
656            .iter()
657            .flat_map(|(path, ast)| {
658                let mut collector =
659                    ReferencesCollector { path: path.clone(), references: HashMap::new() };
660                ast.walk(&mut collector);
661                collector.references
662            })
663            .fold(HashMap::new(), |mut acc, (id, locs)| {
664                acc.entry(id).or_default().extend(locs);
665                acc
666            })
667    }
668
669    /// Removes all imports from all sources.
670    fn remove_imports(&self, updates: &mut Updates) {
671        for loc in self.collect_imports() {
672            updates.entry(loc.path.clone()).or_default().insert((
673                loc.start,
674                loc.end,
675                String::new(),
676            ));
677        }
678    }
679
680    // Collects all imports locations.
681    fn collect_imports(&self) -> HashSet<ItemLocation> {
682        self.asts
683            .iter()
684            .flat_map(|(path, ast)| {
685                ast.nodes.iter().filter_map(|node| match node {
686                    SourceUnitPart::ImportDirective(import) => {
687                        ItemLocation::try_from_source_loc(&import.src, path.clone())
688                    }
689                    _ => None,
690                })
691            })
692            .collect()
693    }
694
695    /// Removes all pragma directives from all sources. Returns Vec with experimental and combined
696    /// version pragmas (if present).
697    fn process_pragmas(&self, updates: &mut Updates) -> Vec<String> {
698        let mut abicoder_v2 = None;
699
700        let pragmas = self.collect_pragmas();
701        let mut version_pragmas = Vec::new();
702
703        for loc in &pragmas {
704            let pragma_content = self.read_location(loc);
705            if pragma_content.contains("experimental") || pragma_content.contains("abicoder") {
706                if abicoder_v2.is_none() {
707                    abicoder_v2 = Some(self.read_location(loc).to_string());
708                }
709            } else if pragma_content.contains("solidity") {
710                version_pragmas.push(pragma_content);
711            }
712
713            updates.entry(loc.path.clone()).or_default().insert((
714                loc.start,
715                loc.end,
716                String::new(),
717            ));
718        }
719
720        let mut pragmas = Vec::new();
721
722        if let Some(version_pragma) = combine_version_pragmas(&version_pragmas) {
723            pragmas.push(version_pragma);
724        }
725
726        if let Some(pragma) = abicoder_v2 {
727            pragmas.push(pragma);
728        }
729
730        pragmas
731    }
732
733    // Collects all pragma directives locations.
734    fn collect_pragmas(&self) -> HashSet<ItemLocation> {
735        self.asts
736            .iter()
737            .flat_map(|(path, ast)| {
738                ast.nodes.iter().filter_map(|node| match node {
739                    SourceUnitPart::PragmaDirective(import) => {
740                        ItemLocation::try_from_source_loc(&import.src, path.clone())
741                    }
742                    _ => None,
743                })
744            })
745            .collect()
746    }
747
748    /// Removes all license identifiers from all sources. Returns license identifier from target
749    /// file, if any.
750    fn process_licenses(&self, updates: &mut Updates) -> Option<String> {
751        let mut target_license = None;
752
753        for loc in &self.collect_licenses() {
754            if loc.path == self.target {
755                let license_line = self.read_location(loc);
756                let license_start = license_line.find("SPDX-License-Identifier:").unwrap();
757                target_license = Some(license_line[license_start..].trim().to_string());
758            }
759            updates.entry(loc.path.clone()).or_default().insert((
760                loc.start,
761                loc.end,
762                String::new(),
763            ));
764        }
765
766        target_license
767    }
768
769    // Collects all SPDX-License-Identifier locations.
770    fn collect_licenses(&self) -> HashSet<ItemLocation> {
771        self.sources
772            .iter()
773            .flat_map(|(path, source)| {
774                let mut licenses = HashSet::new();
775                if let Some(license_start) = source.content.find("SPDX-License-Identifier:") {
776                    let start =
777                        source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0);
778                    let end = start
779                        + source.content[start..]
780                            .find('\n')
781                            .unwrap_or(source.content.len() - start);
782                    licenses.insert(ItemLocation { path: path.clone(), start, end });
783                }
784                licenses
785            })
786            .collect()
787    }
788
789    // Reads value from the given location of a source file.
790    fn read_location(&self, loc: &ItemLocation) -> &str {
791        let content: &str = &self.sources.get(&loc.path).unwrap().content;
792        &content[loc.start..loc.end]
793    }
794}
795
796/// Performs DFS to collect all dependencies of a target
797fn collect_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
798    path: &Path,
799    paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
800    graph: &Graph<P>,
801    deps: &mut HashSet<PathBuf>,
802) -> Result<()> {
803    if deps.insert(path.to_path_buf()) {
804        let target_dir = path.parent().ok_or_else(|| {
805            SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display()))
806        })?;
807
808        let node_id = graph
809            .files()
810            .get(path)
811            .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?;
812
813        if let Some(data) = graph.node(*node_id).data.sol_data() {
814            for import in &data.imports {
815                let path = paths.resolve_import(target_dir, import.data().path())?;
816                collect_deps(&path, paths, graph, deps)?;
817            }
818        }
819    }
820    Ok(())
821}
822
823/// We want to make order in which sources are written to resulted flattened file
824/// deterministic.
825///
826/// We can't just sort files alphabetically as it might break compilation, because Solidity
827/// does not allow base class definitions to appear after derived contract
828/// definitions.
829///
830/// Instead, we sort files by the number of their dependencies (imports of any depth) in ascending
831/// order. If files have the same number of dependencies, we sort them alphabetically.
832/// Target file is always placed last.
833pub fn collect_ordered_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
834    path: &Path,
835    paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
836    graph: &Graph<P>,
837) -> Result<Vec<PathBuf>> {
838    let mut deps = HashSet::new();
839    collect_deps(path, paths, graph, &mut deps)?;
840
841    // Remove path prior counting dependencies
842    // It will be added later to the end of resulted Vec
843    deps.remove(path);
844
845    let mut paths_with_deps_count = Vec::new();
846    for path in deps {
847        let mut path_deps = HashSet::new();
848        collect_deps(&path, paths, graph, &mut path_deps)?;
849        paths_with_deps_count.push((path_deps.len(), path));
850    }
851
852    paths_with_deps_count.sort_by(|(count_0, path_0), (count_1, path_1)| {
853        // Compare dependency counts
854        match count_0.cmp(count_1) {
855            o if !o.is_eq() => return o,
856            _ => {}
857        };
858
859        // Try comparing file names
860        if let Some((name_0, name_1)) = path_0.file_name().zip(path_1.file_name()) {
861            match name_0.cmp(name_1) {
862                o if !o.is_eq() => return o,
863                _ => {}
864            }
865        }
866
867        // If both filenames and dependecy counts are equal, fallback to comparing file paths
868        path_0.cmp(path_1)
869    });
870
871    let mut ordered_deps =
872        paths_with_deps_count.into_iter().map(|(_, path)| path).collect::<Vec<_>>();
873
874    ordered_deps.push(path.to_path_buf());
875
876    Ok(ordered_deps)
877}
878
879pub fn combine_version_pragmas(pragmas: &[impl AsRef<str>]) -> Option<String> {
880    let versions = pragmas
881        .iter()
882        .map(AsRef::as_ref)
883        .filter_map(SolData::parse_version_pragma)
884        .filter_map(Result::ok)
885        .flat_map(|req| req.comparators)
886        .map(|comp| comp.to_string())
887        .collect::<BTreeSet<_>>();
888    if versions.is_empty() {
889        return None;
890    }
891    Some(format!("pragma solidity {};", versions.iter().format(" ")))
892}