python_ast/ast/tree/
function_def.rs1use log::debug;
2use proc_macro2::TokenStream;
3use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
4use quote::{format_ident, quote};
5use serde::{Deserialize, Serialize};
6use crate::ast::tree::statement::PyStatementTrait;
7
8use crate::{
9 CodeGen, CodeGenContext, ExprType, Object, ParameterList, PythonOptions, Statement,
10 StatementType, SymbolTableNode, SymbolTableScopes,
11};
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
14pub struct FunctionDef {
15 pub name: String,
16 pub args: ParameterList,
17 pub body: Vec<Statement>,
18 pub decorator_list: Vec<ExprType>,
19}
20
21impl<'a> FromPyObject<'a> for FunctionDef {
22 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
23 let name: String = ob.getattr("name")?.extract()?;
24 let args: ParameterList = ob.getattr("args")?.extract()?;
25 let body: Vec<Statement> = ob.getattr("body")?.extract()?;
26
27 let decorator_list: Vec<ExprType> = ob.getattr("decorator_list")?.extract().unwrap_or_default();
29
30 Ok(FunctionDef {
31 name,
32 args,
33 body,
34 decorator_list,
35 })
36 }
37}
38
39impl PyStatementTrait for FunctionDef {
40}
41
42impl CodeGen for FunctionDef {
43 type Context = CodeGenContext;
44 type Options = PythonOptions;
45 type SymbolTable = SymbolTableScopes;
46
47 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
48 let mut symbols = symbols;
49 symbols.insert(
50 self.name.clone(),
51 SymbolTableNode::FunctionDef(self.clone()),
52 );
53 symbols
54 }
55
56 fn to_rust(
57 self,
58 ctx: Self::Context,
59 options: Self::Options,
60 symbols: SymbolTableScopes,
61 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
62 let mut streams = TokenStream::new();
63 let fn_name = format_ident!("{}", self.name);
64
65 let visibility = if self.name.starts_with("_") && !self.name.starts_with("__") {
68 quote!() } else if self.name.starts_with("__") && self.name.ends_with("__") {
70 quote!(pub(crate)) } else {
72 quote!(pub) };
74
75 let is_async = match ctx.clone() {
76 CodeGenContext::Async(_) => {
77 quote!(async)
78 }
79 _ => quote!(),
80 };
81
82 let parameters = self
83 .args
84 .clone()
85 .to_rust(ctx.clone(), options.clone(), symbols.clone())
86 .expect(format!("parsing arguments {:?}", self.args).as_str());
87
88 for s in self.body.iter() {
89 streams.extend(
90 s.clone()
91 .to_rust(ctx.clone(), options.clone(), symbols.clone())
92 .expect(format!("parsing statement {:?}", s).as_str()),
93 );
94 streams.extend(quote!(;));
95 }
96
97 let function = if let Some(docstring) = self.get_docstring() {
98 let doc_lines: Vec<_> = docstring
100 .lines()
101 .map(|line| {
102 if line.trim().is_empty() {
103 quote! { #[doc = ""] }
104 } else {
105 let doc_line = format!("{}", line);
106 quote! { #[doc = #doc_line] }
107 }
108 })
109 .collect();
110
111 quote! {
112 #(#doc_lines)*
113 #visibility #is_async fn #fn_name(#parameters) {
114 #streams
115 }
116 }
117 } else {
118 quote! {
119 #visibility #is_async fn #fn_name(#parameters) {
120 #streams
121 }
122 }
123 };
124
125 debug!("function: {}", function);
126 Ok(function)
127 }
128}
129
130impl FunctionDef {
131 fn get_docstring(&self) -> Option<String> {
132 if self.body.is_empty() {
133 return None;
134 }
135
136 let expr = self.body[0].clone();
137 match expr.statement {
138 StatementType::Expr(e) => match e.value {
139 ExprType::Constant(c) => {
140 let raw_string = c.to_string();
141 Some(self.format_docstring(&raw_string))
143 },
144 _ => None,
145 },
146 _ => None,
147 }
148 }
149
150 fn format_docstring(&self, raw: &str) -> String {
151 let content = raw.trim_matches('"');
153
154 let lines: Vec<&str> = content.lines().collect();
156 if lines.is_empty() {
157 return String::new();
158 }
159
160 let mut formatted = vec![lines[0].trim().to_string()];
162
163 if lines.len() > 1 {
164 if !lines[0].trim().is_empty() && !lines[1].trim().is_empty() {
166 formatted.push(String::new());
167 }
168
169 for line in lines.iter().skip(1) {
171 let cleaned = line.trim();
172 if cleaned.starts_with("Args:") {
173 formatted.push(String::new());
174 formatted.push("# Arguments".to_string());
175 } else if cleaned.starts_with("Returns:") {
176 formatted.push(String::new());
177 formatted.push("# Returns".to_string());
178 } else if cleaned.starts_with("Example:") {
179 formatted.push(String::new());
180 formatted.push("# Examples".to_string());
181 } else if cleaned.starts_with(">>>") {
182 formatted.push(format!("```rust"));
184 formatted.push(format!("// {}", cleaned));
185 } else if !cleaned.is_empty() {
186 formatted.push(cleaned.to_string());
187 }
188 }
189
190 if content.contains(">>>") {
192 formatted.push("```".to_string());
193 }
194 }
195
196 formatted.join("\n")
197 }
198}
199
200impl Object for FunctionDef {}