Skip to main content

cqlite_cli/commands/
schema.rs

1use crate::cli_types::SchemaCommands;
2use anyhow::{Context, Result};
3use cqlite_core::{
4    schema::{
5        parse_cql_schema, AggregatorConfig, ClusteringColumn, ClusteringOrder, Column, KeyColumn,
6        SchemaAggregator, TableSchema,
7    },
8    Database,
9};
10use serde_json;
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13
14#[cfg(feature = "state_machine")]
15pub async fn handle_schema_command(database: &Database, command: SchemaCommands) -> Result<()> {
16    match command {
17        SchemaCommands::List => list_tables(database).await,
18        SchemaCommands::Describe { table } => describe_table(database, &table).await,
19        SchemaCommands::Create { schema } => create_table_from_file(database, &schema).await,
20        SchemaCommands::Drop { table, force } => drop_table(database, &table, force).await,
21        SchemaCommands::Load { paths } => load_schemas(database, &paths).await,
22    }
23}
24
25#[cfg(not(feature = "state_machine"))]
26pub async fn handle_schema_command(_database: &Database, _command: SchemaCommands) -> Result<()> {
27    Err(anyhow::anyhow!(
28        "Schema commands requiring query execution are not available in M1.\n\
29         Build with --features state_machine or use SSTableReader directly.\n\
30         See CLAUDE.md for M1 API examples."
31    ))
32}
33
34#[cfg(feature = "state_machine")]
35#[allow(dead_code)]
36async fn list_tables(_database: &Database) -> Result<()> {
37    // TODO: Implement actual table listing from database
38    println!("Tables in database:");
39    println!("- users");
40    println!("- orders");
41    println!("- products");
42    println!("\nNote: Table listing not yet implemented");
43
44    Ok(())
45}
46
47#[cfg(feature = "state_machine")]
48#[allow(dead_code)]
49async fn describe_table(_database: &Database, table: &str) -> Result<()> {
50    // TODO: Implement actual table description from database schema
51    println!("Describing table '{}'", table);
52    println!("Columns:");
53    println!("- id: UUID (primary key)");
54    println!("- name: TEXT");
55    println!("- created_at: TIMESTAMP");
56    println!("\nNote: Table description not yet implemented");
57
58    Ok(())
59}
60
61#[cfg(feature = "state_machine")]
62async fn create_table_from_file(database: &Database, file: &Path) -> Result<()> {
63    println!("Creating table from DDL file: {}", file.display());
64
65    // Read the DDL file
66    let ddl_content = std::fs::read_to_string(file)
67        .with_context(|| format!("Failed to read DDL file: {}", file.display()))?;
68
69    // Execute the CREATE TABLE statement
70    match database.execute(&ddl_content).await {
71        Ok(result) => {
72            println!("Table created successfully");
73            if result.rows_affected > 0 {
74                println!("Rows affected: {}", result.rows_affected);
75            }
76        }
77        Err(e) => {
78            println!("Failed to create table: {}", e);
79            return Err(anyhow::anyhow!("Table creation failed: {}", e));
80        }
81    }
82
83    Ok(())
84}
85
86#[cfg(feature = "state_machine")]
87async fn drop_table(database: &Database, table: &str, force: bool) -> Result<()> {
88    if !force {
89        // Ask for confirmation
90        println!("Are you sure you want to drop table '{}'? [y/N]", table);
91        let mut input = String::new();
92        std::io::stdin().read_line(&mut input)?;
93        if !input.trim().to_lowercase().starts_with('y') {
94            println!("Table drop cancelled");
95            return Ok(());
96        }
97    } else {
98        println!("Force dropping table '{}'", table);
99    }
100
101    let drop_sql = format!("DROP TABLE {}", table);
102    match database.execute(&drop_sql).await {
103        Ok(result) => {
104            println!("Table '{}' dropped successfully", table);
105            if result.rows_affected > 0 {
106                println!("Rows affected: {}", result.rows_affected);
107            }
108        }
109        Err(e) => {
110            println!("Failed to drop table: {}", e);
111            return Err(anyhow::anyhow!("Table drop failed: {}", e));
112        }
113    }
114
115    Ok(())
116}
117
118#[cfg(feature = "state_machine")]
119async fn load_schemas(_database: &Database, paths: &[PathBuf]) -> Result<()> {
120    // Note: Database parameter is not directly used in this implementation.
121    // We create temporary registries for schema aggregation and loading.
122    // Future enhancement: Database should expose registry accessors for direct integration.
123    use cqlite_core::{
124        platform::Platform,
125        schema::{
126            registry::{SchemaRegistry, SchemaRegistryConfig},
127            UdtRegistry,
128        },
129        Config,
130    };
131    use std::sync::Arc;
132    use tokio::sync::RwLock;
133
134    println!("Loading schemas from {} paths...", paths.len());
135
136    // Create temporary registries for schema aggregation
137    let config = Config::default();
138    let platform = Arc::new(
139        Platform::new(&config)
140            .await
141            .context("Failed to initialize platform")?,
142    );
143
144    let registry_config = SchemaRegistryConfig::default();
145    let schema_registry = Arc::new(RwLock::new(
146        SchemaRegistry::new(registry_config, platform, config.clone())
147            .await
148            .context("Failed to create schema registry")?,
149    ));
150    let udt_registry = Arc::new(RwLock::new(UdtRegistry::new()));
151
152    // Create aggregator with config
153    let aggregator_config = AggregatorConfig {
154        graceful_degradation: true,
155        validate_udt_dependencies: true,
156    };
157
158    let mut aggregator = SchemaAggregator::new(
159        schema_registry.clone(),
160        udt_registry.clone(),
161        aggregator_config,
162    );
163
164    // Load schemas from all paths
165    let result = aggregator
166        .load_from_paths(paths)
167        .await
168        .context("Failed to load schemas")?;
169
170    // Report errors if any
171    if !result.errors.is_empty() {
172        eprintln!("\nErrors encountered during schema loading:");
173        for error in &result.errors {
174            if let Some(path) = &error.file_path {
175                eprintln!("  Error in file {}: {}", path.display(), error.message);
176            } else {
177                eprintln!("  Error: {}", error.message);
178            }
179        }
180        eprintln!(
181            "\nSchema loading failed with {} errors. Please fix the schemas and retry.",
182            result.errors.len()
183        );
184        // Exit with code 3 for schema validation errors per M2 spec
185        std::process::exit(3);
186    }
187
188    // Report warnings if any
189    if !result.warnings.is_empty() {
190        println!("\nWarnings:");
191        for warning in &result.warnings {
192            if let Some(path) = &warning.file_path {
193                println!("  Warning in {}: {}", path.display(), warning.message);
194            } else {
195                println!("  Warning: {}", warning.message);
196            }
197        }
198    }
199
200    // Print success message with counts
201    if result.schemas_loaded > 0 || result.udts_loaded > 0 {
202        println!(
203            "\nSuccessfully loaded {} schemas and {} UDTs",
204            result.schemas_loaded, result.udts_loaded
205        );
206    }
207
208    // Register loaded schemas with database using CREATE TABLE statements
209    // Note: This is a workaround until Database exposes direct registry access
210    let registry_read = schema_registry.read().await;
211    let registered_schemas = registry_read.list_schemas(None).await?;
212
213    if !registered_schemas.is_empty() {
214        println!("\nRegistered schemas:");
215        for schema in &registered_schemas {
216            println!(
217                "  {}.{} ({} columns)",
218                schema.keyspace,
219                schema.table,
220                schema.columns.len()
221            );
222        }
223    }
224
225    // Register UDTs with database
226    let udt_read = udt_registry.read().await;
227    let total_udts = udt_read.total_udts();
228    if total_udts > 0 {
229        println!("\nRegistered {} UDTs", total_udts);
230    }
231
232    println!("\nSchema loading completed successfully!");
233    Ok(())
234}
235
236#[allow(dead_code)]
237async fn validate_schema(file_path: &Path) -> Result<()> {
238    println!("Validating schema: {}", file_path.display());
239
240    // Detect file format based on extension
241    let extension = file_path
242        .extension()
243        .and_then(|ext| ext.to_str())
244        .unwrap_or("");
245
246    match extension.to_lowercase().as_str() {
247        "json" => validate_json_schema(file_path).await,
248        "cql" | "sql" => validate_cql_schema(file_path).await,
249        _ => {
250            // Try to auto-detect based on content
251            let content = std::fs::read_to_string(file_path)
252                .with_context(|| format!("Failed to read schema file: {}", file_path.display()))?;
253
254            if content.trim_start().starts_with('{') {
255                println!("šŸ“ Auto-detected JSON format");
256                validate_json_schema(file_path).await
257            } else if content.to_uppercase().contains("CREATE TABLE") {
258                println!("šŸ“ Auto-detected CQL DDL format");
259                validate_cql_schema(file_path).await
260            } else {
261                println!("āŒ Unable to determine file format. Supported formats:");
262                println!("  - .json files: JSON schema format");
263                println!("  - .cql/.sql files: CQL DDL format");
264                println!("\nExample JSON schema:");
265                println!(
266                    "{{\n  \"keyspace\": \"example\",\n  \"table\": \"users\",\n  \"partition_keys\": [{{\"name\": \"id\", \"type\": \"uuid\", \"position\": 0}}],\n  \"clustering_keys\": [],\n  \"columns\": [{{\"name\": \"id\", \"type\": \"uuid\", \"nullable\": false}}]\n}}"
267                );
268                println!("\nExample CQL DDL:");
269                println!(
270                    "CREATE TABLE example.users (\n  id uuid PRIMARY KEY,\n  name text,\n  email text\n);"
271                );
272                Err(anyhow::anyhow!("Unsupported file format"))
273            }
274        }
275    }
276}
277
278#[allow(dead_code)]
279async fn validate_json_schema(json_path: &Path) -> Result<()> {
280    // Read the JSON file
281    let schema_content = std::fs::read_to_string(json_path)
282        .with_context(|| format!("Failed to read JSON schema file: {}", json_path.display()))?;
283
284    // Try to parse it as a TableSchema
285    match serde_json::from_str::<TableSchema>(&schema_content) {
286        Ok(schema) => {
287            println!("āœ… JSON Schema validation successful!");
288            print_schema_details(&schema);
289        }
290        Err(e) => {
291            println!("āŒ JSON Schema validation failed!");
292            println!("Error: {}", e);
293
294            // Try to provide helpful error messages
295            if e.to_string().contains("missing field") {
296                println!("\nšŸ’” Hint: Make sure all required fields are present:");
297                println!("- keyspace (string)");
298                println!("- table (string)");
299                println!("- partition_keys (array)");
300                println!("- clustering_keys (array)");
301                println!("- columns (array)");
302            } else if e.to_string().contains("unknown variant") {
303                println!("\nšŸ’” Hint: Check that all data types are valid CQL types");
304                println!("Valid types: text, bigint, int, uuid, timestamp, etc.");
305            }
306
307            return Err(e.into());
308        }
309    }
310
311    Ok(())
312}
313
314#[allow(dead_code)]
315async fn validate_cql_schema(cql_path: &Path) -> Result<()> {
316    // Read the CQL file
317    let cql_content = std::fs::read_to_string(cql_path)
318        .with_context(|| format!("Failed to read CQL schema file: {}", cql_path.display()))?;
319
320    // Parse CQL DDL and convert to TableSchema
321    match parse_cql_schema(&cql_content) {
322        Ok(schema) => {
323            println!("āœ… CQL DDL validation successful!");
324            print_schema_details(&schema);
325        }
326        Err(e) => {
327            println!("āŒ CQL DDL validation failed!");
328            println!("Error: {}", e);
329            println!("\nšŸ’” Hints for CQL DDL:");
330            println!("- Use CREATE TABLE keyspace.table_name syntax");
331            println!("- Define PRIMARY KEY explicitly");
332            println!("- Use valid CQL data types");
333            println!("\nExample:");
334            println!("CREATE TABLE example.users (");
335            println!("  id uuid PRIMARY KEY,");
336            println!("  name text,");
337            println!("  created_at timestamp");
338            println!(");");
339            return Err(e.into());
340        }
341    }
342
343    Ok(())
344}
345
346#[allow(dead_code)]
347fn print_schema_details(schema: &TableSchema) {
348    println!("šŸ“‹ Table: {}.{}", schema.keyspace, schema.table);
349    println!("šŸ“Š Columns: {}", schema.columns.len());
350
351    // Show column details
352    for (i, column) in schema.columns.iter().enumerate() {
353        let nullable_str = if column.nullable {
354            "nullable"
355        } else {
356            "not null"
357        };
358        println!(
359            "  {}. {} ({}, {})",
360            i + 1,
361            column.name,
362            column.data_type,
363            nullable_str
364        );
365    }
366
367    if !schema.partition_keys.is_empty() {
368        let key_names: Vec<String> = schema
369            .partition_keys
370            .iter()
371            .map(|k| k.name.clone())
372            .collect();
373        println!("šŸ”‘ Partition keys: {}", key_names.join(", "));
374    }
375
376    if !schema.clustering_keys.is_empty() {
377        let clustering_names: Vec<String> = schema
378            .clustering_keys
379            .iter()
380            .map(|k| k.name.clone())
381            .collect();
382        println!("šŸ”— Clustering keys: {}", clustering_names.join(", "));
383    }
384}
385
386/// Parse CQL DDL and convert to TableSchema
387#[allow(dead_code)]
388fn parse_cql_ddl(cql_content: &str) -> Result<TableSchema> {
389    let cql_content = cql_content.trim().to_uppercase();
390
391    // Find CREATE TABLE statement
392    let create_table_start = cql_content
393        .find("CREATE TABLE")
394        .ok_or_else(|| anyhow::anyhow!("No CREATE TABLE statement found"))?;
395
396    let table_part = &cql_content[create_table_start + 12..].trim(); // Skip "CREATE TABLE"
397
398    // Find the opening parenthesis
399    let paren_start = table_part
400        .find('(')
401        .ok_or_else(|| anyhow::anyhow!("Missing opening parenthesis in CREATE TABLE"))?;
402
403    // Extract table name part
404    let table_name_part = &table_part[..paren_start].trim();
405
406    // Parse keyspace and table name
407    let (keyspace, table_name) = if let Some(dot_pos) = table_name_part.find('.') {
408        let keyspace = table_name_part[..dot_pos].trim().to_lowercase();
409        let table = table_name_part[dot_pos + 1..].trim().to_lowercase();
410        (keyspace, table)
411    } else {
412        ("default".to_string(), table_name_part.trim().to_lowercase())
413    };
414
415    // Find the matching closing parenthesis
416    let mut paren_depth = 0;
417    let mut column_end = paren_start;
418    let table_chars: Vec<char> = table_part.chars().collect();
419
420    for (i, &ch) in table_chars.iter().enumerate().skip(paren_start) {
421        match ch {
422            '(' => paren_depth += 1,
423            ')' => {
424                paren_depth -= 1;
425                if paren_depth == 0 {
426                    column_end = i;
427                    break;
428                }
429            }
430            _ => {}
431        }
432    }
433
434    if paren_depth != 0 {
435        return Err(anyhow::anyhow!("Unmatched parentheses in CREATE TABLE"));
436    }
437
438    // Extract column definitions (between parentheses)
439    let column_definitions = &table_part[paren_start + 1..column_end];
440
441    // Parse column definitions
442    let (columns, partition_keys, clustering_keys) = parse_column_definitions(column_definitions)?;
443
444    let schema = TableSchema {
445        keyspace,
446        table: table_name,
447        partition_keys,
448        clustering_keys,
449        columns,
450        comments: HashMap::new(),
451    };
452
453    // Validate the parsed schema
454    schema
455        .validate()
456        .with_context(|| "Generated schema validation failed")?;
457
458    Ok(schema)
459}
460
461/// Parse column definitions from CQL DDL
462#[allow(dead_code)]
463fn parse_column_definitions(
464    definitions: &str,
465) -> Result<(Vec<Column>, Vec<KeyColumn>, Vec<ClusteringColumn>)> {
466    let mut columns = Vec::new();
467    let mut partition_keys = Vec::new();
468    let mut clustering_keys = Vec::new();
469    let mut primary_key_found = false;
470
471    // Split by commas, but be careful with nested types like map<text, int>
472    let column_parts = split_column_definitions(definitions)?;
473
474    for part in column_parts {
475        let part = part.trim();
476
477        if part.to_uppercase().starts_with("PRIMARY KEY") {
478            // Parse PRIMARY KEY (col1, col2, ...)
479            parse_primary_key_constraint(
480                part,
481                &columns,
482                &mut partition_keys,
483                &mut clustering_keys,
484            )?;
485            primary_key_found = true;
486        } else {
487            // Parse column definition: name type [PRIMARY KEY]
488            let column_parts: Vec<&str> = part.split_whitespace().collect();
489            if column_parts.len() < 2 {
490                return Err(anyhow::anyhow!("Invalid column definition: {}", part));
491            }
492
493            let column_name = column_parts[0].to_string();
494            let column_type = column_parts[1].to_string();
495            let is_primary_key = part.to_uppercase().contains("PRIMARY KEY");
496
497            let column = Column {
498                name: column_name.clone(),
499                data_type: column_type.clone(),
500                nullable: !is_primary_key, // Primary key columns are not nullable
501                default: None,
502                is_static: false, // Quick schema creation doesn't support STATIC yet
503            };
504
505            columns.push(column);
506
507            // If this column is marked as PRIMARY KEY, add it as partition key
508            if is_primary_key && !primary_key_found {
509                partition_keys.push(KeyColumn {
510                    name: column_name,
511                    data_type: column_type,
512                    position: partition_keys.len(),
513                });
514            }
515        }
516    }
517
518    // If no PRIMARY KEY constraint was found and no inline PRIMARY KEY,
519    // assume first column is the primary key
520    if partition_keys.is_empty() && !columns.is_empty() {
521        let first_col = &columns[0];
522        partition_keys.push(KeyColumn {
523            name: first_col.name.clone(),
524            data_type: first_col.data_type.clone(),
525            position: 0,
526        });
527
528        // Update the first column to be non-nullable
529        if let Some(col) = columns.get_mut(0) {
530            col.nullable = false;
531        }
532    }
533
534    Ok((columns, partition_keys, clustering_keys))
535}
536
537/// Split column definitions while respecting nested types
538#[allow(dead_code)]
539fn split_column_definitions(definitions: &str) -> Result<Vec<String>> {
540    let mut parts = Vec::new();
541    let mut current_part = String::new();
542    let mut paren_depth = 0;
543    let mut angle_depth = 0;
544
545    for ch in definitions.chars() {
546        match ch {
547            '(' => paren_depth += 1,
548            ')' => paren_depth -= 1,
549            '<' => angle_depth += 1,
550            '>' => angle_depth -= 1,
551            ',' if paren_depth == 0 && angle_depth == 0 => {
552                if !current_part.trim().is_empty() {
553                    parts.push(current_part.trim().to_string());
554                }
555                current_part.clear();
556                continue;
557            }
558            _ => {}
559        }
560        current_part.push(ch);
561    }
562
563    if !current_part.trim().is_empty() {
564        parts.push(current_part.trim().to_string());
565    }
566
567    Ok(parts)
568}
569
570/// Parse PRIMARY KEY constraint like "PRIMARY KEY (id)" or "PRIMARY KEY ((user_id, tenant_id), created_at)"
571#[allow(dead_code)]
572fn parse_primary_key_constraint(
573    constraint: &str,
574    columns: &[Column],
575    partition_keys: &mut Vec<KeyColumn>,
576    clustering_keys: &mut Vec<ClusteringColumn>,
577) -> Result<()> {
578    // Find the opening parenthesis after PRIMARY KEY
579    let paren_start = constraint
580        .find('(')
581        .ok_or_else(|| anyhow::anyhow!("Missing opening parenthesis in PRIMARY KEY"))?;
582
583    // Find the matching closing parenthesis
584    let mut paren_depth = 0;
585    let mut paren_end = paren_start;
586    let constraint_chars: Vec<char> = constraint.chars().collect();
587
588    for (i, &ch) in constraint_chars.iter().enumerate().skip(paren_start) {
589        match ch {
590            '(' => paren_depth += 1,
591            ')' => {
592                paren_depth -= 1;
593                if paren_depth == 0 {
594                    paren_end = i;
595                    break;
596                }
597            }
598            _ => {}
599        }
600    }
601
602    if paren_depth != 0 {
603        return Err(anyhow::anyhow!("Unmatched parentheses in PRIMARY KEY"));
604    }
605
606    // Extract the key specification (inside parentheses)
607    let key_spec = &constraint[paren_start + 1..paren_end].trim();
608
609    // Check if it's a composite primary key with partition and clustering keys
610    // Format: ((partition_key1, partition_key2), clustering_key1, clustering_key2)
611    if key_spec.trim_start().starts_with('(') && key_spec.contains("),") {
612        // Parse composite key
613        parse_composite_primary_key(key_spec, columns, partition_keys, clustering_keys)
614    } else {
615        // Simple primary key - all columns are partition keys
616        let key_names: Vec<&str> = key_spec.split(',').map(|s| s.trim()).collect();
617
618        for (position, key_name) in key_names.iter().enumerate() {
619            let column = columns
620                .iter()
621                .find(|c| c.name == *key_name)
622                .ok_or_else(|| {
623                    anyhow::anyhow!(
624                        "Primary key column '{}' not found in column definitions",
625                        key_name
626                    )
627                })?;
628
629            partition_keys.push(KeyColumn {
630                name: column.name.clone(),
631                data_type: column.data_type.clone(),
632                position,
633            });
634        }
635
636        Ok(())
637    }
638}
639
640/// Parse composite primary key with explicit partition and clustering keys
641#[allow(dead_code)]
642fn parse_composite_primary_key(
643    key_spec: &str,
644    columns: &[Column],
645    partition_keys: &mut Vec<KeyColumn>,
646    clustering_keys: &mut Vec<ClusteringColumn>,
647) -> Result<()> {
648    // Find the end of the partition key specification
649    let mut paren_depth = 0;
650    let mut partition_end = 0;
651
652    for (i, ch) in key_spec.char_indices() {
653        match ch {
654            '(' => paren_depth += 1,
655            ')' => {
656                paren_depth -= 1;
657                if paren_depth == 0 {
658                    partition_end = i;
659                    break;
660                }
661            }
662            _ => {}
663        }
664    }
665
666    if partition_end == 0 {
667        return Err(anyhow::anyhow!("Invalid composite primary key format"));
668    }
669
670    // Extract partition keys (inside the first parentheses)
671    let partition_spec = &key_spec[1..partition_end]; // Skip the opening '('
672    let partition_names: Vec<&str> = partition_spec.split(',').map(|s| s.trim()).collect();
673
674    for (position, key_name) in partition_names.iter().enumerate() {
675        let column = columns
676            .iter()
677            .find(|c| c.name == *key_name)
678            .ok_or_else(|| anyhow::anyhow!("Partition key column '{}' not found", key_name))?;
679
680        partition_keys.push(KeyColumn {
681            name: column.name.clone(),
682            data_type: column.data_type.clone(),
683            position,
684        });
685    }
686
687    // Extract clustering keys (after the first parentheses)
688    let remaining = &key_spec[partition_end + 1..].trim();
689    if remaining.starts_with(',') {
690        let clustering_spec = &remaining[1..].trim(); // Skip the comma
691        let clustering_names: Vec<&str> = clustering_spec.split(',').map(|s| s.trim()).collect();
692
693        for (position, key_name) in clustering_names.iter().enumerate() {
694            if key_name.is_empty() {
695                continue;
696            }
697
698            let column = columns
699                .iter()
700                .find(|c| c.name == *key_name)
701                .ok_or_else(|| anyhow::anyhow!("Clustering key column '{}' not found", key_name))?;
702
703            clustering_keys.push(ClusteringColumn {
704                name: column.name.clone(),
705                data_type: column.data_type.clone(),
706                position,
707                order: ClusteringOrder::Asc, // Default to ASC
708            });
709        }
710    }
711
712    Ok(())
713}