1use crate::{
2 apply_updates,
3 compilers::{Compiler, ParsedSource},
4 filter::MaybeSolData,
5 resolver::parse::SolData,
6 ArtifactOutput, CompilerSettings, Graph, Project, ProjectPathsConfig, SourceParser, Updates,
7};
8use foundry_compilers_artifacts::{
9 ast::{visitor::Visitor, *},
10 output_selection::OutputSelection,
11 solc::ExternalInlineAssemblyReference,
12 sources::{Source, Sources},
13 ContractDefinitionPart, SourceUnit, SourceUnitPart,
14};
15use foundry_compilers_core::{
16 error::{Result, SolcError},
17 utils,
18};
19use itertools::Itertools;
20use std::{
21 collections::{BTreeSet, HashMap, HashSet},
22 hash::Hash,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use visitor::Walk;
27
28#[derive(Clone, Debug, PartialEq, Eq, Hash)]
30struct ItemLocation {
31 path: PathBuf,
32 start: usize,
33 end: usize,
34}
35
36impl ItemLocation {
37 fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option<Self> {
38 let start = src.start?;
39 let end = start + src.length?;
40
41 Some(Self { path, start, end })
42 }
43
44 fn length(&self) -> usize {
45 self.end - self.start
46 }
47}
48
49struct ReferencesCollector {
59 path: PathBuf,
60 references: HashMap<isize, HashSet<ItemLocation>>,
61}
62
63impl ReferencesCollector {
64 fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) {
65 if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) {
66 self.references.entry(id).or_default().insert(loc);
67 }
68 }
69}
70
71impl Visitor for ReferencesCollector {
72 fn visit_identifier(&mut self, identifier: &Identifier) {
73 if let Some(id) = identifier.referenced_declaration {
74 self.process_referenced_declaration(id, &identifier.src);
75 }
76 }
77
78 fn visit_identifier_path(&mut self, path: &IdentifierPath) {
79 self.process_referenced_declaration(path.referenced_declaration, &path.src);
80 }
81
82 fn visit_member_access(&mut self, access: &MemberAccess) {
83 if let Some(referenced_declaration) = access.referenced_declaration {
84 if let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length) {
85 let name_length = access.member_name.len();
86 let start = src_start + src_length - name_length;
88 let end = start + name_length;
89
90 self.references.entry(referenced_declaration).or_default().insert(ItemLocation {
91 start,
92 end,
93 path: self.path.to_path_buf(),
94 });
95 }
96 }
97 }
98
99 fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) {
100 let mut src = reference.src;
101
102 if let Some(suffix) = &reference.suffix {
105 if let Some(len) = src.length.as_mut() {
106 let suffix_len = suffix.to_string().len();
107 *len -= suffix_len + 1;
108 }
109 }
110
111 self.process_referenced_declaration(reference.declaration as isize, &src);
112 }
113}
114
115pub struct FlatteningResult {
116 sources: Vec<String>,
118 pragmas: Vec<String>,
120 license: Option<String>,
122}
123
124impl FlatteningResult {
125 fn new(
126 mut flattener: Flattener,
127 updates: Updates,
128 pragmas: Vec<String>,
129 license: Option<String>,
130 ) -> Self {
131 apply_updates(&mut flattener.sources, updates);
132
133 let sources = flattener
134 .ordered_sources
135 .iter()
136 .map(|path| flattener.sources.remove(path).unwrap().content)
137 .map(Arc::unwrap_or_clone)
138 .collect();
139
140 Self { sources, pragmas, license }
141 }
142
143 fn get_flattened_target(&self) -> String {
144 let mut result = String::new();
145
146 if let Some(license) = &self.license {
147 result.push_str(&format!("// {license}\n"));
148 }
149 for pragma in &self.pragmas {
150 result.push_str(&format!("{pragma}\n"));
151 }
152 for source in &self.sources {
153 result.push_str(&format!("\n\n{source}"));
154 }
155
156 format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())
157 }
158}
159
160#[derive(Debug, thiserror::Error)]
161pub enum FlattenerError {
162 #[error("Failed to compile {0}")]
163 Compilation(SolcError),
164 #[error(transparent)]
165 Other(SolcError),
166}
167
168impl<T: Into<SolcError>> From<T> for FlattenerError {
169 fn from(err: T) -> Self {
170 Self::Other(err.into())
171 }
172}
173
174pub struct Flattener {
176 target: PathBuf,
178 sources: Sources,
180 asts: Vec<(PathBuf, SourceUnit)>,
182 ordered_sources: Vec<PathBuf>,
184 project_root: PathBuf,
186}
187
188impl Flattener {
189 pub fn new<C: Compiler, T: ArtifactOutput<CompilerContract = C::CompilerContract>>(
191 mut project: Project<C, T>,
192 target: &Path,
193 ) -> std::result::Result<Self, FlattenerError>
194 where
195 C::Parser: SourceParser<ParsedSource: MaybeSolData>,
196 {
197 project.cached = false;
199 project.no_artifacts = true;
200 project.settings.update_output_selection(|selection| {
201 *selection = OutputSelection::ast_output_selection();
202 });
203
204 let output = project.compile_file(target).map_err(FlattenerError::Compilation)?;
205
206 if output.has_compiler_errors() {
207 return Err(FlattenerError::Compilation(SolcError::msg(&output)));
208 }
209
210 let output = output.compiler_output;
211
212 let sources = Source::read_all([target.to_path_buf()])?;
213 let graph = Graph::<C::Parser>::resolve_sources(&project.paths, sources)?;
214
215 let ordered_sources = collect_ordered_deps(target, &project.paths, &graph)?;
216
217 #[cfg(windows)]
218 let ordered_sources = {
219 let mut sources = ordered_sources;
220 use path_slash::PathBufExt;
221 for p in &mut sources {
222 *p = PathBuf::from(p.to_slash_lossy().to_string());
223 }
224 sources
225 };
226
227 let sources = Source::read_all(&ordered_sources)?;
228
229 let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new();
231 for (path, ast) in output.sources.0.iter().filter_map(|(path, files)| {
232 if let Some(ast) = files.first().and_then(|source| source.source_file.ast.as_ref()) {
233 if sources.contains_key(path) {
234 return Some((path, ast));
235 }
236 }
237 None
238 }) {
239 asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?));
240 }
241
242 Ok(Self {
243 target: target.into(),
244 sources,
245 asts,
246 ordered_sources,
247 project_root: project.root().to_path_buf(),
248 })
249 }
250
251 pub fn flatten(self) -> String {
261 let mut updates = Updates::new();
262
263 self.append_filenames(&mut updates);
264 let top_level_names = self.rename_top_level_definitions(&mut updates);
265 self.rename_contract_level_types_references(&top_level_names, &mut updates);
266 self.remove_qualified_imports(&mut updates);
267 self.update_inheritdocs(&top_level_names, &mut updates);
268
269 self.remove_imports(&mut updates);
270 let target_pragmas = self.process_pragmas(&mut updates);
271 let target_license = self.process_licenses(&mut updates);
272
273 self.flatten_result(updates, target_pragmas, target_license).get_flattened_target()
274 }
275
276 fn flatten_result(
277 self,
278 updates: Updates,
279 target_pragmas: Vec<String>,
280 target_license: Option<String>,
281 ) -> FlatteningResult {
282 FlatteningResult::new(self, updates, target_pragmas, target_license)
283 }
284
285 fn append_filenames(&self, updates: &mut Updates) {
287 for path in &self.ordered_sources {
288 updates.entry(path.clone()).or_default().insert((
289 0,
290 0,
291 format!("// {}\n", path.strip_prefix(&self.project_root).unwrap_or(path).display()),
292 ));
293 }
294 }
295
296 fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap<usize, String> {
306 let top_level_definitions = self.collect_top_level_definitions();
307 let references = self.collect_references();
308
309 let mut top_level_names = HashMap::new();
310
311 for (name, ids) in top_level_definitions {
312 let mut definition_name = name.to_string();
313 let needs_rename = ids.len() > 1;
314
315 let mut ids = ids.clone().into_iter().collect::<Vec<_>>();
316 if needs_rename {
317 ids.sort_by_key(|(_, loc)| {
323 (self.ordered_sources.iter().position(|p| p == &loc.path).unwrap(), loc.start)
324 });
325 }
326 for (i, (id, loc)) in ids.iter().enumerate() {
327 if needs_rename {
328 definition_name = format!("{name}_{i}");
329 }
330 updates.entry(loc.path.clone()).or_default().insert((
331 loc.start,
332 loc.end,
333 definition_name.clone(),
334 ));
335 if let Some(references) = references.get(&(*id as isize)) {
336 for loc in references {
337 updates.entry(loc.path.clone()).or_default().insert((
338 loc.start,
339 loc.end,
340 definition_name.clone(),
341 ));
342 }
343 }
344
345 top_level_names.insert(*id, definition_name.clone());
346 }
347 }
348 top_level_names
349 }
350
351 fn remove_qualified_imports(&self, updates: &mut Updates) {
365 let imports_ids = self
366 .asts
367 .iter()
368 .flat_map(|(_, ast)| {
369 ast.nodes.iter().filter_map(|node| match node {
370 SourceUnitPart::ImportDirective(directive) => Some(directive.id),
371 _ => None,
372 })
373 })
374 .collect::<HashSet<_>>();
375
376 let references = self.collect_references();
377
378 for (id, locs) in references {
379 if !imports_ids.contains(&(id as usize)) {
380 continue;
381 }
382
383 for loc in locs {
384 updates.entry(loc.path).or_default().insert((
385 loc.start,
386 loc.end + 1,
387 String::new(),
388 ));
389 }
390 }
391 }
392
393 fn rename_contract_level_types_references(
398 &self,
399 top_level_names: &HashMap<usize, String>,
400 updates: &mut Updates,
401 ) {
402 let contract_level_definitions = self.collect_contract_level_definitions();
403
404 for (path, ast) in &self.asts {
405 for node in &ast.nodes {
406 let mut collector =
407 ReferencesCollector { path: self.target.clone(), references: HashMap::new() };
408
409 node.walk(&mut collector);
410
411 let references = collector.references;
412
413 for (id, locs) in references {
414 if let Some((name, contract_id)) =
415 contract_level_definitions.get(&(id as usize))
416 {
417 for loc in &locs {
418 if loc.length() == name.len() {
423 continue;
424 }
425 let parent_name = top_level_names.get(contract_id).unwrap();
428 updates.entry(path.clone()).or_default().insert((
429 loc.start,
430 loc.end,
431 format!("{parent_name}.{name}"),
432 ));
433 }
434 }
435 }
436 }
437 }
438 }
439
440 fn update_inheritdocs(&self, top_level_names: &HashMap<usize, String>, updates: &mut Updates) {
445 trace!("updating @inheritdoc tags");
446 for (path, ast) in &self.asts {
447 let exported_symbols = ast
451 .exported_symbols
452 .iter()
453 .filter_map(
454 |(name, ids)| {
455 if !ids.is_empty() {
456 Some((name.as_str(), ids[0]))
457 } else {
458 None
459 }
460 },
461 )
462 .collect::<HashMap<_, _>>();
463
464 let docs = ast
466 .nodes
467 .iter()
468 .filter_map(|node| match node {
469 SourceUnitPart::ContractDefinition(d) => Some(d),
470 _ => None,
471 })
472 .flat_map(|contract| {
473 contract.nodes.iter().filter_map(|node| match node {
474 ContractDefinitionPart::EventDefinition(event) => {
475 event.documentation.as_ref()
476 }
477 ContractDefinitionPart::ErrorDefinition(error) => {
478 error.documentation.as_ref()
479 }
480 ContractDefinitionPart::FunctionDefinition(func) => {
481 func.documentation.as_ref()
482 }
483 ContractDefinitionPart::VariableDeclaration(var) => {
484 var.documentation.as_ref()
485 }
486 _ => None,
487 })
488 });
489
490 docs.for_each(|doc| {
491 let Documentation::Structured(doc) = doc else {
492 return
493 };
494 let src_start = doc.src.start.unwrap();
495 let src_end = src_start + doc.src.length.unwrap();
496
497 let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end];
500 let tag_len = "@inheritdoc".len();
501
502 if let Some(tag_start) = content.find("@inheritdoc") {
503 trace!("processing doc with content {:?}", content);
504 if let Some(name_start) = content[tag_start + tag_len..]
505 .find(|c| c != ' ')
506 .map(|p| p + tag_start + tag_len)
507 {
508 let name_end = content[name_start..]
509 .find([' ', '\n', '*', '/'])
510 .map(|p| p + name_start)
511 .unwrap_or(content.len());
512
513 let name = &content[name_start..name_end];
514 trace!("found name {name}");
515
516 let mut new_name = None;
517
518 if let Some(ast_id) = exported_symbols.get(name) {
519 if let Some(name) = top_level_names.get(ast_id) {
520 new_name = Some(name);
521 } else {
522 trace!(identifiers=?top_level_names, "ast id {ast_id} cannot be matched to top-level identifier");
523 }
524 }
525
526 if let Some(new_name) = new_name {
527 trace!("updating tag value with {new_name}");
528 updates.entry(path.to_path_buf()).or_default().insert((
529 src_start + name_start,
530 src_start + name_end,
531 new_name.to_string(),
532 ));
533 } else {
534 trace!("name is unknown, removing @inheritdoc tag");
535 updates.entry(path.to_path_buf()).or_default().insert((
536 src_start + tag_start,
537 src_start + name_end,
538 String::new(),
539 ));
540 }
541 }
542 }
543 });
544 }
545 }
546
547 fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> {
550 self.asts
551 .iter()
552 .flat_map(|(path, ast)| {
553 ast.nodes
554 .iter()
555 .filter_map(|node| match node {
556 SourceUnitPart::ContractDefinition(contract) => Some((
557 &contract.name,
558 contract.id,
559 &contract.src,
560 &contract.name_location,
561 )),
562 SourceUnitPart::EnumDefinition(enum_) => {
563 Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location))
564 }
565 SourceUnitPart::StructDefinition(struct_) => {
566 Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location))
567 }
568 SourceUnitPart::FunctionDefinition(func) => {
569 Some((&func.name, func.id, &func.src, &func.name_location))
570 }
571 SourceUnitPart::VariableDeclaration(var) => {
572 Some((&var.name, var.id, &var.src, &var.name_location))
573 }
574 SourceUnitPart::UserDefinedValueTypeDefinition(type_) => {
575 Some((&type_.name, type_.id, &type_.src, &type_.name_location))
576 }
577 _ => None,
578 })
579 .map(|(name, id, src, maybe_name_src)| {
580 let loc = match maybe_name_src {
581 Some(src) => {
582 ItemLocation::try_from_source_loc(src, path.clone()).unwrap()
583 }
584 None => {
585 let content: &str = &self.sources.get(path).unwrap().content;
587 let start = src.start.unwrap();
588 let end = start + src.length.unwrap();
589
590 let name_start = content[start..end].find(name).unwrap();
591 let name_end = name_start + name.len();
592
593 ItemLocation {
594 path: path.clone(),
595 start: start + name_start,
596 end: start + name_end,
597 }
598 }
599 };
600
601 (name, (id, loc))
602 })
603 })
604 .fold(HashMap::new(), |mut acc, (name, (id, item_location))| {
605 acc.entry(name).or_default().insert((id, item_location));
606 acc
607 })
608 }
609
610 fn collect_contract_level_definitions(&self) -> HashMap<usize, (&String, usize)> {
613 self.asts
614 .iter()
615 .flat_map(|(_, ast)| {
616 ast.nodes.iter().filter_map(|node| match node {
617 SourceUnitPart::ContractDefinition(contract) => {
618 Some((contract.id, &contract.nodes))
619 }
620 _ => None,
621 })
622 })
623 .flat_map(|(contract_id, nodes)| {
624 nodes.iter().filter_map(move |node| match node {
625 ContractDefinitionPart::EnumDefinition(enum_) => {
626 Some((enum_.id, (&enum_.name, contract_id)))
627 }
628 ContractDefinitionPart::ErrorDefinition(error) => {
629 Some((error.id, (&error.name, contract_id)))
630 }
631 ContractDefinitionPart::EventDefinition(event) => {
632 Some((event.id, (&event.name, contract_id)))
633 }
634 ContractDefinitionPart::StructDefinition(struct_) => {
635 Some((struct_.id, (&struct_.name, contract_id)))
636 }
637 ContractDefinitionPart::FunctionDefinition(function) => {
638 Some((function.id, (&function.name, contract_id)))
639 }
640 ContractDefinitionPart::VariableDeclaration(variable) => {
641 Some((variable.id, (&variable.name, contract_id)))
642 }
643 ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => {
644 Some((value_type.id, (&value_type.name, contract_id)))
645 }
646 _ => None,
647 })
648 })
649 .collect()
650 }
651
652 fn collect_references(&self) -> HashMap<isize, HashSet<ItemLocation>> {
655 self.asts
656 .iter()
657 .flat_map(|(path, ast)| {
658 let mut collector =
659 ReferencesCollector { path: path.clone(), references: HashMap::new() };
660 ast.walk(&mut collector);
661 collector.references
662 })
663 .fold(HashMap::new(), |mut acc, (id, locs)| {
664 acc.entry(id).or_default().extend(locs);
665 acc
666 })
667 }
668
669 fn remove_imports(&self, updates: &mut Updates) {
671 for loc in self.collect_imports() {
672 updates.entry(loc.path.clone()).or_default().insert((
673 loc.start,
674 loc.end,
675 String::new(),
676 ));
677 }
678 }
679
680 fn collect_imports(&self) -> HashSet<ItemLocation> {
682 self.asts
683 .iter()
684 .flat_map(|(path, ast)| {
685 ast.nodes.iter().filter_map(|node| match node {
686 SourceUnitPart::ImportDirective(import) => {
687 ItemLocation::try_from_source_loc(&import.src, path.clone())
688 }
689 _ => None,
690 })
691 })
692 .collect()
693 }
694
695 fn process_pragmas(&self, updates: &mut Updates) -> Vec<String> {
698 let mut abicoder_v2 = None;
699
700 let pragmas = self.collect_pragmas();
701 let mut version_pragmas = Vec::new();
702
703 for loc in &pragmas {
704 let pragma_content = self.read_location(loc);
705 if pragma_content.contains("experimental") || pragma_content.contains("abicoder") {
706 if abicoder_v2.is_none() {
707 abicoder_v2 = Some(self.read_location(loc).to_string());
708 }
709 } else if pragma_content.contains("solidity") {
710 version_pragmas.push(pragma_content);
711 }
712
713 updates.entry(loc.path.clone()).or_default().insert((
714 loc.start,
715 loc.end,
716 String::new(),
717 ));
718 }
719
720 let mut pragmas = Vec::new();
721
722 if let Some(version_pragma) = combine_version_pragmas(&version_pragmas) {
723 pragmas.push(version_pragma);
724 }
725
726 if let Some(pragma) = abicoder_v2 {
727 pragmas.push(pragma);
728 }
729
730 pragmas
731 }
732
733 fn collect_pragmas(&self) -> HashSet<ItemLocation> {
735 self.asts
736 .iter()
737 .flat_map(|(path, ast)| {
738 ast.nodes.iter().filter_map(|node| match node {
739 SourceUnitPart::PragmaDirective(import) => {
740 ItemLocation::try_from_source_loc(&import.src, path.clone())
741 }
742 _ => None,
743 })
744 })
745 .collect()
746 }
747
748 fn process_licenses(&self, updates: &mut Updates) -> Option<String> {
751 let mut target_license = None;
752
753 for loc in &self.collect_licenses() {
754 if loc.path == self.target {
755 let license_line = self.read_location(loc);
756 let license_start = license_line.find("SPDX-License-Identifier:").unwrap();
757 target_license = Some(license_line[license_start..].trim().to_string());
758 }
759 updates.entry(loc.path.clone()).or_default().insert((
760 loc.start,
761 loc.end,
762 String::new(),
763 ));
764 }
765
766 target_license
767 }
768
769 fn collect_licenses(&self) -> HashSet<ItemLocation> {
771 self.sources
772 .iter()
773 .flat_map(|(path, source)| {
774 let mut licenses = HashSet::new();
775 if let Some(license_start) = source.content.find("SPDX-License-Identifier:") {
776 let start =
777 source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0);
778 let end = start
779 + source.content[start..]
780 .find('\n')
781 .unwrap_or(source.content.len() - start);
782 licenses.insert(ItemLocation { path: path.clone(), start, end });
783 }
784 licenses
785 })
786 .collect()
787 }
788
789 fn read_location(&self, loc: &ItemLocation) -> &str {
791 let content: &str = &self.sources.get(&loc.path).unwrap().content;
792 &content[loc.start..loc.end]
793 }
794}
795
796fn collect_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
798 path: &Path,
799 paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
800 graph: &Graph<P>,
801 deps: &mut HashSet<PathBuf>,
802) -> Result<()> {
803 if deps.insert(path.to_path_buf()) {
804 let target_dir = path.parent().ok_or_else(|| {
805 SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display()))
806 })?;
807
808 let node_id = graph
809 .files()
810 .get(path)
811 .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?;
812
813 if let Some(data) = graph.node(*node_id).data.sol_data() {
814 for import in &data.imports {
815 let path = paths.resolve_import(target_dir, import.data().path())?;
816 collect_deps(&path, paths, graph, deps)?;
817 }
818 }
819 }
820 Ok(())
821}
822
823pub fn collect_ordered_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
834 path: &Path,
835 paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
836 graph: &Graph<P>,
837) -> Result<Vec<PathBuf>> {
838 let mut deps = HashSet::new();
839 collect_deps(path, paths, graph, &mut deps)?;
840
841 deps.remove(path);
844
845 let mut paths_with_deps_count = Vec::new();
846 for path in deps {
847 let mut path_deps = HashSet::new();
848 collect_deps(&path, paths, graph, &mut path_deps)?;
849 paths_with_deps_count.push((path_deps.len(), path));
850 }
851
852 paths_with_deps_count.sort_by(|(count_0, path_0), (count_1, path_1)| {
853 match count_0.cmp(count_1) {
855 o if !o.is_eq() => return o,
856 _ => {}
857 };
858
859 if let Some((name_0, name_1)) = path_0.file_name().zip(path_1.file_name()) {
861 match name_0.cmp(name_1) {
862 o if !o.is_eq() => return o,
863 _ => {}
864 }
865 }
866
867 path_0.cmp(path_1)
869 });
870
871 let mut ordered_deps =
872 paths_with_deps_count.into_iter().map(|(_, path)| path).collect::<Vec<_>>();
873
874 ordered_deps.push(path.to_path_buf());
875
876 Ok(ordered_deps)
877}
878
879pub fn combine_version_pragmas(pragmas: &[impl AsRef<str>]) -> Option<String> {
880 let versions = pragmas
881 .iter()
882 .map(AsRef::as_ref)
883 .filter_map(SolData::parse_version_pragma)
884 .filter_map(Result::ok)
885 .flat_map(|req| req.comparators)
886 .map(|comp| comp.to_string())
887 .collect::<BTreeSet<_>>();
888 if versions.is_empty() {
889 return None;
890 }
891 Some(format!("pragma solidity {};", versions.iter().format(" ")))
892}