1use nom::{
14 IResult, Parser,
15 branch::alt,
16 bytes::complete::{tag, tag_no_case, take_while1},
17 character::complete::{char, multispace1, not_line_ending},
18 combinator::{map, opt},
19 multi::{many0, separated_list0},
20 sequence::preceded,
21};
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct Schema {
27 pub tables: Vec<TableDef>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct TableDef {
33 pub name: String,
34 pub columns: Vec<ColumnDef>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ColumnDef {
40 pub name: String,
41 #[serde(rename = "type", alias = "typ")]
42 pub typ: String,
43 #[serde(default)]
45 pub is_array: bool,
46 #[serde(default)]
48 pub type_params: Option<Vec<String>>,
49 #[serde(default)]
50 pub nullable: bool,
51 #[serde(default)]
52 pub primary_key: bool,
53 #[serde(default)]
54 pub unique: bool,
55 #[serde(default)]
56 pub references: Option<String>,
57 #[serde(default)]
58 pub default_value: Option<String>,
59 #[serde(default)]
61 pub check: Option<String>,
62 #[serde(default)]
64 pub is_serial: bool,
65}
66
67impl Default for ColumnDef {
68 fn default() -> Self {
69 Self {
70 name: String::new(),
71 typ: String::new(),
72 is_array: false,
73 type_params: None,
74 nullable: true,
75 primary_key: false,
76 unique: false,
77 references: None,
78 default_value: None,
79 check: None,
80 is_serial: false,
81 }
82 }
83}
84
85impl Schema {
86 pub fn parse(input: &str) -> Result<Self, String> {
88 match parse_schema(input) {
89 Ok(("", schema)) => Ok(schema),
90 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
91 Err(e) => Err(format!("Parse error: {:?}", e)),
92 }
93 }
94
95 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
97 self.tables
98 .iter()
99 .find(|t| t.name.eq_ignore_ascii_case(name))
100 }
101
102 pub fn to_json(&self) -> Result<String, String> {
104 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
105 }
106
107 pub fn from_json(json: &str) -> Result<Self, String> {
109 serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
110 }
111
112 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
114 let content =
115 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
116
117 if content.trim().starts_with('{') {
119 Self::from_json(&content)
120 } else {
121 Self::parse(&content)
122 }
123 }
124}
125
126impl TableDef {
127 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
129 self.columns
130 .iter()
131 .find(|c| c.name.eq_ignore_ascii_case(name))
132 }
133
134 pub fn to_ddl(&self) -> String {
136 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
137
138 let mut col_defs = Vec::new();
139 for col in &self.columns {
140 let mut line = format!(" {}", col.name);
141
142 let mut typ = col.typ.to_uppercase();
144 if let Some(params) = &col.type_params {
145 typ = format!("{}({})", typ, params.join(", "));
146 }
147 if col.is_array {
148 typ.push_str("[]");
149 }
150 line.push_str(&format!(" {}", typ));
151
152 if col.primary_key {
154 line.push_str(" PRIMARY KEY");
155 }
156 if !col.nullable && !col.primary_key && !col.is_serial {
157 line.push_str(" NOT NULL");
158 }
159 if col.unique && !col.primary_key {
160 line.push_str(" UNIQUE");
161 }
162 if let Some(ref default) = col.default_value {
163 line.push_str(&format!(" DEFAULT {}", default));
164 }
165 if let Some(ref refs) = col.references {
166 line.push_str(&format!(" REFERENCES {}", refs));
167 }
168 if let Some(ref check) = col.check {
169 line.push_str(&format!(" CHECK({})", check));
170 }
171
172 col_defs.push(line);
173 }
174
175 sql.push_str(&col_defs.join(",\n"));
176 sql.push_str("\n)");
177 sql
178 }
179}
180
181fn identifier(input: &str) -> IResult<&str, &str> {
187 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
188}
189
190fn ws_and_comments(input: &str) -> IResult<&str, ()> {
192 let (input, _) = many0(alt((
193 map(multispace1, |_| ()),
194 map((tag("--"), not_line_ending), |_| ()),
195 )))
196 .parse(input)?;
197 Ok((input, ()))
198}
199
200struct TypeInfo {
202 name: String,
203 params: Option<Vec<String>>,
204 is_array: bool,
205 is_serial: bool,
206}
207
208fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
211 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
213
214 let (input, params) = if input.starts_with('(') {
216 let paren_start = 1;
217 let mut paren_end = paren_start;
218 for (i, c) in input[paren_start..].char_indices() {
219 if c == ')' {
220 paren_end = paren_start + i;
221 break;
222 }
223 }
224 let param_str = &input[paren_start..paren_end];
225 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
226 (&input[paren_end + 1..], Some(params))
227 } else {
228 (input, None)
229 };
230
231 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
233 (stripped, true)
234 } else {
235 (input, false)
236 };
237
238 let lower = type_name.to_lowercase();
239 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
240
241 Ok((
242 input,
243 TypeInfo {
244 name: lower,
245 params,
246 is_array,
247 is_serial,
248 },
249 ))
250}
251
252fn constraint_text(input: &str) -> IResult<&str, &str> {
254 let mut paren_depth = 0;
255 let mut end = 0;
256
257 for (i, c) in input.char_indices() {
258 match c {
259 '(' => paren_depth += 1,
260 ')' => {
261 if paren_depth == 0 {
262 break; }
264 paren_depth -= 1;
265 }
266 ',' if paren_depth == 0 => break,
267 '\n' | '\r' if paren_depth == 0 => break,
268 _ => {}
269 }
270 end = i + c.len_utf8();
271 }
272
273 if end == 0 {
274 Err(nom::Err::Error(nom::error::Error::new(
275 input,
276 nom::error::ErrorKind::TakeWhile1,
277 )))
278 } else {
279 Ok((&input[end..], &input[..end]))
280 }
281}
282
283fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
285 let (input, _) = ws_and_comments(input)?;
286 let (input, name) = identifier(input)?;
287 let (input, _) = multispace1(input)?;
288 let (input, type_info) = parse_type_info(input)?;
289
290 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
292
293 let mut col = ColumnDef {
295 name: name.to_string(),
296 typ: type_info.name,
297 is_array: type_info.is_array,
298 type_params: type_info.params,
299 is_serial: type_info.is_serial,
300 nullable: !type_info.is_serial, ..Default::default()
302 };
303
304 if let Some(constraints) = constraint_str {
305 let lower = constraints.to_lowercase();
306
307 if lower.contains("primary_key") || lower.contains("primary key") {
308 col.primary_key = true;
309 col.nullable = false;
310 }
311 if lower.contains("not_null") || lower.contains("not null") {
312 col.nullable = false;
313 }
314 if lower.contains("unique") {
315 col.unique = true;
316 }
317
318 if let Some(idx) = lower.find("references ") {
320 let rest = &constraints[idx + 11..];
321 let mut paren_depth = 0;
323 let mut end = rest.len();
324 for (i, c) in rest.char_indices() {
325 match c {
326 '(' => paren_depth += 1,
327 ')' => {
328 if paren_depth == 0 {
329 end = i;
330 break;
331 }
332 paren_depth -= 1;
333 }
334 c if c.is_whitespace() && paren_depth == 0 => {
335 end = i;
336 break;
337 }
338 _ => {}
339 }
340 }
341 col.references = Some(rest[..end].to_string());
342 }
343
344 if let Some(idx) = lower.find("default ") {
346 let rest = &constraints[idx + 8..];
347 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
348 col.default_value = Some(rest[..end].to_string());
349 }
350
351 if let Some(idx) = lower.find("check(") {
353 let rest = &constraints[idx + 6..];
354 let mut depth = 1;
356 let mut end = rest.len();
357 for (i, c) in rest.char_indices() {
358 match c {
359 '(' => depth += 1,
360 ')' => {
361 depth -= 1;
362 if depth == 0 {
363 end = i;
364 break;
365 }
366 }
367 _ => {}
368 }
369 }
370 col.check = Some(rest[..end].to_string());
371 }
372 }
373
374 Ok((input, col))
375}
376
377fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
379 let (input, _) = ws_and_comments(input)?;
380 let (input, _) = char('(').parse(input)?;
381 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
382 let (input, _) = ws_and_comments(input)?;
383 let (input, _) = char(')').parse(input)?;
384
385 Ok((input, columns))
386}
387
388fn parse_table(input: &str) -> IResult<&str, TableDef> {
390 let (input, _) = ws_and_comments(input)?;
391 let (input, _) = tag_no_case("table").parse(input)?;
392 let (input, _) = multispace1(input)?;
393 let (input, name) = identifier(input)?;
394 let (input, columns) = parse_column_list(input)?;
395
396 Ok((
397 input,
398 TableDef {
399 name: name.to_string(),
400 columns,
401 },
402 ))
403}
404
405fn parse_schema(input: &str) -> IResult<&str, Schema> {
407 let (input, _) = ws_and_comments(input)?;
408 let (input, tables) = many0(parse_table).parse(input)?;
409 let (input, _) = ws_and_comments(input)?;
410
411 Ok((input, Schema { tables }))
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_parse_simple_table() {
420 let input = r#"
421 table users (
422 id uuid primary_key,
423 email text not null,
424 name text
425 )
426 "#;
427
428 let schema = Schema::parse(input).expect("parse failed");
429 assert_eq!(schema.tables.len(), 1);
430
431 let users = &schema.tables[0];
432 assert_eq!(users.name, "users");
433 assert_eq!(users.columns.len(), 3);
434
435 let id = &users.columns[0];
436 assert_eq!(id.name, "id");
437 assert_eq!(id.typ, "uuid");
438 assert!(id.primary_key);
439 assert!(!id.nullable);
440
441 let email = &users.columns[1];
442 assert_eq!(email.name, "email");
443 assert!(!email.nullable);
444
445 let name = &users.columns[2];
446 assert!(name.nullable);
447 }
448
449 #[test]
450 fn test_parse_multiple_tables() {
451 let input = r#"
452 -- Users table
453 table users (
454 id uuid primary_key,
455 email text not null unique
456 )
457
458 -- Orders table
459 table orders (
460 id uuid primary_key,
461 user_id uuid references users(id),
462 total i64 not null default 0
463 )
464 "#;
465
466 let schema = Schema::parse(input).expect("parse failed");
467 assert_eq!(schema.tables.len(), 2);
468
469 let orders = schema.find_table("orders").expect("orders not found");
470 let user_id = orders.find_column("user_id").expect("user_id not found");
471 assert_eq!(user_id.references, Some("users(id)".to_string()));
472
473 let total = orders.find_column("total").expect("total not found");
474 assert_eq!(total.default_value, Some("0".to_string()));
475 }
476
477 #[test]
478 fn test_parse_comments() {
479 let input = r#"
480 -- This is a comment
481 table foo (
482 bar text
483 )
484 "#;
485
486 let schema = Schema::parse(input).expect("parse failed");
487 assert_eq!(schema.tables.len(), 1);
488 }
489
490 #[test]
491 fn test_array_types() {
492 let input = r#"
493 table products (
494 id uuid primary_key,
495 tags text[],
496 prices decimal[]
497 )
498 "#;
499
500 let schema = Schema::parse(input).expect("parse failed");
501 let products = &schema.tables[0];
502
503 let tags = products.find_column("tags").expect("tags not found");
504 assert_eq!(tags.typ, "text");
505 assert!(tags.is_array);
506
507 let prices = products.find_column("prices").expect("prices not found");
508 assert!(prices.is_array);
509 }
510
511 #[test]
512 fn test_type_params() {
513 let input = r#"
514 table items (
515 id serial primary_key,
516 name varchar(255) not null,
517 price decimal(10,2),
518 code varchar(50) unique
519 )
520 "#;
521
522 let schema = Schema::parse(input).expect("parse failed");
523 let items = &schema.tables[0];
524
525 let id = items.find_column("id").expect("id not found");
526 assert!(id.is_serial);
527 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
530 assert_eq!(name.typ, "varchar");
531 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
532
533 let price = items.find_column("price").expect("price not found");
534 assert_eq!(
535 price.type_params,
536 Some(vec!["10".to_string(), "2".to_string()])
537 );
538
539 let code = items.find_column("code").expect("code not found");
540 assert!(code.unique);
541 }
542
543 #[test]
544 fn test_check_constraint() {
545 let input = r#"
546 table employees (
547 id uuid primary_key,
548 age i32 check(age >= 18),
549 salary decimal check(salary > 0)
550 )
551 "#;
552
553 let schema = Schema::parse(input).expect("parse failed");
554 let employees = &schema.tables[0];
555
556 let age = employees.find_column("age").expect("age not found");
557 assert_eq!(age.check, Some("age >= 18".to_string()));
558
559 let salary = employees.find_column("salary").expect("salary not found");
560 assert_eq!(salary.check, Some("salary > 0".to_string()));
561 }
562}