use crate::Diagnostic;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use rustpython_compiler_core::{Mode, bytecode::CodeObject, frozen};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
use syn::{
self, LitByteStr, LitStr, Macro,
parse::{ParseStream, Parser, Result as ParseResult},
spanned::Spanned,
};
enum CompilationSourceKind {
File { base: PathBuf, rel_path: PathBuf },
SourceCode(String),
Dir { base: PathBuf, rel_path: PathBuf },
}
struct CompiledModule {
code: CodeObject,
package: bool,
}
struct CompilationSource {
kind: CompilationSourceKind,
span: (Span, Span),
}
pub trait Compiler {
fn compile(
&self,
source: &str,
mode: Mode,
module_name: String,
) -> Result<CodeObject, Box<dyn core::error::Error>>;
}
impl CompilationSource {
fn compile_string<D: core::fmt::Display, F: FnOnce() -> D>(
&self,
source: &str,
mode: Mode,
module_name: String,
compiler: &dyn Compiler,
origin: F,
) -> Result<CodeObject, Diagnostic> {
compiler.compile(source, mode, module_name).map_err(|err| {
Diagnostic::spans_error(
self.span,
format!("Python compile error from {}: {}", origin(), err),
)
})
}
fn compile(
&self,
mode: Mode,
module_name: String,
compiler: &dyn Compiler,
) -> Result<HashMap<String, CompiledModule>, Diagnostic> {
match &self.kind {
CompilationSourceKind::Dir { base, rel_path } => {
self.compile_dir(base, &base.join(rel_path), String::new(), mode, compiler)
}
_ => Ok(hashmap! {
module_name.clone() => CompiledModule {
code: self.compile_single(mode, module_name, compiler)?,
package: false,
},
}),
}
}
fn compile_single(
&self,
mode: Mode,
module_name: String,
compiler: &dyn Compiler,
) -> Result<CodeObject, Diagnostic> {
match &self.kind {
CompilationSourceKind::File { base, rel_path } => {
let path = base.join(rel_path);
let source = fs::read_to_string(&path).map_err(|err| {
Diagnostic::spans_error(
self.span,
format!("Error reading file {path:?}: {err}"),
)
})?;
self.compile_string(&source, mode, module_name, compiler, || rel_path.display())
}
CompilationSourceKind::SourceCode(code) => self.compile_string(
&textwrap::dedent(code),
mode,
module_name,
compiler,
|| "string literal",
),
CompilationSourceKind::Dir { .. } => {
unreachable!("Can't use compile_single with directory source")
}
}
}
fn compile_dir(
&self,
base: &Path,
path: &Path,
parent: String,
mode: Mode,
compiler: &dyn Compiler,
) -> Result<HashMap<String, CompiledModule>, Diagnostic> {
let mut code_map = HashMap::new();
let paths = fs::read_dir(path)
.or_else(|e| {
if cfg!(windows)
&& let Ok(real_path) = fs::read_to_string(path.canonicalize().unwrap())
{
return fs::read_dir(real_path.trim());
}
Err(e)
})
.map_err(|err| {
Diagnostic::spans_error(self.span, format!("Error listing dir {path:?}: {err}"))
})?;
for path in paths {
let path = path.map_err(|err| {
Diagnostic::spans_error(self.span, format!("Failed to list file: {err}"))
})?;
let path = path.path();
let file_name = path.file_name().unwrap().to_str().ok_or_else(|| {
Diagnostic::spans_error(self.span, format!("Invalid UTF-8 in file name {path:?}"))
})?;
if path.is_dir() {
code_map.extend(self.compile_dir(
base,
&path,
if parent.is_empty() {
file_name.to_string()
} else {
format!("{parent}.{file_name}")
},
mode,
compiler,
)?);
} else if file_name.ends_with(".py") {
let stem = path.file_stem().unwrap().to_str().unwrap();
let is_init = stem == "__init__";
let module_name = if is_init {
parent.clone()
} else if parent.is_empty() {
stem.to_owned()
} else {
format!("{parent}.{stem}")
};
let compile_path = |src_path: &Path| {
let source = fs::read_to_string(src_path).map_err(|err| {
Diagnostic::spans_error(
self.span,
format!("Error reading file {path:?}: {err}"),
)
})?;
self.compile_string(&source, mode, module_name.clone(), compiler, || {
path.strip_prefix(base).ok().unwrap_or(&path).display()
})
};
let code = compile_path(&path).or_else(|e| {
if cfg!(windows)
&& let Ok(real_path) = fs::read_to_string(path.canonicalize().unwrap())
{
let joined = path.parent().unwrap().join(real_path.trim());
if joined.exists() {
return compile_path(&joined);
} else {
return Err(e);
}
}
Err(e)
});
let code = match code {
Ok(code) => code,
Err(_)
if stem.starts_with("badsyntax_")
| parent.ends_with(".encoded_modules") =>
{
continue;
}
Err(e) => return Err(e),
};
code_map.insert(
module_name,
CompiledModule {
code,
package: is_init,
},
);
}
}
Ok(code_map)
}
}
impl PyCompileArgs {
fn parse(input: TokenStream, allow_dir: bool) -> Result<Self, Diagnostic> {
let mut module_name = None;
let mut mode = None;
let mut source: Option<CompilationSource> = None;
let mut crate_name = None;
fn assert_source_empty(source: &Option<CompilationSource>) -> Result<(), syn::Error> {
if let Some(source) = source {
Err(syn::Error::new(
source.span.0,
"Cannot have more than one source",
))
} else {
Ok(())
}
}
syn::meta::parser(|meta| {
let ident = meta
.path
.get_ident()
.ok_or_else(|| meta.error("unknown arg"))?;
let check_str = || meta.value()?.call(parse_str);
let str_path = || {
let s = check_str()?;
let mut base_path = s
.span()
.unwrap()
.local_file()
.ok_or_else(|| err_span!(s, "filepath literal has no span information"))?;
base_path.pop();
Ok::<_, syn::Error>((base_path, PathBuf::from(s.value())))
};
if ident == "mode" {
let s = check_str()?;
match s.value().parse() {
Ok(mode_val) => mode = Some(mode_val),
Err(e) => bail_span!(s, "{}", e),
}
} else if ident == "module_name" {
module_name = Some(check_str()?.value())
} else if ident == "source" {
assert_source_empty(&source)?;
let code = check_str()?.value();
source = Some(CompilationSource {
kind: CompilationSourceKind::SourceCode(code),
span: (ident.span(), meta.input.cursor().span()),
});
} else if ident == "file" {
assert_source_empty(&source)?;
let (base, rel_path) = str_path()?;
source = Some(CompilationSource {
kind: CompilationSourceKind::File { base, rel_path },
span: (ident.span(), meta.input.cursor().span()),
});
} else if ident == "dir" {
if !allow_dir {
bail_span!(ident, "py_compile doesn't accept dir")
}
assert_source_empty(&source)?;
let (base, rel_path) = str_path()?;
source = Some(CompilationSource {
kind: CompilationSourceKind::Dir { base, rel_path },
span: (ident.span(), meta.input.cursor().span()),
});
} else if ident == "crate_name" {
let name = check_str()?.parse()?;
crate_name = Some(name);
} else {
return Err(meta.error("unknown attr"));
}
Ok(())
})
.parse2(input)?;
let source = source.ok_or_else(|| {
syn::Error::new(
Span::call_site(),
"Must have either file or source in py_compile!()/py_freeze!()",
)
})?;
Ok(Self {
source,
mode: mode.unwrap_or(Mode::Exec),
module_name: module_name.unwrap_or_else(|| "frozen".to_owned()),
crate_name: crate_name.unwrap_or_else(|| syn::parse_quote!(::rustpython_vm)),
})
}
}
fn parse_str(input: ParseStream<'_>) -> ParseResult<LitStr> {
let span = input.span();
if input.peek(LitStr) {
input.parse()
} else if let Ok(mac) = input.parse::<Macro>() {
Ok(LitStr::new(&mac.tokens.to_string(), mac.span()))
} else {
Err(syn::Error::new(span, "Expected string or stringify macro"))
}
}
struct PyCompileArgs {
source: CompilationSource,
mode: Mode,
module_name: String,
crate_name: syn::Path,
}
pub fn impl_py_compile(
input: TokenStream,
compiler: &dyn Compiler,
) -> Result<TokenStream, Diagnostic> {
let args = PyCompileArgs::parse(input, false)?;
let crate_name = args.crate_name;
let code = args
.source
.compile_single(args.mode, args.module_name, compiler)?;
let frozen = frozen::FrozenCodeObject::encode(&code);
let bytes = LitByteStr::new(&frozen.bytes, Span::call_site());
let output = quote! {
#crate_name::frozen::FrozenCodeObject { bytes: &#bytes[..] }
};
Ok(output)
}
pub fn impl_py_freeze(
input: TokenStream,
compiler: &dyn Compiler,
) -> Result<TokenStream, Diagnostic> {
let args = PyCompileArgs::parse(input, true)?;
let crate_name = args.crate_name;
let code_map = args.source.compile(args.mode, args.module_name, compiler)?;
let data = frozen::FrozenLib::encode(code_map.iter().map(|(k, v)| {
let v = frozen::FrozenModule {
code: frozen::FrozenCodeObject::encode(&v.code),
package: v.package,
};
(&**k, v)
}));
let bytes = LitByteStr::new(&data.bytes, Span::call_site());
let output = quote! {
#crate_name::frozen::FrozenLib::from_ref(#bytes)
};
Ok(output)
}