Skip to main content

rigsql_core/
segment.rs

1use strum::{Display, EnumString};
2
3use crate::{Span, Token};
4
5/// A node in the Concrete Syntax Tree.
6///
7/// Leaf nodes wrap individual tokens. Branch nodes group children
8/// under a named production (e.g. `SelectStatement`, `WhereClause`).
9#[derive(Debug, Clone)]
10pub enum Segment {
11    Token(TokenSegment),
12    Node(NodeSegment),
13}
14
15impl Segment {
16    pub fn span(&self) -> Span {
17        match self {
18            Segment::Token(t) => t.token.span,
19            Segment::Node(n) => n.span,
20        }
21    }
22
23    pub fn segment_type(&self) -> SegmentType {
24        match self {
25            Segment::Token(t) => t.segment_type,
26            Segment::Node(n) => n.segment_type,
27        }
28    }
29
30    /// Recursively collect all leaf tokens in order.
31    pub fn tokens(&self) -> Vec<&Token> {
32        match self {
33            Segment::Token(t) => vec![&t.token],
34            Segment::Node(n) => n.children.iter().flat_map(|c| c.tokens()).collect(),
35        }
36    }
37
38    /// Iterator over direct children (empty for token segments).
39    pub fn children(&self) -> &[Segment] {
40        match self {
41            Segment::Token(_) => &[],
42            Segment::Node(n) => &n.children,
43        }
44    }
45
46    /// Recursively visit all segments depth-first.
47    pub fn walk(&self, visitor: &mut dyn FnMut(&Segment)) {
48        visitor(self);
49        if let Segment::Node(n) = self {
50            for child in &n.children {
51                child.walk(visitor);
52            }
53        }
54    }
55
56    /// Reconstruct source text from leaf tokens.
57    pub fn raw(&self) -> String {
58        self.tokens().iter().map(|t| t.text.as_str()).collect()
59    }
60}
61
62/// A leaf segment wrapping a single token.
63#[derive(Debug, Clone)]
64pub struct TokenSegment {
65    pub token: Token,
66    pub segment_type: SegmentType,
67}
68
69/// A branch segment grouping children under a named production.
70#[derive(Debug, Clone)]
71pub struct NodeSegment {
72    pub segment_type: SegmentType,
73    pub children: Vec<Segment>,
74    pub span: Span,
75}
76
77impl NodeSegment {
78    /// Create a new node from children, computing span automatically.
79    pub fn new(segment_type: SegmentType, children: Vec<Segment>) -> Self {
80        let span = if children.is_empty() {
81            Span::new(0, 0)
82        } else {
83            let first = children.first().unwrap().span();
84            let last = children.last().unwrap().span();
85            first.merge(last)
86        };
87        Self {
88            segment_type,
89            children,
90            span,
91        }
92    }
93}
94
95/// Type tag for CST segments.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, EnumString)]
97pub enum SegmentType {
98    // Top-level
99    File,
100    Statement,
101
102    // DML Statements
103    SelectStatement,
104    InsertStatement,
105    UpdateStatement,
106    DeleteStatement,
107
108    // DDL Statements
109    CreateTableStatement,
110    AlterTableStatement,
111    DropStatement,
112
113    // PostgreSQL
114    TypeCastExpression,
115    OnConflictClause,
116    ArrayAccessExpression,
117
118    // TSQL
119    TableHint,
120
121    // TSQL Statements
122    DeclareStatement,
123    SetVariableStatement,
124    IfStatement,
125    BeginEndBlock,
126    WhileStatement,
127    TryCatchBlock,
128    ExecStatement,
129    ReturnStatement,
130    PrintStatement,
131    ThrowStatement,
132    RaiserrorStatement,
133    GoStatement,
134
135    // Clauses
136    SelectClause,
137    FromClause,
138    WhereClause,
139    GroupByClause,
140    HavingClause,
141    OrderByClause,
142    LimitClause,
143    OffsetClause,
144    JoinClause,
145    OnClause,
146    UsingClause,
147    SetClause,
148    ValuesClause,
149    ReturningClause,
150    WithClause,
151    CteDefinition,
152    InsertColumnsClause,
153
154    // Expressions
155    ColumnRef,
156    TableRef,
157    FunctionCall,
158    FunctionArgs,
159    Expression,
160    BinaryExpression,
161    UnaryExpression,
162    ParenExpression,
163    CaseExpression,
164    WhenClause,
165    ElseClause,
166    Subquery,
167    ExistsExpression,
168    InExpression,
169    BetweenExpression,
170    CastExpression,
171    IsNullExpression,
172    LikeExpression,
173
174    // Window functions
175    WindowExpression,
176    OverClause,
177    PartitionByClause,
178    WindowFrameClause,
179
180    // Alias
181    AliasExpression,
182
183    // Column / Table definition
184    ColumnDefinition,
185    DataType,
186    ColumnConstraint,
187    TableConstraint,
188
189    // Order
190    OrderByExpression,
191    SortOrder,
192
193    // Atoms (leaf-level semantic types)
194    Keyword,
195    Identifier,
196    QualifiedIdentifier,
197    QuotedIdentifier,
198    Literal,
199    NumericLiteral,
200    StringLiteral,
201    BooleanLiteral,
202    NullLiteral,
203    Operator,
204    ComparisonOperator,
205    ArithmeticOperator,
206    Comma,
207    Dot,
208    Semicolon,
209    Star,
210    LParen,
211    RParen,
212
213    // Trivia
214    Whitespace,
215    Newline,
216    LineComment,
217    BlockComment,
218
219    // Fallback
220    Unparsable,
221}
222
223impl SegmentType {
224    /// Returns true if this is a trivia type (whitespace/comment).
225    pub fn is_trivia(self) -> bool {
226        matches!(
227            self,
228            SegmentType::Whitespace
229                | SegmentType::Newline
230                | SegmentType::LineComment
231                | SegmentType::BlockComment
232        )
233    }
234}