spydecy_codegen/
lib.rs

1//! Rust Code Generator - Sprint 4+
2//!
3//! Generates idiomatic Rust code from optimized Unified HIR.
4//! This is the final stage of the Spydecy pipeline:
5//!
6//! ```text
7//! Python + C → Unified HIR → Optimizer → Codegen → Rust Code
8//! ```
9
10#![warn(missing_docs, clippy::all, clippy::pedantic)]
11#![deny(unsafe_code)]
12#![allow(clippy::module_name_repetitions)]
13
14use anyhow::{Context, Result};
15use spydecy_hir::unified::{UnificationPattern, UnifiedHIR};
16
17/// Rust code generator
18///
19/// Converts optimized `UnifiedHIR` into idiomatic Rust source code.
20pub struct RustCodegen {
21    /// Indentation level
22    indent_level: usize,
23    /// Indentation string (default: 4 spaces)
24    indent: String,
25}
26
27impl RustCodegen {
28    /// Create a new Rust code generator
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            indent_level: 0,
33            indent: "    ".to_owned(), // 4 spaces
34        }
35    }
36
37    /// Generate Rust code from `UnifiedHIR`
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if the HIR cannot be converted to Rust code
42    pub fn generate(&mut self, hir: &UnifiedHIR) -> Result<String> {
43        match hir {
44            UnifiedHIR::Module {
45                name, declarations, ..
46            } => self.generate_module(name, declarations),
47            UnifiedHIR::Function {
48                name,
49                params,
50                return_type,
51                body,
52                ..
53            } => self.generate_function(name, params, return_type, body),
54            UnifiedHIR::Call {
55                callee,
56                args,
57                cross_mapping,
58                ..
59            } => {
60                // Check if this is an optimized pattern
61                if let Some(mapping) = cross_mapping {
62                    if mapping.boundary_eliminated {
63                        return Ok(Self::generate_optimized_call(callee, args, mapping.pattern));
64                    }
65                }
66                self.generate_call(callee, args)
67            }
68            UnifiedHIR::Variable { name, .. } => Ok(name.clone()),
69            UnifiedHIR::Return { value, .. } => self.generate_return(value.as_deref()),
70            UnifiedHIR::Assign { target, value, .. } => self.generate_assign(target, value),
71            _ => Ok("/* Unsupported HIR node */".to_owned()),
72        }
73    }
74
75    /// Generate a module
76    fn generate_module(&mut self, name: &str, declarations: &[UnifiedHIR]) -> Result<String> {
77        let mut output = String::new();
78
79        // Module header
80        output.push_str("// Module: ");
81        output.push_str(name);
82        output.push('\n');
83        output.push_str("// Generated by Spydecy\n");
84        output.push_str("#![allow(dead_code)]\n\n");
85
86        // Generate all declarations
87        for decl in declarations {
88            let code = self.generate(decl)?;
89            output.push_str(&code);
90            output.push_str("\n\n");
91        }
92
93        Ok(output)
94    }
95
96    /// Generate a function
97    fn generate_function(
98        &mut self,
99        name: &str,
100        params: &[spydecy_hir::unified::UnifiedParameter],
101        return_type: &spydecy_hir::types::Type,
102        body: &[UnifiedHIR],
103    ) -> Result<String> {
104        let mut output = String::new();
105
106        // Function signature
107        output.push_str("pub fn ");
108        output.push_str(name);
109        output.push('(');
110
111        // Parameters
112        for (i, param) in params.iter().enumerate() {
113            if i > 0 {
114                output.push_str(", ");
115            }
116            output.push_str(&param.name);
117            output.push_str(": ");
118            output.push_str(&self.generate_type(&param.param_type)?);
119        }
120
121        output.push(')');
122
123        // Return type
124        if !matches!(
125            return_type,
126            spydecy_hir::types::Type::Rust(spydecy_hir::types::RustType::Unit)
127        ) {
128            output.push_str(" -> ");
129            output.push_str(&self.generate_type(return_type)?);
130        }
131
132        output.push_str(" {\n");
133
134        // Function body
135        self.indent_level += 1;
136        for stmt in body {
137            let code = self.generate(stmt)?;
138            output.push_str(&self.indent());
139            output.push_str(&code);
140            if !code.trim().ends_with('}') && !code.trim().is_empty() {
141                output.push(';');
142            }
143            output.push('\n');
144        }
145        self.indent_level -= 1;
146
147        output.push('}');
148
149        Ok(output)
150    }
151
152    /// Extract the receiver variable name from arguments
153    ///
154    /// Returns the name of the first Variable argument, or "x" as fallback
155    fn extract_receiver_name(args: &[UnifiedHIR]) -> String {
156        args.first()
157            .and_then(|arg| {
158                if let UnifiedHIR::Variable { name, .. } = arg {
159                    Some(name.clone())
160                } else {
161                    None
162                }
163            })
164            .unwrap_or_else(|| "x".to_owned())
165    }
166
167    /// Generate an optimized call (post-boundary-elimination)
168    #[allow(clippy::unnecessary_wraps)]
169    fn generate_optimized_call(
170        callee: &str,
171        args: &[UnifiedHIR],
172        pattern: UnificationPattern,
173    ) -> String {
174        // Helper: Extract variable name from first argument
175        let receiver = Self::extract_receiver_name(args);
176
177        // Generate idiomatic Rust based on the pattern
178        match pattern {
179            UnificationPattern::LenPattern => {
180                // Vec::len() becomes <receiver>.len()
181                format!("{receiver}.len()")
182            }
183            UnificationPattern::AppendPattern => {
184                // Vec::push() becomes <receiver>.push(item)
185                format!("{receiver}.push(item)")
186            }
187            UnificationPattern::DictGetPattern => {
188                // HashMap::get() becomes <receiver>.get(&key)
189                format!("{receiver}.get(&key)")
190            }
191            UnificationPattern::ReversePattern => {
192                // Vec::reverse() becomes <receiver>.reverse()
193                format!("{receiver}.reverse()")
194            }
195            UnificationPattern::ClearPattern => {
196                // Vec::clear() becomes <receiver>.clear()
197                format!("{receiver}.clear()")
198            }
199            UnificationPattern::PopPattern => {
200                // Vec::pop() becomes <receiver>.pop()
201                format!("{receiver}.pop()")
202            }
203            UnificationPattern::InsertPattern => {
204                // Vec::insert() becomes <receiver>.insert(index, value)
205                format!("{receiver}.insert(index, value)")
206            }
207            UnificationPattern::ExtendPattern => {
208                // Vec::extend() becomes <receiver>.extend(iter)
209                format!("{receiver}.extend(iter)")
210            }
211            UnificationPattern::DictPopPattern => {
212                // HashMap::remove() becomes <receiver>.remove(&key)
213                format!("{receiver}.remove(&key)")
214            }
215            UnificationPattern::DictClearPattern => {
216                // HashMap::clear() becomes <receiver>.clear()
217                format!("{receiver}.clear()")
218            }
219            UnificationPattern::DictKeysPattern => {
220                // HashMap::keys() becomes <receiver>.keys()
221                format!("{receiver}.keys()")
222            }
223            UnificationPattern::Custom => format!("{callee}()"),
224        }
225    }
226
227    /// Generate a regular call
228    fn generate_call(&mut self, callee: &str, args: &[UnifiedHIR]) -> Result<String> {
229        let mut output = callee.to_owned();
230        output.push('(');
231
232        for (i, arg) in args.iter().enumerate() {
233            if i > 0 {
234                output.push_str(", ");
235            }
236            output.push_str(&self.generate(arg)?);
237        }
238
239        output.push(')');
240        Ok(output)
241    }
242
243    /// Generate a return statement
244    fn generate_return(&mut self, value: Option<&UnifiedHIR>) -> Result<String> {
245        if let Some(val) = value {
246            Ok(format!("return {}", self.generate(val)?))
247        } else {
248            Ok("return".to_owned())
249        }
250    }
251
252    /// Generate an assignment
253    fn generate_assign(&mut self, target: &str, value: &UnifiedHIR) -> Result<String> {
254        let val_code = self.generate(value)?;
255        Ok(format!("let {target} = {val_code}"))
256    }
257
258    /// Generate a type annotation
259    #[allow(clippy::only_used_in_recursion)]
260    fn generate_type(&self, ty: &spydecy_hir::types::Type) -> Result<String> {
261        use spydecy_hir::types::{RustType, Type};
262
263        match ty {
264            Type::Rust(rust_ty) => match rust_ty {
265                RustType::Int { bits, signed } => {
266                    let prefix = if *signed { "i" } else { "u" };
267                    let size = match bits {
268                        spydecy_hir::types::IntSize::I8 => "8",
269                        spydecy_hir::types::IntSize::I16 => "16",
270                        spydecy_hir::types::IntSize::I32 => "32",
271                        spydecy_hir::types::IntSize::I64 => "64",
272                        spydecy_hir::types::IntSize::I128 => "128",
273                        spydecy_hir::types::IntSize::ISize => "size",
274                    };
275                    Ok(format!("{prefix}{size}"))
276                }
277                RustType::Bool => Ok("bool".to_owned()),
278                RustType::String => Ok("String".to_owned()),
279                RustType::Str => Ok("&str".to_owned()),
280                RustType::Vec(inner) => Ok(format!("Vec<{}>", self.generate_type(inner)?)),
281                RustType::Option(inner) => Ok(format!("Option<{}>", self.generate_type(inner)?)),
282                RustType::Unit => Ok("()".to_owned()),
283                RustType::Reference { mutable, inner } => {
284                    let mut_str = if *mutable { "mut " } else { "" };
285                    Ok(format!("&{mut_str}{}", self.generate_type(inner)?))
286                }
287                _ => Ok("/* complex type */".to_owned()),
288            },
289            Type::Unknown => Ok("/* infer */".to_owned()),
290            _ => Ok("/* non-rust type */".to_owned()),
291        }
292    }
293
294    /// Get current indentation string
295    fn indent(&self) -> String {
296        self.indent.repeat(self.indent_level)
297    }
298}
299
300impl Default for RustCodegen {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306/// Generate Rust code from `UnifiedHIR` (convenience function)
307///
308/// # Errors
309///
310/// Returns an error if code generation fails
311pub fn generate_rust(hir: &UnifiedHIR) -> Result<String> {
312    let mut codegen = RustCodegen::new();
313    codegen
314        .generate(hir)
315        .context("Failed to generate Rust code")
316}
317
318#[cfg(test)]
319#[allow(clippy::expect_used)]
320mod tests {
321    use super::*;
322    use spydecy_hir::{
323        metadata::Metadata,
324        types::{IntSize, RustType, Type},
325        unified::{CrossMapping, UnificationPattern},
326        Language, NodeId,
327    };
328
329    #[test]
330    fn test_generate_len_pattern() {
331        // Create UnifiedHIR for optimized len() call
332        let hir = UnifiedHIR::Call {
333            id: NodeId::new(1),
334            target_language: Language::Rust,
335            callee: "Vec::len".to_owned(),
336            args: vec![],
337            inferred_type: Type::Rust(RustType::Int {
338                bits: IntSize::ISize,
339                signed: false,
340            }),
341            source_language: Language::Python,
342            cross_mapping: Some(CrossMapping {
343                python_node: None,
344                c_node: None,
345                pattern: UnificationPattern::LenPattern,
346                boundary_eliminated: true,
347            }),
348            meta: Metadata::new(),
349        };
350
351        let code = generate_rust(&hir).expect("Should generate code");
352        assert_eq!(code.trim(), "x.len()");
353    }
354
355    #[test]
356    fn test_generate_append_pattern() {
357        let hir = UnifiedHIR::Call {
358            id: NodeId::new(1),
359            target_language: Language::Rust,
360            callee: "Vec::push".to_owned(),
361            args: vec![],
362            inferred_type: Type::Rust(RustType::Unit),
363            source_language: Language::Python,
364            cross_mapping: Some(CrossMapping {
365                python_node: None,
366                c_node: None,
367                pattern: UnificationPattern::AppendPattern,
368                boundary_eliminated: true,
369            }),
370            meta: Metadata::new(),
371        };
372
373        let code = generate_rust(&hir).expect("Should generate code");
374        assert_eq!(code.trim(), "x.push(item)");
375    }
376
377    #[test]
378    fn test_generate_dict_get_pattern() {
379        let hir = UnifiedHIR::Call {
380            id: NodeId::new(1),
381            target_language: Language::Rust,
382            callee: "HashMap::get".to_owned(),
383            args: vec![UnifiedHIR::Variable {
384                id: NodeId::new(2),
385                name: "map".to_owned(),
386                var_type: Type::Unknown,
387                source_language: Language::Python,
388                meta: Metadata::new(),
389            }],
390            inferred_type: Type::Unknown,
391            source_language: Language::Python,
392            cross_mapping: Some(CrossMapping {
393                python_node: None,
394                c_node: None,
395                pattern: UnificationPattern::DictGetPattern,
396                boundary_eliminated: true,
397            }),
398            meta: Metadata::new(),
399        };
400
401        let code = generate_rust(&hir).expect("Should generate code");
402        assert_eq!(code.trim(), "map.get(&key)");
403    }
404
405    #[test]
406    fn test_generate_variable() {
407        let hir = UnifiedHIR::Variable {
408            id: NodeId::new(1),
409            name: "my_var".to_owned(),
410            var_type: Type::Unknown,
411            source_language: Language::Rust,
412            meta: Metadata::new(),
413        };
414
415        let code = generate_rust(&hir).expect("Should generate code");
416        assert_eq!(code, "my_var");
417    }
418
419    #[test]
420    fn test_generate_type_int() {
421        let codegen = RustCodegen::new();
422        let ty = Type::Rust(RustType::Int {
423            bits: IntSize::I32,
424            signed: true,
425        });
426
427        let code = codegen.generate_type(&ty).expect("Should generate type");
428        assert_eq!(code, "i32");
429    }
430
431    #[test]
432    fn test_generate_type_vec() {
433        let codegen = RustCodegen::new();
434        let ty = Type::Rust(RustType::Vec(Box::new(Type::Rust(RustType::Int {
435            bits: IntSize::I32,
436            signed: true,
437        }))));
438
439        let code = codegen.generate_type(&ty).expect("Should generate type");
440        assert_eq!(code, "Vec<i32>");
441    }
442}