byteorm_lib/
lib.rs

1mod ast;
2mod parser;
3
4use pest::Parser;
5use pest_derive::Parser;
6
7
8pub mod rustgen;
9#[derive(Parser)]
10#[grammar = "grammar.pest"]
11pub struct SchemaParser;
12
13pub use ast::*;
14pub use parser::parse_schema;
15
16
17pub mod snapshot {
18    use tokio_postgres::Client;
19    use crate::Schema;
20
21    pub async fn init_snapshot_table(client: &Client) -> Result<(), Box<dyn std::error::Error>> {
22        let create_table = "CREATE TABLE IF NOT EXISTS _byteorm_schema ( id SERIAL PRIMARY KEY, schema_json JSONB NOT NULL, created_at TIMESTAMP DEFAULT now(), updated_at TIMESTAMP DEFAULT now() )";
23
24        client.execute(create_table, &[]).await?;
25        Ok(())
26    }
27
28    pub async fn save_snapshot(client: &Client, schema: &Schema) -> Result<(), Box<dyn std::error::Error>> {
29        let json_value = serde_json::to_value(schema)?;
30
31        client.execute(
32            "DELETE FROM _byteorm_schema WHERE id != 0",
33            &[],
34        ).await.ok();
35
36        client.execute(
37            "INSERT INTO _byteorm_schema (schema_json, updated_at) VALUES ($1, now())",
38            &[&json_value], 
39        ).await?;
40
41        println!("Snapshot saved to database");
42        Ok(())
43    }
44
45
46    pub async fn load_snapshot(client: &Client) -> Result<Option<Schema>, Box<dyn std::error::Error>> {
47        let rows = client.query(
48            "SELECT schema_json FROM _byteorm_schema ORDER BY updated_at DESC LIMIT 1",
49            &[],
50        ).await?;
51
52        if rows.is_empty() {
53            return Ok(None);
54        }
55
56        let json_value: serde_json::Value = rows[0].get(0);
57        Ok(Some(serde_json::from_value(json_value)?))
58    }
59
60}
61
62
63
64pub mod diff {
65    use crate::{Schema, Model, Field, Modifier};
66
67    #[derive(Debug, Clone)]
68    pub enum Change {
69        CreateTable(Model),
70        AddColumn { table: String, field: Field },
71        RemoveColumn { table: String, column: String },
72        AlterColumn { table: String, column: String, old: Field, new: Field },
73        RemoveTable(String),
74    }
75
76    pub fn diff_schemas(previous: Option<&Schema>, current: &Schema) -> Vec<Change> {
77        let mut changes = Vec::new();
78
79        if let Some(prev) = previous {
80            for prev_model in &prev.models {
81                if !current.models.iter().any(|m| m.name == prev_model.name) {
82                    changes.push(Change::RemoveTable(prev_model.name.clone()));
83                }
84            }
85
86            for curr_model in &current.models {
87                if let Some(prev_model) = prev.models.iter().find(|m| m.name == curr_model.name) {
88                    for curr_field in &curr_model.fields {
89                        if !prev_model.fields.iter().any(|f| f.name == curr_field.name) {
90                            changes.push(Change::AddColumn {
91                                table: curr_model.name.clone(),
92                                field: curr_field.clone(),
93                            });
94                        }
95                    }
96
97                    for prev_field in &prev_model.fields {
98                        if !curr_model.fields.iter().any(|f| f.name == prev_field.name) {
99                            changes.push(Change::RemoveColumn {
100                                table: curr_model.name.clone(),
101                                column: prev_field.name.clone(),
102                            });
103                        }
104                    }
105                }
106            }
107        } else {
108            for model in &current.models {
109                changes.push(Change::CreateTable(model.clone()));
110            }
111        }
112
113        changes
114    }
115}
116
117pub mod codegen {
118    use crate::{Field, Modifier, diff::Change};
119
120    pub fn postgres_type(type_name: &str) -> &'static str {
121        match type_name {
122            "BigInt" => "BIGINT",
123            "Int" => "INTEGER",
124            "String" => "TEXT",
125            "JsonB" => "JSONB",
126            "TimestamptZ" => "TIMESTAMP WITH TIME ZONE",
127            _ => "TEXT",
128        }
129    }
130
131    pub fn field_to_sql(field: &Field) -> String {
132        let mut sql = format!("{} {}", field.name, postgres_type(&field.type_name));
133
134        for modifier in &field.modifiers {
135            match modifier {
136                Modifier::PrimaryKey => sql.push_str(" PRIMARY KEY"),
137                Modifier::NotNull => sql.push_str(" NOT NULL"),
138                Modifier::Nullable => sql.push_str(" NULL"),
139                Modifier::Unique => sql.push_str(" UNIQUE"),
140                Modifier::ForeignKey { model, field } => {
141                    let fk_field = field.as_deref().unwrap_or("id");
142                    sql.push_str(&format!(" REFERENCES {} ({})", model, fk_field));
143                }
144            }
145        }
146
147        if field.is_sql_default() {
148            if let Some(value) = field.get_default_value() {
149                if value.contains("(") && value.contains(")") {
150                    sql.push_str(&format!(" DEFAULT {}", value));
151                } else {
152                    sql.push_str(&format!(" DEFAULT '{}'", value));
153                }
154            }
155        }
156
157        sql
158    }
159
160
161
162    pub fn change_to_sql(change: &Change) -> String {
163        match change {
164            Change::CreateTable(model) => {
165                let mut sql = format!("CREATE TABLE IF NOT EXISTS {} ( ", model.name);
166                let fields_count = model.fields.len();
167                for (idx, field) in model.fields.iter().enumerate() {
168                    sql.push_str(&field_to_sql(field));
169                    if idx < fields_count - 1 {
170                        sql.push_str(", ");
171                    } else {
172                        sql.push_str(" ");
173                    }
174                }
175                sql.push_str(");");
176                sql
177            }
178
179            Change::AddColumn { table, field } => {
180                format!("ALTER TABLE {} ADD COLUMN {};", table, field_to_sql(field))
181            }
182            Change::RemoveColumn { table, column } => {
183                format!("ALTER TABLE {} DROP COLUMN {};", table, column)
184            }
185            Change::RemoveTable(name) => {
186                format!("DROP TABLE IF EXISTS {};", name)
187            }
188            _ => String::new(),
189        }
190    }
191
192
193    pub fn generate_migration_sql(changes: &[Change]) -> String {
194        let mut sql = String::new();
195        for change in changes {
196            sql.push_str(&change_to_sql(change));
197        }
198        sql
199    }
200}
201
202
203pub mod db {
204    use tokio_postgres::Client;
205    use std::env;
206
207    pub async fn connect() -> Result<Client, Box<dyn std::error::Error>> {
208        let db_url = env::var("DATABASE_URL")
209            .unwrap_or_else(|_| "host=localhost user=postgres dbname=byteorm".to_string());
210
211        let (client, connection) = tokio_postgres::connect(&db_url, tokio_postgres::tls::NoTls).await?;
212
213        tokio::spawn(async move {
214            if let Err(e) = connection.await {
215                eprintln!("Connection error: {}", e);
216            }
217        });
218
219        Ok(client)
220    }
221
222    pub async fn execute_sql(client: &Client, sql: &str) -> Result<(), Box<dyn std::error::Error>> {
223        for statement in sql.split(';').filter(|s| !s.trim().is_empty()) {
224            let normalized = statement
225                .lines()
226                .map(|l| l.trim())
227                .filter(|l| !l.is_empty())
228                .collect::<Vec<_>>()
229                .join(" ");
230
231            println!("Executing: {}", normalized.trim());
232            match client.execute(&normalized, &[]).await {
233                Ok(_) => println!("  ✅ OK"),
234                Err(e) => {
235                    eprintln!("  ❌ Error: {}", e);
236                    eprintln!("  📍 Error details:");
237                    eprintln!("     Code: {:?}", e.code());
238                    eprintln!("     Message: {}", e);
239                    return Err(e.into());
240                }
241            }
242        }
243        Ok(())
244    }
245
246}
247