1use nom::{
14 branch::alt,
15 bytes::complete::{tag, tag_no_case, take_while1},
16 character::complete::{multispace1, char, not_line_ending},
17 combinator::{opt, map},
18 multi::{separated_list0, many0},
19 sequence::{preceded},
20 Parser,
21 IResult,
22};
23use serde::{Deserialize, Serialize};
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
27pub struct Schema {
28 pub tables: Vec<TableDef>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TableDef {
34 pub name: String,
35 pub columns: Vec<ColumnDef>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ColumnDef {
41 pub name: String,
42 #[serde(rename = "type", alias = "typ")]
43 pub typ: String,
44 #[serde(default)]
46 pub is_array: bool,
47 #[serde(default)]
49 pub type_params: Option<Vec<String>>,
50 #[serde(default)]
51 pub nullable: bool,
52 #[serde(default)]
53 pub primary_key: bool,
54 #[serde(default)]
55 pub unique: bool,
56 #[serde(default)]
57 pub references: Option<String>,
58 #[serde(default)]
59 pub default_value: Option<String>,
60 #[serde(default)]
62 pub check: Option<String>,
63 #[serde(default)]
65 pub is_serial: bool,
66}
67
68impl Default for ColumnDef {
69 fn default() -> Self {
70 Self {
71 name: String::new(),
72 typ: String::new(),
73 is_array: false,
74 type_params: None,
75 nullable: true,
76 primary_key: false,
77 unique: false,
78 references: None,
79 default_value: None,
80 check: None,
81 is_serial: false,
82 }
83 }
84}
85
86impl Schema {
87 pub fn parse(input: &str) -> Result<Self, String> {
89 match parse_schema(input) {
90 Ok(("", schema)) => Ok(schema),
91 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
92 Err(e) => Err(format!("Parse error: {:?}", e)),
93 }
94 }
95
96 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
98 self.tables.iter().find(|t| t.name.eq_ignore_ascii_case(name))
99 }
100
101 pub fn to_json(&self) -> Result<String, String> {
103 serde_json::to_string_pretty(self)
104 .map_err(|e| format!("JSON serialization failed: {}", e))
105 }
106
107 pub fn from_json(json: &str) -> Result<Self, String> {
109 serde_json::from_str(json)
110 .map_err(|e| format!("JSON deserialization failed: {}", e))
111 }
112
113 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
115 let content = std::fs::read_to_string(path)
116 .map_err(|e| format!("Failed to read file: {}", e))?;
117
118 if content.trim().starts_with('{') {
120 Self::from_json(&content)
121 } else {
122 Self::parse(&content)
123 }
124 }
125}
126
127impl TableDef {
128 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
130 self.columns.iter().find(|c| c.name.eq_ignore_ascii_case(name))
131 }
132
133 pub fn to_ddl(&self) -> String {
135 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
136
137 let mut col_defs = Vec::new();
138 for col in &self.columns {
139 let mut line = format!(" {}", col.name);
140
141 let mut typ = col.typ.to_uppercase();
143 if let Some(params) = &col.type_params {
144 typ = format!("{}({})", typ, params.join(", "));
145 }
146 if col.is_array {
147 typ.push_str("[]");
148 }
149 line.push_str(&format!(" {}", typ));
150
151 if col.primary_key {
153 line.push_str(" PRIMARY KEY");
154 }
155 if !col.nullable && !col.primary_key && !col.is_serial {
156 line.push_str(" NOT NULL");
157 }
158 if col.unique && !col.primary_key {
159 line.push_str(" UNIQUE");
160 }
161 if let Some(ref default) = col.default_value {
162 line.push_str(&format!(" DEFAULT {}", default));
163 }
164 if let Some(ref refs) = col.references {
165 line.push_str(&format!(" REFERENCES {}", refs));
166 }
167 if let Some(ref check) = col.check {
168 line.push_str(&format!(" CHECK({})", check));
169 }
170
171 col_defs.push(line);
172 }
173
174 sql.push_str(&col_defs.join(",\n"));
175 sql.push_str("\n)");
176 sql
177 }
178}
179
180fn identifier(input: &str) -> IResult<&str, &str> {
186 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
187}
188
189fn ws_and_comments(input: &str) -> IResult<&str, ()> {
191 let (input, _) = many0(alt((
192 map(multispace1, |_| ()),
193 map((tag("--"), not_line_ending), |_| ()),
194 ))).parse(input)?;
195 Ok((input, ()))
196}
197
198struct TypeInfo {
200 name: String,
201 params: Option<Vec<String>>,
202 is_array: bool,
203 is_serial: bool,
204}
205
206fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
209 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
211
212 let (input, params) = if input.starts_with('(') {
214 let paren_start = 1;
215 let mut paren_end = paren_start;
216 for (i, c) in input[paren_start..].char_indices() {
217 if c == ')' {
218 paren_end = paren_start + i;
219 break;
220 }
221 }
222 let param_str = &input[paren_start..paren_end];
223 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
224 (&input[paren_end + 1..], Some(params))
225 } else {
226 (input, None)
227 };
228
229 let (input, is_array) = if input.starts_with("[]") {
231 (&input[2..], true)
232 } else {
233 (input, false)
234 };
235
236 let lower = type_name.to_lowercase();
237 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
238
239 Ok((input, TypeInfo {
240 name: lower,
241 params,
242 is_array,
243 is_serial,
244 }))
245}
246
247fn constraint_text(input: &str) -> IResult<&str, &str> {
249 let mut paren_depth = 0;
250 let mut end = 0;
251
252 for (i, c) in input.char_indices() {
253 match c {
254 '(' => paren_depth += 1,
255 ')' => {
256 if paren_depth == 0 {
257 break; }
259 paren_depth -= 1;
260 }
261 ',' if paren_depth == 0 => break,
262 '\n' | '\r' if paren_depth == 0 => break,
263 _ => {}
264 }
265 end = i + c.len_utf8();
266 }
267
268 if end == 0 {
269 Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::TakeWhile1)))
270 } else {
271 Ok((&input[end..], &input[..end]))
272 }
273}
274
275fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
277 let (input, _) = ws_and_comments(input)?;
278 let (input, name) = identifier(input)?;
279 let (input, _) = multispace1(input)?;
280 let (input, type_info) = parse_type_info(input)?;
281
282 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
284
285 let mut col = ColumnDef {
287 name: name.to_string(),
288 typ: type_info.name,
289 is_array: type_info.is_array,
290 type_params: type_info.params,
291 is_serial: type_info.is_serial,
292 nullable: !type_info.is_serial, ..Default::default()
294 };
295
296 if let Some(constraints) = constraint_str {
297 let lower = constraints.to_lowercase();
298
299 if lower.contains("primary_key") || lower.contains("primary key") {
300 col.primary_key = true;
301 col.nullable = false;
302 }
303 if lower.contains("not_null") || lower.contains("not null") {
304 col.nullable = false;
305 }
306 if lower.contains("unique") {
307 col.unique = true;
308 }
309
310 if let Some(idx) = lower.find("references ") {
312 let rest = &constraints[idx + 11..];
313 let mut paren_depth = 0;
315 let mut end = rest.len();
316 for (i, c) in rest.char_indices() {
317 match c {
318 '(' => paren_depth += 1,
319 ')' => {
320 if paren_depth == 0 {
321 end = i;
322 break;
323 }
324 paren_depth -= 1;
325 }
326 c if c.is_whitespace() && paren_depth == 0 => {
327 end = i;
328 break;
329 }
330 _ => {}
331 }
332 }
333 col.references = Some(rest[..end].to_string());
334 }
335
336 if let Some(idx) = lower.find("default ") {
338 let rest = &constraints[idx + 8..];
339 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
340 col.default_value = Some(rest[..end].to_string());
341 }
342
343 if let Some(idx) = lower.find("check(") {
345 let rest = &constraints[idx + 6..];
346 let mut depth = 1;
348 let mut end = rest.len();
349 for (i, c) in rest.char_indices() {
350 match c {
351 '(' => depth += 1,
352 ')' => {
353 depth -= 1;
354 if depth == 0 {
355 end = i;
356 break;
357 }
358 }
359 _ => {}
360 }
361 }
362 col.check = Some(rest[..end].to_string());
363 }
364 }
365
366 Ok((input, col))
367}
368
369fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
371 let (input, _) = ws_and_comments(input)?;
372 let (input, _) = char('(').parse(input)?;
373 let (input, columns) = separated_list0(
374 char(','),
375 parse_column,
376 ).parse(input)?;
377 let (input, _) = ws_and_comments(input)?;
378 let (input, _) = char(')').parse(input)?;
379
380 Ok((input, columns))
381}
382
383fn parse_table(input: &str) -> IResult<&str, TableDef> {
385 let (input, _) = ws_and_comments(input)?;
386 let (input, _) = tag_no_case("table").parse(input)?;
387 let (input, _) = multispace1(input)?;
388 let (input, name) = identifier(input)?;
389 let (input, columns) = parse_column_list(input)?;
390
391 Ok((input, TableDef {
392 name: name.to_string(),
393 columns,
394 }))
395}
396
397fn parse_schema(input: &str) -> IResult<&str, Schema> {
399 let (input, _) = ws_and_comments(input)?;
400 let (input, tables) = many0(parse_table).parse(input)?;
401 let (input, _) = ws_and_comments(input)?;
402
403 Ok((input, Schema { tables }))
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_parse_simple_table() {
412 let input = r#"
413 table users (
414 id uuid primary_key,
415 email text not null,
416 name text
417 )
418 "#;
419
420 let schema = Schema::parse(input).expect("parse failed");
421 assert_eq!(schema.tables.len(), 1);
422
423 let users = &schema.tables[0];
424 assert_eq!(users.name, "users");
425 assert_eq!(users.columns.len(), 3);
426
427 let id = &users.columns[0];
428 assert_eq!(id.name, "id");
429 assert_eq!(id.typ, "uuid");
430 assert!(id.primary_key);
431 assert!(!id.nullable);
432
433 let email = &users.columns[1];
434 assert_eq!(email.name, "email");
435 assert!(!email.nullable);
436
437 let name = &users.columns[2];
438 assert!(name.nullable);
439 }
440
441 #[test]
442 fn test_parse_multiple_tables() {
443 let input = r#"
444 -- Users table
445 table users (
446 id uuid primary_key,
447 email text not null unique
448 )
449
450 -- Orders table
451 table orders (
452 id uuid primary_key,
453 user_id uuid references users(id),
454 total i64 not null default 0
455 )
456 "#;
457
458 let schema = Schema::parse(input).expect("parse failed");
459 assert_eq!(schema.tables.len(), 2);
460
461 let orders = schema.find_table("orders").expect("orders not found");
462 let user_id = orders.find_column("user_id").expect("user_id not found");
463 assert_eq!(user_id.references, Some("users(id)".to_string()));
464
465 let total = orders.find_column("total").expect("total not found");
466 assert_eq!(total.default_value, Some("0".to_string()));
467 }
468
469 #[test]
470 fn test_parse_comments() {
471 let input = r#"
472 -- This is a comment
473 table foo (
474 bar text
475 )
476 "#;
477
478 let schema = Schema::parse(input).expect("parse failed");
479 assert_eq!(schema.tables.len(), 1);
480 }
481
482 #[test]
483 fn test_array_types() {
484 let input = r#"
485 table products (
486 id uuid primary_key,
487 tags text[],
488 prices decimal[]
489 )
490 "#;
491
492 let schema = Schema::parse(input).expect("parse failed");
493 let products = &schema.tables[0];
494
495 let tags = products.find_column("tags").expect("tags not found");
496 assert_eq!(tags.typ, "text");
497 assert!(tags.is_array);
498
499 let prices = products.find_column("prices").expect("prices not found");
500 assert!(prices.is_array);
501 }
502
503 #[test]
504 fn test_type_params() {
505 let input = r#"
506 table items (
507 id serial primary_key,
508 name varchar(255) not null,
509 price decimal(10,2),
510 code varchar(50) unique
511 )
512 "#;
513
514 let schema = Schema::parse(input).expect("parse failed");
515 let items = &schema.tables[0];
516
517 let id = items.find_column("id").expect("id not found");
518 assert!(id.is_serial);
519 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
522 assert_eq!(name.typ, "varchar");
523 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
524
525 let price = items.find_column("price").expect("price not found");
526 assert_eq!(price.type_params, Some(vec!["10".to_string(), "2".to_string()]));
527
528 let code = items.find_column("code").expect("code not found");
529 assert!(code.unique);
530 }
531
532 #[test]
533 fn test_check_constraint() {
534 let input = r#"
535 table employees (
536 id uuid primary_key,
537 age i32 check(age >= 18),
538 salary decimal check(salary > 0)
539 )
540 "#;
541
542 let schema = Schema::parse(input).expect("parse failed");
543 let employees = &schema.tables[0];
544
545 let age = employees.find_column("age").expect("age not found");
546 assert_eq!(age.check, Some("age >= 18".to_string()));
547
548 let salary = employees.find_column("salary").expect("salary not found");
549 assert_eq!(salary.check, Some("salary > 0".to_string()));
550 }
551}