Skip to main content

cqlite_core/schema/
cql_parser.rs

1//! CQL Schema Parser
2//!
3//! This module provides parsing capabilities for CQL CREATE TABLE statements
4//! to extract table schema information including table names, column definitions,
5//! partition keys, clustering keys, and type information.
6
7use crate::cql::{CqlCreateTable, CqlDataType};
8use crate::error::{Error, Result};
9use crate::parser::types::CqlTypeId;
10use crate::schema::{ClusteringColumn, Column, KeyColumn, TableSchema};
11use nom::{
12    branch::alt,
13    bytes::complete::{tag_no_case, take_while, take_while1},
14    character::complete::char,
15    combinator::{map, opt},
16    multi::{separated_list0, separated_list1},
17    sequence::{delimited, preceded, separated_pair, tuple},
18    IResult,
19};
20use serde_json;
21use std::collections::HashMap;
22
23/// CQL keyword parser - case insensitive
24fn keyword(s: &str) -> impl Fn(&str) -> IResult<&str, &str> + '_ {
25    move |input| tag_no_case(s)(input)
26}
27
28/// Parse whitespace and comments
29fn ws(input: &str) -> IResult<&str, &str> {
30    take_while(|c: char| c.is_whitespace())(input)
31}
32
33/// Parse mandatory whitespace
34fn ws1(input: &str) -> IResult<&str, &str> {
35    take_while1(|c: char| c.is_whitespace())(input)
36}
37
38/// Parse identifier (table name, column name, etc.)
39fn identifier(input: &str) -> IResult<&str, String> {
40    let (input, name) = alt((
41        // Quoted identifier
42        delimited(char('"'), take_while1(|c: char| c != '"'), char('"')),
43        // Unquoted identifier
44        take_while1(|c: char| c.is_alphanumeric() || c == '_'),
45    ))(input)?;
46
47    Ok((input, name.to_string()))
48}
49
50/// Parse a qualified table name (keyspace.table or just table)
51fn qualified_table_name(input: &str) -> IResult<&str, (Option<String>, String)> {
52    let (input, first) = identifier(input)?;
53    let (input, second) = opt(preceded(char('.'), identifier))(input)?;
54
55    match second {
56        Some(table) => Ok((input, (Some(first), table))),
57        None => Ok((input, (None, first))),
58    }
59}
60
61/// Parse CQL data type
62fn cql_type(input: &str) -> IResult<&str, String> {
63    // Handle complex types like list<text>, map<text, bigint>, frozen<set<uuid>>
64    fn parse_type_inner(input: &str) -> IResult<&str, String> {
65        let (input, base) = alt((
66            // Collection types
67            map(
68                tuple((
69                    alt((keyword("list"), keyword("set"))),
70                    char('<'),
71                    parse_type_inner,
72                    char('>'),
73                )),
74                |(collection, _, inner, _)| format!("{}<{}>", collection, inner),
75            ),
76            // Map type
77            map(
78                tuple((
79                    keyword("map"),
80                    char('<'),
81                    parse_type_inner,
82                    char(','),
83                    ws,
84                    parse_type_inner,
85                    char('>'),
86                )),
87                |(_, _, key_type, _, _, value_type, _)| {
88                    format!("map<{}, {}>", key_type, value_type)
89                },
90            ),
91            // Tuple type
92            map(
93                tuple((
94                    keyword("tuple"),
95                    char('<'),
96                    separated_list1(tuple((ws, char(','), ws)), parse_type_inner),
97                    char('>'),
98                )),
99                |(_, _, types, _)| format!("tuple<{}>", types.join(", ")),
100            ),
101            // Frozen type
102            map(
103                tuple((keyword("frozen"), char('<'), parse_type_inner, char('>'))),
104                |(_, _, inner, _)| format!("frozen<{}>", inner),
105            ),
106            // Simple types and UDTs
107            map(identifier, |name| name),
108        ))(input)?;
109
110        Ok((input, base))
111    }
112
113    let (input, _) = ws(input)?;
114    let (input, type_name) = parse_type_inner(input)?;
115    let (input, _) = ws(input)?;
116
117    Ok((input, type_name))
118}
119
120/// Parse column definition (with optional STATIC modifier and inline PRIMARY KEY)
121/// Returns (name, data_type, is_static)
122fn column_definition(input: &str) -> IResult<&str, (String, String, bool)> {
123    let (input, _) = ws(input)?;
124    let (input, name) = identifier(input)?;
125    let (input, _) = ws1(input)?;
126    let (input, data_type) = cql_type(input)?;
127    let (input, _) = ws(input)?;
128
129    // Check for STATIC modifier (Issue #255)
130    let (input, is_static) = opt(keyword("static"))(input)?;
131    let is_static = is_static.is_some();
132    let (input, _) = ws(input)?;
133
134    // Check for inline PRIMARY KEY (parse it but don't modify data_type)
135    // The PRIMARY KEY constraint is tracked via partition_keys/clustering_keys, not in data_type
136    let (input, _is_primary) = opt(tuple((keyword("primary"), ws1, keyword("key"))))(input)?;
137
138    // Return the data_type as-is (e.g., "uuid", not "uuid PRIMARY KEY")
139    // Issue #192: data_type must be a pure CQL type name for proper type matching
140    Ok((input, (name, data_type, is_static)))
141}
142
143/// Parse PRIMARY KEY specification
144fn primary_key_spec(input: &str) -> IResult<&str, (Vec<String>, Vec<String>)> {
145    let (input, _) = ws(input)?;
146    let (input, _) = keyword("primary")(input)?;
147    let (input, _) = ws1(input)?;
148    let (input, _) = keyword("key")(input)?;
149    let (input, _) = ws(input)?;
150    let (input, _) = char('(')(input)?;
151    let (input, _) = ws(input)?;
152
153    // Parse partition key (can be composite)
154    let (input, partition_keys) = alt((
155        // Composite partition key: ((col1, col2), clustering...)
156        map(
157            tuple((
158                char('('),
159                ws,
160                separated_list1(tuple((ws, char(','), ws)), identifier),
161                ws,
162                char(')'),
163            )),
164            |(_, _, keys, _, _)| keys,
165        ),
166        // Single partition key: (col1, clustering...)
167        map(identifier, |key| vec![key]),
168    ))(input)?;
169
170    let (input, _) = ws(input)?;
171
172    // Parse clustering keys (optional)
173    let (input, clustering_keys) = opt(preceded(
174        tuple((char(','), ws)),
175        separated_list1(tuple((ws, char(','), ws)), identifier),
176    ))(input)?;
177
178    let (input, _) = ws(input)?;
179    let (input, _) = char(')')(input)?;
180
181    Ok((input, (partition_keys, clustering_keys.unwrap_or_default())))
182}
183
184/// Parse table options (WITH clause)
185fn table_options(input: &str) -> IResult<&str, HashMap<String, String>> {
186    let (input, _) = ws(input)?;
187    let (input, _) = keyword("with")(input)?;
188    let (input, _) = ws1(input)?;
189
190    // Parse option = value pairs
191    let option_pair = map(
192        separated_pair(
193            identifier,
194            tuple((ws, char('='), ws)),
195            alt((
196                // String value
197                delimited(char('\''), take_while(|c: char| c != '\''), char('\'')),
198                // Numeric or identifier value
199                take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '.'),
200            )),
201        ),
202        |(key, value)| (key, value.to_string()),
203    );
204
205    let (input, options) = separated_list0(tuple((ws, keyword("and"), ws)), option_pair)(input)?;
206
207    Ok((input, options.into_iter().collect()))
208}
209
210/// Split CQL file content into individual statements (semicolon-delimited)
211/// Respects string literals and comments to avoid splitting inside them
212pub fn split_cql_statements(input: &str) -> Vec<String> {
213    let mut statements = Vec::new();
214    let mut current_statement = String::new();
215    let mut in_string = false;
216    let mut in_single_line_comment = false;
217    let mut in_multi_line_comment = false;
218    let mut escape_next = false;
219
220    let chars: Vec<char> = input.chars().collect();
221    let mut i = 0;
222
223    while i < chars.len() {
224        let c = chars[i];
225
226        // Handle escape sequences in strings
227        if escape_next {
228            current_statement.push(c);
229            escape_next = false;
230            i += 1;
231            continue;
232        }
233
234        // Check for multi-line comment start
235        if !in_string
236            && !in_single_line_comment
237            && !in_multi_line_comment
238            && i + 1 < chars.len()
239            && c == '/'
240            && chars[i + 1] == '*'
241        {
242            in_multi_line_comment = true;
243            current_statement.push(c);
244            current_statement.push(chars[i + 1]);
245            i += 2;
246            continue;
247        }
248
249        // Check for multi-line comment end
250        if in_multi_line_comment && i + 1 < chars.len() && c == '*' && chars[i + 1] == '/' {
251            in_multi_line_comment = false;
252            current_statement.push(c);
253            current_statement.push(chars[i + 1]);
254            i += 2;
255            continue;
256        }
257
258        // Check for single-line comment start
259        if !in_string
260            && !in_multi_line_comment
261            && !in_single_line_comment
262            && i + 1 < chars.len()
263            && c == '-'
264            && chars[i + 1] == '-'
265        {
266            in_single_line_comment = true;
267            current_statement.push(c);
268            current_statement.push(chars[i + 1]);
269            i += 2;
270            continue;
271        }
272
273        // Handle newline (ends single-line comment)
274        if c == '\n' {
275            in_single_line_comment = false;
276            current_statement.push(c);
277            i += 1;
278            continue;
279        }
280
281        // Skip processing if inside a comment
282        if in_single_line_comment || in_multi_line_comment {
283            current_statement.push(c);
284            i += 1;
285            continue;
286        }
287
288        // Handle string literals (single quotes)
289        if c == '\'' {
290            in_string = !in_string;
291            current_statement.push(c);
292            i += 1;
293            continue;
294        }
295
296        // Handle escape in string
297        if in_string && c == '\\' {
298            escape_next = true;
299            current_statement.push(c);
300            i += 1;
301            continue;
302        }
303
304        // Handle semicolon (statement separator)
305        if !in_string && c == ';' {
306            let trimmed = current_statement.trim();
307            if !trimmed.is_empty() {
308                statements.push(trimmed.to_string());
309            }
310            current_statement.clear();
311            i += 1;
312            continue;
313        }
314
315        current_statement.push(c);
316        i += 1;
317    }
318
319    // Add final statement if non-empty
320    let trimmed = current_statement.trim();
321    if !trimmed.is_empty() {
322        statements.push(trimmed.to_string());
323    }
324
325    // Clean up statements: remove leading/trailing comment-only lines
326    statements
327        .into_iter()
328        .map(|stmt| strip_leading_trailing_comments(&stmt))
329        .filter(|s| !s.is_empty())
330        .collect()
331}
332
333/// Strip leading and trailing comment-only lines from a statement
334fn strip_leading_trailing_comments(stmt: &str) -> String {
335    let lines: Vec<&str> = stmt.lines().collect();
336    let mut start = 0;
337    let mut end = lines.len();
338
339    // Find first non-comment line
340    for (i, line) in lines.iter().enumerate() {
341        let trimmed = line.trim();
342        if !trimmed.is_empty() && !trimmed.starts_with("--") && !trimmed.starts_with("/*") {
343            start = i;
344            break;
345        }
346    }
347
348    // Find last non-comment line
349    for (i, line) in lines.iter().enumerate().rev() {
350        let trimmed = line.trim();
351        if !trimmed.is_empty() && !trimmed.starts_with("--") && !trimmed.ends_with("*/") {
352            end = i + 1;
353            break;
354        }
355    }
356
357    if start >= end {
358        return String::new();
359    }
360
361    lines[start..end].join("\n")
362}
363
364#[cfg(test)]
365mod tests_splitter {
366    use super::*;
367
368    #[test]
369    fn test_split_with_comments() {
370        let cql = r#"
371        -- Comment
372        CREATE TYPE test.udt (field text);
373
374        /* Multi-line
375           comment */
376        CREATE TABLE test.tbl (id int PRIMARY KEY);
377        "#;
378
379        let stmts = split_cql_statements(cql);
380        assert_eq!(stmts.len(), 2);
381        assert!(stmts[0].contains("CREATE TYPE"));
382        assert!(!stmts[0].contains("--"));
383        assert!(stmts[1].contains("CREATE TABLE"));
384    }
385}
386
387/// Statement type classification
388#[derive(Debug, Clone, PartialEq)]
389pub enum StatementType {
390    CreateTable,
391    CreateType,
392    Other(String),
393}
394
395/// Classify a CQL statement by type
396pub fn classify_statement(statement: &str) -> StatementType {
397    let normalized = statement.trim().to_lowercase();
398
399    // Remove leading whitespace and comments
400    let normalized = normalized
401        .lines()
402        .map(|line| {
403            // Remove single-line comments
404            if let Some(pos) = line.find("--") {
405                &line[..pos]
406            } else {
407                line
408            }
409        })
410        .collect::<Vec<&str>>()
411        .join(" ");
412
413    let normalized = normalized.trim();
414
415    if normalized.starts_with("create table")
416        || normalized.starts_with("create table if not exists")
417    {
418        StatementType::CreateTable
419    } else if normalized.starts_with("create type")
420        || normalized.starts_with("create type if not exists")
421    {
422        StatementType::CreateType
423    } else {
424        StatementType::Other(
425            normalized
426                .split_whitespace()
427                .next()
428                .unwrap_or("unknown")
429                .to_string(),
430        )
431    }
432}
433
434/// Parse CREATE TYPE statement to extract UDT definition
435#[allow(clippy::type_complexity)]
436pub fn parse_create_type(
437    input: &str,
438) -> IResult<&str, (String, Option<String>, Vec<(String, String)>)> {
439    let (input, _) = ws(input)?;
440    let (input, _) = keyword("create")(input)?;
441    let (input, _) = ws1(input)?;
442    let (input, _) = keyword("type")(input)?;
443    let (input, _) = ws1(input)?;
444
445    // Optional IF NOT EXISTS
446    let (input, _) = opt(tuple((
447        keyword("if"),
448        ws1,
449        keyword("not"),
450        ws1,
451        keyword("exists"),
452        ws1,
453    )))(input)?;
454
455    // Type name (qualified or unqualified)
456    let (input, (keyspace, type_name)) = qualified_table_name(input)?;
457
458    let (input, _) = ws(input)?;
459    let (input, _) = char('(')(input)?;
460    let (input, _) = ws(input)?;
461
462    // Parse field definitions
463    let (input, fields) = separated_list1(
464        tuple((ws, char(','), ws)),
465        map(
466            tuple((identifier, ws1, cql_type)),
467            |(name, _, field_type)| (name, field_type),
468        ),
469    )(input)?;
470
471    let (input, _) = ws(input)?;
472    let (input, _) = char(')')(input)?;
473
474    Ok((input, (type_name, keyspace, fields)))
475}
476
477/// Parse a complete CREATE TABLE statement
478pub fn parse_create_table(input: &str) -> IResult<&str, TableSchema> {
479    let (input, _) = ws(input)?;
480    let (input, _) = keyword("create")(input)?;
481    let (input, _) = ws1(input)?;
482    let (input, _) = keyword("table")(input)?;
483    let (input, _) = ws1(input)?;
484
485    // Optional IF NOT EXISTS
486    let (input, _) = opt(tuple((
487        keyword("if"),
488        ws1,
489        keyword("not"),
490        ws1,
491        keyword("exists"),
492        ws1,
493    )))(input)?;
494
495    // Table name (qualified or unqualified)
496    let (input, (keyspace, table_name)) = qualified_table_name(input)?;
497
498    let (input, _) = ws(input)?;
499    let (input, _) = char('(')(input)?;
500    let (input, _) = ws(input)?;
501
502    // Parse column definitions and constraints
503    // Columns are stored as (name, data_type, is_static)
504    let mut columns: Vec<(String, String, bool)> = Vec::new();
505    let mut partition_keys = Vec::new();
506    let mut clustering_keys = Vec::new();
507    let mut primary_key_found = false;
508
509    let (input, items) = separated_list1(
510        tuple((ws, char(','), ws)),
511        alt((
512            // Primary key constraint - returns 3-tuple with is_static=false (unused)
513            map(primary_key_spec, |keys| {
514                (
515                    "PRIMARY_KEY".to_string(),
516                    serde_json::to_string(&keys).unwrap_or_default(),
517                    false, // is_static not applicable for PRIMARY KEY constraint
518                )
519            }),
520            // Column definition - returns (name, data_type, is_static)
521            column_definition,
522        )),
523    )(input)?;
524
525    // Process parsed items
526    for (name, value, is_static) in items {
527        if name == "PRIMARY_KEY" {
528            // Parse the JSON-encoded key specification
529            if let Ok(keys_tuple) = serde_json::from_str::<(Vec<String>, Vec<String>)>(&value) {
530                partition_keys = keys_tuple.0;
531                clustering_keys = keys_tuple.1;
532                primary_key_found = true;
533            }
534            continue;
535        }
536        columns.push((name, value, is_static));
537    }
538
539    let (input, _) = ws(input)?;
540    let (input, _) = char(')')(input)?;
541
542    // Parse optional WITH clause
543    let (input, _options) = opt(table_options)(input)?;
544
545    // If no primary key was found in constraints, look for inline PRIMARY KEY or use first column
546    if !primary_key_found && !columns.is_empty() {
547        // Check if any column has "PRIMARY KEY" in its type (inline definition)
548        let mut found_inline = false;
549        for (col_name, col_type, _is_static) in &columns {
550            if col_type.to_lowercase().contains("primary key") {
551                partition_keys.push(col_name.clone());
552                found_inline = true;
553                break;
554            }
555        }
556
557        // If still no primary key found, assume first column is partition key
558        if !found_inline {
559            partition_keys.push(columns[0].0.clone());
560        }
561    }
562
563    // Build schema
564    let schema = TableSchema {
565        keyspace: keyspace.unwrap_or_else(|| "default".to_string()),
566        table: table_name,
567        partition_keys: partition_keys
568            .into_iter()
569            .enumerate()
570            .map(|(pos, name)| {
571                let data_type = columns
572                    .iter()
573                    .find(|(col_name, _, _)| col_name == &name)
574                    .map(|(_, dt, _)| dt.clone())
575                    .unwrap_or_else(|| "text".to_string());
576
577                KeyColumn {
578                    name,
579                    data_type,
580                    position: pos,
581                }
582            })
583            .collect(),
584        clustering_keys: clustering_keys
585            .into_iter()
586            .enumerate()
587            .map(|(pos, name)| {
588                let data_type = columns
589                    .iter()
590                    .find(|(col_name, _, _)| col_name == &name)
591                    .map(|(_, dt, _)| dt.clone())
592                    .unwrap_or_else(|| "text".to_string());
593
594                ClusteringColumn {
595                    name,
596                    data_type,
597                    position: pos,
598                    order: crate::schema::ClusteringOrder::Asc,
599                }
600            })
601            .collect(),
602        columns: columns
603            .into_iter()
604            .map(|(name, data_type_with_constraints, is_static)| {
605                // Remove PRIMARY KEY constraint from data type
606                let data_type = if data_type_with_constraints
607                    .to_lowercase()
608                    .contains("primary key")
609                {
610                    data_type_with_constraints
611                        .to_lowercase()
612                        .replace("primary key", "")
613                        .trim()
614                        .to_string()
615                } else {
616                    data_type_with_constraints
617                };
618
619                Column {
620                    name,
621                    data_type,
622                    nullable: true,
623                    default: None,
624                    is_static,
625                }
626            })
627            .collect(),
628        comments: HashMap::new(),
629    };
630
631    Ok((input, schema))
632}
633
634/// Convert CQL type string to internal CqlTypeId
635pub fn cql_type_to_type_id(cql_type: &str) -> Result<CqlTypeId> {
636    let type_lower = cql_type.trim().to_lowercase();
637
638    // Handle collection types
639    if type_lower.starts_with("list<") {
640        return Ok(CqlTypeId::List);
641    }
642    if type_lower.starts_with("set<") {
643        return Ok(CqlTypeId::Set);
644    }
645    if type_lower.starts_with("map<") {
646        return Ok(CqlTypeId::Map);
647    }
648    if type_lower.starts_with("tuple<") {
649        return Ok(CqlTypeId::Tuple);
650    }
651    if type_lower.starts_with("frozen<") {
652        // Extract inner type from frozen<type>
653        if let Some(inner_start) = type_lower.find('<') {
654            if let Some(inner_end) = type_lower.rfind('>') {
655                let inner_type = &type_lower[inner_start + 1..inner_end];
656                return cql_type_to_type_id(inner_type);
657            }
658        }
659    }
660
661    // Handle primitive types
662    match type_lower.as_str() {
663        "ascii" => Ok(CqlTypeId::Ascii),
664        "bigint" | "long" => Ok(CqlTypeId::BigInt),
665        "blob" => Ok(CqlTypeId::Blob),
666        "boolean" | "bool" => Ok(CqlTypeId::Boolean),
667        "counter" => Ok(CqlTypeId::Counter),
668        "decimal" => Ok(CqlTypeId::Decimal),
669        "double" => Ok(CqlTypeId::Double),
670        "float" => Ok(CqlTypeId::Float),
671        "int" | "integer" => Ok(CqlTypeId::Int),
672        "timestamp" => Ok(CqlTypeId::Timestamp),
673        "uuid" => Ok(CqlTypeId::Uuid),
674        "varchar" | "text" => Ok(CqlTypeId::Varchar),
675        "varint" => Ok(CqlTypeId::Varint),
676        "timeuuid" => Ok(CqlTypeId::Timeuuid),
677        "inet" => Ok(CqlTypeId::Inet),
678        "date" => Ok(CqlTypeId::Date),
679        "time" => Ok(CqlTypeId::Time),
680        "smallint" => Ok(CqlTypeId::Smallint),
681        "tinyint" => Ok(CqlTypeId::Tinyint),
682        "duration" => Ok(CqlTypeId::Duration),
683        _ => {
684            // Assume it's a UDT if not a known primitive type
685            Ok(CqlTypeId::Udt)
686        }
687    }
688}
689
690/// Extract table name from CQL CREATE TABLE statement
691pub fn extract_table_name(cql: &str) -> Result<(Option<String>, String)> {
692    match parse_create_table(cql) {
693        Ok((_, schema)) => {
694            let keyspace = if schema.keyspace == "default" {
695                None
696            } else {
697                Some(schema.keyspace)
698            };
699            Ok((keyspace, schema.table))
700        }
701        Err(_) => {
702            // Fallback: simple regex-like extraction
703            let cql_lower = cql.to_lowercase();
704            if let Some(table_start) = cql_lower.find("create table") {
705                let after_table = &cql[table_start + 12..];
706                if let Some(if_not_exists) = after_table.find("if not exists") {
707                    let after_if = &after_table[if_not_exists + 13..];
708                    return extract_simple_table_name(after_if);
709                }
710                return extract_simple_table_name(after_table);
711            }
712
713            Err(Error::schema(
714                "Failed to extract table name from CQL".to_string(),
715            ))
716        }
717    }
718}
719
720/// Simple table name extraction fallback
721fn extract_simple_table_name(input: &str) -> Result<(Option<String>, String)> {
722    let trimmed = input.trim();
723    let words: Vec<&str> = trimmed.split_whitespace().collect();
724
725    if words.is_empty() {
726        return Err(Error::schema("No table name found".to_string()));
727    }
728
729    let table_name = words[0];
730
731    // Handle qualified names
732    if let Some(dot_pos) = table_name.find('.') {
733        let keyspace = &table_name[..dot_pos];
734        let table = &table_name[dot_pos + 1..];
735        Ok((Some(keyspace.to_string()), table.to_string()))
736    } else {
737        Ok((None, table_name.to_string()))
738    }
739}
740
741/// Check if a table name matches the given pattern
742pub fn table_name_matches(
743    schema_keyspace: &Option<String>,
744    schema_table: &str,
745    target_keyspace: &Option<String>,
746    target_table: &str,
747) -> bool {
748    // Table name must match exactly
749    if schema_table != target_table {
750        return false;
751    }
752
753    // If target has no keyspace, match any keyspace
754    if target_keyspace.is_none() {
755        return true;
756    }
757
758    // If both have keyspaces, they must match
759    schema_keyspace == target_keyspace
760}
761
762/// Parse CQL schema and extract metadata for SSTable reading
763pub fn parse_cql_schema(cql: &str) -> Result<TableSchema> {
764    match parse_create_table(cql) {
765        Ok((_, schema)) => {
766            // Validate the parsed schema
767            schema.validate()?;
768            Ok(schema)
769        }
770        Err(nom::Err::Error(e) | nom::Err::Failure(e)) => Err(Error::schema(format!(
771            "Failed to parse CQL schema: {:?}",
772            e
773        ))),
774        Err(nom::Err::Incomplete(_)) => Err(Error::schema("Incomplete CQL schema".to_string())),
775    }
776}
777
778/// Parse CQL schema using the visitor pattern (preferred method for new code)
779///
780/// This function demonstrates how to use the visitor pattern for AST-based parsing.
781/// It provides better error handling, validation, and is more maintainable than
782/// the legacy nom-based parser.
783pub fn parse_cql_schema_with_visitor(cql: &str) -> Result<TableSchema> {
784    // Note: This is a demonstration function. In a complete implementation,
785    // you would first parse the CQL into an AST using a parser (nom or ANTLR),
786    // then use the visitor pattern to convert it to TableSchema.
787    //
788    // For now, this uses the existing nom parser for demonstration purposes.
789
790    use crate::cql::traits::CqlVisitor;
791    use crate::cql::visitor::SchemaBuilderVisitor;
792    use crate::cql::CqlStatement;
793
794    // Parse using the existing nom parser to get the TableSchema
795    let schema = parse_cql_schema(cql)?;
796
797    // Demonstrate the visitor pattern by reconstructing the AST and then using the visitor
798    // (In real usage, you would have the AST from a parser)
799    let ast = table_schema_to_ast(&schema)?;
800    let statement = CqlStatement::CreateTable(ast);
801
802    // Use the visitor to convert AST back to TableSchema
803    let mut visitor = SchemaBuilderVisitor;
804    visitor.visit_statement(&statement)
805}
806
807/// Helper function to convert TableSchema to AST for demonstration
808/// (In real usage, the AST would come directly from a parser)
809fn table_schema_to_ast(schema: &TableSchema) -> Result<CqlCreateTable> {
810    use crate::cql::{
811        CqlColumnDef, CqlCreateTable, CqlIdentifier, CqlPrimaryKey, CqlTable, CqlTableOptions,
812    };
813
814    // Convert table reference
815    let table = if schema.keyspace == "default" {
816        CqlTable::new(&schema.table)
817    } else {
818        CqlTable::with_keyspace(&schema.keyspace, &schema.table)
819    };
820
821    // Convert columns
822    let columns: Result<Vec<CqlColumnDef>> = schema
823        .columns
824        .iter()
825        .map(|col| {
826            Ok(CqlColumnDef {
827                name: CqlIdentifier::new(&col.name),
828                data_type: string_to_cql_data_type(&col.data_type)?,
829                is_static: col.is_static,
830            })
831        })
832        .collect();
833
834    let columns = columns?;
835
836    // Convert primary key
837    let partition_key: Vec<CqlIdentifier> = schema
838        .partition_keys
839        .iter()
840        .map(|pk| CqlIdentifier::new(&pk.name))
841        .collect();
842
843    let clustering_key: Vec<CqlIdentifier> = schema
844        .clustering_keys
845        .iter()
846        .map(|ck| CqlIdentifier::new(&ck.name))
847        .collect();
848
849    Ok(CqlCreateTable {
850        if_not_exists: false,
851        table,
852        columns,
853        primary_key: CqlPrimaryKey {
854            partition_key,
855            clustering_key,
856        },
857        options: CqlTableOptions {
858            options: HashMap::new(),
859        },
860    })
861}
862
863/// Convert string type to CqlDataType (simplified version)
864fn string_to_cql_data_type(type_str: &str) -> Result<CqlDataType> {
865    use crate::cql::{CqlDataType, CqlIdentifier};
866
867    let type_lower = type_str.trim().to_lowercase();
868
869    // Handle collection types
870    if type_lower.starts_with("list<") && type_lower.ends_with('>') {
871        let inner_type_str = &type_lower[5..type_lower.len() - 1];
872        let inner_type = string_to_cql_data_type(inner_type_str)?;
873        return Ok(CqlDataType::List(Box::new(inner_type)));
874    }
875
876    if type_lower.starts_with("set<") && type_lower.ends_with('>') {
877        let inner_type_str = &type_lower[4..type_lower.len() - 1];
878        let inner_type = string_to_cql_data_type(inner_type_str)?;
879        return Ok(CqlDataType::Set(Box::new(inner_type)));
880    }
881
882    if type_lower.starts_with("map<") && type_lower.ends_with('>') {
883        let inner = &type_lower[4..type_lower.len() - 1];
884        if let Some(comma_pos) = inner.find(',') {
885            let key_type_str = inner[..comma_pos].trim();
886            let value_type_str = inner[comma_pos + 1..].trim();
887            let key_type = string_to_cql_data_type(key_type_str)?;
888            let value_type = string_to_cql_data_type(value_type_str)?;
889            return Ok(CqlDataType::Map(Box::new(key_type), Box::new(value_type)));
890        }
891    }
892
893    if type_lower.starts_with("frozen<") && type_lower.ends_with('>') {
894        let inner_type_str = &type_lower[7..type_lower.len() - 1];
895        let inner_type = string_to_cql_data_type(inner_type_str)?;
896        return Ok(CqlDataType::Frozen(Box::new(inner_type)));
897    }
898
899    // Handle primitive types
900    match type_lower.as_str() {
901        "boolean" | "bool" => Ok(CqlDataType::Boolean),
902        "tinyint" => Ok(CqlDataType::TinyInt),
903        "smallint" => Ok(CqlDataType::SmallInt),
904        "int" => Ok(CqlDataType::Int),
905        "bigint" | "long" => Ok(CqlDataType::BigInt),
906        "varint" => Ok(CqlDataType::Varint),
907        "decimal" => Ok(CqlDataType::Decimal),
908        "float" => Ok(CqlDataType::Float),
909        "double" => Ok(CqlDataType::Double),
910        "text" | "varchar" => Ok(CqlDataType::Text),
911        "ascii" => Ok(CqlDataType::Ascii),
912        "blob" => Ok(CqlDataType::Blob),
913        "timestamp" => Ok(CqlDataType::Timestamp),
914        "date" => Ok(CqlDataType::Date),
915        "time" => Ok(CqlDataType::Time),
916        "uuid" => Ok(CqlDataType::Uuid),
917        "timeuuid" => Ok(CqlDataType::TimeUuid),
918        "inet" => Ok(CqlDataType::Inet),
919        "duration" => Ok(CqlDataType::Duration),
920        "counter" => Ok(CqlDataType::Counter),
921        _ => {
922            // Assume it's a UDT
923            Ok(CqlDataType::Udt(CqlIdentifier::new(type_str)))
924        }
925    }
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931
932    #[test]
933    fn test_simple_table_parsing() {
934        let cql = r#"
935            CREATE TABLE users (
936                id uuid PRIMARY KEY,
937                name text,
938                email text
939            )
940        "#;
941
942        let schema = parse_cql_schema(cql).unwrap();
943        assert_eq!(schema.table, "users");
944        assert_eq!(schema.columns.len(), 3);
945        assert_eq!(schema.partition_keys.len(), 1);
946        assert_eq!(schema.partition_keys[0].name, "id");
947    }
948
949    #[test]
950    fn test_qualified_table_name() {
951        let cql = r#"
952            CREATE TABLE myapp.users (
953                id bigint PRIMARY KEY,
954                name text
955            )
956        "#;
957
958        let schema = parse_cql_schema(cql).unwrap();
959        assert_eq!(schema.keyspace, "myapp");
960        assert_eq!(schema.table, "users");
961    }
962
963    #[test]
964    fn test_complex_types() {
965        let cql = r#"
966            CREATE TABLE complex_table (
967                id uuid PRIMARY KEY,
968                tags set<text>,
969                metadata map<text, text>,
970                coordinates list<double>
971            )
972        "#;
973
974        let schema = parse_cql_schema(cql).unwrap();
975        assert_eq!(schema.columns.len(), 4);
976
977        let tags_col = schema.columns.iter().find(|c| c.name == "tags").unwrap();
978        assert_eq!(tags_col.data_type, "set<text>");
979
980        let metadata_col = schema
981            .columns
982            .iter()
983            .find(|c| c.name == "metadata")
984            .unwrap();
985        assert_eq!(metadata_col.data_type, "map<text, text>");
986    }
987
988    #[test]
989    fn test_table_name_extraction() {
990        let cql = "CREATE TABLE IF NOT EXISTS myapp.users (id uuid PRIMARY KEY)";
991        let (keyspace, table) = extract_table_name(cql).unwrap();
992        assert_eq!(keyspace, Some("myapp".to_string()));
993        assert_eq!(table, "users");
994    }
995
996    #[test]
997    fn test_cql_type_conversion() {
998        assert_eq!(cql_type_to_type_id("text").unwrap(), CqlTypeId::Varchar);
999        assert_eq!(cql_type_to_type_id("bigint").unwrap(), CqlTypeId::BigInt);
1000        assert_eq!(cql_type_to_type_id("list<text>").unwrap(), CqlTypeId::List);
1001        assert_eq!(
1002            cql_type_to_type_id("frozen<set<uuid>>").unwrap(),
1003            CqlTypeId::Set
1004        );
1005    }
1006
1007    #[test]
1008    fn test_table_name_matching() {
1009        // Exact match
1010        assert!(table_name_matches(
1011            &Some("ks".to_string()),
1012            "users",
1013            &Some("ks".to_string()),
1014            "users"
1015        ));
1016
1017        // Match with wildcard keyspace
1018        assert!(table_name_matches(
1019            &Some("ks".to_string()),
1020            "users",
1021            &None,
1022            "users"
1023        ));
1024
1025        // No match - different table
1026        assert!(!table_name_matches(
1027            &Some("ks".to_string()),
1028            "users",
1029            &Some("ks".to_string()),
1030            "orders"
1031        ));
1032
1033        // No match - different keyspace
1034        assert!(!table_name_matches(
1035            &Some("ks1".to_string()),
1036            "users",
1037            &Some("ks2".to_string()),
1038            "users"
1039        ));
1040    }
1041
1042    #[test]
1043    fn test_composite_primary_key() {
1044        let cql = r#"
1045            CREATE TABLE time_series (
1046                partition_key text,
1047                clustering_key timestamp,
1048                value double,
1049                PRIMARY KEY (partition_key, clustering_key)
1050            )
1051        "#;
1052
1053        let schema = parse_cql_schema(cql).unwrap();
1054        assert_eq!(schema.partition_keys.len(), 1);
1055        assert_eq!(schema.clustering_keys.len(), 1);
1056
1057        assert_eq!(schema.partition_keys[0].name, "partition_key");
1058        assert_eq!(schema.clustering_keys[0].name, "clustering_key");
1059    }
1060
1061    #[test]
1062    fn test_frozen_collections() {
1063        let cql = r#"
1064            CREATE TABLE frozen_test (
1065                id uuid PRIMARY KEY,
1066                frozen_set frozen<set<text>>,
1067                frozen_map frozen<map<text, bigint>>,
1068                nested_frozen frozen<list<frozen<set<uuid>>>>
1069            )
1070        "#;
1071
1072        let schema = parse_cql_schema(cql).unwrap();
1073
1074        let frozen_set = schema
1075            .columns
1076            .iter()
1077            .find(|c| c.name == "frozen_set")
1078            .unwrap();
1079        assert_eq!(frozen_set.data_type, "frozen<set<text>>");
1080
1081        let frozen_map = schema
1082            .columns
1083            .iter()
1084            .find(|c| c.name == "frozen_map")
1085            .unwrap();
1086        assert_eq!(frozen_map.data_type, "frozen<map<text, bigint>>");
1087
1088        let nested = schema
1089            .columns
1090            .iter()
1091            .find(|c| c.name == "nested_frozen")
1092            .unwrap();
1093        assert_eq!(nested.data_type, "frozen<list<frozen<set<uuid>>>>");
1094    }
1095
1096    #[test]
1097    fn test_udt_columns() {
1098        let cql = r#"
1099            CREATE TABLE user_profiles (
1100                user_id uuid PRIMARY KEY,
1101                address address_type,
1102                preferences frozen<user_prefs>
1103            )
1104        "#;
1105
1106        let schema = parse_cql_schema(cql).unwrap();
1107
1108        let address_col = schema.columns.iter().find(|c| c.name == "address").unwrap();
1109        assert_eq!(address_col.data_type, "address_type");
1110
1111        let prefs_col = schema
1112            .columns
1113            .iter()
1114            .find(|c| c.name == "preferences")
1115            .unwrap();
1116        assert_eq!(prefs_col.data_type, "frozen<user_prefs>");
1117    }
1118
1119    #[test]
1120    fn test_tuple_types() {
1121        let cql = r#"
1122            CREATE TABLE tuple_test (
1123                id uuid PRIMARY KEY,
1124                coordinates tuple<double, double>,
1125                person_info tuple<text, int, boolean>
1126            )
1127        "#;
1128
1129        let schema = parse_cql_schema(cql).unwrap();
1130
1131        let coords = schema
1132            .columns
1133            .iter()
1134            .find(|c| c.name == "coordinates")
1135            .unwrap();
1136        assert_eq!(coords.data_type, "tuple<double, double>");
1137
1138        let person = schema
1139            .columns
1140            .iter()
1141            .find(|c| c.name == "person_info")
1142            .unwrap();
1143        assert_eq!(person.data_type, "tuple<text, int, boolean>");
1144    }
1145
1146    #[test]
1147    fn test_case_insensitive_keywords() {
1148        let cql = r#"
1149            create table Users (
1150                ID UUID primary key,
1151                Name TEXT,
1152                Email VARCHAR
1153            )
1154        "#;
1155
1156        let schema = parse_cql_schema(cql).unwrap();
1157        assert_eq!(schema.table, "Users");
1158        assert_eq!(schema.columns.len(), 3);
1159    }
1160
1161    #[test]
1162    fn test_quoted_identifiers() {
1163        let cql = r#"
1164            CREATE TABLE "CaseSensitive" (
1165                "Id" uuid PRIMARY KEY,
1166                "Name With Spaces" text
1167            )
1168        "#;
1169
1170        let schema = parse_cql_schema(cql).unwrap();
1171        assert_eq!(schema.table, "CaseSensitive");
1172
1173        let space_col = schema.columns.iter().find(|c| c.name == "Name With Spaces");
1174        assert!(space_col.is_some());
1175    }
1176
1177    #[test]
1178    fn test_fallback_table_extraction() {
1179        // Test cases where full parsing might fail but we can still extract table name
1180        let cql = "CREATE TABLE myapp.orders (id bigint PRIMARY KEY)";
1181        let (keyspace, table) = extract_table_name(cql).unwrap();
1182        assert_eq!(keyspace, Some("myapp".to_string()));
1183        assert_eq!(table, "orders");
1184    }
1185
1186    #[test]
1187    fn test_all_primitive_types() {
1188        let type_mappings = vec![
1189            ("ascii", CqlTypeId::Ascii),
1190            ("bigint", CqlTypeId::BigInt),
1191            ("blob", CqlTypeId::Blob),
1192            ("boolean", CqlTypeId::Boolean),
1193            ("counter", CqlTypeId::Counter),
1194            ("decimal", CqlTypeId::Decimal),
1195            ("double", CqlTypeId::Double),
1196            ("float", CqlTypeId::Float),
1197            ("int", CqlTypeId::Int),
1198            ("timestamp", CqlTypeId::Timestamp),
1199            ("uuid", CqlTypeId::Uuid),
1200            ("varchar", CqlTypeId::Varchar),
1201            ("text", CqlTypeId::Varchar),
1202            ("varint", CqlTypeId::Varint),
1203            ("timeuuid", CqlTypeId::Timeuuid),
1204            ("inet", CqlTypeId::Inet),
1205            ("date", CqlTypeId::Date),
1206            ("time", CqlTypeId::Time),
1207            ("smallint", CqlTypeId::Smallint),
1208            ("tinyint", CqlTypeId::Tinyint),
1209            ("duration", CqlTypeId::Duration),
1210        ];
1211
1212        for (cql_type, expected_id) in type_mappings {
1213            assert_eq!(
1214                cql_type_to_type_id(cql_type).unwrap(),
1215                expected_id,
1216                "Failed for type: {}",
1217                cql_type
1218            );
1219        }
1220    }
1221}