1use rosetta_stone_core::{IrType, IrExpr, IrLiteral, TranspileError, Result};
13use rosetta_stone_ir::{IrModule, IrFunction, IrTypeDef, Visibility};
14use proc_macro2::TokenStream;
15use quote::{quote, format_ident};
16
17#[derive(Debug, Clone)]
19pub struct CodegenOptions {
20 pub allow_unsafe: bool,
22 pub include_comments: bool,
24 pub format_output: bool,
26 pub generate_tests: bool,
28 pub use_ndarray: bool,
30}
31
32impl Default for CodegenOptions {
33 fn default() -> Self {
34 Self {
35 allow_unsafe: true,
36 include_comments: true,
37 format_output: true,
38 generate_tests: false,
39 use_ndarray: true,
40 }
41 }
42}
43
44pub struct RustCodegen {
46 options: CodegenOptions,
47}
48
49impl RustCodegen {
50 pub fn new(options: CodegenOptions) -> Self {
52 Self { options }
53 }
54
55 pub fn generate(&self, module: &IrModule) -> Result<String> {
57 let tokens = self.generate_tokens(module)?;
58
59 if self.options.format_output {
60 self.format_code(tokens)
61 } else {
62 Ok(tokens.to_string())
63 }
64 }
65
66 fn generate_tokens(&self, module: &IrModule) -> Result<TokenStream> {
68 let module_comment = if self.options.include_comments {
69 let _lang = format!("{:?}", module.source_lang);
70 quote! {
71 }
74 } else {
75 quote! {}
76 };
77
78 let imports = self.generate_imports(module)?;
80
81 let types: Vec<TokenStream> = module.types
83 .iter()
84 .map(|t| self.generate_type_def(t))
85 .collect::<Result<_>>()?;
86
87 let constants: Vec<TokenStream> = module.constants
89 .iter()
90 .map(|c| {
91 let name = format_ident!("{}", &c.name);
92 let ty = self.ir_type_to_rust(&c.ty)?;
93 let value = self.expr_to_tokens(&c.value)?;
94 let vis = self.visibility_to_tokens(c.visibility);
95 Ok(quote! {
96 #vis const #name: #ty = #value;
97 })
98 })
99 .collect::<Result<_>>()?;
100
101 let functions: Vec<TokenStream> = module.functions
103 .iter()
104 .map(|f| self.generate_function(f))
105 .collect::<Result<_>>()?;
106
107 Ok(quote! {
108 #module_comment
109 #imports
110 #(#types)*
111 #(#constants)*
112 #(#functions)*
113 })
114 }
115
116 fn generate_imports(&self, _module: &IrModule) -> Result<TokenStream> {
118 let mut imports = vec![];
119
120 if self.options.use_ndarray {
121 imports.push(quote! { use ndarray::{Array1, Array2, ArrayD}; });
122 }
123
124 Ok(quote! { #(#imports)* })
125 }
126
127 fn generate_type_def(&self, typedef: &IrTypeDef) -> Result<TokenStream> {
129 match typedef {
130 IrTypeDef::Struct { name, fields, derives } => {
131 let name_ident = format_ident!("{}", name);
132 let derive_idents: Vec<_> = derives.iter()
133 .map(|d| format_ident!("{}", d))
134 .collect();
135
136 let field_tokens: Vec<TokenStream> = fields.iter()
137 .map(|(name, ty)| {
138 let field_name = format_ident!("{}", name);
139 let field_ty = self.ir_type_to_rust(ty)?;
140 Ok(quote! { pub #field_name: #field_ty })
141 })
142 .collect::<Result<_>>()?;
143
144 Ok(quote! {
145 #[derive(#(#derive_idents),*)]
146 pub struct #name_ident {
147 #(#field_tokens),*
148 }
149 })
150 }
151 IrTypeDef::Enum { name, variants } => {
152 let name_ident = format_ident!("{}", name);
153 let variant_tokens: Vec<TokenStream> = variants.iter()
154 .map(|v| {
155 let var_name = format_ident!("{}", &v.name);
156 if let Some(disc) = v.discriminant {
157 quote! { #var_name = #disc }
158 } else {
159 quote! { #var_name }
160 }
161 })
162 .collect();
163
164 Ok(quote! {
165 pub enum #name_ident {
166 #(#variant_tokens),*
167 }
168 })
169 }
170 IrTypeDef::Alias { name, target } => {
171 let name_ident = format_ident!("{}", name);
172 let target_ty = self.ir_type_to_rust(target)?;
173 Ok(quote! {
174 pub type #name_ident = #target_ty;
175 })
176 }
177 }
178 }
179
180 fn generate_function(&self, func: &IrFunction) -> Result<TokenStream> {
182 let name = format_ident!("{}", &func.name);
183 let vis = self.visibility_to_tokens(func.visibility);
184
185 let params: Vec<TokenStream> = func.params.iter()
186 .map(|p| {
187 let param_name = format_ident!("{}", &p.name);
188 let param_ty = self.ir_type_to_rust(&p.ty)?;
189 if p.by_ref && p.is_mutable {
190 Ok(quote! { #param_name: &mut #param_ty })
191 } else if p.by_ref {
192 Ok(quote! { #param_name: &#param_ty })
193 } else {
194 Ok(quote! { #param_name: #param_ty })
195 }
196 })
197 .collect::<Result<_>>()?;
198
199 let return_ty = self.ir_type_to_rust(&func.return_type)?;
200
201 let body_tokens: Vec<TokenStream> = func.body.iter()
202 .map(|expr| self.expr_to_tokens(expr))
203 .collect::<Result<_>>()?;
204
205 let unsafe_kw = if func.is_unsafe {
206 quote! { unsafe }
207 } else {
208 quote! {}
209 };
210
211 Ok(quote! {
212 #vis #unsafe_kw fn #name(#(#params),*) -> #return_ty {
213 #(#body_tokens)*
214 }
215 })
216 }
217
218 fn ir_type_to_rust(&self, ty: &IrType) -> Result<TokenStream> {
220 match ty {
221 IrType::Int(8) => Ok(quote! { i8 }),
222 IrType::Int(16) => Ok(quote! { i16 }),
223 IrType::Int(32) => Ok(quote! { i32 }),
224 IrType::Int(64) => Ok(quote! { i64 }),
225 IrType::Int(_) => Ok(quote! { i32 }),
226 IrType::Float(32) => Ok(quote! { f32 }),
227 IrType::Float(64) => Ok(quote! { f64 }),
228 IrType::Float(_) => Ok(quote! { f64 }),
229 IrType::Bool => Ok(quote! { bool }),
230 IrType::String => Ok(quote! { String }),
231 IrType::Char => Ok(quote! { char }),
232 IrType::Unit => Ok(quote! { () }),
233 IrType::Array(inner, Some(size)) => {
234 let inner_ty = self.ir_type_to_rust(inner)?;
235 let size_lit = proc_macro2::Literal::usize_unsuffixed(*size);
236 Ok(quote! { [#inner_ty; #size_lit] })
237 }
238 IrType::Array(inner, None) => {
239 let inner_ty = self.ir_type_to_rust(inner)?;
240 Ok(quote! { Vec<#inner_ty> })
241 }
242 IrType::Vec(inner) => {
243 let inner_ty = self.ir_type_to_rust(inner)?;
244 Ok(quote! { Vec<#inner_ty> })
245 }
246 IrType::Ref(inner) => {
247 let inner_ty = self.ir_type_to_rust(inner)?;
248 Ok(quote! { &#inner_ty })
249 }
250 IrType::MutRef(inner) => {
251 let inner_ty = self.ir_type_to_rust(inner)?;
252 Ok(quote! { &mut #inner_ty })
253 }
254 IrType::Box(inner) => {
255 let inner_ty = self.ir_type_to_rust(inner)?;
256 Ok(quote! { Box<#inner_ty> })
257 }
258 IrType::Option(inner) => {
259 let inner_ty = self.ir_type_to_rust(inner)?;
260 Ok(quote! { Option<#inner_ty> })
261 }
262 IrType::Result(ok, err) => {
263 let ok_ty = self.ir_type_to_rust(ok)?;
264 let err_ty = self.ir_type_to_rust(err)?;
265 Ok(quote! { Result<#ok_ty, #err_ty> })
266 }
267 IrType::Tuple(types) => {
268 let type_tokens: Vec<TokenStream> = types.iter()
269 .map(|t| self.ir_type_to_rust(t))
270 .collect::<Result<_>>()?;
271 Ok(quote! { (#(#type_tokens),*) })
272 }
273 IrType::Struct(name) => {
274 let name_ident = format_ident!("{}", name);
275 Ok(quote! { #name_ident })
276 }
277 IrType::Fn(params, ret) => {
278 let param_types: Vec<TokenStream> = params.iter()
279 .map(|t| self.ir_type_to_rust(t))
280 .collect::<Result<_>>()?;
281 let ret_ty = self.ir_type_to_rust(ret)?;
282 Ok(quote! { fn(#(#param_types),*) -> #ret_ty })
283 }
284 IrType::Iterator(inner) => {
285 let inner_ty = self.ir_type_to_rust(inner)?;
286 Ok(quote! { impl Iterator<Item = #inner_ty> })
287 }
288 IrType::Any => Ok(quote! { Box<dyn std::any::Any> }),
289 IrType::Unknown => Ok(quote! { () }),
290 }
291 }
292
293 #[allow(dead_code)]
295 fn node_to_tokens(&self, _node: &rosetta_stone_core::IrNode) -> Result<TokenStream> {
296 Ok(quote! {})
298 }
299
300 fn expr_to_tokens(&self, expr: &IrExpr) -> Result<TokenStream> {
302 match expr {
303 IrExpr::Int(i) => Ok(quote! { #i }),
305 IrExpr::Float(f) => {
306 let f_lit = proc_macro2::Literal::f64_unsuffixed(*f);
307 Ok(quote! { #f_lit })
308 }
309 IrExpr::Bool(b) => Ok(quote! { #b }),
310 IrExpr::String(s) => Ok(quote! { #s.to_string() }),
311 IrExpr::Char(c) => Ok(quote! { #c }),
312 IrExpr::Nil => Ok(quote! { None }),
313 IrExpr::Symbol(s) => {
314 let sym = format_ident!("{}", s);
315 Ok(quote! { #sym })
316 }
317
318 IrExpr::Identifier(name) => {
320 let ident = format_ident!("{}", name);
321 Ok(quote! { #ident })
322 }
323 IrExpr::PatternVar(name) => {
324 let ident = format_ident!("{}", name);
325 Ok(quote! { #ident })
326 }
327
328 IrExpr::BinaryOp { op, left, right } => {
330 let left_tokens = self.expr_to_tokens(left)?;
331 let right_tokens = self.expr_to_tokens(right)?;
332 let op_tokens = self.string_binop_to_tokens(op);
333 Ok(quote! { (#left_tokens #op_tokens #right_tokens) })
334 }
335 IrExpr::UnaryOp { op, operand } => {
336 let operand_tokens = self.expr_to_tokens(operand)?;
337 let op_tokens = self.string_unaryop_to_tokens(op);
338 Ok(quote! { #op_tokens #operand_tokens })
339 }
340
341 IrExpr::Call { func, args } => {
343 let func_tokens = self.expr_to_tokens(func)?;
344 let arg_tokens: Vec<TokenStream> = args.iter()
345 .map(|a| self.expr_to_tokens(a))
346 .collect::<Result<_>>()?;
347 Ok(quote! { #func_tokens(#(#arg_tokens),*) })
348 }
349
350 IrExpr::List(items) => {
352 let item_tokens: Vec<TokenStream> = items.iter()
353 .map(|i| self.expr_to_tokens(i))
354 .collect::<Result<_>>()?;
355 Ok(quote! { vec![#(#item_tokens),*] })
356 }
357 IrExpr::ListCons { head, tail } => {
358 let h = self.expr_to_tokens(head)?;
359 let t = self.expr_to_tokens(tail)?;
360 Ok(quote! { std::iter::once(#h).chain(#t.into_iter()).collect::<Vec<_>>() })
361 }
362 IrExpr::StructInit { name, fields } => {
363 let name_ident = format_ident!("{}", name);
364 let field_inits: Vec<TokenStream> = fields.iter()
365 .map(|(fname, fval)| {
366 let fi = format_ident!("{}", fname);
367 let fv = self.expr_to_tokens(fval)?;
368 Ok(quote! { #fi: #fv })
369 })
370 .collect::<Result<_>>()?;
371 Ok(quote! { #name_ident { #(#field_inits),* } })
372 }
373 IrExpr::Tuple(items) => {
374 let item_tokens: Vec<TokenStream> = items.iter()
375 .map(|i| self.expr_to_tokens(i))
376 .collect::<Result<_>>()?;
377 Ok(quote! { (#(#item_tokens),*) })
378 }
379
380 IrExpr::FieldAccess { object, field } => {
382 let obj = self.expr_to_tokens(object)?;
383 let fld = format_ident!("{}", field);
384 Ok(quote! { #obj.#fld })
385 }
386 IrExpr::FieldAssign { object, field, value } => {
387 let obj = format_ident!("{}", object);
388 let fld = format_ident!("{}", field);
389 let val = self.expr_to_tokens(value)?;
390 Ok(quote! { #obj.#fld = #val })
391 }
392 IrExpr::Index { array, index } => {
393 let arr = self.expr_to_tokens(array)?;
394 let idx = self.expr_to_tokens(index)?;
395 Ok(quote! { #arr[#idx] })
396 }
397
398 IrExpr::If { condition, then_branch, else_branch } => {
400 let cond = self.expr_to_tokens(condition)?;
401 let then_b = self.expr_to_tokens(then_branch)?;
402 if let Some(else_b) = else_branch {
403 let else_tokens = self.expr_to_tokens(else_b)?;
404 Ok(quote! { if #cond { #then_b } else { #else_tokens } })
405 } else {
406 Ok(quote! { if #cond { #then_b } })
407 }
408 }
409 IrExpr::Cond(branches) => {
410 if branches.is_empty() {
411 return Ok(quote! { () });
412 }
413 let mut tokens = TokenStream::new();
414 for (i, (cond, body)) in branches.iter().enumerate() {
415 let c = self.expr_to_tokens(cond)?;
416 let b = self.expr_to_tokens(body)?;
417 if i == 0 {
418 tokens = quote! { if #c { #b } };
419 } else {
420 tokens = quote! { #tokens else if #c { #b } };
421 }
422 }
423 Ok(tokens)
424 }
425 IrExpr::Match { scrutinee, arms } => {
426 let scrut = self.expr_to_tokens(scrutinee)?;
427 let arm_tokens: Vec<TokenStream> = arms.iter()
428 .map(|(pat, body)| {
429 let p = self.expr_to_tokens(pat)?;
430 let b = self.expr_to_tokens(body)?;
431 Ok(quote! { #p => #b })
432 })
433 .collect::<Result<_>>()?;
434 Ok(quote! { match #scrut { #(#arm_tokens),* } })
435 }
436 IrExpr::Block(exprs) => {
437 let expr_tokens: Vec<TokenStream> = exprs.iter()
438 .map(|e| self.expr_to_tokens(e))
439 .collect::<Result<_>>()?;
440 Ok(quote! { { #(#expr_tokens;)* } })
441 }
442 IrExpr::Return(expr) => {
443 let e = self.expr_to_tokens(expr)?;
444 Ok(quote! { return #e })
445 }
446
447 IrExpr::Assign { target, value } => {
449 let t = format_ident!("{}", target);
450 let v = self.expr_to_tokens(value)?;
451 Ok(quote! { #t = #v })
452 }
453 IrExpr::Let { name, value, body } => {
454 let n = format_ident!("{}", name);
455 let v = self.expr_to_tokens(value)?;
456 let b = self.expr_to_tokens(body)?;
457 Ok(quote! { { let #n = #v; #b } })
458 }
459
460 IrExpr::Lambda { params, body } => {
462 let param_idents: Vec<_> = params.iter()
463 .map(|p| format_ident!("{}", p))
464 .collect();
465 let body_tokens: Vec<TokenStream> = body.iter()
466 .map(|e| self.expr_to_tokens(e))
467 .collect::<Result<_>>()?;
468 Ok(quote! { |#(#param_idents),*| { #(#body_tokens;)* } })
469 }
470
471 IrExpr::Quote(expr) => {
473 let e = self.expr_to_tokens(expr)?;
474 Ok(quote! { #e })
475 }
476 IrExpr::Quasiquote(expr) => {
477 let e = self.expr_to_tokens(expr)?;
478 Ok(quote! { #e })
479 }
480 IrExpr::Unquote(expr) => {
481 let e = self.expr_to_tokens(expr)?;
482 Ok(quote! { #e })
483 }
484
485 IrExpr::Goal { pattern, body } => {
487 let p = self.expr_to_tokens(pattern)?;
488 let b: Vec<TokenStream> = body.iter()
489 .map(|e| self.expr_to_tokens(e))
490 .collect::<Result<_>>()?;
491 Ok(quote! { goal(#p, || { #(#b;)* }) })
492 }
493 IrExpr::Unify { left, right } => {
494 let l = self.expr_to_tokens(left)?;
495 let r = self.expr_to_tokens(right)?;
496 Ok(quote! { unify(#l, #r) })
497 }
498 IrExpr::PatternMatch { value, pattern } => {
499 let v = self.expr_to_tokens(value)?;
500 let p = self.expr_to_tokens(pattern)?;
501 Ok(quote! { matches!(#v, #p) })
502 }
503 IrExpr::Choice(choices) => {
504 let choice_tokens: Vec<TokenStream> = choices.iter()
505 .map(|c| self.expr_to_tokens(c))
506 .collect::<Result<_>>()?;
507 Ok(quote! { choice([#(|| #choice_tokens),*]) })
508 }
509
510 IrExpr::WmeCreate { class, attributes } => {
512 let c = format_ident!("{}", class);
513 let attrs: Vec<TokenStream> = attributes.iter()
514 .map(|(k, v)| {
515 let ki = format_ident!("{}", k);
516 let vi = self.expr_to_tokens(v)?;
517 Ok(quote! { #ki: #vi })
518 })
519 .collect::<Result<_>>()?;
520 Ok(quote! { wm.insert(#c { #(#attrs),* }) })
521 }
522
523 IrExpr::Comment(_text) => {
525 Ok(quote! { })
526 }
527 }
528 }
529
530 #[allow(dead_code)]
532 fn literal_to_tokens(&self, lit: &IrLiteral) -> Result<TokenStream> {
533 match lit {
534 IrLiteral::Int(i) => Ok(quote! { #i }),
535 IrLiteral::Float(f) => {
536 let f_lit = proc_macro2::Literal::f64_unsuffixed(*f);
537 Ok(quote! { #f_lit })
538 }
539 IrLiteral::Bool(b) => Ok(quote! { #b }),
540 IrLiteral::String(s) => Ok(quote! { #s.to_string() }),
541 IrLiteral::Char(c) => Ok(quote! { #c }),
542 }
543 }
544
545 fn string_binop_to_tokens(&self, op: &str) -> TokenStream {
547 match op {
548 "+" | "add" => quote! { + },
549 "-" | "sub" => quote! { - },
550 "*" | "mul" => quote! { * },
551 "/" | "div" => quote! { / },
552 "%" | "mod" | "rem" => quote! { % },
553 "==" | "eq" => quote! { == },
554 "!=" | "ne" | "<>" => quote! { != },
555 "<" | "lt" => quote! { < },
556 "<=" | "le" => quote! { <= },
557 ">" | "gt" => quote! { > },
558 ">=" | "ge" => quote! { >= },
559 "&&" | "and" => quote! { && },
560 "||" | "or" => quote! { || },
561 "^" | "xor" => quote! { ^ },
562 "&" | "bitand" => quote! { & },
563 "|" | "bitor" => quote! { | },
564 "<<" | "shl" => quote! { << },
565 ">>" | "shr" => quote! { >> },
566 "**" | "pow" => quote! { .pow },
567 _ => {
568 let op_ident = format_ident!("{}", op);
569 quote! { .#op_ident() }
570 }
571 }
572 }
573
574 fn string_unaryop_to_tokens(&self, op: &str) -> TokenStream {
576 match op {
577 "-" | "neg" => quote! { - },
578 "!" | "not" => quote! { ! },
579 "~" | "bitnot" => quote! { ! },
580 "*" | "deref" => quote! { * },
581 "&" | "ref" => quote! { & },
582 _ => quote! { }
583 }
584 }
585
586 fn visibility_to_tokens(&self, vis: Visibility) -> TokenStream {
588 match vis {
589 Visibility::Public => quote! { pub },
590 Visibility::Private => quote! {},
591 Visibility::Crate => quote! { pub(crate) },
592 Visibility::Super => quote! { pub(super) },
593 }
594 }
595
596 fn format_code(&self, tokens: TokenStream) -> Result<String> {
598 let file: syn::File = syn::parse2(tokens)
599 .map_err(|e| TranspileError::CodegenError(format!("Parse error: {}", e)))?;
600
601 Ok(prettyplease::unparse(&file))
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use rosetta_stone_core::SourceLanguage;
609 use rosetta_stone_ir::IrBuilder;
610
611 #[test]
612 fn test_basic_codegen() {
613 let mut builder = IrBuilder::with_language("test", SourceLanguage::Fortran77);
614 builder.add_function(rosetta_stone_ir::IrFunction {
615 name: "add".to_string(),
616 generics: vec![],
617 params: vec![
618 rosetta_stone_ir::IrParam {
619 name: "a".to_string(),
620 ty: IrType::Int(32),
621 is_mutable: false,
622 by_ref: false,
623 },
624 rosetta_stone_ir::IrParam {
625 name: "b".to_string(),
626 ty: IrType::Int(32),
627 is_mutable: false,
628 by_ref: false,
629 },
630 ],
631 return_type: IrType::Int(32),
632 body: vec![],
633 is_unsafe: false,
634 visibility: Visibility::Public,
635 attributes: vec![],
636 span: None,
637 });
638
639 let module = builder.build();
640 let codegen = RustCodegen::new(CodegenOptions::default());
641 let result = codegen.generate(&module);
642
643 assert!(result.is_ok());
644 }
645}