Skip to main content

foundry_compilers/
flatten.rs

1use crate::{
2    ArtifactOutput, CompilerSettings, Graph, Project, ProjectPathsConfig, SourceParser, Updates,
3    apply_updates,
4    compilers::{Compiler, ParsedSource},
5    filter::MaybeSolData,
6    resolver::parse::SolData,
7};
8use foundry_compilers_artifacts::{
9    ast::{visitor::Visitor, *},
10    output_selection::OutputSelection,
11    sources::{Source, Sources},
12};
13use foundry_compilers_core::{
14    error::{Result, SolcError},
15    utils,
16};
17use itertools::Itertools;
18use std::{
19    collections::{BTreeSet, HashMap, HashSet},
20    hash::Hash,
21    path::{Path, PathBuf},
22    sync::Arc,
23};
24use visitor::Walk;
25
26/// Alternative of `SourceLocation` which includes path of the file.
27#[derive(Clone, Debug, PartialEq, Eq, Hash)]
28struct ItemLocation {
29    path: PathBuf,
30    start: usize,
31    end: usize,
32}
33
34impl ItemLocation {
35    fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option<Self> {
36        let start = src.start?;
37        let end = start + src.length?;
38
39        Some(Self { path, start, end })
40    }
41
42    const fn length(&self) -> usize {
43        self.end - self.start
44    }
45}
46
47/// Visitor exploring AST and collecting all references to declarations via `Identifier` and
48/// `IdentifierPath` nodes.
49///
50/// It also collects `MemberAccess` parts. So, if we have `X.Y` expression, loc and AST ID will be
51/// saved for Y only.
52///
53/// That way, even if we have a long `MemberAccess` expression (a.b.c.d) then the first member (a)
54/// will be collected as either `Identifier` or `IdentifierPath`, and all subsequent parts (b, c, d)
55/// will be collected as `MemberAccess` parts.
56struct ReferencesCollector {
57    path: PathBuf,
58    references: HashMap<isize, HashSet<ItemLocation>>,
59}
60
61impl ReferencesCollector {
62    fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) {
63        if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) {
64            self.references.entry(id).or_default().insert(loc);
65        }
66    }
67}
68
69impl Visitor for ReferencesCollector {
70    fn visit_identifier(&mut self, identifier: &Identifier) {
71        if let Some(id) = identifier.referenced_declaration {
72            self.process_referenced_declaration(id, &identifier.src);
73        }
74    }
75
76    fn visit_identifier_path(&mut self, path: &IdentifierPath) {
77        self.process_referenced_declaration(path.referenced_declaration, &path.src);
78    }
79
80    fn visit_member_access(&mut self, access: &MemberAccess) {
81        if let Some(referenced_declaration) = access.referenced_declaration
82            && let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length)
83        {
84            let name_length = access.member_name.len();
85            // Accessed member name is in the last name.len() symbols of the expression.
86            let start = src_start + src_length - name_length;
87            let end = start + name_length;
88
89            self.references.entry(referenced_declaration).or_default().insert(ItemLocation {
90                start,
91                end,
92                path: self.path.clone(),
93            });
94        }
95    }
96
97    fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) {
98        let mut src = reference.src;
99
100        // If suffix is used in assembly reference (e.g. value.slot), it will be included into src.
101        // However, we are only interested in the referenced name, thus we strip .<suffix> part.
102        if let Some(suffix) = &reference.suffix
103            && let Some(len) = src.length.as_mut()
104        {
105            let suffix_len = suffix.to_string().len();
106            *len -= suffix_len + 1;
107        }
108
109        self.process_referenced_declaration(reference.declaration as isize, &src);
110    }
111}
112
113pub struct FlatteningResult {
114    /// Updated source in the order they should be written to the output file.
115    sources: Vec<String>,
116    /// Pragmas that should be present in the target file.
117    pragmas: Vec<String>,
118    /// License identifier that should be present in the target file.
119    license: Option<String>,
120}
121
122impl FlatteningResult {
123    fn new(
124        mut flattener: Flattener,
125        updates: Updates,
126        pragmas: Vec<String>,
127        license: Option<String>,
128    ) -> Self {
129        apply_updates(&mut flattener.sources, updates);
130
131        let sources = flattener
132            .ordered_sources
133            .iter()
134            .map(|path| flattener.sources.remove(path).unwrap().content)
135            .map(Arc::unwrap_or_clone)
136            .collect();
137
138        Self { sources, pragmas, license }
139    }
140
141    fn get_flattened_target(&self) -> String {
142        let mut result = String::new();
143
144        if let Some(license) = &self.license {
145            result.push_str(&format!("// {license}\n"));
146        }
147        for pragma in &self.pragmas {
148            result.push_str(&format!("{pragma}\n"));
149        }
150        for source in &self.sources {
151            result.push_str(&format!("\n\n{source}"));
152        }
153
154        format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())
155    }
156}
157
158#[derive(Debug, thiserror::Error)]
159pub enum FlattenerError {
160    #[error("Failed to compile {0}")]
161    Compilation(SolcError),
162    #[error(transparent)]
163    Other(SolcError),
164}
165
166impl<T: Into<SolcError>> From<T> for FlattenerError {
167    fn from(err: T) -> Self {
168        Self::Other(err.into())
169    }
170}
171
172/// Context for flattening. Stores all sources and ASTs that are in scope of the flattening target.
173pub struct Flattener {
174    /// Target file to flatten.
175    target: PathBuf,
176    /// Sources including only target and it dependencies (imports of any depth).
177    sources: Sources,
178    /// Vec of (path, ast) pairs.
179    asts: Vec<(PathBuf, SourceUnit)>,
180    /// Sources in the order they should be written to the output file.
181    ordered_sources: Vec<PathBuf>,
182    /// Project root directory.
183    project_root: PathBuf,
184}
185
186impl Flattener {
187    /// Compiles the target file and prepares AST and analysis data for flattening.
188    pub fn new<C: Compiler, T: ArtifactOutput<CompilerContract = C::CompilerContract>>(
189        mut project: Project<C, T>,
190        target: &Path,
191    ) -> std::result::Result<Self, FlattenerError>
192    where
193        C::Parser: SourceParser<ParsedSource: MaybeSolData>,
194    {
195        // Configure project to compile the target file and only request AST for target file.
196        project.cached = false;
197        project.no_artifacts = true;
198        project.settings.update_output_selection(|selection| {
199            *selection = OutputSelection::ast_output_selection();
200        });
201
202        let output = project.compile_file(target).map_err(FlattenerError::Compilation)?;
203
204        if output.has_compiler_errors() {
205            return Err(FlattenerError::Compilation(SolcError::msg(&output)));
206        }
207
208        let output = output.compiler_output;
209
210        let sources = Source::read_all([target.to_path_buf()])?;
211        let graph = Graph::<C::Parser>::resolve_sources(&project.paths, sources)?;
212
213        let ordered_sources = collect_ordered_deps(target, &project.paths, &graph)?;
214
215        #[cfg(windows)]
216        let ordered_sources = {
217            let mut sources = ordered_sources;
218            use path_slash::PathBufExt;
219            for p in &mut sources {
220                *p = PathBuf::from(p.to_slash_lossy().to_string());
221            }
222            sources
223        };
224
225        let sources = Source::read_all(&ordered_sources)?;
226
227        // Convert all ASTs from artifacts to strongly typed ASTs
228        let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new();
229        for (path, ast) in output.sources.0.iter().filter_map(|(path, files)| {
230            if let Some(ast) = files.first().and_then(|source| source.source_file.ast.as_ref())
231                && sources.contains_key(path)
232            {
233                return Some((path, ast));
234            }
235            None
236        }) {
237            asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?));
238        }
239
240        Ok(Self {
241            target: target.into(),
242            sources,
243            asts,
244            ordered_sources,
245            project_root: project.root().to_path_buf(),
246        })
247    }
248
249    /// Flattens target file and returns the result as a string
250    ///
251    /// Flattening process includes following steps:
252    /// 1. Find all file-level definitions and rename references to them via aliased or qualified
253    ///    imports.
254    /// 2. Find all duplicates among file-level definitions and rename them to avoid conflicts.
255    /// 3. Remove all imports.
256    /// 4. Remove all pragmas except for the ones in the target file.
257    /// 5. Remove all license identifiers except for the one in the target file.
258    pub fn flatten(self) -> String {
259        let mut updates = Updates::new();
260
261        self.append_filenames(&mut updates);
262        let top_level_names = self.rename_top_level_definitions(&mut updates);
263        self.rename_contract_level_types_references(&top_level_names, &mut updates);
264        self.remove_qualified_imports(&mut updates);
265        self.update_inheritdocs(&top_level_names, &mut updates);
266
267        self.remove_imports(&mut updates);
268        let target_pragmas = self.process_pragmas(&mut updates);
269        let target_license = self.process_licenses(&mut updates);
270
271        self.flatten_result(updates, target_pragmas, target_license).get_flattened_target()
272    }
273
274    fn flatten_result(
275        self,
276        updates: Updates,
277        target_pragmas: Vec<String>,
278        target_license: Option<String>,
279    ) -> FlatteningResult {
280        FlatteningResult::new(self, updates, target_pragmas, target_license)
281    }
282
283    /// Appends a comment with the file name to the beginning of each source.
284    fn append_filenames(&self, updates: &mut Updates) {
285        for path in &self.ordered_sources {
286            updates.entry(path.clone()).or_default().insert((
287                0,
288                0,
289                format!("// {}\n", path.strip_prefix(&self.project_root).unwrap_or(path).display()),
290            ));
291        }
292    }
293
294    /// Finds and goes over all references to file-level definitions and updates them to match
295    /// definition name. This is needed for two reasons:
296    /// 1. We want to rename all aliased or qualified imports.
297    /// 2. We want to find any duplicates and rename them to avoid conflicts.
298    ///
299    /// If we find more than 1 declaration with the same name, it's name is getting changed.
300    /// Two Counter contracts will be renamed to Counter_0 and Counter_1
301    ///
302    /// Returns mapping from top-level declaration id to its name (possibly updated)
303    fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap<usize, String> {
304        let top_level_definitions = self.collect_top_level_definitions();
305        let references = self.collect_references();
306
307        let mut top_level_names = HashMap::new();
308
309        for (name, ids) in top_level_definitions {
310            let mut definition_name = name.clone();
311            let needs_rename = ids.len() > 1;
312
313            let mut ids = ids.clone().into_iter().collect::<Vec<_>>();
314            if needs_rename {
315                // `loc.path` is expected to be different for each id because there can't be 2
316                // top-level declarations with the same name in the same file.
317                //
318                // Sorting by index loc.path and loc.start in sorted files to make the renaming
319                // process deterministic.
320                ids.sort_by_key(|(_, loc)| {
321                    (self.ordered_sources.iter().position(|p| p == &loc.path).unwrap(), loc.start)
322                });
323            }
324            for (i, (id, loc)) in ids.iter().enumerate() {
325                if needs_rename {
326                    definition_name = format!("{name}_{i}");
327                }
328                updates.entry(loc.path.clone()).or_default().insert((
329                    loc.start,
330                    loc.end,
331                    definition_name.clone(),
332                ));
333                if let Some(references) = references.get(&(*id as isize)) {
334                    for loc in references {
335                        updates.entry(loc.path.clone()).or_default().insert((
336                            loc.start,
337                            loc.end,
338                            definition_name.clone(),
339                        ));
340                    }
341                }
342
343                top_level_names.insert(*id, definition_name.clone());
344            }
345        }
346        top_level_names
347    }
348
349    /// This is not very clean, but in most cases effective enough method to remove qualified
350    /// imports from sources.
351    ///
352    /// Every qualified import part is an `Identifier` with `referencedDeclaration` field matching
353    /// ID of one of the import directives.
354    ///
355    /// This approach works by firstly collecting all IDs of import directives, and then looks for
356    /// any references of them. Once the reference is found, it's full length is getting removed
357    /// from source + 1 character ('.')
358    ///
359    /// This should work correctly for vast majority of cases, however there are situations for
360    /// which such approach won't work, most of which are related to code being formatted in an
361    /// uncommon way.
362    fn remove_qualified_imports(&self, updates: &mut Updates) {
363        let imports_ids = self
364            .asts
365            .iter()
366            .flat_map(|(_, ast)| {
367                ast.nodes.iter().filter_map(|node| match node {
368                    SourceUnitPart::ImportDirective(directive) => Some(directive.id),
369                    _ => None,
370                })
371            })
372            .collect::<HashSet<_>>();
373
374        let references = self.collect_references();
375
376        for (id, locs) in references {
377            if !imports_ids.contains(&(id as usize)) {
378                continue;
379            }
380
381            for loc in locs {
382                updates.entry(loc.path).or_default().insert((
383                    loc.start,
384                    loc.end + 1,
385                    String::new(),
386                ));
387            }
388        }
389    }
390
391    /// Here we are going through all references to items defined in scope of contracts and updating
392    /// them to be using correct parent contract name.
393    ///
394    /// This will only operate on references from `IdentifierPath` nodes.
395    fn rename_contract_level_types_references(
396        &self,
397        top_level_names: &HashMap<usize, String>,
398        updates: &mut Updates,
399    ) {
400        let contract_level_definitions = self.collect_contract_level_definitions();
401
402        for (path, ast) in &self.asts {
403            for node in &ast.nodes {
404                let mut collector =
405                    ReferencesCollector { path: self.target.clone(), references: HashMap::new() };
406
407                node.walk(&mut collector);
408
409                let references = collector.references;
410
411                for (id, locs) in references {
412                    if let Some((name, contract_id)) =
413                        contract_level_definitions.get(&(id as usize))
414                    {
415                        for loc in &locs {
416                            // If child item is referenced directly by it's name it's either defined
417                            // in the same contract or in one of it's base contracts, so we don't
418                            // have to change anything.
419                            // Comparing lengths is enough because such items cannot be aliased.
420                            if loc.length() == name.len() {
421                                continue;
422                            }
423                            // If it was referenced somehow else, we rename it to `Parent.Child`
424                            // format.
425                            let parent_name = top_level_names.get(contract_id).unwrap();
426                            updates.entry(path.clone()).or_default().insert((
427                                loc.start,
428                                loc.end,
429                                format!("{parent_name}.{name}"),
430                            ));
431                        }
432                    }
433                }
434            }
435        }
436    }
437
438    /// Finds all @inheritdoc tags in natspec comments and tries replacing them.
439    ///
440    /// We will either replace contract name or remove @inheritdoc tag completely to avoid
441    /// generating invalid source code.
442    fn update_inheritdocs(&self, top_level_names: &HashMap<usize, String>, updates: &mut Updates) {
443        trace!("updating @inheritdoc tags");
444        for (path, ast) in &self.asts {
445            // Collect all exported symbols for this source unit
446            // @inheritdoc value is either one of those or qualified import path which we don't
447            // support
448            let exported_symbols = ast
449                .exported_symbols
450                .iter()
451                .filter(|(_, ids)| !ids.is_empty())
452                .map(|(name, ids)| (name.as_str(), ids[0]))
453                .collect::<HashMap<_, _>>();
454
455            // Collect all docs in all contracts
456            let docs = ast
457                .nodes
458                .iter()
459                .filter_map(|node| match node {
460                    SourceUnitPart::ContractDefinition(d) => Some(d),
461                    _ => None,
462                })
463                .flat_map(|contract| {
464                    contract.nodes.iter().filter_map(|node| match node {
465                        ContractDefinitionPart::EventDefinition(event) => {
466                            event.documentation.as_ref()
467                        }
468                        ContractDefinitionPart::ErrorDefinition(error) => {
469                            error.documentation.as_ref()
470                        }
471                        ContractDefinitionPart::FunctionDefinition(func) => {
472                            func.documentation.as_ref()
473                        }
474                        ContractDefinitionPart::VariableDeclaration(var) => {
475                            var.documentation.as_ref()
476                        }
477                        _ => None,
478                    })
479                });
480
481            docs.for_each(|doc| {
482                let Documentation::Structured(doc) = doc else {
483                    return
484                };
485                let src_start = doc.src.start.unwrap();
486                let src_end = src_start + doc.src.length.unwrap();
487
488                // Documentation node has `text` field, however, it does not contain
489                // slashes and we can't use if to find positions.
490                let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end];
491                let tag_len = "@inheritdoc".len();
492
493                if let Some(tag_start) = content.find("@inheritdoc") {
494                    trace!("processing doc with content {:?}", content);
495                    if let Some(name_start) = content[tag_start + tag_len..]
496                        .find(|c| c != ' ')
497                        .map(|p| p + tag_start + tag_len)
498                    {
499                        let name_end = content[name_start..]
500                            .find([' ', '\n', '*', '/'])
501                            .map(|p| p + name_start)
502                            .unwrap_or(content.len());
503
504                        let name = &content[name_start..name_end];
505                        trace!("found name {name}");
506
507                        let mut new_name = None;
508
509                        if let Some(ast_id) = exported_symbols.get(name) {
510                            if let Some(name) = top_level_names.get(ast_id) {
511                                new_name = Some(name);
512                            } else {
513                                trace!(identifiers=?top_level_names, "ast id {ast_id} cannot be matched to top-level identifier");
514                            }
515                        }
516
517                        if let Some(new_name) = new_name {
518                            trace!("updating tag value with {new_name}");
519                            updates.entry(path.clone()).or_default().insert((
520                                src_start + name_start,
521                                src_start + name_end,
522                                new_name.clone(),
523                            ));
524                        } else {
525                            trace!("name is unknown, removing @inheritdoc tag");
526                            updates.entry(path.clone()).or_default().insert((
527                                src_start + tag_start,
528                                src_start + name_end,
529                                String::new(),
530                            ));
531                        }
532                    }
533                }
534            });
535        }
536    }
537
538    /// Processes all ASTs and collects all top-level definitions in the form of
539    /// a mapping from name to (definition id, source location)
540    fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> {
541        self.asts
542            .iter()
543            .flat_map(|(path, ast)| {
544                ast.nodes
545                    .iter()
546                    .filter_map(|node| match node {
547                        SourceUnitPart::ContractDefinition(contract) => Some((
548                            &contract.name,
549                            contract.id,
550                            &contract.src,
551                            &contract.name_location,
552                        )),
553                        SourceUnitPart::EnumDefinition(enum_) => {
554                            Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location))
555                        }
556                        SourceUnitPart::StructDefinition(struct_) => {
557                            Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location))
558                        }
559                        SourceUnitPart::FunctionDefinition(func) => {
560                            Some((&func.name, func.id, &func.src, &func.name_location))
561                        }
562                        SourceUnitPart::VariableDeclaration(var) => {
563                            Some((&var.name, var.id, &var.src, &var.name_location))
564                        }
565                        SourceUnitPart::UserDefinedValueTypeDefinition(type_) => {
566                            Some((&type_.name, type_.id, &type_.src, &type_.name_location))
567                        }
568                        _ => None,
569                    })
570                    .map(|(name, id, src, maybe_name_src)| {
571                        let loc = match maybe_name_src {
572                            Some(src) => {
573                                ItemLocation::try_from_source_loc(src, path.clone()).unwrap()
574                            }
575                            None => {
576                                // Find location of name in source
577                                let content: &str = &self.sources.get(path).unwrap().content;
578                                let start = src.start.unwrap();
579                                let end = start + src.length.unwrap();
580
581                                let name_start = content[start..end].find(name).unwrap();
582                                let name_end = name_start + name.len();
583
584                                ItemLocation {
585                                    path: path.clone(),
586                                    start: start + name_start,
587                                    end: start + name_end,
588                                }
589                            }
590                        };
591
592                        (name, (id, loc))
593                    })
594            })
595            .fold(HashMap::new(), |mut acc, (name, (id, item_location))| {
596                acc.entry(name).or_default().insert((id, item_location));
597                acc
598            })
599    }
600
601    /// Collect all contract-level definitions in the form of a mapping from definition id to
602    /// (definition name, contract id)
603    fn collect_contract_level_definitions(&self) -> HashMap<usize, (&String, usize)> {
604        self.asts
605            .iter()
606            .flat_map(|(_, ast)| {
607                ast.nodes.iter().filter_map(|node| match node {
608                    SourceUnitPart::ContractDefinition(contract) => {
609                        Some((contract.id, &contract.nodes))
610                    }
611                    _ => None,
612                })
613            })
614            .flat_map(|(contract_id, nodes)| {
615                nodes.iter().filter_map(move |node| match node {
616                    ContractDefinitionPart::EnumDefinition(enum_) => {
617                        Some((enum_.id, (&enum_.name, contract_id)))
618                    }
619                    ContractDefinitionPart::ErrorDefinition(error) => {
620                        Some((error.id, (&error.name, contract_id)))
621                    }
622                    ContractDefinitionPart::EventDefinition(event) => {
623                        Some((event.id, (&event.name, contract_id)))
624                    }
625                    ContractDefinitionPart::StructDefinition(struct_) => {
626                        Some((struct_.id, (&struct_.name, contract_id)))
627                    }
628                    ContractDefinitionPart::FunctionDefinition(function) => {
629                        Some((function.id, (&function.name, contract_id)))
630                    }
631                    ContractDefinitionPart::VariableDeclaration(variable) => {
632                        Some((variable.id, (&variable.name, contract_id)))
633                    }
634                    ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => {
635                        Some((value_type.id, (&value_type.name, contract_id)))
636                    }
637                    _ => None,
638                })
639            })
640            .collect()
641    }
642
643    /// Collects all references to any declaration in the form of a mapping from declaration id to
644    /// set of source locations it appears in
645    fn collect_references(&self) -> HashMap<isize, HashSet<ItemLocation>> {
646        self.asts
647            .iter()
648            .flat_map(|(path, ast)| {
649                let mut collector =
650                    ReferencesCollector { path: path.clone(), references: HashMap::new() };
651                ast.walk(&mut collector);
652                collector.references
653            })
654            .fold(HashMap::new(), |mut acc, (id, locs)| {
655                acc.entry(id).or_default().extend(locs);
656                acc
657            })
658    }
659
660    /// Removes all imports from all sources.
661    fn remove_imports(&self, updates: &mut Updates) {
662        for loc in self.collect_imports() {
663            updates.entry(loc.path.clone()).or_default().insert((
664                loc.start,
665                loc.end,
666                String::new(),
667            ));
668        }
669    }
670
671    // Collects all imports locations.
672    fn collect_imports(&self) -> HashSet<ItemLocation> {
673        self.asts
674            .iter()
675            .flat_map(|(path, ast)| {
676                ast.nodes.iter().filter_map(|node| match node {
677                    SourceUnitPart::ImportDirective(import) => {
678                        ItemLocation::try_from_source_loc(&import.src, path.clone())
679                    }
680                    _ => None,
681                })
682            })
683            .collect()
684    }
685
686    /// Removes all pragma directives from all sources. Returns Vec with experimental and combined
687    /// version pragmas (if present).
688    fn process_pragmas(&self, updates: &mut Updates) -> Vec<String> {
689        let mut abicoder_v2 = None;
690
691        let pragmas = self.collect_pragmas();
692        let mut version_pragmas = Vec::new();
693
694        for loc in &pragmas {
695            let pragma_content = self.read_location(loc);
696            if pragma_content.contains("experimental") || pragma_content.contains("abicoder") {
697                if abicoder_v2.is_none() {
698                    abicoder_v2 = Some(self.read_location(loc).to_string());
699                }
700            } else if pragma_content.contains("solidity") {
701                version_pragmas.push(pragma_content);
702            }
703
704            updates.entry(loc.path.clone()).or_default().insert((
705                loc.start,
706                loc.end,
707                String::new(),
708            ));
709        }
710
711        let mut pragmas = Vec::new();
712
713        if let Some(version_pragma) = combine_version_pragmas(&version_pragmas) {
714            pragmas.push(version_pragma);
715        }
716
717        if let Some(pragma) = abicoder_v2 {
718            pragmas.push(pragma);
719        }
720
721        pragmas
722    }
723
724    // Collects all pragma directives locations.
725    fn collect_pragmas(&self) -> HashSet<ItemLocation> {
726        self.asts
727            .iter()
728            .flat_map(|(path, ast)| {
729                ast.nodes.iter().filter_map(|node| match node {
730                    SourceUnitPart::PragmaDirective(import) => {
731                        ItemLocation::try_from_source_loc(&import.src, path.clone())
732                    }
733                    _ => None,
734                })
735            })
736            .collect()
737    }
738
739    /// Removes all license identifiers from all sources. Returns license identifier from target
740    /// file, if any.
741    fn process_licenses(&self, updates: &mut Updates) -> Option<String> {
742        let mut target_license = None;
743
744        for loc in &self.collect_licenses() {
745            if loc.path == self.target {
746                let license_line = self.read_location(loc);
747                let license_start = license_line.find("SPDX-License-Identifier:").unwrap();
748                target_license = Some(license_line[license_start..].trim().to_string());
749            }
750            updates.entry(loc.path.clone()).or_default().insert((
751                loc.start,
752                loc.end,
753                String::new(),
754            ));
755        }
756
757        target_license
758    }
759
760    // Collects all SPDX-License-Identifier locations.
761    fn collect_licenses(&self) -> HashSet<ItemLocation> {
762        self.sources
763            .iter()
764            .flat_map(|(path, source)| {
765                let mut licenses = HashSet::new();
766                if let Some(license_start) = source.content.find("SPDX-License-Identifier:") {
767                    let start =
768                        source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0);
769                    let end = start
770                        + source.content[start..]
771                            .find('\n')
772                            .unwrap_or(source.content.len() - start);
773                    licenses.insert(ItemLocation { path: path.clone(), start, end });
774                }
775                licenses
776            })
777            .collect()
778    }
779
780    // Reads value from the given location of a source file.
781    fn read_location(&self, loc: &ItemLocation) -> &str {
782        let content: &str = &self.sources.get(&loc.path).unwrap().content;
783        &content[loc.start..loc.end]
784    }
785}
786
787/// Performs DFS to collect all dependencies of a target
788fn collect_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
789    path: &Path,
790    paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
791    graph: &Graph<P>,
792    deps: &mut HashSet<PathBuf>,
793) -> Result<()> {
794    if deps.insert(path.to_path_buf()) {
795        let target_dir = path.parent().ok_or_else(|| {
796            SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display()))
797        })?;
798
799        let node_id = graph
800            .files()
801            .get(path)
802            .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?;
803
804        if let Some(data) = graph.node(*node_id).data.sol_data() {
805            for import in &data.imports {
806                let path = paths.resolve_import(target_dir, import.data().path())?;
807                collect_deps(&path, paths, graph, deps)?;
808            }
809        }
810    }
811    Ok(())
812}
813
814/// We want to make order in which sources are written to resulted flattened file
815/// deterministic.
816///
817/// We can't just sort files alphabetically as it might break compilation, because Solidity
818/// does not allow base class definitions to appear after derived contract
819/// definitions.
820///
821/// Instead, we sort files by the number of their dependencies (imports of any depth) in ascending
822/// order. If files have the same number of dependencies, we sort them alphabetically.
823/// Target file is always placed last.
824pub fn collect_ordered_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
825    path: &Path,
826    paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
827    graph: &Graph<P>,
828) -> Result<Vec<PathBuf>> {
829    let mut deps = HashSet::new();
830    collect_deps(path, paths, graph, &mut deps)?;
831
832    // Remove path prior counting dependencies
833    // It will be added later to the end of resulted Vec
834    deps.remove(path);
835
836    let mut paths_with_deps_count = Vec::new();
837    for path in deps {
838        let mut path_deps = HashSet::new();
839        collect_deps(&path, paths, graph, &mut path_deps)?;
840        paths_with_deps_count.push((path_deps.len(), path));
841    }
842
843    paths_with_deps_count.sort_by(|(count_0, path_0), (count_1, path_1)| {
844        // Compare dependency counts
845        match count_0.cmp(count_1) {
846            o if !o.is_eq() => return o,
847            _ => {}
848        };
849
850        // Try comparing file names
851        if let Some((name_0, name_1)) = path_0.file_name().zip(path_1.file_name()) {
852            match name_0.cmp(name_1) {
853                o if !o.is_eq() => return o,
854                _ => {}
855            }
856        }
857
858        // If both filenames and dependency counts are equal, fallback to comparing file paths
859        path_0.cmp(path_1)
860    });
861
862    let mut ordered_deps =
863        paths_with_deps_count.into_iter().map(|(_, path)| path).collect::<Vec<_>>();
864
865    ordered_deps.push(path.to_path_buf());
866
867    Ok(ordered_deps)
868}
869
870pub fn combine_version_pragmas(pragmas: &[impl AsRef<str>]) -> Option<String> {
871    let versions = pragmas
872        .iter()
873        .map(AsRef::as_ref)
874        .filter_map(SolData::parse_version_pragma)
875        .filter_map(Result::ok)
876        .flat_map(|req| req.comparators)
877        .map(|comp| comp.to_string())
878        .collect::<BTreeSet<_>>();
879    if versions.is_empty() {
880        return None;
881    }
882    Some(format!("pragma solidity {};", versions.iter().format(" ")))
883}