1use crate::{
2 ArtifactOutput, CompilerSettings, Graph, Project, ProjectPathsConfig, SourceParser, Updates,
3 apply_updates,
4 compilers::{Compiler, ParsedSource},
5 filter::MaybeSolData,
6 resolver::parse::SolData,
7};
8use foundry_compilers_artifacts::{
9 ast::{visitor::Visitor, *},
10 output_selection::OutputSelection,
11 sources::{Source, Sources},
12};
13use foundry_compilers_core::{
14 error::{Result, SolcError},
15 utils,
16};
17use itertools::Itertools;
18use std::{
19 collections::{BTreeSet, HashMap, HashSet},
20 hash::Hash,
21 path::{Path, PathBuf},
22 sync::Arc,
23};
24use visitor::Walk;
25
26#[derive(Clone, Debug, PartialEq, Eq, Hash)]
28struct ItemLocation {
29 path: PathBuf,
30 start: usize,
31 end: usize,
32}
33
34impl ItemLocation {
35 fn try_from_source_loc(src: &SourceLocation, path: PathBuf) -> Option<Self> {
36 let start = src.start?;
37 let end = start + src.length?;
38
39 Some(Self { path, start, end })
40 }
41
42 const fn length(&self) -> usize {
43 self.end - self.start
44 }
45}
46
47struct ReferencesCollector {
57 path: PathBuf,
58 references: HashMap<isize, HashSet<ItemLocation>>,
59}
60
61impl ReferencesCollector {
62 fn process_referenced_declaration(&mut self, id: isize, src: &SourceLocation) {
63 if let Some(loc) = ItemLocation::try_from_source_loc(src, self.path.clone()) {
64 self.references.entry(id).or_default().insert(loc);
65 }
66 }
67}
68
69impl Visitor for ReferencesCollector {
70 fn visit_identifier(&mut self, identifier: &Identifier) {
71 if let Some(id) = identifier.referenced_declaration {
72 self.process_referenced_declaration(id, &identifier.src);
73 }
74 }
75
76 fn visit_identifier_path(&mut self, path: &IdentifierPath) {
77 self.process_referenced_declaration(path.referenced_declaration, &path.src);
78 }
79
80 fn visit_member_access(&mut self, access: &MemberAccess) {
81 if let Some(referenced_declaration) = access.referenced_declaration
82 && let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length)
83 {
84 let name_length = access.member_name.len();
85 let start = src_start + src_length - name_length;
87 let end = start + name_length;
88
89 self.references.entry(referenced_declaration).or_default().insert(ItemLocation {
90 start,
91 end,
92 path: self.path.clone(),
93 });
94 }
95 }
96
97 fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) {
98 let mut src = reference.src;
99
100 if let Some(suffix) = &reference.suffix
103 && let Some(len) = src.length.as_mut()
104 {
105 let suffix_len = suffix.to_string().len();
106 *len -= suffix_len + 1;
107 }
108
109 self.process_referenced_declaration(reference.declaration as isize, &src);
110 }
111}
112
113pub struct FlatteningResult {
114 sources: Vec<String>,
116 pragmas: Vec<String>,
118 license: Option<String>,
120}
121
122impl FlatteningResult {
123 fn new(
124 mut flattener: Flattener,
125 updates: Updates,
126 pragmas: Vec<String>,
127 license: Option<String>,
128 ) -> Self {
129 apply_updates(&mut flattener.sources, updates);
130
131 let sources = flattener
132 .ordered_sources
133 .iter()
134 .map(|path| flattener.sources.remove(path).unwrap().content)
135 .map(Arc::unwrap_or_clone)
136 .collect();
137
138 Self { sources, pragmas, license }
139 }
140
141 fn get_flattened_target(&self) -> String {
142 let mut result = String::new();
143
144 if let Some(license) = &self.license {
145 result.push_str(&format!("// {license}\n"));
146 }
147 for pragma in &self.pragmas {
148 result.push_str(&format!("{pragma}\n"));
149 }
150 for source in &self.sources {
151 result.push_str(&format!("\n\n{source}"));
152 }
153
154 format!("{}\n", utils::RE_THREE_OR_MORE_NEWLINES.replace_all(&result, "\n\n").trim())
155 }
156}
157
158#[derive(Debug, thiserror::Error)]
159pub enum FlattenerError {
160 #[error("Failed to compile {0}")]
161 Compilation(SolcError),
162 #[error(transparent)]
163 Other(SolcError),
164}
165
166impl<T: Into<SolcError>> From<T> for FlattenerError {
167 fn from(err: T) -> Self {
168 Self::Other(err.into())
169 }
170}
171
172pub struct Flattener {
174 target: PathBuf,
176 sources: Sources,
178 asts: Vec<(PathBuf, SourceUnit)>,
180 ordered_sources: Vec<PathBuf>,
182 project_root: PathBuf,
184}
185
186impl Flattener {
187 pub fn new<C: Compiler, T: ArtifactOutput<CompilerContract = C::CompilerContract>>(
189 mut project: Project<C, T>,
190 target: &Path,
191 ) -> std::result::Result<Self, FlattenerError>
192 where
193 C::Parser: SourceParser<ParsedSource: MaybeSolData>,
194 {
195 project.cached = false;
197 project.no_artifacts = true;
198 project.settings.update_output_selection(|selection| {
199 *selection = OutputSelection::ast_output_selection();
200 });
201
202 let output = project.compile_file(target).map_err(FlattenerError::Compilation)?;
203
204 if output.has_compiler_errors() {
205 return Err(FlattenerError::Compilation(SolcError::msg(&output)));
206 }
207
208 let output = output.compiler_output;
209
210 let sources = Source::read_all([target.to_path_buf()])?;
211 let graph = Graph::<C::Parser>::resolve_sources(&project.paths, sources)?;
212
213 let ordered_sources = collect_ordered_deps(target, &project.paths, &graph)?;
214
215 #[cfg(windows)]
216 let ordered_sources = {
217 let mut sources = ordered_sources;
218 use path_slash::PathBufExt;
219 for p in &mut sources {
220 *p = PathBuf::from(p.to_slash_lossy().to_string());
221 }
222 sources
223 };
224
225 let sources = Source::read_all(&ordered_sources)?;
226
227 let mut asts: Vec<(PathBuf, SourceUnit)> = Vec::new();
229 for (path, ast) in output.sources.0.iter().filter_map(|(path, files)| {
230 if let Some(ast) = files.first().and_then(|source| source.source_file.ast.as_ref())
231 && sources.contains_key(path)
232 {
233 return Some((path, ast));
234 }
235 None
236 }) {
237 asts.push((PathBuf::from(path), serde_json::from_str(&serde_json::to_string(ast)?)?));
238 }
239
240 Ok(Self {
241 target: target.into(),
242 sources,
243 asts,
244 ordered_sources,
245 project_root: project.root().to_path_buf(),
246 })
247 }
248
249 pub fn flatten(self) -> String {
259 let mut updates = Updates::new();
260
261 self.append_filenames(&mut updates);
262 let top_level_names = self.rename_top_level_definitions(&mut updates);
263 self.rename_contract_level_types_references(&top_level_names, &mut updates);
264 self.remove_qualified_imports(&mut updates);
265 self.update_inheritdocs(&top_level_names, &mut updates);
266
267 self.remove_imports(&mut updates);
268 let target_pragmas = self.process_pragmas(&mut updates);
269 let target_license = self.process_licenses(&mut updates);
270
271 self.flatten_result(updates, target_pragmas, target_license).get_flattened_target()
272 }
273
274 fn flatten_result(
275 self,
276 updates: Updates,
277 target_pragmas: Vec<String>,
278 target_license: Option<String>,
279 ) -> FlatteningResult {
280 FlatteningResult::new(self, updates, target_pragmas, target_license)
281 }
282
283 fn append_filenames(&self, updates: &mut Updates) {
285 for path in &self.ordered_sources {
286 updates.entry(path.clone()).or_default().insert((
287 0,
288 0,
289 format!("// {}\n", path.strip_prefix(&self.project_root).unwrap_or(path).display()),
290 ));
291 }
292 }
293
294 fn rename_top_level_definitions(&self, updates: &mut Updates) -> HashMap<usize, String> {
304 let top_level_definitions = self.collect_top_level_definitions();
305 let references = self.collect_references();
306
307 let mut top_level_names = HashMap::new();
308
309 for (name, ids) in top_level_definitions {
310 let mut definition_name = name.clone();
311 let needs_rename = ids.len() > 1;
312
313 let mut ids = ids.clone().into_iter().collect::<Vec<_>>();
314 if needs_rename {
315 ids.sort_by_key(|(_, loc)| {
321 (self.ordered_sources.iter().position(|p| p == &loc.path).unwrap(), loc.start)
322 });
323 }
324 for (i, (id, loc)) in ids.iter().enumerate() {
325 if needs_rename {
326 definition_name = format!("{name}_{i}");
327 }
328 updates.entry(loc.path.clone()).or_default().insert((
329 loc.start,
330 loc.end,
331 definition_name.clone(),
332 ));
333 if let Some(references) = references.get(&(*id as isize)) {
334 for loc in references {
335 updates.entry(loc.path.clone()).or_default().insert((
336 loc.start,
337 loc.end,
338 definition_name.clone(),
339 ));
340 }
341 }
342
343 top_level_names.insert(*id, definition_name.clone());
344 }
345 }
346 top_level_names
347 }
348
349 fn remove_qualified_imports(&self, updates: &mut Updates) {
363 let imports_ids = self
364 .asts
365 .iter()
366 .flat_map(|(_, ast)| {
367 ast.nodes.iter().filter_map(|node| match node {
368 SourceUnitPart::ImportDirective(directive) => Some(directive.id),
369 _ => None,
370 })
371 })
372 .collect::<HashSet<_>>();
373
374 let references = self.collect_references();
375
376 for (id, locs) in references {
377 if !imports_ids.contains(&(id as usize)) {
378 continue;
379 }
380
381 for loc in locs {
382 updates.entry(loc.path).or_default().insert((
383 loc.start,
384 loc.end + 1,
385 String::new(),
386 ));
387 }
388 }
389 }
390
391 fn rename_contract_level_types_references(
396 &self,
397 top_level_names: &HashMap<usize, String>,
398 updates: &mut Updates,
399 ) {
400 let contract_level_definitions = self.collect_contract_level_definitions();
401
402 for (path, ast) in &self.asts {
403 for node in &ast.nodes {
404 let mut collector =
405 ReferencesCollector { path: self.target.clone(), references: HashMap::new() };
406
407 node.walk(&mut collector);
408
409 let references = collector.references;
410
411 for (id, locs) in references {
412 if let Some((name, contract_id)) =
413 contract_level_definitions.get(&(id as usize))
414 {
415 for loc in &locs {
416 if loc.length() == name.len() {
421 continue;
422 }
423 let parent_name = top_level_names.get(contract_id).unwrap();
426 updates.entry(path.clone()).or_default().insert((
427 loc.start,
428 loc.end,
429 format!("{parent_name}.{name}"),
430 ));
431 }
432 }
433 }
434 }
435 }
436 }
437
438 fn update_inheritdocs(&self, top_level_names: &HashMap<usize, String>, updates: &mut Updates) {
443 trace!("updating @inheritdoc tags");
444 for (path, ast) in &self.asts {
445 let exported_symbols = ast
449 .exported_symbols
450 .iter()
451 .filter(|(_, ids)| !ids.is_empty())
452 .map(|(name, ids)| (name.as_str(), ids[0]))
453 .collect::<HashMap<_, _>>();
454
455 let docs = ast
457 .nodes
458 .iter()
459 .filter_map(|node| match node {
460 SourceUnitPart::ContractDefinition(d) => Some(d),
461 _ => None,
462 })
463 .flat_map(|contract| {
464 contract.nodes.iter().filter_map(|node| match node {
465 ContractDefinitionPart::EventDefinition(event) => {
466 event.documentation.as_ref()
467 }
468 ContractDefinitionPart::ErrorDefinition(error) => {
469 error.documentation.as_ref()
470 }
471 ContractDefinitionPart::FunctionDefinition(func) => {
472 func.documentation.as_ref()
473 }
474 ContractDefinitionPart::VariableDeclaration(var) => {
475 var.documentation.as_ref()
476 }
477 _ => None,
478 })
479 });
480
481 docs.for_each(|doc| {
482 let Documentation::Structured(doc) = doc else {
483 return
484 };
485 let src_start = doc.src.start.unwrap();
486 let src_end = src_start + doc.src.length.unwrap();
487
488 let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end];
491 let tag_len = "@inheritdoc".len();
492
493 if let Some(tag_start) = content.find("@inheritdoc") {
494 trace!("processing doc with content {:?}", content);
495 if let Some(name_start) = content[tag_start + tag_len..]
496 .find(|c| c != ' ')
497 .map(|p| p + tag_start + tag_len)
498 {
499 let name_end = content[name_start..]
500 .find([' ', '\n', '*', '/'])
501 .map(|p| p + name_start)
502 .unwrap_or(content.len());
503
504 let name = &content[name_start..name_end];
505 trace!("found name {name}");
506
507 let mut new_name = None;
508
509 if let Some(ast_id) = exported_symbols.get(name) {
510 if let Some(name) = top_level_names.get(ast_id) {
511 new_name = Some(name);
512 } else {
513 trace!(identifiers=?top_level_names, "ast id {ast_id} cannot be matched to top-level identifier");
514 }
515 }
516
517 if let Some(new_name) = new_name {
518 trace!("updating tag value with {new_name}");
519 updates.entry(path.clone()).or_default().insert((
520 src_start + name_start,
521 src_start + name_end,
522 new_name.clone(),
523 ));
524 } else {
525 trace!("name is unknown, removing @inheritdoc tag");
526 updates.entry(path.clone()).or_default().insert((
527 src_start + tag_start,
528 src_start + name_end,
529 String::new(),
530 ));
531 }
532 }
533 }
534 });
535 }
536 }
537
538 fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> {
541 self.asts
542 .iter()
543 .flat_map(|(path, ast)| {
544 ast.nodes
545 .iter()
546 .filter_map(|node| match node {
547 SourceUnitPart::ContractDefinition(contract) => Some((
548 &contract.name,
549 contract.id,
550 &contract.src,
551 &contract.name_location,
552 )),
553 SourceUnitPart::EnumDefinition(enum_) => {
554 Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location))
555 }
556 SourceUnitPart::StructDefinition(struct_) => {
557 Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location))
558 }
559 SourceUnitPart::FunctionDefinition(func) => {
560 Some((&func.name, func.id, &func.src, &func.name_location))
561 }
562 SourceUnitPart::VariableDeclaration(var) => {
563 Some((&var.name, var.id, &var.src, &var.name_location))
564 }
565 SourceUnitPart::UserDefinedValueTypeDefinition(type_) => {
566 Some((&type_.name, type_.id, &type_.src, &type_.name_location))
567 }
568 _ => None,
569 })
570 .map(|(name, id, src, maybe_name_src)| {
571 let loc = match maybe_name_src {
572 Some(src) => {
573 ItemLocation::try_from_source_loc(src, path.clone()).unwrap()
574 }
575 None => {
576 let content: &str = &self.sources.get(path).unwrap().content;
578 let start = src.start.unwrap();
579 let end = start + src.length.unwrap();
580
581 let name_start = content[start..end].find(name).unwrap();
582 let name_end = name_start + name.len();
583
584 ItemLocation {
585 path: path.clone(),
586 start: start + name_start,
587 end: start + name_end,
588 }
589 }
590 };
591
592 (name, (id, loc))
593 })
594 })
595 .fold(HashMap::new(), |mut acc, (name, (id, item_location))| {
596 acc.entry(name).or_default().insert((id, item_location));
597 acc
598 })
599 }
600
601 fn collect_contract_level_definitions(&self) -> HashMap<usize, (&String, usize)> {
604 self.asts
605 .iter()
606 .flat_map(|(_, ast)| {
607 ast.nodes.iter().filter_map(|node| match node {
608 SourceUnitPart::ContractDefinition(contract) => {
609 Some((contract.id, &contract.nodes))
610 }
611 _ => None,
612 })
613 })
614 .flat_map(|(contract_id, nodes)| {
615 nodes.iter().filter_map(move |node| match node {
616 ContractDefinitionPart::EnumDefinition(enum_) => {
617 Some((enum_.id, (&enum_.name, contract_id)))
618 }
619 ContractDefinitionPart::ErrorDefinition(error) => {
620 Some((error.id, (&error.name, contract_id)))
621 }
622 ContractDefinitionPart::EventDefinition(event) => {
623 Some((event.id, (&event.name, contract_id)))
624 }
625 ContractDefinitionPart::StructDefinition(struct_) => {
626 Some((struct_.id, (&struct_.name, contract_id)))
627 }
628 ContractDefinitionPart::FunctionDefinition(function) => {
629 Some((function.id, (&function.name, contract_id)))
630 }
631 ContractDefinitionPart::VariableDeclaration(variable) => {
632 Some((variable.id, (&variable.name, contract_id)))
633 }
634 ContractDefinitionPart::UserDefinedValueTypeDefinition(value_type) => {
635 Some((value_type.id, (&value_type.name, contract_id)))
636 }
637 _ => None,
638 })
639 })
640 .collect()
641 }
642
643 fn collect_references(&self) -> HashMap<isize, HashSet<ItemLocation>> {
646 self.asts
647 .iter()
648 .flat_map(|(path, ast)| {
649 let mut collector =
650 ReferencesCollector { path: path.clone(), references: HashMap::new() };
651 ast.walk(&mut collector);
652 collector.references
653 })
654 .fold(HashMap::new(), |mut acc, (id, locs)| {
655 acc.entry(id).or_default().extend(locs);
656 acc
657 })
658 }
659
660 fn remove_imports(&self, updates: &mut Updates) {
662 for loc in self.collect_imports() {
663 updates.entry(loc.path.clone()).or_default().insert((
664 loc.start,
665 loc.end,
666 String::new(),
667 ));
668 }
669 }
670
671 fn collect_imports(&self) -> HashSet<ItemLocation> {
673 self.asts
674 .iter()
675 .flat_map(|(path, ast)| {
676 ast.nodes.iter().filter_map(|node| match node {
677 SourceUnitPart::ImportDirective(import) => {
678 ItemLocation::try_from_source_loc(&import.src, path.clone())
679 }
680 _ => None,
681 })
682 })
683 .collect()
684 }
685
686 fn process_pragmas(&self, updates: &mut Updates) -> Vec<String> {
689 let mut abicoder_v2 = None;
690
691 let pragmas = self.collect_pragmas();
692 let mut version_pragmas = Vec::new();
693
694 for loc in &pragmas {
695 let pragma_content = self.read_location(loc);
696 if pragma_content.contains("experimental") || pragma_content.contains("abicoder") {
697 if abicoder_v2.is_none() {
698 abicoder_v2 = Some(self.read_location(loc).to_string());
699 }
700 } else if pragma_content.contains("solidity") {
701 version_pragmas.push(pragma_content);
702 }
703
704 updates.entry(loc.path.clone()).or_default().insert((
705 loc.start,
706 loc.end,
707 String::new(),
708 ));
709 }
710
711 let mut pragmas = Vec::new();
712
713 if let Some(version_pragma) = combine_version_pragmas(&version_pragmas) {
714 pragmas.push(version_pragma);
715 }
716
717 if let Some(pragma) = abicoder_v2 {
718 pragmas.push(pragma);
719 }
720
721 pragmas
722 }
723
724 fn collect_pragmas(&self) -> HashSet<ItemLocation> {
726 self.asts
727 .iter()
728 .flat_map(|(path, ast)| {
729 ast.nodes.iter().filter_map(|node| match node {
730 SourceUnitPart::PragmaDirective(import) => {
731 ItemLocation::try_from_source_loc(&import.src, path.clone())
732 }
733 _ => None,
734 })
735 })
736 .collect()
737 }
738
739 fn process_licenses(&self, updates: &mut Updates) -> Option<String> {
742 let mut target_license = None;
743
744 for loc in &self.collect_licenses() {
745 if loc.path == self.target {
746 let license_line = self.read_location(loc);
747 let license_start = license_line.find("SPDX-License-Identifier:").unwrap();
748 target_license = Some(license_line[license_start..].trim().to_string());
749 }
750 updates.entry(loc.path.clone()).or_default().insert((
751 loc.start,
752 loc.end,
753 String::new(),
754 ));
755 }
756
757 target_license
758 }
759
760 fn collect_licenses(&self) -> HashSet<ItemLocation> {
762 self.sources
763 .iter()
764 .flat_map(|(path, source)| {
765 let mut licenses = HashSet::new();
766 if let Some(license_start) = source.content.find("SPDX-License-Identifier:") {
767 let start =
768 source.content[..license_start].rfind('\n').map(|i| i + 1).unwrap_or(0);
769 let end = start
770 + source.content[start..]
771 .find('\n')
772 .unwrap_or(source.content.len() - start);
773 licenses.insert(ItemLocation { path: path.clone(), start, end });
774 }
775 licenses
776 })
777 .collect()
778 }
779
780 fn read_location(&self, loc: &ItemLocation) -> &str {
782 let content: &str = &self.sources.get(&loc.path).unwrap().content;
783 &content[loc.start..loc.end]
784 }
785}
786
787fn collect_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
789 path: &Path,
790 paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
791 graph: &Graph<P>,
792 deps: &mut HashSet<PathBuf>,
793) -> Result<()> {
794 if deps.insert(path.to_path_buf()) {
795 let target_dir = path.parent().ok_or_else(|| {
796 SolcError::msg(format!("failed to get parent directory for \"{}\"", path.display()))
797 })?;
798
799 let node_id = graph
800 .files()
801 .get(path)
802 .ok_or_else(|| SolcError::msg(format!("cannot resolve file at {}", path.display())))?;
803
804 if let Some(data) = graph.node(*node_id).data.sol_data() {
805 for import in &data.imports {
806 let path = paths.resolve_import(target_dir, import.data().path())?;
807 collect_deps(&path, paths, graph, deps)?;
808 }
809 }
810 }
811 Ok(())
812}
813
814pub fn collect_ordered_deps<P: SourceParser<ParsedSource: MaybeSolData>>(
825 path: &Path,
826 paths: &ProjectPathsConfig<<P::ParsedSource as ParsedSource>::Language>,
827 graph: &Graph<P>,
828) -> Result<Vec<PathBuf>> {
829 let mut deps = HashSet::new();
830 collect_deps(path, paths, graph, &mut deps)?;
831
832 deps.remove(path);
835
836 let mut paths_with_deps_count = Vec::new();
837 for path in deps {
838 let mut path_deps = HashSet::new();
839 collect_deps(&path, paths, graph, &mut path_deps)?;
840 paths_with_deps_count.push((path_deps.len(), path));
841 }
842
843 paths_with_deps_count.sort_by(|(count_0, path_0), (count_1, path_1)| {
844 match count_0.cmp(count_1) {
846 o if !o.is_eq() => return o,
847 _ => {}
848 };
849
850 if let Some((name_0, name_1)) = path_0.file_name().zip(path_1.file_name()) {
852 match name_0.cmp(name_1) {
853 o if !o.is_eq() => return o,
854 _ => {}
855 }
856 }
857
858 path_0.cmp(path_1)
860 });
861
862 let mut ordered_deps =
863 paths_with_deps_count.into_iter().map(|(_, path)| path).collect::<Vec<_>>();
864
865 ordered_deps.push(path.to_path_buf());
866
867 Ok(ordered_deps)
868}
869
870pub fn combine_version_pragmas(pragmas: &[impl AsRef<str>]) -> Option<String> {
871 let versions = pragmas
872 .iter()
873 .map(AsRef::as_ref)
874 .filter_map(SolData::parse_version_pragma)
875 .filter_map(Result::ok)
876 .flat_map(|req| req.comparators)
877 .map(|comp| comp.to_string())
878 .collect::<BTreeSet<_>>();
879 if versions.is_empty() {
880 return None;
881 }
882 Some(format!("pragma solidity {};", versions.iter().format(" ")))
883}