1mod ast;
2mod parser;
3
4use pest::Parser;
5use pest_derive::Parser;
6
7
8pub mod rustgen;
9#[derive(Parser)]
10#[grammar = "grammar.pest"]
11pub struct SchemaParser;
12
13pub use ast::*;
14pub use parser::parse_schema;
15
16
17pub mod snapshot {
18 use tokio_postgres::Client;
19 use crate::Schema;
20
21 pub async fn init_snapshot_table(client: &Client) -> Result<(), Box<dyn std::error::Error>> {
22 let create_table = "CREATE TABLE IF NOT EXISTS _byteorm_schema ( id SERIAL PRIMARY KEY, schema_json JSONB NOT NULL, created_at TIMESTAMP DEFAULT now(), updated_at TIMESTAMP DEFAULT now() )";
23
24 client.execute(create_table, &[]).await?;
25 Ok(())
26 }
27
28 pub async fn save_snapshot(client: &Client, schema: &Schema) -> Result<(), Box<dyn std::error::Error>> {
29 let json_value = serde_json::to_value(schema)?;
30
31 client.execute(
32 "DELETE FROM _byteorm_schema WHERE id != 0",
33 &[],
34 ).await.ok();
35
36 client.execute(
37 "INSERT INTO _byteorm_schema (schema_json, updated_at) VALUES ($1, now())",
38 &[&json_value],
39 ).await?;
40
41 println!("Snapshot saved to database");
42 Ok(())
43 }
44
45
46 pub async fn load_snapshot(client: &Client) -> Result<Option<Schema>, Box<dyn std::error::Error>> {
47 let rows = client.query(
48 "SELECT schema_json FROM _byteorm_schema ORDER BY updated_at DESC LIMIT 1",
49 &[],
50 ).await?;
51
52 if rows.is_empty() {
53 return Ok(None);
54 }
55
56 let json_value: serde_json::Value = rows[0].get(0);
57 Ok(Some(serde_json::from_value(json_value)?))
58 }
59
60}
61
62
63
64pub mod diff {
65 use crate::{Schema, Model, Field, Modifier};
66
67 #[derive(Debug, Clone)]
68 pub enum Change {
69 CreateTable(Model),
70 AddColumn { table: String, field: Field },
71 RemoveColumn { table: String, column: String },
72 AlterColumn { table: String, column: String, old: Field, new: Field },
73 RemoveTable(String),
74 }
75
76 pub fn diff_schemas(previous: Option<&Schema>, current: &Schema) -> Vec<Change> {
77 let mut changes = Vec::new();
78
79 if let Some(prev) = previous {
80 for prev_model in &prev.models {
81 if !current.models.iter().any(|m| m.name == prev_model.name) {
82 changes.push(Change::RemoveTable(prev_model.name.clone()));
83 }
84 }
85
86 for curr_model in ¤t.models {
87 if let Some(prev_model) = prev.models.iter().find(|m| m.name == curr_model.name) {
88 for curr_field in &curr_model.fields {
89 if !prev_model.fields.iter().any(|f| f.name == curr_field.name) {
90 changes.push(Change::AddColumn {
91 table: curr_model.name.clone(),
92 field: curr_field.clone(),
93 });
94 }
95 }
96
97 for prev_field in &prev_model.fields {
98 if !curr_model.fields.iter().any(|f| f.name == prev_field.name) {
99 changes.push(Change::RemoveColumn {
100 table: curr_model.name.clone(),
101 column: prev_field.name.clone(),
102 });
103 }
104 }
105 }
106 }
107 } else {
108 for model in ¤t.models {
109 changes.push(Change::CreateTable(model.clone()));
110 }
111 }
112
113 changes
114 }
115}
116
117pub mod codegen {
118 use crate::{Field, Modifier, diff::Change};
119
120 pub fn postgres_type(type_name: &str) -> &'static str {
121 match type_name {
122 "BigInt" => "BIGINT",
123 "Int" => "INTEGER",
124 "String" => "TEXT",
125 "JsonB" => "JSONB",
126 "TimestamptZ" => "TIMESTAMP WITH TIME ZONE",
127 _ => "TEXT",
128 }
129 }
130
131 pub fn field_to_sql(field: &Field) -> String {
132 let mut sql = format!("{} {}", field.name, postgres_type(&field.type_name));
133
134 for modifier in &field.modifiers {
135 match modifier {
136 Modifier::PrimaryKey => sql.push_str(" PRIMARY KEY"),
137 Modifier::NotNull => sql.push_str(" NOT NULL"),
138 Modifier::Nullable => sql.push_str(" NULL"),
139 Modifier::Unique => sql.push_str(" UNIQUE"),
140 Modifier::ForeignKey { model, field } => {
141 let fk_field = field.as_deref().unwrap_or("id");
142 sql.push_str(&format!(" REFERENCES {} ({})", model, fk_field));
143 }
144 }
145 }
146
147 if field.is_sql_default() {
148 if let Some(value) = field.get_default_value() {
149 if value.contains("(") && value.contains(")") {
150 sql.push_str(&format!(" DEFAULT {}", value));
151 } else {
152 sql.push_str(&format!(" DEFAULT '{}'", value));
153 }
154 }
155 }
156
157 sql
158 }
159
160
161
162 pub fn change_to_sql(change: &Change) -> String {
163 match change {
164 Change::CreateTable(model) => {
165 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} ( ", model.name);
166 let fields_count = model.fields.len();
167 for (idx, field) in model.fields.iter().enumerate() {
168 sql.push_str(&field_to_sql(field));
169 if idx < fields_count - 1 {
170 sql.push_str(", ");
171 } else {
172 sql.push_str(" ");
173 }
174 }
175 sql.push_str(");");
176 sql
177 }
178
179 Change::AddColumn { table, field } => {
180 format!("ALTER TABLE {} ADD COLUMN {};", table, field_to_sql(field))
181 }
182 Change::RemoveColumn { table, column } => {
183 format!("ALTER TABLE {} DROP COLUMN {};", table, column)
184 }
185 Change::RemoveTable(name) => {
186 format!("DROP TABLE IF EXISTS {};", name)
187 }
188 _ => String::new(),
189 }
190 }
191
192
193 pub fn generate_migration_sql(changes: &[Change]) -> String {
194 let mut sql = String::new();
195 for change in changes {
196 sql.push_str(&change_to_sql(change));
197 }
198 sql
199 }
200}
201
202
203pub mod db {
204 use tokio_postgres::Client;
205 use std::env;
206
207 pub async fn connect() -> Result<Client, Box<dyn std::error::Error>> {
208 let db_url = env::var("DATABASE_URL")
209 .unwrap_or_else(|_| "host=localhost user=postgres dbname=byteorm".to_string());
210
211 let (client, connection) = tokio_postgres::connect(&db_url, tokio_postgres::tls::NoTls).await?;
212
213 tokio::spawn(async move {
214 if let Err(e) = connection.await {
215 eprintln!("Connection error: {}", e);
216 }
217 });
218
219 Ok(client)
220 }
221
222 pub async fn execute_sql(client: &Client, sql: &str) -> Result<(), Box<dyn std::error::Error>> {
223 for statement in sql.split(';').filter(|s| !s.trim().is_empty()) {
224 let normalized = statement
225 .lines()
226 .map(|l| l.trim())
227 .filter(|l| !l.is_empty())
228 .collect::<Vec<_>>()
229 .join(" ");
230
231 println!("Executing: {}", normalized.trim());
232 match client.execute(&normalized, &[]).await {
233 Ok(_) => println!(" ✅ OK"),
234 Err(e) => {
235 eprintln!(" ❌ Error: {}", e);
236 eprintln!(" 📍 Error details:");
237 eprintln!(" Code: {:?}", e.code());
238 eprintln!(" Message: {}", e);
239 return Err(e.into());
240 }
241 }
242 }
243 Ok(())
244 }
245
246}
247