use std::collections::{HashMap, HashSet};
use quote::ToTokens;
use syn::{
parse::Parser,
punctuated::Punctuated,
spanned::Spanned,
visit_mut::{self, VisitMut},
Expr, Ident, Item, Pat, Path, PathArguments, PathSegment, Token, UseTree,
};
use crate::{
into_wir::{Error, Errors},
util::extract_path_ident,
wir::WSpan,
};
pub fn extract_use_map(items: &mut [Item]) -> Result<HashMap<Ident, Path>, Errors> {
let mut use_map = HashMap::new();
let mut use_path_vec = Vec::new();
for item in items.iter_mut() {
let Item::Use(item_use) = item else {
continue;
};
let use_prefix = Path {
leading_colon: item_use.leading_colon,
segments: Punctuated::new(),
};
recurse_use_tree(&mut use_map, &mut use_path_vec, &item_use.tree, use_prefix)?;
}
let mut errors: Vec<Result<(), Error>> = Vec::new();
for use_path in use_path_vec {
let Some(first_segment) = use_path.segments.first() else {
panic!("Unexpected zero-segment path");
};
if first_segment.ident != "machine_check" && first_segment.ident != "std" {
errors.push(Err(Error::unsupported_construct(
"Using paths not starting with 'machine_check' or 'std'",
WSpan::from_syn(&use_path),
)));
}
for segment in use_path.segments.iter() {
if segment.ident == "self" || segment.ident == "super" {
errors.push(Err(Error::unsupported_construct(
"Use path segment 'self' or 'super'",
WSpan::from_syn(&segment.ident),
)));
}
}
}
for use_path in use_map.values_mut() {
if use_path.leading_colon.is_some() {
continue;
}
let Some(first_segment) = use_path.segments.first_mut() else {
panic!("Unexpected zero-segment path");
};
use_path.leading_colon = Some(Token));
}
Errors::vec_result(errors)?;
Ok(use_map)
}
pub fn resolve_use_items(items: &mut [Item], use_map: &HashMap<Ident, Path>) -> Result<(), Errors> {
let mut visitor = Visitor {
result: Ok(()),
use_map,
local_scopes_idents: Vec::new(),
};
for item in items.iter_mut() {
visitor.visit_item_mut(item);
}
assert!(visitor.local_scopes_idents.is_empty());
visitor.result.map_err(Errors::single)
}
pub fn resolve_use_expr(expr: &mut Expr, use_map: &HashMap<Ident, Path>) -> Result<(), Errors> {
let mut visitor = Visitor {
result: Ok(()),
use_map,
local_scopes_idents: Vec::new(),
};
visitor.visit_expr_mut(expr);
assert!(visitor.local_scopes_idents.is_empty());
visitor.result.map_err(Errors::single)
}
pub fn remove_use(items: &mut Vec<Item>) -> Result<(), Error> {
items.retain(|item| !matches!(item, Item::Use(_)));
Ok(())
}
fn recurse_use_tree(
use_map: &mut HashMap<Ident, Path>,
use_path_vec: &mut Vec<Path>,
use_tree: &UseTree,
mut use_prefix: Path,
) -> Result<(), Error> {
let use_ident = match use_tree {
UseTree::Path(use_path) => {
use_prefix.segments.push(PathSegment {
ident: use_path.ident.clone(),
arguments: PathArguments::None,
});
recurse_use_tree(use_map, use_path_vec, &use_path.tree, use_prefix)?;
return Ok(());
}
UseTree::Group(use_group) => {
for item in &use_group.items {
recurse_use_tree(use_map, use_path_vec, item, use_prefix.clone())?;
}
return Ok(());
}
UseTree::Name(use_name) => {
use_prefix.segments.push(PathSegment {
ident: use_name.ident.clone(),
arguments: PathArguments::None,
});
&use_name.ident
}
UseTree::Rename(use_rename) => {
use_prefix.segments.push(PathSegment {
ident: use_rename.ident.clone(),
arguments: PathArguments::None,
});
&use_rename.rename
}
UseTree::Glob(use_glob) => {
return Err(Error::unsupported_syn_construct("Wildcard use", &use_glob));
}
};
if let Some(_previous) = use_map.insert(use_ident.clone(), use_prefix.clone()) {
Err(Error::unsupported_syn_construct(
"Duplicate use declaration",
&use_ident,
))
} else {
use_path_vec.push(use_prefix);
Ok(())
}
}
struct Visitor<'a> {
result: Result<(), Error>,
use_map: &'a HashMap<Ident, Path>,
local_scopes_idents: Vec<HashSet<Ident>>,
}
impl VisitMut for Visitor<'_> {
fn visit_path_mut(&mut self, path: &mut Path) {
if let Some(path_ident) = extract_path_ident(path) {
for local_scope in self.local_scopes_idents.iter() {
if local_scope.contains(path_ident) {
return;
}
}
}
let mut used_idents = HashSet::new();
loop {
if path.leading_colon.is_some() {
break;
}
let path_span = path.span();
let first_segment = path
.segments
.first_mut()
.expect("Path should have at least one segment");
let first_ident = first_segment.ident.clone();
let Some(use_path) = self.use_map.get(&first_ident) else {
break;
};
if used_idents.contains(&first_ident) {
break;
}
used_idents.insert(first_ident);
let mut leading_segments = use_path.segments.clone();
for leading_segment in leading_segments.iter_mut() {
leading_segment.ident = Ident::new(&leading_segment.ident.to_string(), path_span);
}
let last_use_path_segment = leading_segments
.pop()
.expect("Use path should have at least one segment")
.into_value();
first_segment.ident = last_use_path_segment.ident.clone();
let mut trailing_segments = Punctuated::new();
std::mem::swap(&mut path.segments, &mut trailing_segments);
path.segments =
Punctuated::from_iter(leading_segments.into_iter().chain(trailing_segments));
if use_path.leading_colon.is_some() {
path.leading_colon = Some(Token);
}
}
visit_mut::visit_path_mut(self, path);
}
fn visit_attribute_mut(&mut self, attr: &mut syn::Attribute) {
if let syn::Meta::List(meta_list) = &mut attr.meta {
if meta_list.path.is_ident("derive") {
let parser = Punctuated::<Path, Token![,]>::parse_terminated;
if let Ok(mut punctuated) = parser.parse2(meta_list.tokens.clone()) {
for path in punctuated.iter_mut() {
self.visit_path_mut(path);
}
meta_list.tokens = punctuated.to_token_stream();
}
}
}
visit_mut::visit_attribute_mut(self, attr);
}
fn visit_block_mut(&mut self, block: &mut syn::Block) {
self.local_scopes_idents.push(HashSet::new());
visit_mut::visit_block_mut(self, block);
assert!(self.local_scopes_idents.pop().is_some())
}
fn visit_local_mut(&mut self, local: &mut syn::Local) {
let mut local_pat = &local.pat;
if let Pat::Type(pat_type) = local_pat {
local_pat = &pat_type.pat;
}
let Pat::Ident(local_pat) = local_pat else {
if self.result.is_ok() {
self.result = Err(Error::unsupported_syn_construct(
"Local pattern that is not ident or typed local",
&local_pat,
));
}
visit_mut::visit_local_mut(self, local);
return;
};
self.local_scopes_idents
.last_mut()
.expect("Local should be in some scope")
.insert(local_pat.ident.clone());
visit_mut::visit_local_mut(self, local);
}
}