use anyhow::{Context, Result, anyhow, bail};
use clap::Args;
use colored::Colorize;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use crate::client::{AthenaClient, backend::QueryResult};
use crate::provisioning::split_provision_statements;
const DEFAULT_SQL_ENV_FILES: [&str; 2] = [".env", ".env.local"];
const SQL_RESULT_ROW_PREVIEW_LIMIT: usize = 20;
const ATHENA_URL_ENV_KEYS: [&str; 4] = [
"ATHENA_URL",
"ATHENA_DB_URL",
"ATHENA_GATEWAY_URL",
"NEXT_PUBLIC_ATHENA_BASE_URL",
];
const ATHENA_KEY_ENV_KEYS: [&str; 3] =
["ATHENA_API_KEY", "ATHENA_GATEWAY_API_KEY", "ATHENA_KEY_12"];
const ATHENA_CLIENT_ENV_KEYS: [&str; 2] = ["ATHENA_CLIENT", "NEXT_PUBLIC_ATHENA_CLIENT"];
#[derive(Args, Debug, Clone)]
pub struct SqlArgs {
#[arg(value_name = "PATH")]
pub sql_file: PathBuf,
#[arg(long, value_name = "PATH")]
pub env_file: Option<PathBuf>,
#[arg(long, value_name = "URL")]
pub url: Option<String>,
#[arg(long, value_name = "KEY")]
pub key: Option<String>,
#[arg(long, value_name = "CLIENT")]
pub client: Option<String>,
}
#[derive(Debug, Clone)]
struct LoadedSqlEnv {
values: HashMap<String, String>,
source_label: String,
}
#[derive(Debug, Clone)]
struct ResolvedSqlCommand {
url: String,
key: String,
client: String,
}
impl ResolvedSqlCommand {
fn from_args(args: &SqlArgs, file_env: &HashMap<String, String>) -> Result<Self> {
let url = resolve_preferred_value(args.url.as_deref(), file_env, &ATHENA_URL_ENV_KEYS)
.context(format!(
"missing Athena URL; pass --url or set one of {} in the selected env file(s)",
join_env_keys(&ATHENA_URL_ENV_KEYS)
))?;
let key = resolve_preferred_value(args.key.as_deref(), file_env, &ATHENA_KEY_ENV_KEYS)
.context(format!(
"missing Athena API key; pass --key or set one of {} in the selected env file(s)",
join_env_keys(&ATHENA_KEY_ENV_KEYS)
))?;
let client =
resolve_preferred_value(args.client.as_deref(), file_env, &ATHENA_CLIENT_ENV_KEYS)
.context(format!(
"missing Athena client name; pass --client or set one of {} in the selected env file(s)",
join_env_keys(&ATHENA_CLIENT_ENV_KEYS)
))?;
Ok(Self { url, key, client })
}
}
pub async fn run_sql_command(args: SqlArgs) -> Result<()> {
let cwd = std::env::current_dir().context("resolving current working directory")?;
let loaded_env = load_sql_env_values(&cwd, args.env_file.as_deref())?;
let resolved = ResolvedSqlCommand::from_args(&args, &loaded_env.values)?;
let sql = fs::read_to_string(&args.sql_file)
.with_context(|| format!("reading SQL file {}", args.sql_file.display()))?;
let statements = split_provision_statements(&sql);
if statements.is_empty() {
bail!(
"SQL file '{}' did not contain any executable statements",
args.sql_file.display()
);
}
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!(
"Applying {} statement(s) from {}",
statements.len(),
args.sql_file.display()
)
.bright_white()
);
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Env source: {}", loaded_env.source_label).bright_white()
);
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Athena URL: {}", resolved.url).bright_white()
);
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Athena client: {}", resolved.client).bright_white()
);
let client = AthenaClient::new(&resolved.url, &resolved.key, &resolved.client)
.await
.map_err(|err| {
anyhow!(
"failed to build Athena client for '{}' / '{}': {}",
resolved.url,
resolved.client,
err
)
})?;
for (index, statement) in statements.iter().enumerate() {
let result = client.sql(statement).await.map_err(|err| {
anyhow!(
"statement {}/{} failed ({}): {}",
index + 1,
statements.len(),
statement_preview(statement),
err
)
})?;
print_statement_result(index + 1, statements.len(), statement, &result)?;
}
println!(
"{} {}",
"OK".bright_green().bold(),
format!(
"Applied {} statement(s) from {}",
statements.len(),
args.sql_file.display()
)
.bright_white()
);
Ok(())
}
fn load_sql_env_values(base_dir: &Path, env_file: Option<&Path>) -> Result<LoadedSqlEnv> {
let mut values = HashMap::new();
let mut loaded_paths = Vec::new();
if let Some(env_file) = env_file {
let path = resolve_env_path(base_dir, env_file);
merge_dotenv_file(&path, &mut values)?;
loaded_paths.push(path);
} else {
for file_name in DEFAULT_SQL_ENV_FILES {
let path = base_dir.join(file_name);
if !path.exists() {
continue;
}
merge_dotenv_file(&path, &mut values)?;
loaded_paths.push(path);
}
}
let source_label = if loaded_paths.is_empty() {
"process environment fallback only".to_string()
} else {
loaded_paths
.iter()
.map(|path| path.display().to_string())
.collect::<Vec<_>>()
.join(", ")
};
Ok(LoadedSqlEnv {
values,
source_label,
})
}
fn resolve_env_path(base_dir: &Path, env_file: &Path) -> PathBuf {
if env_file.is_absolute() {
env_file.to_path_buf()
} else {
base_dir.join(env_file)
}
}
fn merge_dotenv_file(path: &Path, values: &mut HashMap<String, String>) -> Result<()> {
for (key, value) in read_dotenv_file(path)? {
values.insert(key, value);
}
Ok(())
}
#[allow(deprecated)]
fn read_dotenv_file(path: &Path) -> Result<Vec<(String, String)>> {
let iter = dotenv::from_path_iter(path)
.with_context(|| format!("opening dotenv file {}", path.display()))?;
let mut entries = Vec::new();
for item in iter {
let (key, value) =
item.with_context(|| format!("parsing dotenv file {}", path.display()))?;
entries.push((key, value));
}
Ok(entries)
}
fn resolve_preferred_value(
explicit: Option<&str>,
file_env: &HashMap<String, String>,
keys: &[&str],
) -> Option<String> {
normalize_non_empty(explicit)
.or_else(|| {
keys.iter().find_map(|key| {
file_env
.get(*key)
.and_then(|value| normalize_non_empty(Some(value.as_str())))
})
})
.or_else(|| {
keys.iter()
.find_map(|key| std::env::var(key).ok())
.and_then(|value| normalize_non_empty(Some(value.as_str())))
})
}
fn normalize_non_empty(value: Option<&str>) -> Option<String> {
value
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
fn join_env_keys(keys: &[&str]) -> String {
keys.join(", ")
}
fn statement_preview(statement: &str) -> String {
let normalized = statement.split_whitespace().collect::<Vec<_>>().join(" ");
const PREVIEW_LIMIT: usize = 96;
if normalized.len() <= PREVIEW_LIMIT {
normalized
} else {
format!("{}...", &normalized[..PREVIEW_LIMIT])
}
}
fn print_statement_result(
index: usize,
total: usize,
statement: &str,
result: &QueryResult,
) -> Result<()> {
println!(
"{} {}",
"OK".bright_green().bold(),
format!(
"Statement {}/{} applied: {}",
index,
total,
statement_preview(statement)
)
.bright_white()
);
if let Some(status) = result.status.as_deref() {
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Status: {}", status).bright_white()
);
}
if let Some(message) = result.message.as_deref() {
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Message: {}", message).bright_white()
);
}
if let Some(count) = result.count {
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!("Count: {}", count).bright_white()
);
}
if !result.rows.is_empty() {
let preview_rows: Vec<Value> = result
.rows
.iter()
.take(SQL_RESULT_ROW_PREVIEW_LIMIT)
.cloned()
.collect();
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!(
"Returned {} row(s); showing {} row(s).",
result.rows.len(),
preview_rows.len()
)
.bright_white()
);
println!(
"{}",
serde_json::to_string_pretty(&preview_rows)
.context("serializing Athena SQL row preview")?
);
if result.rows.len() > SQL_RESULT_ROW_PREVIEW_LIMIT {
println!(
"{} {}",
"INFO".bright_cyan().bold(),
format!(
"Output truncated after {} row(s).",
SQL_RESULT_ROW_PREVIEW_LIMIT
)
.bright_white()
);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::TestEnvGuard;
use std::time::{SystemTime, UNIX_EPOCH};
struct TempDir {
path: PathBuf,
}
impl TempDir {
fn new(prefix: &str) -> Self {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("current time should be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!(
"athena_cli_sql_{}_{}_{}",
prefix,
std::process::id(),
unique
));
fs::create_dir_all(&path).expect("temp dir should be created");
Self { path }
}
fn write(&self, relative: &str, contents: &str) {
fs::write(self.path.join(relative), contents).expect("temp file should be written");
}
}
impl Drop for TempDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.path);
}
}
fn clear_sql_env_keys() -> TestEnvGuard {
TestEnvGuard::apply(&[
("ATHENA_URL", None),
("ATHENA_DB_URL", None),
("ATHENA_GATEWAY_URL", None),
("NEXT_PUBLIC_ATHENA_BASE_URL", None),
("ATHENA_API_KEY", None),
("ATHENA_GATEWAY_API_KEY", None),
("ATHENA_KEY_12", None),
("ATHENA_CLIENT", None),
("NEXT_PUBLIC_ATHENA_CLIENT", None),
])
}
#[test]
fn load_sql_env_values_merges_default_env_files_with_local_override() {
let _env = clear_sql_env_keys();
let temp = TempDir::new("defaults");
temp.write(
".env",
"ATHENA_DB_URL=https://base.example.com\nATHENA_KEY_12=base-key\nATHENA_CLIENT=base-client\n",
);
temp.write(".env.local", "ATHENA_CLIENT=local-client\n");
let loaded = load_sql_env_values(&temp.path, None).expect("default env files should load");
assert_eq!(
loaded.values.get("ATHENA_DB_URL").map(String::as_str),
Some("https://base.example.com")
);
assert_eq!(
loaded.values.get("ATHENA_KEY_12").map(String::as_str),
Some("base-key")
);
assert_eq!(
loaded.values.get("ATHENA_CLIENT").map(String::as_str),
Some("local-client")
);
assert!(loaded.source_label.contains(".env"));
assert!(loaded.source_label.contains(".env.local"));
}
#[test]
fn load_sql_env_values_uses_explicit_env_file_instead_of_defaults() {
let _env = clear_sql_env_keys();
let temp = TempDir::new("explicit");
temp.write(".env", "ATHENA_CLIENT=base-client\n");
temp.write(
"ops.env",
"ATHENA_CLIENT=ops-client\nATHENA_KEY_12=ops-key\n",
);
let loaded = load_sql_env_values(&temp.path, Some(Path::new("ops.env")))
.expect("explicit env file should load");
assert_eq!(
loaded.values.get("ATHENA_CLIENT").map(String::as_str),
Some("ops-client")
);
assert_eq!(
loaded.values.get("ATHENA_KEY_12").map(String::as_str),
Some("ops-key")
);
assert!(!loaded.source_label.contains(".env,"));
assert!(loaded.source_label.ends_with("ops.env"));
}
#[test]
fn resolved_sql_command_prefers_cli_then_file_then_process_env() {
let _env = TestEnvGuard::apply(&[
("ATHENA_DB_URL", Some("https://process.example.com")),
("ATHENA_KEY_12", Some("process-key")),
("ATHENA_CLIENT", Some("process-client")),
("ATHENA_URL", None),
("ATHENA_GATEWAY_URL", None),
("NEXT_PUBLIC_ATHENA_BASE_URL", None),
("ATHENA_API_KEY", None),
("ATHENA_GATEWAY_API_KEY", None),
("NEXT_PUBLIC_ATHENA_CLIENT", None),
]);
let args = SqlArgs {
sql_file: PathBuf::from("sql/provision.sql"),
env_file: None,
url: Some("https://cli.example.com".to_string()),
key: None,
client: None,
};
let file_env = HashMap::from([
("ATHENA_KEY_12".to_string(), "file-key".to_string()),
("ATHENA_CLIENT".to_string(), "file-client".to_string()),
]);
let resolved =
ResolvedSqlCommand::from_args(&args, &file_env).expect("settings should resolve");
assert_eq!(resolved.url, "https://cli.example.com");
assert_eq!(resolved.key, "file-key");
assert_eq!(resolved.client, "file-client");
}
#[test]
fn resolved_sql_command_requires_client_name() {
let _env = clear_sql_env_keys();
let args = SqlArgs {
sql_file: PathBuf::from("sql/provision.sql"),
env_file: None,
url: Some("https://gateway.example.com".to_string()),
key: Some("secret".to_string()),
client: None,
};
let err = ResolvedSqlCommand::from_args(&args, &HashMap::new())
.expect_err("client should be required");
assert!(err.to_string().contains("ATHENA_CLIENT"));
}
}