Skip to main content

nautilus_schema/
ast.rs

1//! Abstract Syntax Tree (AST) for the nautilus schema language.
2//!
3//! This module defines the complete AST structure for representing parsed schemas.
4//! All nodes include [`Span`] information for precise error diagnostics.
5//!
6//! The AST supports the Visitor pattern via the [`accept`](Schema::accept) methods,
7//! allowing flexible traversal and transformation operations.
8//!
9//! # Example
10//!
11//! ```ignore
12//! use nautilus_schema::{Lexer, Parser};
13//!
14//! let source = r#"
15//!     model User {
16//!       id    Int    @id @default(autoincrement())
17//!       email String @unique
18//!     }
19//! "#;
20//!
21//! let tokens = Lexer::new(source).collect::<Result<Vec<_>, _>>().unwrap();
22//! let schema = Parser::new(&tokens).parse_schema().unwrap();
23//!
24//! println!("Found {} declarations", schema.declarations.len());
25//! ```
26
27use crate::span::Span;
28use std::fmt;
29
30/// Top-level schema document containing all declarations.
31#[derive(Debug, Clone, PartialEq)]
32pub struct Schema {
33    /// All declarations in the schema (datasources, generators, models, enums).
34    pub declarations: Vec<Declaration>,
35    /// Span covering the entire schema.
36    pub span: Span,
37}
38
39impl Schema {
40    /// Creates a new schema with the given declarations.
41    pub fn new(declarations: Vec<Declaration>, span: Span) -> Self {
42        Self { declarations, span }
43    }
44
45    /// Finds all model declarations in the schema.
46    pub fn models(&self) -> impl Iterator<Item = &ModelDecl> {
47        self.declarations.iter().filter_map(|d| match d {
48            Declaration::Model(m) => Some(m),
49            _ => None,
50        })
51    }
52
53    /// Finds all enum declarations in the schema.
54    pub fn enums(&self) -> impl Iterator<Item = &EnumDecl> {
55        self.declarations.iter().filter_map(|d| match d {
56            Declaration::Enum(e) => Some(e),
57            _ => None,
58        })
59    }
60
61    /// Finds all composite type declarations in the schema.
62    pub fn types(&self) -> impl Iterator<Item = &TypeDecl> {
63        self.declarations.iter().filter_map(|d| match d {
64            Declaration::Type(t) => Some(t),
65            _ => None,
66        })
67    }
68
69    /// Finds the first datasource declaration.
70    pub fn datasource(&self) -> Option<&DatasourceDecl> {
71        self.declarations.iter().find_map(|d| match d {
72            Declaration::Datasource(ds) => Some(ds),
73            _ => None,
74        })
75    }
76
77    /// Finds the first generator declaration.
78    pub fn generator(&self) -> Option<&GeneratorDecl> {
79        self.declarations.iter().find_map(|d| match d {
80            Declaration::Generator(g) => Some(g),
81            _ => None,
82        })
83    }
84}
85
86/// A top-level declaration in the schema.
87#[derive(Debug, Clone, PartialEq)]
88pub enum Declaration {
89    /// A datasource block.
90    Datasource(DatasourceDecl),
91    /// A generator block.
92    Generator(GeneratorDecl),
93    /// A model block.
94    Model(ModelDecl),
95    /// An enum block.
96    Enum(EnumDecl),
97    /// A composite type block.
98    Type(TypeDecl),
99}
100
101impl Declaration {
102    /// Returns the span of this declaration.
103    pub fn span(&self) -> Span {
104        match self {
105            Declaration::Datasource(d) => d.span,
106            Declaration::Generator(g) => g.span,
107            Declaration::Model(m) => m.span,
108            Declaration::Enum(e) => e.span,
109            Declaration::Type(t) => t.span,
110        }
111    }
112
113    /// Returns the name of this declaration.
114    pub fn name(&self) -> &str {
115        match self {
116            Declaration::Datasource(d) => &d.name.value,
117            Declaration::Generator(g) => &g.name.value,
118            Declaration::Model(m) => &m.name.value,
119            Declaration::Enum(e) => &e.name.value,
120            Declaration::Type(t) => &t.name.value,
121        }
122    }
123}
124
125/// A datasource block declaration.
126///
127/// # Example
128///
129/// ```prisma
130/// datasource db {
131///   provider = "postgresql"
132///   url      = env("DATABASE_URL")
133/// }
134/// ```
135#[derive(Debug, Clone, PartialEq)]
136pub struct DatasourceDecl {
137    /// The name of the datasource (e.g., "db").
138    pub name: Ident,
139    /// Configuration fields (key-value pairs).
140    pub fields: Vec<ConfigField>,
141    /// Span covering the entire datasource block.
142    pub span: Span,
143}
144
145impl DatasourceDecl {
146    /// Finds a configuration field by name.
147    pub fn find_field(&self, name: &str) -> Option<&ConfigField> {
148        self.fields.iter().find(|f| f.name.value == name)
149    }
150
151    /// Gets the provider value if present.
152    pub fn provider(&self) -> Option<&str> {
153        self.find_field("provider").and_then(|f| match &f.value {
154            Expr::Literal(Literal::String(s, _)) => Some(s.as_str()),
155            _ => None,
156        })
157    }
158}
159
160/// A generator block declaration.
161///
162/// # Example
163///
164/// ```prisma
165/// generator client {
166///   provider = "nautilus-client-rs"
167///   output   = "../generated"
168/// }
169/// ```
170#[derive(Debug, Clone, PartialEq)]
171pub struct GeneratorDecl {
172    /// The name of the generator (e.g., "client").
173    pub name: Ident,
174    /// Configuration fields (key-value pairs).
175    pub fields: Vec<ConfigField>,
176    /// Span covering the entire generator block.
177    pub span: Span,
178}
179
180impl GeneratorDecl {
181    /// Finds a configuration field by name.
182    pub fn find_field(&self, name: &str) -> Option<&ConfigField> {
183        self.fields.iter().find(|f| f.name.value == name)
184    }
185}
186
187/// A configuration field in a datasource or generator block.
188#[derive(Debug, Clone, PartialEq)]
189pub struct ConfigField {
190    /// The field name.
191    pub name: Ident,
192    /// The field value (typically a string or function call).
193    pub value: Expr,
194    /// Span covering the entire field declaration.
195    pub span: Span,
196}
197
198/// A model block declaration.
199///
200/// # Example
201///
202/// ```prisma
203/// model User {
204///   id    Int    @id @default(autoincrement())
205///   email String @unique
206///   @@map("users")
207/// }
208/// ```
209#[derive(Debug, Clone, PartialEq)]
210pub struct ModelDecl {
211    /// The model name (e.g., "User").
212    pub name: Ident,
213    /// Field declarations.
214    pub fields: Vec<FieldDecl>,
215    /// Model-level attributes (@@map, @@id, etc.).
216    pub attributes: Vec<ModelAttribute>,
217    /// Span covering the entire model block.
218    pub span: Span,
219}
220
221impl ModelDecl {
222    /// Finds a field by name.
223    pub fn find_field(&self, name: &str) -> Option<&FieldDecl> {
224        self.fields.iter().find(|f| f.name.value == name)
225    }
226
227    /// Gets the physical table name from @@map attribute, or the model name.
228    pub fn table_name(&self) -> &str {
229        self.attributes
230            .iter()
231            .find_map(|attr| match attr {
232                ModelAttribute::Map(name) => Some(name.as_str()),
233                _ => None,
234            })
235            .unwrap_or(&self.name.value)
236    }
237
238    /// Checks if this model has a composite primary key (@@id).
239    pub fn has_composite_key(&self) -> bool {
240        self.attributes
241            .iter()
242            .any(|attr| matches!(attr, ModelAttribute::Id(_)))
243    }
244
245    /// Returns all fields that are part of relations.
246    /// This includes fields with user-defined types (model/enum references).
247    pub fn relation_fields(&self) -> impl Iterator<Item = &FieldDecl> {
248        self.fields.iter().filter(|f| {
249            // Either has @relation attribute or is a user-defined type
250            f.has_relation_attribute() || matches!(f.field_type, FieldType::UserType(_))
251        })
252    }
253}
254
255/// A field declaration within a model.
256#[derive(Debug, Clone, PartialEq)]
257pub struct FieldDecl {
258    /// The field name.
259    pub name: Ident,
260    /// The field type.
261    pub field_type: FieldType,
262    /// Optional or array modifier.
263    pub modifier: FieldModifier,
264    /// Field-level attributes (@id, @unique, etc.).
265    pub attributes: Vec<FieldAttribute>,
266    /// Span covering the entire field declaration.
267    pub span: Span,
268}
269
270impl FieldDecl {
271    /// Checks if this field is optional (has `?` modifier).
272    pub fn is_optional(&self) -> bool {
273        matches!(self.modifier, FieldModifier::Optional)
274    }
275
276    /// Checks if this field has an explicit not-null modifier (`!`).
277    pub fn is_not_null(&self) -> bool {
278        matches!(self.modifier, FieldModifier::NotNull)
279    }
280
281    /// Checks if this field is an array (has `[]` modifier).
282    pub fn is_array(&self) -> bool {
283        matches!(self.modifier, FieldModifier::Array)
284    }
285
286    /// Finds a field attribute by kind.
287    pub fn find_attribute(&self, kind: &str) -> Option<&FieldAttribute> {
288        self.attributes.iter().find(|attr| {
289            matches!(
290                (kind, attr),
291                ("id", FieldAttribute::Id)
292                    | ("unique", FieldAttribute::Unique)
293                    | ("default", FieldAttribute::Default(_, _))
294                    | ("map", FieldAttribute::Map(_))
295                    | ("relation", FieldAttribute::Relation { .. })
296                    | ("check", FieldAttribute::Check { .. })
297            )
298        })
299    }
300
301    /// Checks if this field has a @relation attribute.
302    pub fn has_relation_attribute(&self) -> bool {
303        self.attributes
304            .iter()
305            .any(|attr| matches!(attr, FieldAttribute::Relation { .. }))
306    }
307
308    /// Gets the physical column name from @map attribute, or the field name.
309    pub fn column_name(&self) -> &str {
310        self.attributes
311            .iter()
312            .find_map(|attr| match attr {
313                FieldAttribute::Map(name) => Some(name.as_str()),
314                _ => None,
315            })
316            .unwrap_or(&self.name.value)
317    }
318}
319
320/// Field type modifiers (optional, not-null, or array).
321#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322pub enum FieldModifier {
323    /// No modifier (required field).
324    None,
325    /// Optional field (`?`).
326    Optional,
327    /// Explicit not-null field (`!`).
328    NotNull,
329    /// Array field (`[]`).
330    Array,
331}
332
333/// A field type in a model.
334#[derive(Debug, Clone, PartialEq)]
335pub enum FieldType {
336    /// String type.
337    String,
338    /// Boolean type.
339    Boolean,
340    /// Int type (32-bit).
341    Int,
342    /// BigInt type (64-bit).
343    BigInt,
344    /// Float type.
345    Float,
346    /// Decimal type with precision and scale.
347    Decimal {
348        /// Precision (total digits).
349        precision: u32,
350        /// Scale (digits after decimal point).
351        scale: u32,
352    },
353    /// DateTime type.
354    DateTime,
355    /// Bytes type.
356    Bytes,
357    /// JSON type.
358    Json,
359    /// UUID type.
360    Uuid,
361    /// JSONB type (PostgreSQL only).
362    Jsonb,
363    /// XML type (PostgreSQL only).
364    Xml,
365    /// Fixed-length character type.
366    Char {
367        /// Column length.
368        length: u32,
369    },
370    /// Variable-length character type.
371    VarChar {
372        /// Maximum column length.
373        length: u32,
374    },
375    /// User-defined type (model or enum reference).
376    UserType(String),
377}
378
379impl fmt::Display for FieldType {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        match self {
382            FieldType::String => write!(f, "String"),
383            FieldType::Boolean => write!(f, "Boolean"),
384            FieldType::Int => write!(f, "Int"),
385            FieldType::BigInt => write!(f, "BigInt"),
386            FieldType::Float => write!(f, "Float"),
387            FieldType::Decimal { precision, scale } => {
388                write!(f, "Decimal({}, {})", precision, scale)
389            }
390            FieldType::DateTime => write!(f, "DateTime"),
391            FieldType::Bytes => write!(f, "Bytes"),
392            FieldType::Json => write!(f, "Json"),
393            FieldType::Uuid => write!(f, "Uuid"),
394            FieldType::Jsonb => write!(f, "Jsonb"),
395            FieldType::Xml => write!(f, "Xml"),
396            FieldType::Char { length } => write!(f, "Char({})", length),
397            FieldType::VarChar { length } => write!(f, "VarChar({})", length),
398            FieldType::UserType(name) => write!(f, "{}", name),
399        }
400    }
401}
402
403/// Storage strategy for array fields on databases without native array support.
404#[derive(Debug, Clone, Copy, PartialEq, Eq)]
405pub enum StorageStrategy {
406    /// Native database array type (PostgreSQL).
407    Native,
408    /// JSON-serialized array storage (MySQL, SQLite).
409    Json,
410}
411
412/// Whether a computed column is physically stored or computed on every read.
413#[derive(Debug, Clone, Copy, PartialEq, Eq)]
414pub enum ComputedKind {
415    /// Column value persisted on disk (PostgreSQL, MySQL, SQLite).
416    Stored,
417    /// Column value computed on read, not stored (MySQL and SQLite only).
418    Virtual,
419}
420
421/// A field-level attribute (@id, @unique, etc.).
422#[derive(Debug, Clone, PartialEq)]
423pub enum FieldAttribute {
424    /// @id attribute.
425    Id,
426    /// @unique attribute.
427    Unique,
428    /// @default(value) attribute.
429    /// The `Span` covers the full `@default(...)` token range.
430    Default(Expr, Span),
431    /// @map("name") attribute.
432    Map(String),
433    /// @store(json) attribute for array storage strategy.
434    Store {
435        /// Storage strategy (currently only "json" supported).
436        strategy: StorageStrategy,
437        /// Span of the entire attribute.
438        span: Span,
439    },
440    /// @relation(...) attribute.
441    Relation {
442        /// name: "relationName" (optional, required for multiple relations)
443        name: Option<String>,
444        /// fields: [field1, field2]
445        fields: Option<Vec<Ident>>,
446        /// references: [field1, field2]
447        references: Option<Vec<Ident>>,
448        /// onDelete: Cascade | SetNull | ...
449        on_delete: Option<ReferentialAction>,
450        /// onUpdate: Cascade | SetNull | ...
451        on_update: Option<ReferentialAction>,
452        /// Span of the entire attribute.
453        span: Span,
454    },
455    /// @updatedAt — auto-set to current timestamp on every write.
456    UpdatedAt {
457        /// Span covering `@updatedAt`.
458        span: Span,
459    },
460    /// @computed(expr, Stored | Virtual) — database-generated column.
461    Computed {
462        /// Parsed SQL expression (e.g. `price * quantity`).
463        expr: crate::sql_expr::SqlExpr,
464        /// Whether the value is stored on disk or computed on every read.
465        kind: ComputedKind,
466        /// Span of the entire `@computed(...)` attribute.
467        span: Span,
468    },
469    /// @check(bool_expr) — column-level CHECK constraint.
470    Check {
471        /// Parsed boolean expression (e.g. `age >= 0 AND age <= 150`).
472        expr: crate::bool_expr::BoolExpr,
473        /// Span of the entire `@check(...)` attribute.
474        span: Span,
475    },
476}
477
478/// Referential actions for foreign key constraints.
479#[derive(Debug, Clone, Copy, PartialEq, Eq)]
480pub enum ReferentialAction {
481    /// CASCADE action.
482    Cascade,
483    /// RESTRICT action.
484    Restrict,
485    /// NO ACTION.
486    NoAction,
487    /// SET NULL.
488    SetNull,
489    /// SET DEFAULT.
490    SetDefault,
491}
492
493impl fmt::Display for ReferentialAction {
494    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495        match self {
496            ReferentialAction::Cascade => write!(f, "Cascade"),
497            ReferentialAction::Restrict => write!(f, "Restrict"),
498            ReferentialAction::NoAction => write!(f, "NoAction"),
499            ReferentialAction::SetNull => write!(f, "SetNull"),
500            ReferentialAction::SetDefault => write!(f, "SetDefault"),
501        }
502    }
503}
504
505/// A model-level attribute (@@map, @@id, etc.).
506#[derive(Debug, Clone, PartialEq)]
507pub enum ModelAttribute {
508    /// @@map("table_name") attribute.
509    Map(String),
510    /// @@id([field1, field2]) composite primary key.
511    Id(Vec<Ident>),
512    /// @@unique([field1, field2]) composite unique constraint.
513    Unique(Vec<Ident>),
514    /// @@index([field1, field2], type: Hash, name: "idx_name", map: "db_idx") index.
515    Index {
516        /// Fields that form the index key.
517        fields: Vec<Ident>,
518        /// Optional index type (`type:` argument). `None` → let the DBMS choose.
519        index_type: Option<Ident>,
520        /// Optional logical name (`name:` argument).
521        name: Option<String>,
522        /// Optional physical DB name (`map:` argument).
523        map: Option<String>,
524    },
525    /// @@check(bool_expr) — table-level CHECK constraint.
526    Check {
527        /// Parsed boolean expression (e.g. `start_date < end_date`).
528        expr: crate::bool_expr::BoolExpr,
529        /// Span of the entire `@@check(...)` attribute.
530        span: Span,
531    },
532}
533
534/// An enum block declaration.
535///
536/// # Example
537///
538/// ```prisma
539/// enum Role {
540///   USER
541///   ADMIN
542/// }
543/// ```
544#[derive(Debug, Clone, PartialEq)]
545pub struct EnumDecl {
546    /// The enum name (e.g., "Role").
547    pub name: Ident,
548    /// Enum variants.
549    pub variants: Vec<EnumVariant>,
550    /// Span covering the entire enum block.
551    pub span: Span,
552}
553
554/// A composite type block declaration.
555///
556/// Composite types define named struct-like types that can be embedded in models.
557/// On PostgreSQL they map to native composite types; on MySQL/SQLite they are
558/// serialised to JSON (`@store(Json)` is required on the model field).
559///
560/// # Example
561///
562/// ```prisma
563/// type Address {
564///   street String
565///   city   String
566///   zip    String
567/// }
568/// ```
569#[derive(Debug, Clone, PartialEq)]
570pub struct TypeDecl {
571    /// The type name (e.g., "Address").
572    pub name: Ident,
573    /// Field declarations (scalars, enums, and arrays — no relations).
574    pub fields: Vec<FieldDecl>,
575    /// Span covering the entire type block.
576    pub span: Span,
577}
578
579impl TypeDecl {
580    /// Finds a field by name.
581    pub fn find_field(&self, name: &str) -> Option<&FieldDecl> {
582        self.fields.iter().find(|f| f.name.value == name)
583    }
584}
585
586/// An enum variant.
587#[derive(Debug, Clone, PartialEq)]
588pub struct EnumVariant {
589    /// The variant name.
590    pub name: Ident,
591    /// Span covering the variant.
592    pub span: Span,
593}
594
595/// An identifier.
596#[derive(Debug, Clone, PartialEq, Eq, Hash)]
597pub struct Ident {
598    /// The identifier value.
599    pub value: String,
600    /// Span of the identifier.
601    pub span: Span,
602}
603
604impl Ident {
605    /// Creates a new identifier.
606    pub fn new(value: String, span: Span) -> Self {
607        Self { value, span }
608    }
609}
610
611impl fmt::Display for Ident {
612    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
613        write!(f, "{}", self.value)
614    }
615}
616
617/// An expression (used in attribute arguments).
618#[derive(Debug, Clone, PartialEq)]
619pub enum Expr {
620    /// A literal value.
621    Literal(Literal),
622    /// A function call: `name(arg1, arg2, ...)`.
623    FunctionCall {
624        /// Function name.
625        name: Ident,
626        /// Arguments.
627        args: Vec<Expr>,
628        /// Span of the entire call.
629        span: Span,
630    },
631    /// An array: `[item1, item2, ...]`.
632    Array {
633        /// Array elements.
634        elements: Vec<Expr>,
635        /// Span of the entire array.
636        span: Span,
637    },
638    /// A named argument: `name: value`.
639    NamedArg {
640        /// Argument name.
641        name: Ident,
642        /// Argument value.
643        value: Box<Expr>,
644        /// Span of the entire named argument.
645        span: Span,
646    },
647    /// An identifier reference.
648    Ident(Ident),
649}
650
651impl Expr {
652    /// Returns the span of this expression.
653    pub fn span(&self) -> Span {
654        match self {
655            Expr::Literal(lit) => lit.span(),
656            Expr::FunctionCall { span, .. } => *span,
657            Expr::Array { span, .. } => *span,
658            Expr::NamedArg { span, .. } => *span,
659            Expr::Ident(ident) => ident.span,
660        }
661    }
662}
663
664/// A literal value.
665#[derive(Debug, Clone, PartialEq)]
666pub enum Literal {
667    /// String literal.
668    String(String, Span),
669    /// Number literal (stored as string, can be int or float).
670    Number(String, Span),
671    /// Boolean literal.
672    Boolean(bool, Span),
673}
674
675impl Literal {
676    /// Returns the span of this literal.
677    pub fn span(&self) -> Span {
678        match self {
679            Literal::String(_, span) => *span,
680            Literal::Number(_, span) => *span,
681            Literal::Boolean(_, span) => *span,
682        }
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[test]
691    fn test_field_modifier() {
692        assert_eq!(FieldModifier::None, FieldModifier::None);
693        assert_ne!(FieldModifier::Optional, FieldModifier::Array);
694    }
695
696    #[test]
697    fn test_field_type_display() {
698        assert_eq!(FieldType::String.to_string(), "String");
699        assert_eq!(FieldType::Int.to_string(), "Int");
700        assert_eq!(
701            FieldType::Decimal {
702                precision: 10,
703                scale: 2
704            }
705            .to_string(),
706            "Decimal(10, 2)"
707        );
708    }
709
710    #[test]
711    fn test_ident() {
712        let ident = Ident::new("test".to_string(), Span::new(0, 4));
713        assert_eq!(ident.value, "test");
714        assert_eq!(ident.to_string(), "test");
715    }
716
717    #[test]
718    fn test_referential_action_display() {
719        assert_eq!(ReferentialAction::Cascade.to_string(), "Cascade");
720        assert_eq!(ReferentialAction::SetNull.to_string(), "SetNull");
721    }
722
723    #[test]
724    fn test_model_table_name() {
725        let model = ModelDecl {
726            name: Ident::new("User".to_string(), Span::new(0, 4)),
727            fields: vec![],
728            attributes: vec![ModelAttribute::Map("users".to_string())],
729            span: Span::new(0, 10),
730        };
731        assert_eq!(model.table_name(), "users");
732    }
733
734    #[test]
735    fn test_model_table_name_default() {
736        let model = ModelDecl {
737            name: Ident::new("User".to_string(), Span::new(0, 4)),
738            fields: vec![],
739            attributes: vec![],
740            span: Span::new(0, 10),
741        };
742        assert_eq!(model.table_name(), "User");
743    }
744
745    #[test]
746    fn test_field_column_name() {
747        let field = FieldDecl {
748            name: Ident::new("userId".to_string(), Span::new(0, 6)),
749            field_type: FieldType::Int,
750            modifier: FieldModifier::None,
751            attributes: vec![FieldAttribute::Map("user_id".to_string())],
752            span: Span::new(0, 20),
753        };
754        assert_eq!(field.column_name(), "user_id");
755    }
756
757    #[test]
758    fn test_schema_helpers() {
759        let schema = Schema {
760            declarations: vec![
761                Declaration::Model(ModelDecl {
762                    name: Ident::new("User".to_string(), Span::new(0, 4)),
763                    fields: vec![],
764                    attributes: vec![],
765                    span: Span::new(0, 10),
766                }),
767                Declaration::Enum(EnumDecl {
768                    name: Ident::new("Role".to_string(), Span::new(0, 4)),
769                    variants: vec![],
770                    span: Span::new(0, 10),
771                }),
772            ],
773            span: Span::new(0, 100),
774        };
775
776        assert_eq!(schema.models().count(), 1);
777        assert_eq!(schema.enums().count(), 1);
778        assert!(schema.datasource().is_none());
779    }
780}