use std::fs;
use std::io::{self, IsTerminal, Read};
use std::ops::Deref;
use std::path::PathBuf;
use clap::ValueEnum;
use syntaqlite::any::AnyDialect;
use syntaqlite::any::{AnyParser, ParseOutcome};
use syntaqlite::fmt::FormatError;
use syntaqlite::fmt::KeywordCase;
use syntaqlite::semantic::DiagnosticMessage;
use syntaqlite::semantic::Severity;
use syntaqlite::util::DiagnosticRenderer;
use syntaqlite::{
Catalog, Diagnostic, FormatConfig, Formatter, SemanticAnalyzer, ValidationConfig,
};
use super::{Cli, Command};
#[derive(Clone, Copy, ValueEnum)]
pub(crate) enum KeywordCasing {
Upper,
Lower,
}
#[derive(Clone, Copy, ValueEnum)]
pub(crate) enum HostLanguage {
Python,
Typescript,
}
fn expand_paths(patterns: &[String]) -> Result<Vec<PathBuf>, String> {
let mut out = Vec::new();
for pat in patterns {
let matches: Vec<_> = glob::glob(pat)
.map_err(|e| format!("bad glob pattern {pat:?}: {e}"))?
.collect();
if matches.is_empty() {
return Err(format!("no files matched: {pat}"));
}
for entry in matches {
let path = entry.map_err(|e| format!("glob error: {e}"))?;
if path.is_file() {
out.push(path);
}
}
}
Ok(out)
}
fn require_dialect(dialect: Option<AnyDialect>) -> Result<AnyDialect, String> {
dialect.ok_or_else(|| {
"this command requires a dialect; build with --features=builtin-sqlite or use --dialect"
.to_string()
})
}
pub(crate) fn dispatch(cli: Cli, dialect: Option<AnyDialect>) -> Result<(), String> {
let base = if let Some(path) = &cli.dialect_path {
Some(
AnyDialect::load(path, cli.dialect_name.as_deref()).unwrap_or_else(|e| {
eprintln!("error: {e}");
std::process::exit(1);
}),
)
} else {
dialect
};
let configured = match base {
Some(d) => Some(apply_version_cflags(
d,
cli.sqlite_version.as_ref(),
&cli.sqlite_cflag,
)?),
None => None,
};
dispatch_commands(cli.command, configured)
}
fn apply_version_cflags(
mut dialect: AnyDialect,
version: Option<&String>,
cflags: &[String],
) -> Result<AnyDialect, String> {
use syntaqlite::util::{SqliteFlags, SqliteVersion};
if let Some(v) = version {
let ver = SqliteVersion::parse_with_latest(v)
.map_err(|e| format!("invalid --sqlite-version: {e}"))?;
dialect = dialect.with_version(ver);
}
if !cflags.is_empty() {
let mut flags = SqliteFlags::default();
for name in cflags {
let flag = syntaqlite::util::SqliteFlag::from_name(name)
.ok_or_else(|| format!("unknown --sqlite-cflag: {name}"))?;
flags = flags.with(flag);
}
dialect = dialect.with_cflags(flags);
}
Ok(dialect)
}
fn dispatch_commands(command: Command, dialect: Option<AnyDialect>) -> Result<(), String> {
match command {
Command::Parse {
files,
expression,
output,
} => require_dialect(dialect)
.and_then(|d| cmd_parse(&d, &files, expression.as_deref(), output)),
Command::Validate {
files,
expression,
schema,
lang,
} => require_dialect(dialect)
.and_then(|d| cmd_validate(&d, &files, expression.as_deref(), &schema, lang)),
Command::Lsp => require_dialect(dialect).and_then(|d| {
syntaqlite::lsp::LspServer::run(d).map_err(|e| format!("LSP error: {e}"))
}),
Command::Fmt {
files,
expression,
line_width,
indent_width,
keyword_case,
in_place,
check,
semicolons,
} => {
let config = FormatConfig::default()
.with_line_width(line_width)
.with_indent_width(indent_width)
.with_keyword_case(match keyword_case {
KeywordCasing::Upper => KeywordCase::Upper,
KeywordCasing::Lower => KeywordCase::Lower,
})
.with_semicolons(semicolons);
require_dialect(dialect)
.and_then(|d| cmd_fmt(&d, &files, expression.as_deref(), &config, in_place, check))
}
Command::Version => {
println!("syntaqlite {}", env!("CARGO_PKG_VERSION"));
Ok(())
}
Command::Dialect(args) => crate::codegen::dispatch_dialect(&args),
Command::DialectTool(cmd) => crate::codegen::dispatch_tool(cmd),
}
}
fn read_stdin() -> Result<String, String> {
if io::stdin().is_terminal() {
eprintln!("reading from stdin; paste SQL then press Ctrl-D (or pass files as arguments)");
}
let mut buf = String::new();
io::stdin()
.read_to_string(&mut buf)
.map_err(|e| format!("reading stdin: {e}"))?;
Ok(buf)
}
fn process_files(
files: &[String],
expression: Option<&str>,
on_inline: impl FnOnce(&str, &str) -> Result<(), String>,
mut on_file: impl FnMut(&str, &PathBuf, bool) -> Result<(), String>,
) -> Result<(), String> {
if let Some(expr) = expression {
return on_inline(expr, "<expression>");
}
let paths = expand_paths(files)?;
if paths.is_empty() {
return on_inline(&read_stdin()?, "<stdin>");
}
let multi = paths.len() > 1;
for path in &paths {
let source = fs::read_to_string(path).map_err(|e| format!("{}: {e}", path.display()))?;
on_file(&source, path, multi)?;
}
Ok(())
}
fn cmd_parse(
dialect: &AnyDialect,
files: &[String],
expression: Option<&str>,
output: crate::ParseOutput,
) -> Result<(), String> {
let mut total_stmts = 0u64;
let mut total_errors = 0u64;
let mut json_nodes: Vec<serde_json::Value> = Vec::new();
let mut process_source = |source: &str, file: &str| {
let (s, e, nodes) = cmd_parse_source(dialect, source, file, output);
total_stmts += s;
total_errors += e;
json_nodes.extend(nodes);
};
if let Some(expr) = expression {
process_source(expr, "<expression>");
} else {
let paths = expand_paths(files)?;
if paths.is_empty() {
let source = read_stdin()?;
process_source(&source, "<stdin>");
} else {
let multi = paths.len() > 1;
for path in &paths {
let source =
fs::read_to_string(path).map_err(|e| format!("{}: {e}", path.display()))?;
let file = path.display().to_string();
if multi && matches!(output, crate::ParseOutput::Text) {
println!("==> {file} <==");
}
process_source(&source, &file);
}
}
}
match output {
crate::ParseOutput::Summary => {
println!("{total_stmts} statements parsed, {total_errors} errors");
}
crate::ParseOutput::Json => {
println!(
"{}",
serde_json::to_string_pretty(&json_nodes)
.map_err(|e| format!("JSON serialization failed: {e}"))?
);
}
crate::ParseOutput::Text => {}
}
if total_errors > 0 {
Err(format!("{total_errors} syntax error(s)"))
} else {
Ok(())
}
}
fn cmd_parse_source(
dialect: &AnyDialect,
source: &str,
file: &str,
output: crate::ParseOutput,
) -> (u64, u64, Vec<serde_json::Value>) {
let parser = AnyParser::new(dialect.deref().clone());
let mut session = parser.parse(source);
let mut ast_out = String::new();
let mut json_nodes: Vec<serde_json::Value> = Vec::new();
let mut error_diags: Vec<Diagnostic> = Vec::new();
let mut count = 0u64;
loop {
match session.next() {
ParseOutcome::Ok(stmt) => {
match output {
crate::ParseOutput::Text => {
if count > 0 {
ast_out.push_str("----\n");
}
stmt.dump(&mut ast_out, 0);
}
crate::ParseOutput::Json => {
let val = stmt
.erase()
.root_node()
.map_or(serde_json::Value::Null, |n| {
serde_json::to_value(n).unwrap_or(serde_json::Value::Null)
});
json_nodes.push(val);
}
crate::ParseOutput::Summary => {}
}
count += 1;
}
ParseOutcome::Err(err) => {
let start = err.offset().unwrap_or(0);
let end = start + err.length().unwrap_or(0);
error_diags.push(Diagnostic::new(
start,
end,
DiagnosticMessage::Other(err.message().to_string()),
Severity::Error,
None,
));
}
ParseOutcome::Done => break,
}
}
if matches!(output, crate::ParseOutput::Text) {
print!("{ast_out}");
}
let n_err = error_diags.len() as u64;
if !error_diags.is_empty() {
DiagnosticRenderer::new(source, file)
.render_diagnostics(&error_diags, &mut io::stderr())
.ok();
}
(count, n_err, json_nodes)
}
fn cmd_fmt(
dialect: &AnyDialect,
files: &[String],
expression: Option<&str>,
config: &FormatConfig,
in_place: bool,
check: bool,
) -> Result<(), String> {
let mut errors = Vec::new();
let mut unformatted = Vec::new();
process_files(
files,
expression,
|source, label| {
if in_place || check {
return Err(format!(
"--{} requires file arguments",
if check { "check" } else { "in-place" }
));
}
let out = format_source(dialect, source, config).map_err(|e| {
render_format_error(&e, source, label);
format!("{label}: {e}")
})?;
print!("{out}");
Ok(())
},
|source, path, multi| {
match format_source(dialect, source, config) {
Ok(out) => {
if check {
if out != source {
unformatted.push(path.display().to_string());
}
} else if in_place {
if out != source {
fs::write(path, &out)
.map_err(|e| format!("{}: {e}", path.display()))?;
eprintln!("formatted {}", path.display());
}
} else {
if multi {
println!("==> {} <==", path.display());
}
print!("{out}");
}
}
Err(e) => {
let label = path.display().to_string();
render_format_error(&e, source, &label);
errors.push(format!("{label}: {e}"));
}
}
Ok(())
},
)?;
if !errors.is_empty() {
return Err(errors.join("\n"));
}
if !unformatted.is_empty() {
for f in &unformatted {
eprintln!("would reformat {f}");
}
return Err(format!(
"{} file(s) would be reformatted",
unformatted.len()
));
}
Ok(())
}
fn render_format_error(e: &FormatError, source: &str, file: &str) {
let start = e.offset().unwrap_or(0);
let end = start + e.length().unwrap_or(0);
let diag = Diagnostic::new(
start,
end,
DiagnosticMessage::Other(e.message().to_owned()),
Severity::Error,
None,
);
DiagnosticRenderer::new(source, file)
.render_diagnostic(&diag, &mut io::stderr())
.ok();
}
fn format_source(
dialect: &AnyDialect,
source: &str,
config: &FormatConfig,
) -> Result<String, FormatError> {
Formatter::with_dialect_config(dialect.clone(), config).format(source)
}
fn build_schema_catalog(dialect: &AnyDialect, schema_files: &[String]) -> Result<Catalog, String> {
if schema_files.is_empty() {
return Ok(Catalog::new(dialect.clone()));
}
let paths = expand_paths(schema_files)?;
let mut sources = Vec::new();
let mut uris = Vec::new();
for path in &paths {
let source =
fs::read_to_string(path).map_err(|e| format!("schema {}: {e}", path.display()))?;
uris.push(format!("file://{}", path.display()));
sources.push(source);
}
let pairs: Vec<(&str, Option<&str>)> = sources
.iter()
.zip(uris.iter())
.map(|(s, u)| (s.as_str(), Some(u.as_str())))
.collect();
let (catalog, errors) = Catalog::from_ddl(dialect.clone(), &pairs);
for err in &errors {
eprintln!("warning: schema: {err}");
}
Ok(catalog)
}
fn cmd_validate(
dialect: &AnyDialect,
files: &[String],
expression: Option<&str>,
schema_files: &[String],
lang: Option<HostLanguage>,
) -> Result<(), String> {
let schema_catalog = build_schema_catalog(dialect, schema_files)?;
let config = ValidationConfig::default();
let mut any_errors = false;
let validate = |source: &str, file: &str| -> bool {
match lang {
Some(lang) => {
validate_embedded_source(dialect, source, file, &config, lang, schema_files)
}
None => validate_source(dialect, source, file, &config, &schema_catalog),
}
};
process_files(
files,
expression,
|source, label| {
if validate(source, label) {
std::process::exit(1);
}
Ok(())
},
|source, path, multi| {
let file = path.display().to_string();
if multi {
println!("==> {file} <==");
}
if validate(source, &file) {
any_errors = true;
}
Ok(())
},
)?;
if any_errors {
std::process::exit(1);
}
Ok(())
}
fn validate_source(
dialect: &AnyDialect,
source: &str,
file: &str,
config: &ValidationConfig,
schema_catalog: &Catalog,
) -> bool {
let mut analyzer = SemanticAnalyzer::with_dialect(dialect.clone());
let model = analyzer.analyze(source, schema_catalog, config);
DiagnosticRenderer::new(source, file)
.render_diagnostics(model.diagnostics(), &mut io::stderr())
.unwrap_or(false)
}
fn validate_embedded_source(
dialect: &AnyDialect,
source: &str,
file: &str,
config: &ValidationConfig,
lang: HostLanguage,
schema_files: &[String],
) -> bool {
let fragments = match lang {
HostLanguage::Python => syntaqlite::embedded::extract_python(source),
HostLanguage::Typescript => syntaqlite::embedded::extract_typescript(source),
};
if fragments.is_empty() {
eprintln!("no SQL fragments found in {file}");
return false;
}
let catalog = match build_schema_catalog(dialect, schema_files) {
Ok(c) => c,
Err(e) => {
eprintln!("error: {e}");
return true;
}
};
let diags = syntaqlite::embedded::EmbeddedAnalyzer::new(dialect.clone())
.with_catalog(catalog)
.with_config(*config)
.validate(&fragments);
DiagnosticRenderer::new(source, file)
.render_diagnostics(&diags, &mut io::stderr())
.unwrap_or(false)
}