1use nom::{
14 IResult, Parser,
15 branch::alt,
16 bytes::complete::{tag, tag_no_case, take_while1},
17 character::complete::{char, multispace1, not_line_ending},
18 combinator::{map, opt},
19 multi::{many0, separated_list0},
20 sequence::preceded,
21};
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct Schema {
27 #[serde(default)]
29 pub version: Option<u32>,
30 pub tables: Vec<TableDef>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct TableDef {
35 pub name: String,
36 pub columns: Vec<ColumnDef>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ColumnDef {
41 pub name: String,
42 #[serde(rename = "type", alias = "typ")]
43 pub typ: String,
44 #[serde(default)]
46 pub is_array: bool,
47 #[serde(default)]
49 pub type_params: Option<Vec<String>>,
50 #[serde(default)]
51 pub nullable: bool,
52 #[serde(default)]
53 pub primary_key: bool,
54 #[serde(default)]
55 pub unique: bool,
56 #[serde(default)]
57 pub references: Option<String>,
58 #[serde(default)]
59 pub default_value: Option<String>,
60 #[serde(default)]
62 pub check: Option<String>,
63 #[serde(default)]
65 pub is_serial: bool,
66}
67
68impl Default for ColumnDef {
69 fn default() -> Self {
70 Self {
71 name: String::new(),
72 typ: String::new(),
73 is_array: false,
74 type_params: None,
75 nullable: true,
76 primary_key: false,
77 unique: false,
78 references: None,
79 default_value: None,
80 check: None,
81 is_serial: false,
82 }
83 }
84}
85
86impl Schema {
87 pub fn parse(input: &str) -> Result<Self, String> {
89 match parse_schema(input) {
90 Ok(("", schema)) => Ok(schema),
91 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
92 Err(e) => Err(format!("Parse error: {:?}", e)),
93 }
94 }
95
96 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
98 self.tables
99 .iter()
100 .find(|t| t.name.eq_ignore_ascii_case(name))
101 }
102
103 pub fn to_json(&self) -> Result<String, String> {
105 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization failed: {}", e))
106 }
107
108 pub fn from_json(json: &str) -> Result<Self, String> {
110 serde_json::from_str(json).map_err(|e| format!("JSON deserialization failed: {}", e))
111 }
112
113 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
115 let content =
116 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
117
118 if content.trim().starts_with('{') {
119 Self::from_json(&content)
120 } else {
121 Self::parse(&content)
122 }
123 }
124}
125
126impl TableDef {
127 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
129 self.columns
130 .iter()
131 .find(|c| c.name.eq_ignore_ascii_case(name))
132 }
133
134 pub fn to_ddl(&self) -> String {
136 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", self.name);
137
138 let mut col_defs = Vec::new();
139 for col in &self.columns {
140 let mut line = format!(" {}", col.name);
141
142 let mut typ = col.typ.to_uppercase();
144 if let Some(params) = &col.type_params {
145 typ = format!("{}({})", typ, params.join(", "));
146 }
147 if col.is_array {
148 typ.push_str("[]");
149 }
150 line.push_str(&format!(" {}", typ));
151
152 if col.primary_key {
154 line.push_str(" PRIMARY KEY");
155 }
156 if !col.nullable && !col.primary_key && !col.is_serial {
157 line.push_str(" NOT NULL");
158 }
159 if col.unique && !col.primary_key {
160 line.push_str(" UNIQUE");
161 }
162 if let Some(ref default) = col.default_value {
163 line.push_str(&format!(" DEFAULT {}", default));
164 }
165 if let Some(ref refs) = col.references {
166 line.push_str(&format!(" REFERENCES {}", refs));
167 }
168 if let Some(ref check) = col.check {
169 line.push_str(&format!(" CHECK({})", check));
170 }
171
172 col_defs.push(line);
173 }
174
175 sql.push_str(&col_defs.join(",\n"));
176 sql.push_str("\n)");
177 sql
178 }
179}
180
181fn identifier(input: &str) -> IResult<&str, &str> {
187 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
188}
189
190fn ws_and_comments(input: &str) -> IResult<&str, ()> {
192 let (input, _) = many0(alt((
193 map(multispace1, |_| ()),
194 map((tag("--"), not_line_ending), |_| ()),
195 )))
196 .parse(input)?;
197 Ok((input, ()))
198}
199
200struct TypeInfo {
201 name: String,
202 params: Option<Vec<String>>,
203 is_array: bool,
204 is_serial: bool,
205}
206
207fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
210 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
211
212 let (input, params) = if input.starts_with('(') {
213 let paren_start = 1;
214 let mut paren_end = paren_start;
215 for (i, c) in input[paren_start..].char_indices() {
216 if c == ')' {
217 paren_end = paren_start + i;
218 break;
219 }
220 }
221 let param_str = &input[paren_start..paren_end];
222 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
223 (&input[paren_end + 1..], Some(params))
224 } else {
225 (input, None)
226 };
227
228 let (input, is_array) = if let Some(stripped) = input.strip_prefix("[]") {
229 (stripped, true)
230 } else {
231 (input, false)
232 };
233
234 let lower = type_name.to_lowercase();
235 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
236
237 Ok((
238 input,
239 TypeInfo {
240 name: lower,
241 params,
242 is_array,
243 is_serial,
244 },
245 ))
246}
247
248fn constraint_text(input: &str) -> IResult<&str, &str> {
250 let mut paren_depth = 0;
251 let mut end = 0;
252
253 for (i, c) in input.char_indices() {
254 match c {
255 '(' => paren_depth += 1,
256 ')' => {
257 if paren_depth == 0 {
258 break; }
260 paren_depth -= 1;
261 }
262 ',' if paren_depth == 0 => break,
263 '\n' | '\r' if paren_depth == 0 => break,
264 _ => {}
265 }
266 end = i + c.len_utf8();
267 }
268
269 if end == 0 {
270 Err(nom::Err::Error(nom::error::Error::new(
271 input,
272 nom::error::ErrorKind::TakeWhile1,
273 )))
274 } else {
275 Ok((&input[end..], &input[..end]))
276 }
277}
278
279fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
281 let (input, _) = ws_and_comments(input)?;
282 let (input, name) = identifier(input)?;
283 let (input, _) = multispace1(input)?;
284 let (input, type_info) = parse_type_info(input)?;
285
286 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
287
288 let mut col = ColumnDef {
289 name: name.to_string(),
290 typ: type_info.name,
291 is_array: type_info.is_array,
292 type_params: type_info.params,
293 is_serial: type_info.is_serial,
294 nullable: !type_info.is_serial, ..Default::default()
296 };
297
298 if let Some(constraints) = constraint_str {
299 let lower = constraints.to_lowercase();
300
301 if lower.contains("primary_key") || lower.contains("primary key") {
302 col.primary_key = true;
303 col.nullable = false;
304 }
305 if lower.contains("not_null") || lower.contains("not null") {
306 col.nullable = false;
307 }
308 if lower.contains("unique") {
309 col.unique = true;
310 }
311
312 if let Some(idx) = lower.find("references ") {
313 let rest = &constraints[idx + 11..];
314 let mut paren_depth = 0;
316 let mut end = rest.len();
317 for (i, c) in rest.char_indices() {
318 match c {
319 '(' => paren_depth += 1,
320 ')' => {
321 if paren_depth == 0 {
322 end = i;
323 break;
324 }
325 paren_depth -= 1;
326 }
327 c if c.is_whitespace() && paren_depth == 0 => {
328 end = i;
329 break;
330 }
331 _ => {}
332 }
333 }
334 col.references = Some(rest[..end].to_string());
335 }
336
337 if let Some(idx) = lower.find("default ") {
338 let rest = &constraints[idx + 8..];
339 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
340 col.default_value = Some(rest[..end].to_string());
341 }
342
343 if let Some(idx) = lower.find("check(") {
344 let rest = &constraints[idx + 6..];
345 let mut depth = 1;
347 let mut end = rest.len();
348 for (i, c) in rest.char_indices() {
349 match c {
350 '(' => depth += 1,
351 ')' => {
352 depth -= 1;
353 if depth == 0 {
354 end = i;
355 break;
356 }
357 }
358 _ => {}
359 }
360 }
361 col.check = Some(rest[..end].to_string());
362 }
363 }
364
365 Ok((input, col))
366}
367
368fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
370 let (input, _) = ws_and_comments(input)?;
371 let (input, _) = char('(').parse(input)?;
372 let (input, columns) = separated_list0(char(','), parse_column).parse(input)?;
373 let (input, _) = ws_and_comments(input)?;
374 let (input, _) = char(')').parse(input)?;
375
376 Ok((input, columns))
377}
378
379fn parse_table(input: &str) -> IResult<&str, TableDef> {
381 let (input, _) = ws_and_comments(input)?;
382 let (input, _) = tag_no_case("table").parse(input)?;
383 let (input, _) = multispace1(input)?;
384 let (input, name) = identifier(input)?;
385 let (input, columns) = parse_column_list(input)?;
386
387 Ok((
388 input,
389 TableDef {
390 name: name.to_string(),
391 columns,
392 },
393 ))
394}
395
396fn parse_schema(input: &str) -> IResult<&str, Schema> {
398 let version = extract_version_directive(input);
400
401 let (input, _) = ws_and_comments(input)?;
402 let (input, tables) = many0(parse_table).parse(input)?;
403 let (input, _) = ws_and_comments(input)?;
404
405 Ok((input, Schema { version, tables }))
406}
407
408fn extract_version_directive(input: &str) -> Option<u32> {
410 for line in input.lines() {
411 let line = line.trim();
412 if let Some(rest) = line.strip_prefix("-- qail:") {
413 let rest = rest.trim();
414 if let Some(version_str) = rest.strip_prefix("version=") {
415 return version_str.trim().parse().ok();
416 }
417 }
418 }
419 None
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_parse_simple_table() {
428 let input = r#"
429 table users (
430 id uuid primary_key,
431 email text not null,
432 name text
433 )
434 "#;
435
436 let schema = Schema::parse(input).expect("parse failed");
437 assert_eq!(schema.tables.len(), 1);
438
439 let users = &schema.tables[0];
440 assert_eq!(users.name, "users");
441 assert_eq!(users.columns.len(), 3);
442
443 let id = &users.columns[0];
444 assert_eq!(id.name, "id");
445 assert_eq!(id.typ, "uuid");
446 assert!(id.primary_key);
447 assert!(!id.nullable);
448
449 let email = &users.columns[1];
450 assert_eq!(email.name, "email");
451 assert!(!email.nullable);
452
453 let name = &users.columns[2];
454 assert!(name.nullable);
455 }
456
457 #[test]
458 fn test_parse_multiple_tables() {
459 let input = r#"
460 -- Users table
461 table users (
462 id uuid primary_key,
463 email text not null unique
464 )
465
466 -- Orders table
467 table orders (
468 id uuid primary_key,
469 user_id uuid references users(id),
470 total i64 not null default 0
471 )
472 "#;
473
474 let schema = Schema::parse(input).expect("parse failed");
475 assert_eq!(schema.tables.len(), 2);
476
477 let orders = schema.find_table("orders").expect("orders not found");
478 let user_id = orders.find_column("user_id").expect("user_id not found");
479 assert_eq!(user_id.references, Some("users(id)".to_string()));
480
481 let total = orders.find_column("total").expect("total not found");
482 assert_eq!(total.default_value, Some("0".to_string()));
483 }
484
485 #[test]
486 fn test_parse_comments() {
487 let input = r#"
488 -- This is a comment
489 table foo (
490 bar text
491 )
492 "#;
493
494 let schema = Schema::parse(input).expect("parse failed");
495 assert_eq!(schema.tables.len(), 1);
496 }
497
498 #[test]
499 fn test_array_types() {
500 let input = r#"
501 table products (
502 id uuid primary_key,
503 tags text[],
504 prices decimal[]
505 )
506 "#;
507
508 let schema = Schema::parse(input).expect("parse failed");
509 let products = &schema.tables[0];
510
511 let tags = products.find_column("tags").expect("tags not found");
512 assert_eq!(tags.typ, "text");
513 assert!(tags.is_array);
514
515 let prices = products.find_column("prices").expect("prices not found");
516 assert!(prices.is_array);
517 }
518
519 #[test]
520 fn test_type_params() {
521 let input = r#"
522 table items (
523 id serial primary_key,
524 name varchar(255) not null,
525 price decimal(10,2),
526 code varchar(50) unique
527 )
528 "#;
529
530 let schema = Schema::parse(input).expect("parse failed");
531 let items = &schema.tables[0];
532
533 let id = items.find_column("id").expect("id not found");
534 assert!(id.is_serial);
535 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
538 assert_eq!(name.typ, "varchar");
539 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
540
541 let price = items.find_column("price").expect("price not found");
542 assert_eq!(
543 price.type_params,
544 Some(vec!["10".to_string(), "2".to_string()])
545 );
546
547 let code = items.find_column("code").expect("code not found");
548 assert!(code.unique);
549 }
550
551 #[test]
552 fn test_check_constraint() {
553 let input = r#"
554 table employees (
555 id uuid primary_key,
556 age i32 check(age >= 18),
557 salary decimal check(salary > 0)
558 )
559 "#;
560
561 let schema = Schema::parse(input).expect("parse failed");
562 let employees = &schema.tables[0];
563
564 let age = employees.find_column("age").expect("age not found");
565 assert_eq!(age.check, Some("age >= 18".to_string()));
566
567 let salary = employees.find_column("salary").expect("salary not found");
568 assert_eq!(salary.check, Some("salary > 0".to_string()));
569 }
570
571 #[test]
572 fn test_version_directive() {
573 let input = r#"
574 -- qail: version=1
575 table users (
576 id uuid primary_key
577 )
578 "#;
579
580 let schema = Schema::parse(input).expect("parse failed");
581 assert_eq!(schema.version, Some(1));
582 assert_eq!(schema.tables.len(), 1);
583
584 let input_no_version = r#"
586 table items (
587 id uuid primary_key
588 )
589 "#;
590 let schema2 = Schema::parse(input_no_version).expect("parse failed");
591 assert_eq!(schema2.version, None);
592 }
593}