1use nom::{
14 branch::alt,
15 bytes::complete::{tag, tag_no_case, take_while1},
16 character::complete::{multispace1, char, not_line_ending},
17 combinator::{opt, map},
18 multi::{separated_list0, many0},
19 sequence::{preceded},
20 Parser,
21 IResult,
22};
23use serde::{Deserialize, Serialize};
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
27pub struct Schema {
28 pub tables: Vec<TableDef>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TableDef {
34 pub name: String,
35 pub columns: Vec<ColumnDef>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ColumnDef {
41 pub name: String,
42 #[serde(rename = "type", alias = "typ")]
43 pub typ: String,
44 #[serde(default)]
46 pub is_array: bool,
47 #[serde(default)]
49 pub type_params: Option<Vec<String>>,
50 #[serde(default)]
51 pub nullable: bool,
52 #[serde(default)]
53 pub primary_key: bool,
54 #[serde(default)]
55 pub unique: bool,
56 #[serde(default)]
57 pub references: Option<String>,
58 #[serde(default)]
59 pub default_value: Option<String>,
60 #[serde(default)]
62 pub check: Option<String>,
63 #[serde(default)]
65 pub is_serial: bool,
66}
67
68impl Default for ColumnDef {
69 fn default() -> Self {
70 Self {
71 name: String::new(),
72 typ: String::new(),
73 is_array: false,
74 type_params: None,
75 nullable: true,
76 primary_key: false,
77 unique: false,
78 references: None,
79 default_value: None,
80 check: None,
81 is_serial: false,
82 }
83 }
84}
85
86impl Schema {
87 pub fn parse(input: &str) -> Result<Self, String> {
89 match parse_schema(input) {
90 Ok(("", schema)) => Ok(schema),
91 Ok((remaining, _)) => Err(format!("Unexpected content: '{}'", remaining.trim())),
92 Err(e) => Err(format!("Parse error: {:?}", e)),
93 }
94 }
95
96 pub fn find_table(&self, name: &str) -> Option<&TableDef> {
98 self.tables.iter().find(|t| t.name.eq_ignore_ascii_case(name))
99 }
100
101 pub fn to_json(&self) -> Result<String, String> {
103 serde_json::to_string_pretty(self)
104 .map_err(|e| format!("JSON serialization failed: {}", e))
105 }
106
107 pub fn from_json(json: &str) -> Result<Self, String> {
109 serde_json::from_str(json)
110 .map_err(|e| format!("JSON deserialization failed: {}", e))
111 }
112
113 pub fn from_file(path: &std::path::Path) -> Result<Self, String> {
115 let content = std::fs::read_to_string(path)
116 .map_err(|e| format!("Failed to read file: {}", e))?;
117
118 if content.trim().starts_with('{') {
120 Self::from_json(&content)
121 } else {
122 Self::parse(&content)
123 }
124 }
125}
126
127impl TableDef {
128 pub fn find_column(&self, name: &str) -> Option<&ColumnDef> {
130 self.columns.iter().find(|c| c.name.eq_ignore_ascii_case(name))
131 }
132}
133
134fn identifier(input: &str) -> IResult<&str, &str> {
140 take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
141}
142
143fn ws_and_comments(input: &str) -> IResult<&str, ()> {
145 let (input, _) = many0(alt((
146 map(multispace1, |_| ()),
147 map((tag("--"), not_line_ending), |_| ()),
148 ))).parse(input)?;
149 Ok((input, ()))
150}
151
152struct TypeInfo {
154 name: String,
155 params: Option<Vec<String>>,
156 is_array: bool,
157 is_serial: bool,
158}
159
160fn parse_type_info(input: &str) -> IResult<&str, TypeInfo> {
163 let (input, type_name) = take_while1(|c: char| c.is_alphanumeric()).parse(input)?;
165
166 let (input, params) = if input.starts_with('(') {
168 let paren_start = 1;
169 let mut paren_end = paren_start;
170 for (i, c) in input[paren_start..].char_indices() {
171 if c == ')' {
172 paren_end = paren_start + i;
173 break;
174 }
175 }
176 let param_str = &input[paren_start..paren_end];
177 let params: Vec<String> = param_str.split(',').map(|s| s.trim().to_string()).collect();
178 (&input[paren_end + 1..], Some(params))
179 } else {
180 (input, None)
181 };
182
183 let (input, is_array) = if input.starts_with("[]") {
185 (&input[2..], true)
186 } else {
187 (input, false)
188 };
189
190 let lower = type_name.to_lowercase();
191 let is_serial = lower == "serial" || lower == "bigserial" || lower == "smallserial";
192
193 Ok((input, TypeInfo {
194 name: lower,
195 params,
196 is_array,
197 is_serial,
198 }))
199}
200
201fn constraint_text(input: &str) -> IResult<&str, &str> {
203 let mut paren_depth = 0;
204 let mut end = 0;
205
206 for (i, c) in input.char_indices() {
207 match c {
208 '(' => paren_depth += 1,
209 ')' => {
210 if paren_depth == 0 {
211 break; }
213 paren_depth -= 1;
214 }
215 ',' if paren_depth == 0 => break,
216 '\n' | '\r' if paren_depth == 0 => break,
217 _ => {}
218 }
219 end = i + c.len_utf8();
220 }
221
222 if end == 0 {
223 Err(nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::TakeWhile1)))
224 } else {
225 Ok((&input[end..], &input[..end]))
226 }
227}
228
229fn parse_column(input: &str) -> IResult<&str, ColumnDef> {
231 let (input, _) = ws_and_comments(input)?;
232 let (input, name) = identifier(input)?;
233 let (input, _) = multispace1(input)?;
234 let (input, type_info) = parse_type_info(input)?;
235
236 let (input, constraint_str) = opt(preceded(multispace1, constraint_text)).parse(input)?;
238
239 let mut col = ColumnDef {
241 name: name.to_string(),
242 typ: type_info.name,
243 is_array: type_info.is_array,
244 type_params: type_info.params,
245 is_serial: type_info.is_serial,
246 nullable: !type_info.is_serial, ..Default::default()
248 };
249
250 if let Some(constraints) = constraint_str {
251 let lower = constraints.to_lowercase();
252
253 if lower.contains("primary_key") || lower.contains("primary key") {
254 col.primary_key = true;
255 col.nullable = false;
256 }
257 if lower.contains("not_null") || lower.contains("not null") {
258 col.nullable = false;
259 }
260 if lower.contains("unique") {
261 col.unique = true;
262 }
263
264 if let Some(idx) = lower.find("references ") {
266 let rest = &constraints[idx + 11..];
267 let mut paren_depth = 0;
269 let mut end = rest.len();
270 for (i, c) in rest.char_indices() {
271 match c {
272 '(' => paren_depth += 1,
273 ')' => {
274 if paren_depth == 0 {
275 end = i;
276 break;
277 }
278 paren_depth -= 1;
279 }
280 c if c.is_whitespace() && paren_depth == 0 => {
281 end = i;
282 break;
283 }
284 _ => {}
285 }
286 }
287 col.references = Some(rest[..end].to_string());
288 }
289
290 if let Some(idx) = lower.find("default ") {
292 let rest = &constraints[idx + 8..];
293 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
294 col.default_value = Some(rest[..end].to_string());
295 }
296
297 if let Some(idx) = lower.find("check(") {
299 let rest = &constraints[idx + 6..];
300 let mut depth = 1;
302 let mut end = rest.len();
303 for (i, c) in rest.char_indices() {
304 match c {
305 '(' => depth += 1,
306 ')' => {
307 depth -= 1;
308 if depth == 0 {
309 end = i;
310 break;
311 }
312 }
313 _ => {}
314 }
315 }
316 col.check = Some(rest[..end].to_string());
317 }
318 }
319
320 Ok((input, col))
321}
322
323fn parse_column_list(input: &str) -> IResult<&str, Vec<ColumnDef>> {
325 let (input, _) = ws_and_comments(input)?;
326 let (input, _) = char('(').parse(input)?;
327 let (input, columns) = separated_list0(
328 char(','),
329 parse_column,
330 ).parse(input)?;
331 let (input, _) = ws_and_comments(input)?;
332 let (input, _) = char(')').parse(input)?;
333
334 Ok((input, columns))
335}
336
337fn parse_table(input: &str) -> IResult<&str, TableDef> {
339 let (input, _) = ws_and_comments(input)?;
340 let (input, _) = tag_no_case("table").parse(input)?;
341 let (input, _) = multispace1(input)?;
342 let (input, name) = identifier(input)?;
343 let (input, columns) = parse_column_list(input)?;
344
345 Ok((input, TableDef {
346 name: name.to_string(),
347 columns,
348 }))
349}
350
351fn parse_schema(input: &str) -> IResult<&str, Schema> {
353 let (input, _) = ws_and_comments(input)?;
354 let (input, tables) = many0(parse_table).parse(input)?;
355 let (input, _) = ws_and_comments(input)?;
356
357 Ok((input, Schema { tables }))
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_parse_simple_table() {
366 let input = r#"
367 table users (
368 id uuid primary_key,
369 email text not null,
370 name text
371 )
372 "#;
373
374 let schema = Schema::parse(input).expect("parse failed");
375 assert_eq!(schema.tables.len(), 1);
376
377 let users = &schema.tables[0];
378 assert_eq!(users.name, "users");
379 assert_eq!(users.columns.len(), 3);
380
381 let id = &users.columns[0];
382 assert_eq!(id.name, "id");
383 assert_eq!(id.typ, "uuid");
384 assert!(id.primary_key);
385 assert!(!id.nullable);
386
387 let email = &users.columns[1];
388 assert_eq!(email.name, "email");
389 assert!(!email.nullable);
390
391 let name = &users.columns[2];
392 assert!(name.nullable);
393 }
394
395 #[test]
396 fn test_parse_multiple_tables() {
397 let input = r#"
398 -- Users table
399 table users (
400 id uuid primary_key,
401 email text not null unique
402 )
403
404 -- Orders table
405 table orders (
406 id uuid primary_key,
407 user_id uuid references users(id),
408 total i64 not null default 0
409 )
410 "#;
411
412 let schema = Schema::parse(input).expect("parse failed");
413 assert_eq!(schema.tables.len(), 2);
414
415 let orders = schema.find_table("orders").expect("orders not found");
416 let user_id = orders.find_column("user_id").expect("user_id not found");
417 assert_eq!(user_id.references, Some("users(id)".to_string()));
418
419 let total = orders.find_column("total").expect("total not found");
420 assert_eq!(total.default_value, Some("0".to_string()));
421 }
422
423 #[test]
424 fn test_parse_comments() {
425 let input = r#"
426 -- This is a comment
427 table foo (
428 bar text
429 )
430 "#;
431
432 let schema = Schema::parse(input).expect("parse failed");
433 assert_eq!(schema.tables.len(), 1);
434 }
435
436 #[test]
437 fn test_array_types() {
438 let input = r#"
439 table products (
440 id uuid primary_key,
441 tags text[],
442 prices decimal[]
443 )
444 "#;
445
446 let schema = Schema::parse(input).expect("parse failed");
447 let products = &schema.tables[0];
448
449 let tags = products.find_column("tags").expect("tags not found");
450 assert_eq!(tags.typ, "text");
451 assert!(tags.is_array);
452
453 let prices = products.find_column("prices").expect("prices not found");
454 assert!(prices.is_array);
455 }
456
457 #[test]
458 fn test_type_params() {
459 let input = r#"
460 table items (
461 id serial primary_key,
462 name varchar(255) not null,
463 price decimal(10,2),
464 code varchar(50) unique
465 )
466 "#;
467
468 let schema = Schema::parse(input).expect("parse failed");
469 let items = &schema.tables[0];
470
471 let id = items.find_column("id").expect("id not found");
472 assert!(id.is_serial);
473 assert!(!id.nullable); let name = items.find_column("name").expect("name not found");
476 assert_eq!(name.typ, "varchar");
477 assert_eq!(name.type_params, Some(vec!["255".to_string()]));
478
479 let price = items.find_column("price").expect("price not found");
480 assert_eq!(price.type_params, Some(vec!["10".to_string(), "2".to_string()]));
481
482 let code = items.find_column("code").expect("code not found");
483 assert!(code.unique);
484 }
485
486 #[test]
487 fn test_check_constraint() {
488 let input = r#"
489 table employees (
490 id uuid primary_key,
491 age i32 check(age >= 18),
492 salary decimal check(salary > 0)
493 )
494 "#;
495
496 let schema = Schema::parse(input).expect("parse failed");
497 let employees = &schema.tables[0];
498
499 let age = employees.find_column("age").expect("age not found");
500 assert_eq!(age.check, Some("age >= 18".to_string()));
501
502 let salary = employees.find_column("salary").expect("salary not found");
503 assert_eq!(salary.check, Some("salary > 0".to_string()));
504 }
505}