Skip to main content

visitor_demo/
visitor_demo.rs

1//! Example: Using the Visitor pattern to analyze schemas.
2//!
3//! This example demonstrates:
4//! - Implementing custom visitors
5//! - Traversing AST nodes
6//! - Collecting information from schemas
7//! - Building relationships between models
8//!
9//! Run with: cargo run --package nautilus-schema --example visitor_demo
10
11use nautilus_schema::{
12    ast::*,
13    visitor::{walk_model, Visitor},
14    Lexer, Parser, Result,
15};
16use std::collections::{HashMap, HashSet};
17
18const SCHEMA: &str = r#"
19datasource db {
20  provider = "postgresql"
21  url = env("DATABASE_URL")
22}
23
24enum Role {
25  USER
26  ADMIN
27  MODERATOR
28}
29
30model User {
31  id        Int      @id @default(autoincrement())
32  email     String   @unique
33  username  String   @unique
34  role      Role     @default(USER)
35  createdAt DateTime @default(now())
36  
37  posts     Post[]
38  comments  Comment[]
39  profile   Profile?
40  
41  @@map("users")
42}
43
44model Profile {
45  id     Int    @id @default(autoincrement())
46  userId Int    @unique
47  bio    String?
48  avatar String?
49  
50  user   User   @relation(fields: [userId], references: [id], onDelete: Cascade)
51  
52  @@map("profiles")
53}
54
55model Post {
56  id        Int      @id @default(autoincrement())
57  authorId  Int
58  title     String
59  content   String
60  published Boolean  @default(false)
61  createdAt DateTime @default(now())
62  
63  author    User     @relation(fields: [authorId], references: [id])
64  comments  Comment[]
65  
66  @@map("posts")
67  @@index([authorId])
68}
69
70model Comment {
71  id        Int      @id @default(autoincrement())
72  postId    Int
73  authorId  Int
74  content   String
75  createdAt DateTime @default(now())
76  
77  post      Post     @relation(fields: [postId], references: [id], onDelete: Cascade)
78  author    User     @relation(fields: [authorId], references: [id])
79  
80  @@map("comments")
81  @@index([postId, authorId])
82}
83"#;
84
85// Custom visitor 1: Collect statistics about the schema
86#[derive(Default, Debug)]
87struct SchemaStats {
88    models: usize,
89    enums: usize,
90    total_fields: usize,
91    unique_constraints: usize,
92    indexes: usize,
93    relations: usize,
94    optional_fields: usize,
95    array_fields: usize,
96}
97
98impl Visitor for SchemaStats {
99    fn visit_model(&mut self, model: &ModelDecl) -> Result<()> {
100        self.models += 1;
101        self.total_fields += model.fields.len();
102
103        // Count indexes
104        for attr in &model.attributes {
105            if let ModelAttribute::Index { .. } = attr {
106                self.indexes += 1
107            }
108        }
109
110        walk_model(self, model)
111    }
112
113    fn visit_enum(&mut self, _enum_decl: &EnumDecl) -> Result<()> {
114        self.enums += 1;
115        Ok(())
116    }
117
118    fn visit_field(&mut self, field: &FieldDecl) -> Result<()> {
119        if field.is_optional() {
120            self.optional_fields += 1;
121        }
122        if field.is_array() {
123            self.array_fields += 1;
124        }
125
126        for attr in &field.attributes {
127            match attr {
128                FieldAttribute::Unique => self.unique_constraints += 1,
129                FieldAttribute::Relation { .. } => self.relations += 1,
130                _ => {}
131            }
132        }
133
134        Ok(())
135    }
136}
137
138// Custom visitor 2: Build a relationship graph
139#[derive(Default)]
140struct RelationshipGraph {
141    // model_name -> [(related_model, field_name, is_required)]
142    relationships: HashMap<String, Vec<(String, String, bool)>>,
143}
144
145impl Visitor for RelationshipGraph {
146    fn visit_model(&mut self, model: &ModelDecl) -> Result<()> {
147        let model_name = model.name.value.clone();
148
149        for field in &model.fields {
150            if let FieldType::UserType(related_type) = &field.field_type {
151                if field.has_relation_attribute() || field.is_array() {
152                    let is_required = !field.is_optional() && !field.is_array();
153
154                    self.relationships
155                        .entry(model_name.clone())
156                        .or_default()
157                        .push((related_type.clone(), field.name.value.clone(), is_required));
158                }
159            }
160        }
161
162        walk_model(self, model)
163    }
164}
165
166impl RelationshipGraph {
167    fn print(&self) {
168        println!("🔗 Relationship Graph:");
169        for (model, relations) in &self.relationships {
170            println!("\n  {} has:", model);
171            for (related, field_name, is_required) in relations {
172                let cardinality = if *is_required { "one" } else { "zero or more" };
173                println!(
174                    "    - {} {} via field '{}'",
175                    cardinality, related, field_name
176                );
177            }
178        }
179    }
180}
181
182// Custom visitor 3: Find all default values
183#[derive(Default)]
184struct DefaultValueCollector {
185    // model.field -> default_expr
186    defaults: HashMap<String, String>,
187}
188
189impl Visitor for DefaultValueCollector {
190    fn visit_model(&mut self, model: &ModelDecl) -> Result<()> {
191        let model_name = model.name.value.clone();
192
193        for field in &model.fields {
194            for attr in &field.attributes {
195                if let FieldAttribute::Default(expr, _) = attr {
196                    let key = format!("{}.{}", model_name, field.name.value);
197                    let value = format!("{:?}", expr);
198                    self.defaults.insert(key, value);
199                }
200            }
201        }
202
203        walk_model(self, model)
204    }
205}
206
207// Custom visitor 4: Validate naming conventions
208struct NamingValidator {
209    errors: Vec<String>,
210}
211
212impl NamingValidator {
213    fn new() -> Self {
214        Self { errors: Vec::new() }
215    }
216}
217
218impl Visitor for NamingValidator {
219    fn visit_model(&mut self, model: &ModelDecl) -> Result<()> {
220        // Models should start with uppercase
221        if !model.name.value.chars().next().unwrap().is_uppercase() {
222            self.errors.push(format!(
223                "Model '{}' should start with uppercase",
224                model.name.value
225            ));
226        }
227
228        walk_model(self, model)
229    }
230
231    fn visit_field(&mut self, field: &FieldDecl) -> Result<()> {
232        // Fields should start with lowercase
233        if !field.name.value.chars().next().unwrap().is_lowercase() {
234            self.errors.push(format!(
235                "Field '{}' should start with lowercase",
236                field.name.value
237            ));
238        }
239
240        Ok(())
241    }
242
243    fn visit_enum(&mut self, enum_decl: &EnumDecl) -> Result<()> {
244        // Enums should start with uppercase
245        if !enum_decl.name.value.chars().next().unwrap().is_uppercase() {
246            self.errors.push(format!(
247                "Enum '{}' should start with uppercase",
248                enum_decl.name.value
249            ));
250        }
251
252        // Enum variants should be UPPERCASE
253        for variant in &enum_decl.variants {
254            if !variant
255                .name
256                .value
257                .chars()
258                .all(|c| c.is_uppercase() || c == '_')
259            {
260                self.errors.push(format!(
261                    "Enum variant '{}' should be UPPERCASE",
262                    variant.name.value
263                ));
264            }
265        }
266
267        Ok(())
268    }
269}
270
271// Custom visitor 5: Find model dependencies for migration ordering
272#[derive(Default)]
273struct MigrationOrderer {
274    // Models with their foreign key dependencies
275    dependencies: HashMap<String, HashSet<String>>,
276}
277
278impl Visitor for MigrationOrderer {
279    fn visit_model(&mut self, model: &ModelDecl) -> Result<()> {
280        let model_name = model.name.value.clone();
281        let mut deps = HashSet::new();
282
283        for field in &model.fields {
284            if let FieldType::UserType(related_type) = &field.field_type {
285                // Only count as dependency if it's a foreign key (has @relation with fields)
286                for attr in &field.attributes {
287                    if let FieldAttribute::Relation {
288                        fields: Some(_), ..
289                    } = attr
290                    {
291                        deps.insert(related_type.clone());
292                    }
293                }
294            }
295        }
296
297        self.dependencies.insert(model_name, deps);
298        walk_model(self, model)
299    }
300}
301
302impl MigrationOrderer {
303    fn print(&self) {
304        println!("📦 Migration Order (based on foreign key dependencies):");
305        println!("  Models should be created in this order:\n");
306
307        let mut remaining: HashSet<_> = self.dependencies.keys().cloned().collect();
308        let mut order = Vec::new();
309
310        while !remaining.is_empty() {
311            // Find models with no unmet dependencies
312            let can_create: Vec<_> = remaining
313                .iter()
314                .filter(|m| {
315                    self.dependencies[*m]
316                        .iter()
317                        .all(|dep| order.contains(dep) || !remaining.contains(dep))
318                })
319                .cloned()
320                .collect();
321
322            if can_create.is_empty() {
323                println!("  ⚠️  Circular dependency detected!");
324                break;
325            }
326
327            for model in can_create {
328                order.push(model.clone());
329                remaining.remove(&model);
330
331                let deps_list: Vec<_> = self.dependencies[&model].iter().collect();
332                if deps_list.is_empty() {
333                    println!("  {}. {} (no dependencies)", order.len(), model);
334                } else {
335                    println!(
336                        "  {}. {} (depends on: {})",
337                        order.len(),
338                        model,
339                        deps_list
340                            .iter()
341                            .map(|s| s.as_str())
342                            .collect::<Vec<_>>()
343                            .join(", ")
344                    );
345                }
346            }
347        }
348    }
349}
350
351fn main() -> Result<()> {
352    println!("=== Nautilus Schema Visitor Pattern Demo ===\n");
353
354    // Parse the schema
355    let mut lexer = Lexer::new(SCHEMA);
356    let mut tokens = Vec::new();
357    loop {
358        let token = lexer.next_token()?;
359        if matches!(token.kind, nautilus_schema::TokenKind::Eof) {
360            tokens.push(token);
361            break;
362        }
363        tokens.push(token);
364    }
365    let schema = Parser::new(&tokens, SCHEMA).parse_schema()?;
366
367    println!(
368        "Parsed schema with {} declarations\n",
369        schema.declarations.len()
370    );
371    let separator = "=".repeat(60);
372    println!("{}", separator);
373
374    // Visitor 1: Collect statistics
375    println!("\n1️⃣  Schema Statistics\n");
376    let mut stats = SchemaStats::default();
377    stats.visit_schema(&schema)?;
378    println!("{:#?}", stats);
379
380    println!("\n{}", separator);
381
382    // Visitor 2: Build relationship graph
383    println!("\n2️⃣  Relationship Analysis\n");
384    let mut graph = RelationshipGraph::default();
385    graph.visit_schema(&schema)?;
386    graph.print();
387
388    println!("\n{}", separator);
389
390    // Visitor 3: Find default values
391    println!("\n3️⃣  Default Values\n");
392    let mut defaults = DefaultValueCollector::default();
393    defaults.visit_schema(&schema)?;
394    for (field, default) in &defaults.defaults {
395        println!("  {} = {}", field, default);
396    }
397
398    println!("\n{}", separator);
399
400    // Visitor 4: Validate naming conventions
401    println!("\n4️⃣  Naming Convention Validation\n");
402    let mut validator = NamingValidator::new();
403    validator.visit_schema(&schema)?;
404    if validator.errors.is_empty() {
405        println!("  ✅ All names follow conventions!");
406    } else {
407        println!("  ⚠️  Found {} naming issues:", validator.errors.len());
408        for error in &validator.errors {
409            println!("    - {}", error);
410        }
411    }
412
413    println!("\n{}", separator);
414
415    // Visitor 5: Determine migration order
416    println!("\n5️⃣  Migration Planning\n");
417    let mut orderer = MigrationOrderer::default();
418    orderer.visit_schema(&schema)?;
419    orderer.print();
420
421    println!("\n{}", separator);
422    println!("\n✅ Visitor demo completed successfully!");
423
424    Ok(())
425}