Skip to main content

oak_python/ast/
mod.rs

1#![doc = include_str!("readme.md")]
2use core::range::Range;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6/// Root node of a Python source file.
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8#[derive(Debug, Clone, PartialEq)]
9pub struct PythonRoot {
10    /// The program structure
11    pub program: Program,
12    /// Source code span
13    #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
14    pub span: Range<usize>,
15}
16
17/// A Python program consisting of a list of statements.
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Debug, Clone, PartialEq)]
20pub struct Program {
21    /// List of statements in the program
22    pub statements: Vec<Statement>,
23}
24
25/// Represents a Python statement.
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, PartialEq)]
28pub enum Statement {
29    /// Function definition
30    FunctionDef {
31        /// Decorators applied to the function
32        decorators: Vec<Expression>,
33        /// Function name
34        name: String,
35        /// List of parameters
36        parameters: Vec<Parameter>,
37        /// Optional return type annotation
38        return_type: Option<Type>,
39        /// Function body
40        body: Vec<Statement>,
41    },
42    /// Async function definition
43    AsyncFunctionDef {
44        /// Decorators applied to the function
45        decorators: Vec<Expression>,
46        /// Function name
47        name: String,
48        /// List of parameters
49        parameters: Vec<Parameter>,
50        /// Optional return type annotation
51        return_type: Option<Type>,
52        /// Function body
53        body: Vec<Statement>,
54    },
55    /// Class definition
56    ClassDef {
57        /// Decorators applied to the class
58        decorators: Vec<Expression>,
59        /// Class name
60        name: String,
61        /// Base classes
62        bases: Vec<Expression>,
63        /// Class body
64        body: Vec<Statement>,
65    },
66    /// Variable assignment
67    Assignment {
68        /// Target expression
69        target: Expression,
70        /// Value expression
71        value: Expression,
72    },
73    /// Augmented assignment (e.g., `+=`, `-=`)
74    AugmentedAssignment {
75        /// Target expression
76        target: Expression,
77        /// Augmented operator
78        operator: AugmentedOperator,
79        /// Value expression
80        value: Expression,
81    },
82    /// Expression statement
83    Expression(Expression),
84    /// Return statement
85    Return(Option<Expression>),
86    /// If statement
87    If {
88        /// Test expression
89        test: Expression,
90        /// Body of the if block
91        body: Vec<Statement>,
92        /// Else block (or empty)
93        orelse: Vec<Statement>,
94    },
95    /// For loop
96    For {
97        /// Loop target
98        target: Expression,
99        /// Iterable expression
100        iter: Expression,
101        /// Loop body
102        body: Vec<Statement>,
103        /// Else block (or empty)
104        orelse: Vec<Statement>,
105    },
106    /// Async for loop
107    AsyncFor {
108        /// Loop target
109        target: Expression,
110        /// Iterable expression
111        iter: Expression,
112        /// Loop body
113        body: Vec<Statement>,
114        /// Else block (or empty)
115        orelse: Vec<Statement>,
116    },
117    /// While loop
118    While {
119        /// Test expression
120        test: Expression,
121        /// Loop body
122        body: Vec<Statement>,
123        /// Else block (or empty)
124        orelse: Vec<Statement>,
125    },
126    /// Break statement
127    Break,
128    /// Continue statement
129    Continue,
130    /// Pass statement
131    Pass,
132    /// Import statement
133    Import {
134        /// List of names being imported
135        names: Vec<ImportName>,
136    },
137    /// From-import statement
138    ImportFrom {
139        /// Optional module name
140        module: Option<String>,
141        /// List of names being imported
142        names: Vec<ImportName>,
143    },
144    /// Global statement
145    Global {
146        /// List of global names
147        names: Vec<String>,
148    },
149    /// Nonlocal statement
150    Nonlocal {
151        /// List of nonlocal names
152        names: Vec<String>,
153    },
154    /// Try statement
155    Try {
156        /// Try body
157        body: Vec<Statement>,
158        /// Exception handlers
159        handlers: Vec<ExceptHandler>,
160        /// Else block
161        orelse: Vec<Statement>,
162        /// Finally block
163        finalbody: Vec<Statement>,
164    },
165    /// Raise statement
166    Raise {
167        /// Optional exception
168        exc: Option<Expression>,
169        /// Optional cause
170        cause: Option<Expression>,
171    },
172    /// With statement
173    With {
174        /// With items
175        items: Vec<WithItem>,
176        /// With body
177        body: Vec<Statement>,
178    },
179    /// Async with statement
180    AsyncWith {
181        /// With items
182        items: Vec<WithItem>,
183        /// With body
184        body: Vec<Statement>,
185    },
186    /// Assert statement
187    Assert {
188        /// Test expression
189        test: Expression,
190        /// Optional error message
191        msg: Option<Expression>,
192    },
193    /// Match statement
194    Match {
195        /// Subject expression
196        subject: Expression,
197        /// Match cases
198        cases: Vec<MatchCase>,
199    },
200    /// Delete statement
201    Delete {
202        /// Targets to delete
203        targets: Vec<Expression>,
204    },
205}
206
207/// Represents a case in a match statement.
208#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
209#[derive(Debug, Clone, PartialEq)]
210pub struct MatchCase {
211    /// Pattern to match
212    pub pattern: Pattern,
213    /// Optional guard expression
214    pub guard: Option<Expression>,
215    /// Case body
216    pub body: Vec<Statement>,
217}
218
219/// Represents a pattern in a match case.
220#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
221#[derive(Debug, Clone, PartialEq)]
222pub enum Pattern {
223    /// Value pattern
224    Value(Expression),
225    /// Wildcard pattern
226    Wildcard,
227    /// As pattern
228    As {
229        /// Optional sub-pattern
230        pattern: Option<Box<Pattern>>,
231        /// Target name
232        name: String,
233    },
234    /// Sequence pattern
235    Sequence(Vec<Pattern>),
236    /// Mapping pattern
237    Mapping {
238        /// Keys to match
239        keys: Vec<Expression>,
240        /// Corresponding patterns
241        patterns: Vec<Pattern>,
242    },
243    /// Class pattern
244    Class {
245        /// Class expression
246        cls: Expression,
247        /// Positional patterns
248        patterns: Vec<Pattern>,
249        /// Keyword names
250        keywords: Vec<String>,
251        /// Keyword patterns
252        keyword_patterns: Vec<Pattern>,
253    },
254    /// Or pattern
255    Or(Vec<Pattern>),
256}
257
258/// Represents a Python expression.
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260#[derive(Debug, Clone, PartialEq)]
261pub enum Expression {
262    /// Literal value
263    Literal(Literal),
264    /// Identifier name
265    Name(String),
266    /// Binary operation
267    BinaryOp {
268        /// Left operand
269        left: Box<Expression>,
270        /// Binary operator
271        operator: BinaryOperator,
272        /// Right operand
273        right: Box<Expression>,
274    },
275    /// Unary operation
276    UnaryOp {
277        /// Unary operator
278        operator: UnaryOperator,
279        /// Operand
280        operand: Box<Expression>,
281    },
282    /// Boolean operation (and, or)
283    BoolOp {
284        /// Boolean operator
285        operator: BoolOperator,
286        /// List of values
287        values: Vec<Expression>,
288    },
289    /// Comparison operation
290    Compare {
291        /// Leftmost operand
292        left: Box<Expression>,
293        /// Comparison operators
294        ops: Vec<CompareOperator>,
295        /// Subsequent operands
296        comparators: Vec<Expression>,
297    },
298    /// Function call
299    Call {
300        /// Function being called
301        func: Box<Expression>,
302        /// Positional arguments
303        args: Vec<Expression>,
304        /// Keyword arguments
305        keywords: Vec<Keyword>,
306    },
307    /// Attribute access
308    Attribute {
309        /// Base expression
310        value: Box<Expression>,
311        /// Attribute name
312        attr: String,
313    },
314    /// Subscript access
315    Subscript {
316        /// Base expression
317        value: Box<Expression>,
318        /// Slice or index expression
319        slice: Box<Expression>,
320    },
321    /// List literal
322    List {
323        /// List elements
324        elts: Vec<Expression>,
325    },
326    /// Tuple literal
327    Tuple {
328        /// Tuple elements
329        elts: Vec<Expression>,
330    },
331    /// Slice expression
332    Slice {
333        /// Optional lower bound
334        lower: Option<Box<Expression>>,
335        /// Optional upper bound
336        upper: Option<Box<Expression>>,
337        /// Optional step
338        step: Option<Box<Expression>>,
339    },
340    /// Dictionary literal
341    Dict {
342        /// Optional keys
343        keys: Vec<Option<Expression>>,
344        /// Values
345        values: Vec<Expression>,
346    },
347    /// Set literal
348    Set {
349        /// Set elements
350        elts: Vec<Expression>,
351    },
352    /// List comprehension
353    ListComp {
354        /// Result expression
355        elt: Box<Expression>,
356        /// Generators
357        generators: Vec<Comprehension>,
358    },
359    /// Dictionary comprehension
360    DictComp {
361        /// Key expression
362        key: Box<Expression>,
363        /// Value expression
364        value: Box<Expression>,
365        /// Generators
366        generators: Vec<Comprehension>,
367    },
368    /// Set comprehension
369    SetComp {
370        /// Result expression
371        elt: Box<Expression>,
372        /// Generators
373        generators: Vec<Comprehension>,
374    },
375    /// Generator expression
376    GeneratorExp {
377        /// Result expression
378        elt: Box<Expression>,
379        /// Generators
380        generators: Vec<Comprehension>,
381    },
382    /// Lambda expression
383    Lambda {
384        /// Lambda arguments
385        args: Vec<Parameter>,
386        /// Lambda body
387        body: Box<Expression>,
388    },
389    /// Conditional expression (ternary operator)
390    IfExp {
391        /// Test expression
392        test: Box<Expression>,
393        /// Body expression
394        body: Box<Expression>,
395        /// Else expression
396        orelse: Box<Expression>,
397    },
398    /// f-string
399    JoinedStr {
400        /// f-string parts
401        values: Vec<Expression>,
402    },
403    /// Formatted value within an f-string
404    FormattedValue {
405        /// Value to format
406        value: Box<Expression>,
407        /// Conversion type
408        conversion: usize,
409        /// Optional format specification
410        format_spec: Option<Box<Expression>>,
411    },
412    /// Yield expression
413    Yield(Option<Box<Expression>>),
414    /// Yield from expression
415    YieldFrom(Box<Expression>),
416    /// Await expression
417    Await(Box<Expression>),
418    /// Starred expression (*args, **kwargs)
419    Starred {
420        /// Value being starred
421        value: Box<Expression>,
422        /// Whether it's a double star (**kwargs)
423        is_double: bool,
424    },
425}
426
427/// Represents a literal value.
428#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
429#[derive(Debug, Clone, PartialEq)]
430pub enum Literal {
431    /// Integer literal
432    Integer(i64),
433    /// Float literal
434    Float(f64),
435    /// String literal
436    String(String),
437    /// Bytes literal
438    Bytes(Vec<u8>),
439    /// Boolean literal
440    Boolean(bool),
441    /// None literal
442    None,
443}
444
445/// Represents an augmented assignment operator.
446#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
447#[derive(Debug, Clone, PartialEq)]
448pub enum AugmentedOperator {
449    /// `+=`
450    Add,
451    /// `-=`
452    Sub,
453    /// `*=`
454    Mult,
455    /// `/=`
456    Div,
457    /// `//= `
458    FloorDiv,
459    /// `%=`
460    Mod,
461    /// `**=`
462    Pow,
463    /// `<<=`
464    LShift,
465    /// `>>=`
466    RShift,
467    /// `|=`
468    BitOr,
469    /// `^=`
470    BitXor,
471    /// `&=`
472    BitAnd,
473}
474
475/// Represents a binary operator.
476#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
477#[derive(Debug, Clone, PartialEq)]
478pub enum BinaryOperator {
479    /// `+`
480    Add,
481    /// `-`
482    Sub,
483    /// `*`
484    Mult,
485    /// `/`
486    Div,
487    /// `//`
488    FloorDiv,
489    /// `%`
490    Mod,
491    /// `**`
492    Pow,
493    /// `<<`
494    LShift,
495    /// `>>`
496    RShift,
497    /// `|`
498    BitOr,
499    /// `^`
500    BitXor,
501    /// `&`
502    BitAnd,
503}
504
505/// Represents a unary operator.
506#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
507#[derive(Debug, Clone, PartialEq)]
508pub enum UnaryOperator {
509    /// `~`
510    Invert,
511    /// `not`
512    Not,
513    /// `+`
514    UAdd,
515    /// `-`
516    USub,
517}
518
519/// Represents a boolean operator.
520#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
521#[derive(Debug, Clone, PartialEq)]
522pub enum BoolOperator {
523    /// `and`
524    And,
525    /// `or`
526    Or,
527}
528
529/// Represents a comparison operator.
530#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
531#[derive(Debug, Clone, PartialEq)]
532pub enum CompareOperator {
533    /// `==`
534    Eq,
535    /// `!=`
536    NotEq,
537    /// `<`
538    Lt,
539    /// `<=`
540    LtE,
541    /// `>`
542    Gt,
543    /// `>=`
544    GtE,
545    /// `is`
546    Is,
547    /// `is not`
548    IsNot,
549    /// `in`
550    In,
551    /// `not in`
552    NotIn,
553}
554
555/// Represents a function parameter.
556#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
557#[derive(Debug, Clone, PartialEq)]
558pub struct Parameter {
559    /// Parameter name
560    pub name: String,
561    /// Optional type annotation
562    pub annotation: Option<Type>,
563    /// Optional default value
564    pub default: Option<Expression>,
565    /// Whether it's a variable positional argument (*args)
566    pub is_vararg: bool,
567    /// Whether it's a variable keyword argument (**kwargs)
568    pub is_kwarg: bool,
569}
570
571/// Represents a type annotation.
572#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
573#[derive(Debug, Clone, PartialEq)]
574pub enum Type {
575    /// Basic type name
576    Name(String),
577    /// Generic type
578    Generic {
579        /// Type name
580        name: String,
581        /// Type arguments
582        args: Vec<Type>,
583    },
584    /// Union type
585    Union(Vec<Type>),
586    /// Optional type
587    Optional(Box<Type>),
588}
589
590/// Represents a keyword argument.
591#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
592#[derive(Debug, Clone, PartialEq)]
593pub struct Keyword {
594    /// Optional argument name
595    pub arg: Option<String>,
596    /// Argument value
597    pub value: Expression,
598}
599
600/// Represents a comprehension in a list/dict/set/generator.
601#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
602#[derive(Debug, Clone, PartialEq)]
603pub struct Comprehension {
604    /// Target expression
605    pub target: Expression,
606    /// Iterable expression
607    pub iter: Expression,
608    /// Optional conditions
609    pub ifs: Vec<Expression>,
610    /// Whether it's an async comprehension
611    pub is_async: bool,
612}
613
614/// Represents a name in an import statement.
615#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
616#[derive(Debug, Clone, PartialEq)]
617pub struct ImportName {
618    /// Name being imported
619    pub name: String,
620    /// Optional alias (asname)
621    pub asname: Option<String>,
622}
623
624/// Represents an exception handler in a try statement.
625#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
626#[derive(Debug, Clone, PartialEq)]
627pub struct ExceptHandler {
628    /// Optional exception type
629    pub type_: Option<Expression>,
630    /// Optional name for the exception instance
631    pub name: Option<String>,
632    /// Handler body
633    pub body: Vec<Statement>,
634}
635
636/// Represents an item in a with statement.
637#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
638#[derive(Debug, Clone, PartialEq)]
639pub struct WithItem {
640    /// Context manager expression
641    pub context_expr: Expression,
642    /// Optional variables to bind to
643    pub optional_vars: Option<Expression>,
644}
645
646impl Program {
647    /// Creates a new empty program.
648    pub fn new() -> Self {
649        Self { statements: Vec::new() }
650    }
651
652    /// Adds a statement to the program.
653    pub fn add_statement(&mut self, statement: Statement) {
654        self.statements.push(statement)
655    }
656}
657
658impl Default for Program {
659    fn default() -> Self {
660        Self::new()
661    }
662}
663
664impl Expression {
665    /// Creates a name expression.
666    pub fn name(name: impl Into<String>) -> Self {
667        Self::Name(name.into())
668    }
669
670    /// Creates a string literal expression.
671    pub fn string(value: impl Into<String>) -> Self {
672        Self::Literal(Literal::String(value.into()))
673    }
674
675    /// Creates an integer literal expression.
676    pub fn integer(value: i64) -> Self {
677        Self::Literal(Literal::Integer(value))
678    }
679
680    /// Creates a float literal expression.
681    pub fn float(value: f64) -> Self {
682        Self::Literal(Literal::Float(value))
683    }
684
685    /// Creates a boolean literal expression.
686    pub fn boolean(value: bool) -> Self {
687        Self::Literal(Literal::Boolean(value))
688    }
689
690    /// Creates a None literal expression.
691    pub fn none() -> Self {
692        Self::Literal(Literal::None)
693    }
694}
695
696impl Statement {
697    /// Creates a function definition statement.
698    pub fn function_def(name: impl Into<String>, parameters: Vec<Parameter>, return_type: Option<Type>, body: Vec<Statement>) -> Self {
699        Self::FunctionDef { decorators: Vec::new(), name: name.into(), parameters, return_type, body }
700    }
701
702    /// Creates an assignment statement.
703    pub fn assignment(target: Expression, value: Expression) -> Self {
704        Self::Assignment { target, value }
705    }
706
707    /// Creates an expression statement.
708    pub fn expression(expr: Expression) -> Self {
709        Self::Expression(expr)
710    }
711
712    /// Creates a return statement.
713    pub fn return_stmt(value: Option<Expression>) -> Self {
714        Self::Return(value)
715    }
716}