use crate::{
indent::{IndentConfig, IndentedWriter},
CodeGeneratorConfig, Encoding,
};
use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
use std::{
collections::{BTreeMap, HashMap},
io::{Result, Write},
path::PathBuf,
};
pub struct CodeGenerator<'a> {
config: &'a CodeGeneratorConfig,
serde_package_name: Option<String>,
external_qualified_names: HashMap<String, String>,
}
struct PythonEmitter<'a, T> {
out: IndentedWriter<T>,
generator: &'a CodeGenerator<'a>,
current_namespace: Vec<String>,
}
impl<'a> CodeGenerator<'a> {
pub fn new(config: &'a CodeGeneratorConfig) -> Self {
if config.c_style_enums {
panic!("Python 3 does not support generating c-style enums");
}
let mut external_qualified_names = HashMap::new();
for (module_path, names) in &config.external_definitions {
let module = {
let mut path = module_path.split('.').collect::<Vec<_>>();
if path.len() < 2 {
module_path
} else {
path.pop().unwrap()
}
};
for name in names {
external_qualified_names.insert(name.to_string(), format!("{}.{}", module, name));
}
}
Self {
config,
serde_package_name: None,
external_qualified_names,
}
}
pub fn with_serde_package_name(mut self, serde_package_name: Option<String>) -> Self {
self.serde_package_name = serde_package_name;
self
}
pub fn output(&self, out: &mut dyn Write, registry: &Registry) -> Result<()> {
let current_namespace = self
.config
.module_name
.split('.')
.map(String::from)
.collect();
let mut emitter = PythonEmitter {
out: IndentedWriter::new(out, IndentConfig::Space(4)),
generator: self,
current_namespace,
};
emitter.output_preamble()?;
for (name, format) in registry {
emitter.output_container(name, format)?;
}
Ok(())
}
}
impl<'a, T> PythonEmitter<'a, T>
where
T: Write,
{
fn quote_import(&self, module: &str) -> String {
let mut parts = module.split('.').collect::<Vec<_>>();
if parts.len() <= 1 {
format!("import {}", module)
} else {
let module_name = parts.pop().unwrap();
format!("from {} import {}", parts.join("."), module_name)
}
}
fn output_preamble(&mut self) -> Result<()> {
let from_serde_package = match &self.generator.serde_package_name {
None => "".to_string(),
Some(name) => format!("from {} ", name),
};
writeln!(
self.out,
r#"# pyre-strict
from dataclasses import dataclass
import typing
{}import serde_types as st"#,
from_serde_package,
)?;
for encoding in &self.generator.config.encodings {
writeln!(self.out, "{}import {}", from_serde_package, encoding.name())?;
}
for module in self.generator.config.external_definitions.keys() {
writeln!(self.out, "{}\n", self.quote_import(module))?;
}
Ok(())
}
fn quote_qualified_name(&self, name: &str) -> String {
self.generator
.external_qualified_names
.get(name)
.cloned()
.unwrap_or_else(|| {
format!("\"{}\"", name.to_string())
})
}
fn quote_type(&self, format: &Format) -> String {
use Format::*;
match format {
TypeName(x) => self.quote_qualified_name(x),
Unit => "st.unit".into(),
Bool => "bool".into(),
I8 => "st.int8".into(),
I16 => "st.int16".into(),
I32 => "st.int32".into(),
I64 => "st.int64".into(),
I128 => "st.int128".into(),
U8 => "st.uint8".into(),
U16 => "st.uint16".into(),
U32 => "st.uint32".into(),
U64 => "st.uint64".into(),
U128 => "st.uint128".into(),
F32 => "st.float32".into(),
F64 => "st.float64".into(),
Char => "st.char".into(),
Str => "str".into(),
Bytes => "bytes".into(),
Option(format) => format!("typing.Optional[{}]", self.quote_type(format)),
Seq(format) => format!("typing.Sequence[{}]", self.quote_type(format)),
Map { key, value } => format!(
"typing.Dict[{}, {}]",
self.quote_type(key),
self.quote_type(value)
),
Tuple(formats) => {
if formats.is_empty() {
"typing.Tuple[()]".into()
} else {
format!("typing.Tuple[{}]", self.quote_types(formats))
}
}
TupleArray { content, size } => format!(
"typing.Tuple[{}]",
self.quote_types(&vec![content.as_ref().clone(); *size])
),
Variable(_) => panic!("unexpected value"),
}
}
fn quote_types(&self, formats: &[Format]) -> String {
formats
.iter()
.map(|x| self.quote_type(x))
.collect::<Vec<_>>()
.join(", ")
}
fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
let mut path = self.current_namespace.clone();
path.push(name.to_string());
if let Some(doc) = self.generator.config.comments.get(&path) {
writeln!(self.out, "\"\"\"{}\"\"\"", doc)?;
}
Ok(())
}
fn output_custom_code(&mut self) -> std::io::Result<bool> {
match self
.generator
.config
.custom_code
.get(&self.current_namespace)
{
Some(code) => {
writeln!(self.out, "\n{}", code)?;
Ok(true)
}
None => Ok(false),
}
}
fn output_fields(&mut self, fields: &[Named<Format>]) -> Result<()> {
if fields.is_empty() {
writeln!(self.out, "pass")?;
return Ok(());
}
for field in fields {
writeln!(
self.out,
"{}: {}",
field.name,
self.quote_type(&field.value)
)?;
}
Ok(())
}
fn output_variant(
&mut self,
base: &str,
name: &str,
index: u32,
variant: &VariantFormat,
) -> Result<()> {
use VariantFormat::*;
let fields = match variant {
Unit => Vec::new(),
NewType(format) => vec![Named {
name: "value".to_string(),
value: format.as_ref().clone(),
}],
Tuple(formats) => vec![Named {
name: "value".to_string(),
value: Format::Tuple(formats.clone()),
}],
Struct(fields) => fields.clone(),
Variable(_) => panic!("incorrect value"),
};
writeln!(
self.out,
"\n@dataclass(frozen=True)\nclass {0}__{1}({0}):",
base, name
)?;
self.out.indent();
self.output_comment(name)?;
if self.generator.config.serialization {
writeln!(self.out, "INDEX = {} # type: int", index)?;
}
self.current_namespace.push(name.to_string());
self.output_fields(&fields)?;
self.output_custom_code()?;
self.current_namespace.pop();
self.out.unindent();
writeln!(self.out)
}
fn output_enum_container(
&mut self,
name: &str,
variants: &BTreeMap<u32, Named<VariantFormat>>,
) -> Result<()> {
writeln!(self.out, "\nclass {}:", name)?;
self.out.indent();
self.output_comment(name)?;
self.current_namespace.push(name.to_string());
if self.generator.config.serialization {
writeln!(
self.out,
"VARIANTS = [] # type: typing.Sequence[typing.Type[{}]]",
name
)?;
for encoding in &self.generator.config.encodings {
self.output_serialize_method_for_encoding(name, *encoding)?;
self.output_deserialize_method_for_encoding(name, *encoding)?;
}
}
let wrote_custom_code = self.output_custom_code()?;
if !self.generator.config.serialization && !wrote_custom_code {
writeln!(self.out, "pass")?;
}
writeln!(self.out)?;
self.out.unindent();
for (index, variant) in variants {
self.output_variant(name, &variant.name, *index, &variant.value)?;
}
self.current_namespace.pop();
if self.generator.config.serialization {
writeln!(
self.out,
"{}.VARIANTS = [\n{}]\n",
name,
variants
.iter()
.map(|(_, v)| format!(" {}__{},\n", name, v.name))
.collect::<Vec<_>>()
.join("")
)?;
}
Ok(())
}
fn output_serialize_method_for_encoding(
&mut self,
name: &str,
encoding: Encoding,
) -> Result<()> {
writeln!(
self.out,
r#"
def {0}_serialize(self) -> bytes:
return {0}.serialize(self, {1})"#,
encoding.name(),
name
)
}
fn output_deserialize_method_for_encoding(
&mut self,
name: &str,
encoding: Encoding,
) -> Result<()> {
writeln!(
self.out,
r#"
@staticmethod
def {0}_deserialize(input: bytes) -> '{1}':
v, buffer = {0}.deserialize(input, {1})
if buffer:
raise st.DeserializationError("Some input bytes were not read");
return v"#,
encoding.name(),
name
)
}
fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
use ContainerFormat::*;
let fields = match format {
UnitStruct => Vec::new(),
NewTypeStruct(format) => vec![Named {
name: "value".to_string(),
value: format.as_ref().clone(),
}],
TupleStruct(formats) => vec![Named {
name: "value".to_string(),
value: Format::Tuple(formats.clone()),
}],
Struct(fields) => fields.clone(),
Enum(variants) => {
self.output_enum_container(name, variants)?;
return Ok(());
}
};
writeln!(self.out, "\n@dataclass(frozen=True)\nclass {}:", name)?;
self.out.indent();
self.output_comment(name)?;
self.current_namespace.push(name.to_string());
self.output_fields(&fields)?;
for encoding in &self.generator.config.encodings {
self.output_serialize_method_for_encoding(name, *encoding)?;
self.output_deserialize_method_for_encoding(name, *encoding)?;
}
self.output_custom_code()?;
self.current_namespace.pop();
self.out.unindent();
writeln!(self.out)
}
}
pub struct Installer {
install_dir: PathBuf,
serde_package_name: Option<String>,
}
impl Installer {
pub fn new(install_dir: PathBuf, serde_package_name: Option<String>) -> Self {
Installer {
install_dir,
serde_package_name,
}
}
fn create_module_init_file(&self, name: &str) -> Result<std::fs::File> {
let dir_path = self.install_dir.join(name);
std::fs::create_dir_all(&dir_path)?;
std::fs::File::create(dir_path.join("__init__.py"))
}
fn fix_serde_package(&self, content: &str) -> String {
match &self.serde_package_name {
None => content.into(),
Some(name) => content
.replace(
"import serde_types",
&format!("from {} import serde_types", name),
)
.replace(
"import serde_binary",
&format!("from {} import serde_binary", name),
),
}
}
}
impl crate::SourceInstaller for Installer {
type Error = Box<dyn std::error::Error>;
fn install_module(
&self,
config: &crate::CodeGeneratorConfig,
registry: &Registry,
) -> std::result::Result<(), Self::Error> {
let mut file = self.create_module_init_file(&config.module_name)?;
let generator =
CodeGenerator::new(config).with_serde_package_name(self.serde_package_name.clone());
generator.output(&mut file, registry)?;
Ok(())
}
fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
let mut file = self.create_module_init_file("serde_types")?;
write!(
file,
"{}",
self.fix_serde_package(include_str!("../runtime/python/serde_types/__init__.py"))
)?;
let mut file = self.create_module_init_file("serde_binary")?;
write!(
file,
"{}",
self.fix_serde_package(include_str!("../runtime/python/serde_binary/__init__.py"))
)?;
Ok(())
}
fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
let mut file = self.create_module_init_file("bincode")?;
write!(
file,
"{}",
self.fix_serde_package(include_str!("../runtime/python/bincode/__init__.py"))
)?;
Ok(())
}
fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
let mut file = self.create_module_init_file("bcs")?;
write!(
file,
"{}",
self.fix_serde_package(include_str!("../runtime/python/bcs/__init__.py"))
)?;
Ok(())
}
}