aether/pytranspile/
python.rs

1#[cfg(feature = "pytranspile")]
2use pyo3::{prelude::*, types::PyList};
3
4use crate::pytranspile::diagnostics::{Diagnostic, Diagnostics};
5#[cfg(feature = "pytranspile")]
6use crate::pytranspile::ir::{Expr, Stmt};
7use crate::pytranspile::ir::{Module, Span};
8use crate::pytranspile::options::TranspileOptions;
9
10#[derive(Debug)]
11pub struct PythonIrResult {
12    pub module: Option<Module>,
13    pub diagnostics: Diagnostics,
14    pub numpy_used: bool,
15    pub io_used: bool,
16    pub console_used: bool,
17}
18
19impl PythonIrResult {
20    fn fail(diagnostics: Diagnostics) -> Self {
21        Self {
22            module: None,
23            diagnostics,
24            numpy_used: false,
25            io_used: false,
26            console_used: false,
27        }
28    }
29}
30
31#[cfg(feature = "pytranspile")]
32pub fn python_to_ir(source: &str, opts: &TranspileOptions) -> PythonIrResult {
33    Python::attach(|py| {
34        let mut diagnostics = Diagnostics::new();
35
36        let ast = match py.import("ast") {
37            Ok(m) => m,
38            Err(err) => {
39                diagnostics.push(Diagnostic::error(
40                    "PY_IMPORT_AST_FAILED",
41                    format!("failed to import python ast module: {err}"),
42                    Span::default(),
43                ));
44                return PythonIrResult::fail(diagnostics);
45            }
46        };
47
48        let tree = match ast.call_method1("parse", (source,)) {
49            Ok(t) => t,
50            Err(err) => {
51                let msg = err.to_string();
52                diagnostics.push(Diagnostic::error("PY_SYNTAX_ERROR", msg, Span::default()));
53                return PythonIrResult::fail(diagnostics);
54            }
55        };
56
57        let mut builder = IrBuilder::new(py, ast, opts, diagnostics);
58        let module = builder.emit_module(tree.as_ref());
59
60        // Enforce hard rejections as early diagnostics.
61        if opts.reject_numpy && builder.numpy_used {
62            builder.diagnostics.push(Diagnostic::error(
63                "PY_NUMPY_REJECTED",
64                "numpy usage is rejected by transpile options",
65                Span::default(),
66            ));
67        }
68        if opts.reject_io && builder.io_used {
69            builder.diagnostics.push(Diagnostic::error(
70                "PY_IO_REJECTED",
71                "filesystem/network usage is rejected by transpile options",
72                Span::default(),
73            ));
74        }
75        if opts.reject_console && builder.console_used {
76            builder.diagnostics.push(Diagnostic::error(
77                "PY_CONSOLE_REJECTED",
78                "console IO (print/input) is rejected by transpile options",
79                Span::default(),
80            ));
81        }
82
83        let (module, diagnostics, numpy_used, io_used, console_used) = (
84            module,
85            builder.diagnostics,
86            builder.numpy_used,
87            builder.io_used,
88            builder.console_used,
89        );
90
91        PythonIrResult {
92            module,
93            diagnostics,
94            numpy_used,
95            io_used,
96            console_used,
97        }
98    })
99}
100
101#[cfg(not(feature = "pytranspile"))]
102pub fn python_to_ir(_source: &str, _opts: &TranspileOptions) -> PythonIrResult {
103    let mut diagnostics = Diagnostics::new();
104    diagnostics.push(Diagnostic::error(
105        "PYTRANSPILE_FEATURE_DISABLED",
106        "enable cargo feature `pytranspile` to use python_to_ir",
107        Span::default(),
108    ));
109    PythonIrResult::fail(diagnostics)
110}
111
112#[cfg(feature = "pytranspile")]
113struct IrBuilder {
114    diagnostics: Diagnostics,
115    numpy_used: bool,
116    io_used: bool,
117    console_used: bool,
118    aliases: std::collections::HashMap<String, String>,
119}
120
121#[cfg(feature = "pytranspile")]
122impl IrBuilder {
123    fn new(
124        _py: Python<'_>,
125        _ast: Bound<'_, PyModule>,
126        _opts: &TranspileOptions,
127        diagnostics: Diagnostics,
128    ) -> Self {
129        Self {
130            diagnostics,
131            numpy_used: false,
132            io_used: false,
133            console_used: false,
134            aliases: std::collections::HashMap::new(),
135        }
136    }
137
138    fn span_of(&self, node: &Bound<'_, PyAny>) -> Span {
139        let line = node
140            .getattr("lineno")
141            .ok()
142            .and_then(|v| v.extract().ok())
143            .unwrap_or(0);
144        let col = node
145            .getattr("col_offset")
146            .ok()
147            .and_then(|v| v.extract().ok())
148            .unwrap_or(0);
149        let end_line = node
150            .getattr("end_lineno")
151            .ok()
152            .and_then(|v| v.extract().ok())
153            .unwrap_or(0);
154        let end_col = node
155            .getattr("end_col_offset")
156            .ok()
157            .and_then(|v| v.extract().ok())
158            .unwrap_or(0);
159
160        Span {
161            line,
162            col,
163            end_line,
164            end_col,
165        }
166    }
167
168    fn node_type(&self, node: &Bound<'_, PyAny>) -> Option<String> {
169        node.get_type()
170            .name()
171            .ok()
172            .map(|s| s.to_string_lossy().into_owned())
173    }
174
175    fn type_name(&self, node: &Bound<'_, PyAny>, fallback: &str) -> String {
176        node.get_type()
177            .name()
178            .ok()
179            .map(|s| s.to_string_lossy().into_owned())
180            .unwrap_or_else(|| fallback.to_string())
181    }
182
183    fn emit_module(&mut self, node: &Bound<'_, PyAny>) -> Option<Module> {
184        let span = self.span_of(node);
185        let ty = self.node_type(node)?;
186        if ty != "Module" {
187            self.diagnostics.push(Diagnostic::error(
188                "PY_UNEXPECTED_ROOT",
189                format!("expected ast.Module, found {ty}"),
190                span,
191            ));
192            return None;
193        }
194
195        let body_list = node.getattr("body").ok()?.cast_into::<PyList>().ok()?;
196        let body: Vec<Bound<'_, PyAny>> = body_list.iter().collect();
197
198        let mut out = Vec::new();
199        for stmt in body {
200            out.push(self.emit_stmt(&stmt));
201        }
202
203        Some(Module { span, body: out })
204    }
205
206    fn emit_stmt(&mut self, node: &Bound<'_, PyAny>) -> Stmt {
207        let span = self.span_of(node);
208        let ty = self
209            .node_type(node)
210            .unwrap_or_else(|| "<unknown>".to_string());
211
212        match ty.as_str() {
213            "Assign" => {
214                let targets = node.getattr("targets").ok();
215                if let Some(targets) = targets
216                    && let Ok(list) = targets.cast_into::<PyList>()
217                {
218                    if list.len() != 1 {
219                        return Stmt::Unsupported {
220                            span,
221                            reason: "multiple assignment targets".to_string(),
222                        };
223                    }
224                    let target = self.emit_expr(&list.get_item(0).unwrap());
225                    let value = self.emit_expr(&node.getattr("value").unwrap());
226                    return Stmt::Assign {
227                        span,
228                        target,
229                        value,
230                    };
231                }
232                Stmt::Unsupported {
233                    span,
234                    reason: "invalid Assign".to_string(),
235                }
236            }
237            "AugAssign" => {
238                // target += value  ==>  target = (target + value)
239                let target = self.emit_expr(&node.getattr("target").unwrap());
240                let value = self.emit_expr(&node.getattr("value").unwrap());
241                let op = node.getattr("op").unwrap();
242                let op_name = self.type_name(&op, "<op>");
243                let combined = Expr::BinOp {
244                    span,
245                    op: op_name,
246                    left: Box::new(target.clone()),
247                    right: Box::new(value),
248                };
249                Stmt::Assign {
250                    span,
251                    target,
252                    value: combined,
253                }
254            }
255            "Expr" => {
256                let value = self.emit_expr(&node.getattr("value").unwrap());
257                // detect python print()
258                if let Expr::Call { func, .. } = &value
259                    && let Expr::Name { id, .. } = func.as_ref()
260                    && id == "print"
261                {
262                    self.console_used = true;
263                }
264                Stmt::ExprStmt { span, value }
265            }
266            "Return" => {
267                let value = node.getattr("value").ok().and_then(|v| {
268                    if v.is_none() {
269                        None
270                    } else {
271                        Some(self.emit_expr(&v))
272                    }
273                });
274                Stmt::Return { span, value }
275            }
276            "Pass" => Stmt::Pass { span },
277            "Break" => Stmt::Break { span },
278            "Continue" => Stmt::Continue { span },
279            "If" => {
280                let test = self.emit_expr(&node.getattr("test").unwrap());
281                let body = self.emit_stmt_list(node.getattr("body").unwrap());
282                let orelse = self.emit_stmt_list(node.getattr("orelse").unwrap());
283                Stmt::If {
284                    span,
285                    test,
286                    body,
287                    orelse,
288                }
289            }
290            "While" => {
291                let test = self.emit_expr(&node.getattr("test").unwrap());
292                let body = self.emit_stmt_list(node.getattr("body").unwrap());
293                Stmt::While { span, test, body }
294            }
295            "For" => {
296                let target = self.emit_expr(&node.getattr("target").unwrap());
297                let iter = self.emit_expr(&node.getattr("iter").unwrap());
298                let body = self.emit_stmt_list(node.getattr("body").unwrap());
299                Stmt::For {
300                    span,
301                    target,
302                    iter,
303                    body,
304                }
305            }
306            "FunctionDef" => {
307                let name: String = node.getattr("name").unwrap().extract().unwrap_or_default();
308                let args = node.getattr("args").unwrap();
309                let args_list = args.getattr("args").unwrap();
310                let mut params = Vec::new();
311                if let Ok(list) = args_list.cast_into::<PyList>() {
312                    for item in list.iter() {
313                        let arg_name: String =
314                            item.getattr("arg").unwrap().extract().unwrap_or_default();
315                        params.push(arg_name);
316                    }
317                }
318                let body = self.emit_stmt_list(node.getattr("body").unwrap());
319                Stmt::FunctionDef {
320                    span,
321                    name,
322                    args: params,
323                    body,
324                }
325            }
326            "Import" => {
327                let names = node.getattr("names").unwrap();
328                if let Ok(list) = names.cast_into::<PyList>() {
329                    // one per stmt; expand to multiple Stmt::Import for simplicity
330                    let mut out: Vec<Stmt> = Vec::new();
331                    for alias in list.iter() {
332                        let name: String =
333                            alias.getattr("name").unwrap().extract().unwrap_or_default();
334                        let asname: Option<String> =
335                            alias.getattr("asname").ok().and_then(|v| v.extract().ok());
336                        if name.split('.').next() == Some("numpy")
337                            || asname.as_deref() == Some("np")
338                        {
339                            self.numpy_used = true;
340                        }
341                        if let Some(a) = asname.clone() {
342                            self.aliases.insert(a, name.clone());
343                        }
344                        out.push(Stmt::Import {
345                            span,
346                            module: name,
347                            asname,
348                        });
349                    }
350                    if out.len() == 1 {
351                        return out.remove(0);
352                    }
353                    return Stmt::Unsupported {
354                        span,
355                        reason: "multiple imports in single statement".to_string(),
356                    };
357                }
358                Stmt::Unsupported {
359                    span,
360                    reason: "invalid Import".to_string(),
361                }
362            }
363            "ImportFrom" => {
364                let module: String = node
365                    .getattr("module")
366                    .ok()
367                    .and_then(|v| v.extract().ok())
368                    .unwrap_or_default();
369                if module.split('.').next() == Some("numpy") {
370                    self.numpy_used = true;
371                }
372                let names = node.getattr("names").unwrap();
373                if let Ok(list) = names.cast_into::<PyList>() {
374                    if list.len() != 1 {
375                        return Stmt::Unsupported {
376                            span,
377                            reason: "multiple from-import names".to_string(),
378                        };
379                    }
380                    let alias = list.get_item(0).unwrap();
381                    let name: String = alias.getattr("name").unwrap().extract().unwrap_or_default();
382                    let asname: Option<String> =
383                        alias.getattr("asname").ok().and_then(|v| v.extract().ok());
384                    if let Some(a) = asname.clone() {
385                        self.aliases.insert(a, format!("{module}.{name}"));
386                    }
387                    return Stmt::ImportFrom {
388                        span,
389                        module,
390                        name,
391                        asname,
392                    };
393                }
394                Stmt::Unsupported {
395                    span,
396                    reason: "invalid ImportFrom".to_string(),
397                }
398            }
399            _ => Stmt::Unsupported { span, reason: ty },
400        }
401    }
402
403    fn emit_stmt_list(&mut self, list_any: Bound<'_, PyAny>) -> Vec<Stmt> {
404        if let Ok(list) = list_any.cast_into::<PyList>() {
405            list.iter().map(|n| self.emit_stmt(&n)).collect()
406        } else {
407            vec![Stmt::Unsupported {
408                span: Span::default(),
409                reason: "expected list".to_string(),
410            }]
411        }
412    }
413
414    fn emit_expr(&mut self, node: &Bound<'_, PyAny>) -> Expr {
415        let span = self.span_of(node);
416        let ty = self
417            .node_type(node)
418            .unwrap_or_else(|| "<unknown>".to_string());
419
420        match ty.as_str() {
421            "Name" => {
422                let id: String = node.getattr("id").unwrap().extract().unwrap_or_default();
423                Expr::Name { span, id }
424            }
425            "Constant" => {
426                if node.is_none() {
427                    return Expr::None { span };
428                }
429                let value = node.getattr("value").ok();
430                if let Some(value) = value {
431                    if value.is_none() {
432                        return Expr::None { span };
433                    }
434                    if let Ok(v) = value.extract::<bool>() {
435                        return Expr::Bool { span, value: v };
436                    }
437                    if let Ok(v) = value.extract::<i64>() {
438                        return Expr::Number {
439                            span,
440                            value: v as f64,
441                        };
442                    }
443                    if let Ok(v) = value.extract::<f64>() {
444                        return Expr::Number { span, value: v };
445                    }
446                    if let Ok(v) = value.extract::<String>() {
447                        return Expr::String { span, value: v };
448                    }
449                }
450                Expr::Unsupported {
451                    span,
452                    reason: "unsupported constant".to_string(),
453                }
454            }
455            "BinOp" => {
456                let op = node.getattr("op").unwrap();
457                let op_name = self.type_name(&op, "<op>");
458                let left = self.emit_expr(&node.getattr("left").unwrap());
459                let right = self.emit_expr(&node.getattr("right").unwrap());
460                Expr::BinOp {
461                    span,
462                    op: op_name,
463                    left: Box::new(left),
464                    right: Box::new(right),
465                }
466            }
467            "UnaryOp" => {
468                let op = node.getattr("op").unwrap();
469                let op_name = self.type_name(&op, "<op>");
470                let operand = self.emit_expr(&node.getattr("operand").unwrap());
471                Expr::UnaryOp {
472                    span,
473                    op: op_name,
474                    operand: Box::new(operand),
475                }
476            }
477            "BoolOp" => {
478                let op = node.getattr("op").unwrap();
479                let op_name = self.type_name(&op, "<op>");
480                let values_any = node.getattr("values").unwrap();
481                if let Ok(list) = values_any.cast_into::<PyList>() {
482                    let values = list.iter().map(|n| self.emit_expr(&n)).collect();
483                    Expr::BoolOp {
484                        span,
485                        op: op_name,
486                        values,
487                    }
488                } else {
489                    Expr::Unsupported {
490                        span,
491                        reason: "invalid BoolOp".to_string(),
492                    }
493                }
494            }
495            "Compare" => {
496                let ops_any = node.getattr("ops").unwrap();
497                let comps_any = node.getattr("comparators").unwrap();
498                if let (Ok(ops), Ok(comps)) = (
499                    ops_any.cast_into::<PyList>(),
500                    comps_any.cast_into::<PyList>(),
501                ) {
502                    if ops.len() != 1 || comps.len() != 1 {
503                        return Expr::Unsupported {
504                            span,
505                            reason: "chained comparison".to_string(),
506                        };
507                    }
508                    let op_name = self.type_name(&ops.get_item(0).unwrap(), "<op>");
509                    let left = self.emit_expr(&node.getattr("left").unwrap());
510                    let right = self.emit_expr(&comps.get_item(0).unwrap());
511                    return Expr::Compare {
512                        span,
513                        op: op_name,
514                        left: Box::new(left),
515                        right: Box::new(right),
516                    };
517                }
518                Expr::Unsupported {
519                    span,
520                    reason: "invalid Compare".to_string(),
521                }
522            }
523            "Call" => {
524                let func = self.emit_expr(&node.getattr("func").unwrap());
525                let args_any = node.getattr("args").unwrap();
526                let mut args = Vec::new();
527                if let Ok(list) = args_any.cast_into::<PyList>() {
528                    for item in list.iter() {
529                        args.push(self.emit_expr(&item));
530                    }
531                }
532
533                // Precheck: filesystem/network based on callee
534                self.inspect_call_for_io(span, &func);
535
536                Expr::Call {
537                    span,
538                    func: Box::new(func),
539                    args,
540                }
541            }
542            "Attribute" => {
543                let value = self.emit_expr(&node.getattr("value").unwrap());
544                let attr: String = node.getattr("attr").unwrap().extract().unwrap_or_default();
545
546                // Detect numpy usage via alias `np.*`
547                if let Expr::Name { id, .. } = &value {
548                    if id == "np" {
549                        self.numpy_used = true;
550                    }
551                    if let Some(resolved) = self.aliases.get(id)
552                        && resolved.split('.').next() == Some("numpy")
553                    {
554                        self.numpy_used = true;
555                    }
556                }
557
558                Expr::Attribute {
559                    span,
560                    value: Box::new(value),
561                    attr,
562                }
563            }
564            "List" => {
565                let elts_any = node.getattr("elts").unwrap();
566                if let Ok(list) = elts_any.cast_into::<PyList>() {
567                    let elts = list.iter().map(|n| self.emit_expr(&n)).collect();
568                    Expr::List { span, elts }
569                } else {
570                    Expr::Unsupported {
571                        span,
572                        reason: "invalid List".to_string(),
573                    }
574                }
575            }
576            "Dict" => {
577                let keys_any = node.getattr("keys").unwrap();
578                let values_any = node.getattr("values").unwrap();
579                if let (Ok(keys), Ok(values)) = (
580                    keys_any.cast_into::<PyList>(),
581                    values_any.cast_into::<PyList>(),
582                ) {
583                    let mut items = Vec::new();
584                    for (k, v) in keys.iter().zip(values.iter()) {
585                        // Dict unpacking like {**x} appears as a None key in Python AST.
586                        if k.is_none() {
587                            return Expr::Unsupported {
588                                span,
589                                reason: "dict unpack is not supported".to_string(),
590                            };
591                        }
592                        items.push((self.emit_expr(&k), self.emit_expr(&v)));
593                    }
594                    Expr::Dict { span, items }
595                } else {
596                    Expr::Unsupported {
597                        span,
598                        reason: "invalid Dict".to_string(),
599                    }
600                }
601            }
602            "Subscript" => {
603                let value = self.emit_expr(&node.getattr("value").unwrap());
604                let slice = node.getattr("slice").unwrap();
605                // python 3.9+: slice is Expr; we reject Slice objects
606                let idx_ty = self.node_type(&slice).unwrap_or_default();
607                if idx_ty == "Slice" {
608                    return Expr::Unsupported {
609                        span,
610                        reason: "slice is not supported".to_string(),
611                    };
612                }
613                let index = self.emit_expr(&slice);
614                Expr::Subscript {
615                    span,
616                    value: Box::new(value),
617                    index: Box::new(index),
618                }
619            }
620            _ => Expr::Unsupported { span, reason: ty },
621        }
622    }
623
624    fn inspect_call_for_io(&mut self, _span: Span, func: &Expr) {
625        fn leftmost_name(expr: &Expr) -> Option<&str> {
626            match expr {
627                Expr::Name { id, .. } => Some(id.as_str()),
628                Expr::Attribute { value, .. } => leftmost_name(value.as_ref()),
629                _ => None,
630            }
631        }
632
633        if let Some(base) = leftmost_name(func) {
634            if base == "open" {
635                self.io_used = true;
636            }
637            if base == "socket" || base == "urllib" || base == "requests" || base == "http" {
638                self.io_used = true;
639            }
640            if base == "print" {
641                self.console_used = true;
642            }
643            if base == "os" || base == "pathlib" || base == "shutil" {
644                self.io_used = true;
645            }
646        }
647    }
648}