use proc_macro2::TokenStream;
use pyo3::FromPyObject;
use quote::{format_ident, quote};
use crate::{
CodeGen, CodeGenContext, ExprType, Name, PythonOptions, Statement, StatementType,
SymbolTableNode, SymbolTableScopes,
};
use log::debug;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, FromPyObject, Serialize, Deserialize, PartialEq)]
pub struct ClassDef {
pub name: String,
pub bases: Vec<Name>,
pub keywords: Vec<String>,
pub body: Vec<Statement>,
}
impl CodeGen for ClassDef {
type Context = CodeGenContext;
type Options = PythonOptions;
type SymbolTable = SymbolTableScopes;
fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
let mut symbols = symbols;
symbols.insert(self.name.clone(), SymbolTableNode::ClassDef(self.clone()));
symbols
}
fn to_rust(
self,
_ctx: Self::Context,
options: Self::Options,
symbols: Self::SymbolTable,
) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut streams = TokenStream::new();
let class_name = format_ident!("{}", self.name);
let visibility = if self.name.starts_with("_") && !self.name.starts_with("__") {
format_ident!("")
} else if self.name.starts_with("__") && self.name.ends_with("__") {
format_ident!("pub(crate)")
} else {
format_ident!("pub")
};
let mut bases = TokenStream::new();
if self.bases.len() > 0 {
bases.extend(quote!(:));
let base_name = format_ident!("{}", self.bases[0].id);
bases.extend(quote!(#base_name::Cls));
for base in &self.bases[1..] {
bases.extend(quote!(+));
let base_name = format_ident!("{}", base.id);
bases.extend(quote!(#base_name));
}
}
for s in self.body.clone() {
streams.extend(
s.clone()
.to_rust(CodeGenContext::Class, options.clone(), symbols.clone())
.expect(format!("Failed to parse statement {:?}", s).as_str()),
);
}
let class = if let Some(docstring) = self.get_docstring() {
let doc_lines: Vec<_> = docstring
.lines()
.map(|line| {
if line.trim().is_empty() {
quote! { #[doc = ""] }
} else {
let doc_line = format!("{}", line);
quote! { #[doc = #doc_line] }
}
})
.collect();
quote! {
#(#doc_lines)*
#visibility mod #class_name {
use super::*;
#visibility trait Cls #bases {
#streams
}
#[derive(Clone, Default)]
#visibility struct Data {
}
impl Cls for Data {}
}
}
} else {
quote! {
#visibility mod #class_name {
use super::*;
#visibility trait Cls #bases {
#streams
}
#[derive(Clone, Default)]
#visibility struct Data {
}
impl Cls for Data {}
}
}
};
debug!("class: {}", class);
Ok(class)
}
}
impl ClassDef {
fn get_docstring(&self) -> Option<String> {
if self.body.is_empty() {
return None;
}
let expr = self.body[0].clone();
match expr.statement {
StatementType::Expr(e) => match e.value {
ExprType::Constant(c) => {
let raw_string = c.to_string();
Some(self.format_docstring(&raw_string))
},
_ => None,
},
_ => None,
}
}
fn format_docstring(&self, raw: &str) -> String {
let content = raw.trim_matches('"');
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return String::new();
}
let mut formatted = vec![lines[0].trim().to_string()];
if lines.len() > 1 {
if !lines[0].trim().is_empty() && !lines[1].trim().is_empty() {
formatted.push(String::new());
}
for line in lines.iter().skip(1) {
let cleaned = line.trim();
if !cleaned.is_empty() {
formatted.push(cleaned.to_string());
}
}
}
formatted.join("\n")
}
}