Skip to main content

oak_python/ast/
mod.rs

1#![doc = include_str!("readme.md")]
2use core::range::Range;
3
4/// Root node of a Python source file.
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[derive(Debug, Clone, PartialEq)]
7pub struct PythonRoot {
8    /// The program structure
9    pub program: Program,
10    /// Source code span
11    #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
12    pub span: Range<usize>,
13}
14
15/// A Python program consisting of a list of statements.
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17#[derive(Debug, Clone, PartialEq)]
18pub struct Program {
19    /// List of statements in the program
20    pub statements: Vec<Statement>,
21}
22
23/// Represents a Python statement.
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, PartialEq)]
26pub enum Statement {
27    /// Function definition
28    FunctionDef {
29        /// Decorators applied to the function
30        decorators: Vec<Expression>,
31        /// Function name
32        name: String,
33        /// List of parameters
34        parameters: Vec<Parameter>,
35        /// Optional return type annotation
36        return_type: Option<Type>,
37        /// Function body
38        body: Vec<Statement>,
39    },
40    /// Async function definition
41    AsyncFunctionDef {
42        /// Decorators applied to the function
43        decorators: Vec<Expression>,
44        /// Function name
45        name: String,
46        /// List of parameters
47        parameters: Vec<Parameter>,
48        /// Optional return type annotation
49        return_type: Option<Type>,
50        /// Function body
51        body: Vec<Statement>,
52    },
53    /// Class definition
54    ClassDef {
55        /// Decorators applied to the class
56        decorators: Vec<Expression>,
57        /// Class name
58        name: String,
59        /// Base classes
60        bases: Vec<Expression>,
61        /// Class body
62        body: Vec<Statement>,
63    },
64    /// Variable assignment
65    Assignment {
66        /// Target expression
67        target: Expression,
68        /// Value expression
69        value: Expression,
70    },
71    /// Augmented assignment (e.g., `+=`, `-=`)
72    AugmentedAssignment {
73        /// Target expression
74        target: Expression,
75        /// Augmented operator
76        operator: AugmentedOperator,
77        /// Value expression
78        value: Expression,
79    },
80    /// Expression statement
81    Expression(Expression),
82    /// Return statement
83    Return(Option<Expression>),
84    /// If statement
85    If {
86        /// Test expression
87        test: Expression,
88        /// Body of the if block
89        body: Vec<Statement>,
90        /// Else block (or empty)
91        orelse: Vec<Statement>,
92    },
93    /// For loop
94    For {
95        /// Loop target
96        target: Expression,
97        /// Iterable expression
98        iter: Expression,
99        /// Loop body
100        body: Vec<Statement>,
101        /// Else block (or empty)
102        orelse: Vec<Statement>,
103    },
104    /// Async for loop
105    AsyncFor {
106        /// Loop target
107        target: Expression,
108        /// Iterable expression
109        iter: Expression,
110        /// Loop body
111        body: Vec<Statement>,
112        /// Else block (or empty)
113        orelse: Vec<Statement>,
114    },
115    /// While loop
116    While {
117        /// Test expression
118        test: Expression,
119        /// Loop body
120        body: Vec<Statement>,
121        /// Else block (or empty)
122        orelse: Vec<Statement>,
123    },
124    /// Break statement
125    Break,
126    /// Continue statement
127    Continue,
128    /// Pass statement
129    Pass,
130    /// Import statement
131    Import {
132        /// List of names being imported
133        names: Vec<ImportName>,
134    },
135    /// From-import statement
136    ImportFrom {
137        /// Optional module name
138        module: Option<String>,
139        /// List of names being imported
140        names: Vec<ImportName>,
141    },
142    /// Global statement
143    Global {
144        /// List of global names
145        names: Vec<String>,
146    },
147    /// Nonlocal statement
148    Nonlocal {
149        /// List of nonlocal names
150        names: Vec<String>,
151    },
152    /// Try statement
153    Try {
154        /// Try body
155        body: Vec<Statement>,
156        /// Exception handlers
157        handlers: Vec<ExceptHandler>,
158        /// Else block
159        orelse: Vec<Statement>,
160        /// Finally block
161        finalbody: Vec<Statement>,
162    },
163    /// Raise statement
164    Raise {
165        /// Optional exception
166        exc: Option<Expression>,
167        /// Optional cause
168        cause: Option<Expression>,
169    },
170    /// With statement
171    With {
172        /// With items
173        items: Vec<WithItem>,
174        /// With body
175        body: Vec<Statement>,
176    },
177    /// Async with statement
178    AsyncWith {
179        /// With items
180        items: Vec<WithItem>,
181        /// With body
182        body: Vec<Statement>,
183    },
184    /// Assert statement
185    Assert {
186        /// Test expression
187        test: Expression,
188        /// Optional error message
189        msg: Option<Expression>,
190    },
191    /// Match statement
192    Match {
193        /// Subject expression
194        subject: Expression,
195        /// Match cases
196        cases: Vec<MatchCase>,
197    },
198    /// Delete statement
199    Delete {
200        /// Targets to delete
201        targets: Vec<Expression>,
202    },
203}
204
205/// Represents a case in a match statement.
206#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
207#[derive(Debug, Clone, PartialEq)]
208pub struct MatchCase {
209    /// Pattern to match
210    pub pattern: Pattern,
211    /// Optional guard expression
212    pub guard: Option<Expression>,
213    /// Case body
214    pub body: Vec<Statement>,
215}
216
217/// Represents a pattern in a match case.
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
219#[derive(Debug, Clone, PartialEq)]
220pub enum Pattern {
221    /// Value pattern
222    Value(Expression),
223    /// Wildcard pattern
224    Wildcard,
225    /// As pattern
226    As {
227        /// Optional sub-pattern
228        pattern: Option<Box<Pattern>>,
229        /// Target name
230        name: String,
231    },
232    /// Sequence pattern
233    Sequence(Vec<Pattern>),
234    /// Mapping pattern
235    Mapping {
236        /// Keys to match
237        keys: Vec<Expression>,
238        /// Corresponding patterns
239        patterns: Vec<Pattern>,
240    },
241    /// Class pattern
242    Class {
243        /// Class expression
244        cls: Expression,
245        /// Positional patterns
246        patterns: Vec<Pattern>,
247        /// Keyword names
248        keywords: Vec<String>,
249        /// Keyword patterns
250        keyword_patterns: Vec<Pattern>,
251    },
252    /// Or pattern
253    Or(Vec<Pattern>),
254}
255
256/// Represents a Python expression.
257#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
258#[derive(Debug, Clone, PartialEq)]
259pub enum Expression {
260    /// Literal value
261    Literal(Literal),
262    /// Identifier name
263    Name(String),
264    /// Binary operation
265    BinaryOp {
266        /// Left operand
267        left: Box<Expression>,
268        /// Binary operator
269        operator: BinaryOperator,
270        /// Right operand
271        right: Box<Expression>,
272    },
273    /// Unary operation
274    UnaryOp {
275        /// Unary operator
276        operator: UnaryOperator,
277        /// Operand
278        operand: Box<Expression>,
279    },
280    /// Boolean operation (and, or)
281    BoolOp {
282        /// Boolean operator
283        operator: BoolOperator,
284        /// List of values
285        values: Vec<Expression>,
286    },
287    /// Comparison operation
288    Compare {
289        /// Leftmost operand
290        left: Box<Expression>,
291        /// Comparison operators
292        ops: Vec<CompareOperator>,
293        /// Subsequent operands
294        comparators: Vec<Expression>,
295    },
296    /// Function call
297    Call {
298        /// Function being called
299        func: Box<Expression>,
300        /// Positional arguments
301        args: Vec<Expression>,
302        /// Keyword arguments
303        keywords: Vec<Keyword>,
304    },
305    /// Attribute access
306    Attribute {
307        /// Base expression
308        value: Box<Expression>,
309        /// Attribute name
310        attr: String,
311    },
312    /// Subscript access
313    Subscript {
314        /// Base expression
315        value: Box<Expression>,
316        /// Slice or index expression
317        slice: Box<Expression>,
318    },
319    /// List literal
320    List {
321        /// List elements
322        elts: Vec<Expression>,
323    },
324    /// Tuple literal
325    Tuple {
326        /// Tuple elements
327        elts: Vec<Expression>,
328    },
329    /// Slice expression
330    Slice {
331        /// Optional lower bound
332        lower: Option<Box<Expression>>,
333        /// Optional upper bound
334        upper: Option<Box<Expression>>,
335        /// Optional step
336        step: Option<Box<Expression>>,
337    },
338    /// Dictionary literal
339    Dict {
340        /// Optional keys
341        keys: Vec<Option<Expression>>,
342        /// Values
343        values: Vec<Expression>,
344    },
345    /// Set literal
346    Set {
347        /// Set elements
348        elts: Vec<Expression>,
349    },
350    /// List comprehension
351    ListComp {
352        /// Result expression
353        elt: Box<Expression>,
354        /// Generators
355        generators: Vec<Comprehension>,
356    },
357    /// Dictionary comprehension
358    DictComp {
359        /// Key expression
360        key: Box<Expression>,
361        /// Value expression
362        value: Box<Expression>,
363        /// Generators
364        generators: Vec<Comprehension>,
365    },
366    /// Set comprehension
367    SetComp {
368        /// Result expression
369        elt: Box<Expression>,
370        /// Generators
371        generators: Vec<Comprehension>,
372    },
373    /// Generator expression
374    GeneratorExp {
375        /// Result expression
376        elt: Box<Expression>,
377        /// Generators
378        generators: Vec<Comprehension>,
379    },
380    /// Lambda expression
381    Lambda {
382        /// Lambda arguments
383        args: Vec<Parameter>,
384        /// Lambda body
385        body: Box<Expression>,
386    },
387    /// Conditional expression (ternary operator)
388    IfExp {
389        /// Test expression
390        test: Box<Expression>,
391        /// Body expression
392        body: Box<Expression>,
393        /// Else expression
394        orelse: Box<Expression>,
395    },
396    /// f-string
397    JoinedStr {
398        /// f-string parts
399        values: Vec<Expression>,
400    },
401    /// Formatted value within an f-string
402    FormattedValue {
403        /// Value to format
404        value: Box<Expression>,
405        /// Conversion type
406        conversion: usize,
407        /// Optional format specification
408        format_spec: Option<Box<Expression>>,
409    },
410    /// Yield expression
411    Yield(Option<Box<Expression>>),
412    /// Yield from expression
413    YieldFrom(Box<Expression>),
414    /// Await expression
415    Await(Box<Expression>),
416    /// Starred expression (*args, **kwargs)
417    Starred {
418        /// Value being starred
419        value: Box<Expression>,
420        /// Whether it's a double star (**kwargs)
421        is_double: bool,
422    },
423}
424
425/// Represents a literal value.
426#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
427#[derive(Debug, Clone, PartialEq)]
428pub enum Literal {
429    /// Integer literal
430    Integer(i64),
431    /// Float literal
432    Float(f64),
433    /// String literal
434    String(String),
435    /// Bytes literal
436    Bytes(Vec<u8>),
437    /// Boolean literal
438    Boolean(bool),
439    /// None literal
440    None,
441}
442
443/// Represents an augmented assignment operator.
444#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
445#[derive(Debug, Clone, PartialEq)]
446pub enum AugmentedOperator {
447    /// `+=`
448    Add,
449    /// `-=`
450    Sub,
451    /// `*=`
452    Mult,
453    /// `/=`
454    Div,
455    /// `//= `
456    FloorDiv,
457    /// `%=`
458    Mod,
459    /// `**=`
460    Pow,
461    /// `<<=`
462    LShift,
463    /// `>>=`
464    RShift,
465    /// `|=`
466    BitOr,
467    /// `^=`
468    BitXor,
469    /// `&=`
470    BitAnd,
471}
472
473/// Represents a binary operator.
474#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
475#[derive(Debug, Clone, PartialEq)]
476pub enum BinaryOperator {
477    /// `+`
478    Add,
479    /// `-`
480    Sub,
481    /// `*`
482    Mult,
483    /// `/`
484    Div,
485    /// `//`
486    FloorDiv,
487    /// `%`
488    Mod,
489    /// `**`
490    Pow,
491    /// `<<`
492    LShift,
493    /// `>>`
494    RShift,
495    /// `|`
496    BitOr,
497    /// `^`
498    BitXor,
499    /// `&`
500    BitAnd,
501}
502
503/// Represents a unary operator.
504#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
505#[derive(Debug, Clone, PartialEq)]
506pub enum UnaryOperator {
507    /// `~`
508    Invert,
509    /// `not`
510    Not,
511    /// `+`
512    UAdd,
513    /// `-`
514    USub,
515}
516
517/// Represents a boolean operator.
518#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
519#[derive(Debug, Clone, PartialEq)]
520pub enum BoolOperator {
521    /// `and`
522    And,
523    /// `or`
524    Or,
525}
526
527/// Represents a comparison operator.
528#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
529#[derive(Debug, Clone, PartialEq)]
530pub enum CompareOperator {
531    /// `==`
532    Eq,
533    /// `!=`
534    NotEq,
535    /// `<`
536    Lt,
537    /// `<=`
538    LtE,
539    /// `>`
540    Gt,
541    /// `>=`
542    GtE,
543    /// `is`
544    Is,
545    /// `is not`
546    IsNot,
547    /// `in`
548    In,
549    /// `not in`
550    NotIn,
551}
552
553/// Represents a function parameter.
554#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
555#[derive(Debug, Clone, PartialEq)]
556pub struct Parameter {
557    /// Parameter name
558    pub name: String,
559    /// Optional type annotation
560    pub annotation: Option<Type>,
561    /// Optional default value
562    pub default: Option<Expression>,
563    /// Whether it's a variable positional argument (*args)
564    pub is_vararg: bool,
565    /// Whether it's a variable keyword argument (**kwargs)
566    pub is_kwarg: bool,
567}
568
569/// Represents a type annotation.
570#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
571#[derive(Debug, Clone, PartialEq)]
572pub enum Type {
573    /// Basic type name
574    Name(String),
575    /// Generic type
576    Generic {
577        /// Type name
578        name: String,
579        /// Type arguments
580        args: Vec<Type>,
581    },
582    /// Union type
583    Union(Vec<Type>),
584    /// Optional type
585    Optional(Box<Type>),
586}
587
588/// Represents a keyword argument.
589#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
590#[derive(Debug, Clone, PartialEq)]
591pub struct Keyword {
592    /// Optional argument name
593    pub arg: Option<String>,
594    /// Argument value
595    pub value: Expression,
596}
597
598/// Represents a comprehension in a list/dict/set/generator.
599#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
600#[derive(Debug, Clone, PartialEq)]
601pub struct Comprehension {
602    /// Target expression
603    pub target: Expression,
604    /// Iterable expression
605    pub iter: Expression,
606    /// Optional conditions
607    pub ifs: Vec<Expression>,
608    /// Whether it's an async comprehension
609    pub is_async: bool,
610}
611
612/// Represents a name in an import statement.
613#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
614#[derive(Debug, Clone, PartialEq)]
615pub struct ImportName {
616    /// Name being imported
617    pub name: String,
618    /// Optional alias (asname)
619    pub asname: Option<String>,
620}
621
622/// Represents an exception handler in a try statement.
623#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
624#[derive(Debug, Clone, PartialEq)]
625pub struct ExceptHandler {
626    /// Optional exception type
627    pub type_: Option<Expression>,
628    /// Optional name for the exception instance
629    pub name: Option<String>,
630    /// Handler body
631    pub body: Vec<Statement>,
632}
633
634/// Represents an item in a with statement.
635#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
636#[derive(Debug, Clone, PartialEq)]
637pub struct WithItem {
638    /// Context manager expression
639    pub context_expr: Expression,
640    /// Optional variables to bind to
641    pub optional_vars: Option<Expression>,
642}
643
644impl Program {
645    /// Creates a new empty program.
646    pub fn new() -> Self {
647        Self { statements: Vec::new() }
648    }
649
650    /// Adds a statement to the program.
651    pub fn add_statement(&mut self, statement: Statement) {
652        self.statements.push(statement)
653    }
654}
655
656impl Default for Program {
657    fn default() -> Self {
658        Self::new()
659    }
660}
661
662impl Expression {
663    /// Creates a name expression.
664    pub fn name(name: impl Into<String>) -> Self {
665        Self::Name(name.into())
666    }
667
668    /// Creates a string literal expression.
669    pub fn string(value: impl Into<String>) -> Self {
670        Self::Literal(Literal::String(value.into()))
671    }
672
673    /// Creates an integer literal expression.
674    pub fn integer(value: i64) -> Self {
675        Self::Literal(Literal::Integer(value))
676    }
677
678    /// Creates a float literal expression.
679    pub fn float(value: f64) -> Self {
680        Self::Literal(Literal::Float(value))
681    }
682
683    /// Creates a boolean literal expression.
684    pub fn boolean(value: bool) -> Self {
685        Self::Literal(Literal::Boolean(value))
686    }
687
688    /// Creates a None literal expression.
689    pub fn none() -> Self {
690        Self::Literal(Literal::None)
691    }
692}
693
694impl Statement {
695    /// Creates a function definition statement.
696    pub fn function_def(name: impl Into<String>, parameters: Vec<Parameter>, return_type: Option<Type>, body: Vec<Statement>) -> Self {
697        Self::FunctionDef { decorators: Vec::new(), name: name.into(), parameters, return_type, body }
698    }
699
700    /// Creates an assignment statement.
701    pub fn assignment(target: Expression, value: Expression) -> Self {
702        Self::Assignment { target, value }
703    }
704
705    /// Creates an expression statement.
706    pub fn expression(expr: Expression) -> Self {
707        Self::Expression(expr)
708    }
709
710    /// Creates a return statement.
711    pub fn return_stmt(value: Option<Expression>) -> Self {
712        Self::Return(value)
713    }
714}