Skip to main content

openauth_cli/
app.rs

1use std::ffi::OsString;
2use std::fs;
3use std::io;
4use std::path::{Path, PathBuf};
5
6use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
7use clap_complete::Shell;
8use inquire::Confirm;
9use openauth_core::db::sql::SqlDialect;
10use serde::Serialize;
11
12use crate::config::{CliConfig, ConfigError};
13use crate::db::{self, DbCliError};
14use crate::diagnostics::{doctor, DiagnosticReport, Severity};
15use crate::plugins::{is_official_plugin, official_plugins, rust_snippet};
16use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
17use crate::secret::{assess_secret, generate_secret, SecretSeverity};
18use crate::workspace;
19
20#[derive(Debug, Parser)]
21#[command(name = "openauth", version, about = "Command-line tools for OpenAuth.")]
22pub struct Cli {
23    #[arg(long, global = true, default_value = ".")]
24    cwd: PathBuf,
25    #[command(subcommand)]
26    command: Commands,
27}
28
29#[derive(Debug, Subcommand)]
30enum Commands {
31    Init(InitArgs),
32    Doctor(DiagnosticArgs),
33    Info(InfoArgs),
34    Secret(SecretArgs),
35    Db(DbArgs),
36    Generate(GenerateArgs),
37    Migrate(MigrateArgs),
38    Schema(SchemaArgs),
39    Plugins(PluginsArgs),
40    Completions(CompletionsArgs),
41}
42
43#[derive(Debug, clap::Args)]
44struct InitArgs {
45    #[arg(long)]
46    framework: Option<String>,
47    #[arg(long)]
48    adapter: Option<String>,
49    #[arg(long)]
50    database: Option<String>,
51    #[arg(long)]
52    base_url: Option<String>,
53    #[arg(long, value_delimiter = ',')]
54    plugins: Vec<String>,
55    #[arg(short = 'y', long)]
56    yes: bool,
57    #[arg(long)]
58    force: bool,
59}
60
61#[derive(Debug, clap::Args)]
62struct DiagnosticArgs {
63    #[arg(long)]
64    production: bool,
65    #[arg(long)]
66    json: bool,
67    #[arg(long)]
68    strict: bool,
69}
70
71#[derive(Debug, clap::Args)]
72struct InfoArgs {
73    #[arg(long)]
74    json: bool,
75}
76
77#[derive(Debug, clap::Args)]
78struct SecretArgs {
79    #[arg(long, default_value_t = 32)]
80    bytes: usize,
81    #[arg(long)]
82    check: Option<String>,
83    #[arg(long)]
84    check_env: Option<String>,
85}
86
87#[derive(Debug, clap::Args)]
88struct DbArgs {
89    #[command(subcommand)]
90    command: DbCommands,
91}
92
93#[derive(Debug, Subcommand)]
94enum DbCommands {
95    Status(StatusArgs),
96    Generate(GenerateArgs),
97    Migrate(MigrateArgs),
98}
99
100#[derive(Debug, clap::Args)]
101struct StatusArgs {
102    #[arg(long)]
103    json: bool,
104    #[arg(long)]
105    check: bool,
106}
107
108#[derive(Debug, clap::Args)]
109struct GenerateArgs {
110    #[arg(long)]
111    output: Option<PathBuf>,
112    #[arg(long)]
113    from_empty: bool,
114    #[arg(long)]
115    force: bool,
116}
117
118#[derive(Debug, clap::Args)]
119struct MigrateArgs {
120    #[arg(long)]
121    dry_run: bool,
122    #[arg(short = 'y', long)]
123    yes: bool,
124}
125
126#[derive(Debug, clap::Args)]
127struct SchemaArgs {
128    #[command(subcommand)]
129    command: SchemaCommands,
130}
131
132#[derive(Debug, Subcommand)]
133enum SchemaCommands {
134    Print(SchemaPrintArgs),
135}
136
137#[derive(Debug, clap::Args)]
138struct SchemaPrintArgs {
139    #[arg(long, value_enum, default_value_t = SchemaFormat::Sql)]
140    format: SchemaFormat,
141    #[arg(long, default_value = "sqlite")]
142    dialect: String,
143}
144
145#[derive(Debug, Clone, Copy, ValueEnum)]
146enum SchemaFormat {
147    Sql,
148    Json,
149}
150
151#[derive(Debug, clap::Args)]
152struct PluginsArgs {
153    #[command(subcommand)]
154    command: PluginsCommands,
155}
156
157#[derive(Debug, Subcommand)]
158enum PluginsCommands {
159    List(PluginListArgs),
160    Add(PluginChangeArgs),
161    Remove(PluginChangeArgs),
162}
163
164#[derive(Debug, clap::Args)]
165struct PluginListArgs {
166    #[arg(long)]
167    json: bool,
168}
169
170#[derive(Debug, clap::Args)]
171struct PluginChangeArgs {
172    plugin: String,
173    #[arg(short = 'y', long)]
174    yes: bool,
175}
176
177#[derive(Debug, clap::Args)]
178struct CompletionsArgs {
179    shell: Shell,
180}
181
182pub fn run() -> i32 {
183    run_from(std::env::args_os())
184}
185
186pub fn run_cargo() -> i32 {
187    let mut args = std::env::args_os().collect::<Vec<_>>();
188    if args
189        .get(1)
190        .and_then(|arg| arg.to_str())
191        .is_some_and(is_cargo_subcommand_name)
192    {
193        args.remove(1);
194    }
195    run_from(args)
196}
197
198fn is_cargo_subcommand_name(value: &str) -> bool {
199    matches!(
200        value,
201        "openauth" | "open-auth" | "better-auth" | "betterauth"
202    )
203}
204
205pub fn run_from<I, T>(args: I) -> i32
206where
207    I: IntoIterator<Item = T>,
208    T: Into<OsString> + Clone,
209{
210    match Cli::try_parse_from(args) {
211        Ok(cli) => match execute(cli) {
212            Ok(()) => 0,
213            Err(error) => {
214                eprintln!("{error}");
215                1
216            }
217        },
218        Err(error) => {
219            let _ = error.print();
220            error.exit_code()
221        }
222    }
223}
224
225fn execute(cli: Cli) -> Result<(), AppError> {
226    let runtime = tokio::runtime::Runtime::new().map_err(AppError::Runtime)?;
227    runtime.block_on(async move { execute_async(cli).await })
228}
229
230async fn execute_async(cli: Cli) -> Result<(), AppError> {
231    let cwd = absolute_cwd(&cli.cwd)?;
232    match cli.command {
233        Commands::Init(args) => init(&cwd, args),
234        Commands::Doctor(args) => doctor_command(&cwd, args).await,
235        Commands::Info(args) => info_command(&cwd, args).await,
236        Commands::Secret(args) => secret_command(args),
237        Commands::Db(args) => match args.command {
238            DbCommands::Status(args) => db_status(&cwd, args).await,
239            DbCommands::Generate(args) => db_generate(&cwd, args).await,
240            DbCommands::Migrate(args) => db_migrate(&cwd, args).await,
241        },
242        Commands::Generate(args) => db_generate(&cwd, args).await,
243        Commands::Migrate(args) => db_migrate(&cwd, args).await,
244        Commands::Schema(args) => match args.command {
245            SchemaCommands::Print(args) => schema_print(&cwd, args),
246        },
247        Commands::Plugins(args) => match args.command {
248            PluginsCommands::List(args) => plugins_list(args),
249            PluginsCommands::Add(args) => plugin_add(&cwd, args).await,
250            PluginsCommands::Remove(args) => plugin_remove(&cwd, args),
251        },
252        Commands::Completions(args) => completions(args),
253    }
254}
255
256fn init(cwd: &Path, args: InitArgs) -> Result<(), AppError> {
257    let config_path = cwd.join("openauth.toml");
258    if config_path.exists() && !args.force {
259        return Err(AppError::Message(format!(
260            "{} already exists. Use --force to overwrite it.",
261            config_path.display()
262        )));
263    }
264
265    let detected = workspace::inspect(cwd).ok();
266    let framework = args
267        .framework
268        .or_else(|| {
269            detected
270                .as_ref()
271                .and_then(|info| info.detected_frameworks.first())
272                .map(|item| item.name.clone())
273        })
274        .unwrap_or_else(|| "axum".to_owned());
275    let database = args.database.or_else(detect_provider_from_env).or_else(|| {
276        detected.as_ref().and_then(|info| {
277            if info
278                .detected_databases
279                .iter()
280                .any(|item| item.name == "sqlx")
281            {
282                Some("sqlite".to_owned())
283            } else {
284                None
285            }
286        })
287    });
288
289    let config = CliConfig {
290        project: crate::config::ProjectConfig {
291            framework: Some(framework.clone()),
292            base_url: args
293                .base_url
294                .unwrap_or_else(|| "http://localhost:3000/api/auth".to_owned()),
295            ..crate::config::ProjectConfig::default()
296        },
297        database: crate::config::DatabaseConfig {
298            adapter: args.adapter.unwrap_or_else(|| "sqlx".to_owned()),
299            provider: database.or(Some("sqlite".to_owned())),
300            ..crate::config::DatabaseConfig::default()
301        },
302        plugins: crate::config::PluginsConfig {
303            enabled: normalize_plugins(args.plugins)?,
304        },
305        ..CliConfig::default()
306    };
307
308    if config_path.exists() && !confirm("Overwrite existing openauth.toml?", args.yes)? {
309        return Err(AppError::Message("Initialization aborted.".to_owned()));
310    }
311    config.write(&config_path)?;
312    update_env_example(cwd, &config)?;
313    println!("Created openauth.toml");
314    println!("Updated .env.example");
315    if framework == "axum" {
316        println!();
317        println!("Axum integration snippet:");
318        println!("let app = openauth_axum::router(auth)?;");
319    }
320    Ok(())
321}
322
323async fn doctor_command(cwd: &Path, args: DiagnosticArgs) -> Result<(), AppError> {
324    let config = load_config(cwd)?;
325    let report = doctor(cwd, &config, args.production).await;
326    if args.json {
327        print_json(&report)?;
328    } else {
329        print_report(&report);
330    }
331    if report.has_errors() || (args.strict && report.has_warnings()) {
332        return Err(AppError::ExitOnly);
333    }
334    Ok(())
335}
336
337async fn info_command(cwd: &Path, args: InfoArgs) -> Result<(), AppError> {
338    let config = load_config(cwd)?;
339    let report = doctor(cwd, &config, false).await;
340    if args.json {
341        print_json(&report)?;
342    } else {
343        println!("OpenAuth info");
344        println!("Rust: {}", report.rust);
345        println!("Cargo: {}", report.cargo);
346        if let Some(root) = report.workspace_root {
347            println!("Workspace: {root}");
348        }
349        println!(
350            "Framework: {}",
351            config.project.framework.unwrap_or_default()
352        );
353        println!("Adapter: {}", config.database.adapter);
354        println!(
355            "Database provider: {}",
356            config.database.provider.unwrap_or_default()
357        );
358        println!("Plugins: {}", config.plugins.enabled.join(", "));
359    }
360    Ok(())
361}
362
363fn secret_command(args: SecretArgs) -> Result<(), AppError> {
364    let value = match (args.check, args.check_env) {
365        (Some(value), None) => Some(value),
366        (None, Some(env)) => Some(std::env::var(&env).unwrap_or_default()),
367        (Some(_), Some(_)) => {
368            return Err(AppError::Message(
369                "Use only one of --check or --check-env.".to_owned(),
370            ))
371        }
372        (None, None) => None,
373    };
374    let Some(secret) = value else {
375        println!("{}", generate_secret(args.bytes));
376        return Ok(());
377    };
378
379    let assessment = assess_secret(&secret, true);
380    match assessment.severity {
381        SecretSeverity::Ok => {
382            println!("{}", assessment.message);
383            Ok(())
384        }
385        SecretSeverity::Warning => {
386            eprintln!("{}", assessment.message);
387            Ok(())
388        }
389        SecretSeverity::Error => Err(AppError::Message(assessment.message)),
390    }
391}
392
393async fn db_status(cwd: &Path, args: StatusArgs) -> Result<(), AppError> {
394    let config = load_config(cwd)?;
395    let planned = db::plan(&config, false).await?;
396    let summary = planned.summary();
397    if args.json {
398        print_json(&summary)?;
399    } else {
400        print_plan(&planned);
401    }
402    if args.check && !planned.plan.is_empty() {
403        return Err(AppError::ExitOnly);
404    }
405    Ok(())
406}
407
408async fn db_generate(cwd: &Path, args: GenerateArgs) -> Result<(), AppError> {
409    let config = load_config(cwd)?;
410    let planned = db::plan(&config, args.from_empty).await?;
411    if planned.plan.is_empty() {
412        println!("Schema is already up to date.");
413        return Ok(());
414    }
415    let output = args
416        .output
417        .as_ref()
418        .map(|path| resolve_project_path(cwd, path))
419        .unwrap_or_else(|| cwd.join(&config.database.migrations_dir));
420    let path = db::write_migration(&config, &planned, Some(&output), args.force)?;
421    println!("Generated migration: {}", path.display());
422    Ok(())
423}
424
425async fn db_migrate(cwd: &Path, args: MigrateArgs) -> Result<(), AppError> {
426    let config = load_config(cwd)?;
427    let planned = db::plan(&config, false).await?;
428    if planned.plan.is_empty() {
429        println!("No migrations needed.");
430        return Ok(());
431    }
432    print_plan(&planned);
433    if args.dry_run {
434        println!("Dry run complete; no changes were applied.");
435        return Ok(());
436    }
437    if !confirm("Apply these migrations?", args.yes)? {
438        println!("Migration cancelled.");
439        return Ok(());
440    }
441    db::migrate(&config).await?;
442    println!("Migration completed successfully.");
443    Ok(())
444}
445
446fn schema_print(cwd: &Path, args: SchemaPrintArgs) -> Result<(), AppError> {
447    let config = load_config(cwd)?;
448    let schema = target_schema(&config)?;
449    match args.format {
450        SchemaFormat::Json => print_json(&schema)?,
451        SchemaFormat::Sql => {
452            let dialect = dialect_from_provider(&args.dialect).ok_or_else(|| {
453                AppError::Message(format!("unsupported dialect `{}`", args.dialect))
454            })?;
455            let plan = full_schema_plan(dialect, &schema)?;
456            println!("{}", plan.compile());
457        }
458    }
459    Ok(())
460}
461
462fn plugins_list(args: PluginListArgs) -> Result<(), AppError> {
463    let plugins = official_plugins();
464    if args.json {
465        print_json(&plugins)?;
466    } else {
467        for plugin in plugins {
468            let schema = if plugin.schema { "schema" } else { "no schema" };
469            println!("{} ({schema})", plugin.id);
470        }
471    }
472    Ok(())
473}
474
475async fn plugin_add(cwd: &Path, args: PluginChangeArgs) -> Result<(), AppError> {
476    if !is_official_plugin(&args.plugin) {
477        return Err(AppError::Message(format!(
478            "`{}` is not an official OpenAuth plugin.",
479            args.plugin
480        )));
481    }
482    let path = cwd.join("openauth.toml");
483    let source = fs::read_to_string(&path).map_err(|source| AppError::Io {
484        context: format!("failed to read {}", path.display()),
485        source,
486    })?;
487    let updated = CliConfig::add_plugin_to_document(&source, &args.plugin)?;
488    if !confirm(
489        &format!("Add `{}` to openauth.toml?", args.plugin),
490        args.yes,
491    )? {
492        return Err(AppError::Message("Plugin update aborted.".to_owned()));
493    }
494    fs::write(&path, updated).map_err(|source| AppError::Io {
495        context: format!("failed to write {}", path.display()),
496        source,
497    })?;
498    println!("Added plugin `{}` to openauth.toml.", args.plugin);
499    if let Some(snippet) = rust_snippet(&args.plugin) {
500        println!("Rust snippet: {snippet}");
501    }
502    let config = load_config(cwd)?;
503    match db::plan(&config, false).await {
504        Ok(plan) if !plan.plan.is_empty() => {
505            println!("This plugin changes the database schema.");
506            println!("Run `openauth db generate` or `openauth db migrate`.");
507        }
508        Ok(_) => {}
509        Err(error) => {
510            println!("Database impact could not be checked: {error}");
511        }
512    }
513    Ok(())
514}
515
516fn plugin_remove(cwd: &Path, args: PluginChangeArgs) -> Result<(), AppError> {
517    let path = cwd.join("openauth.toml");
518    let source = fs::read_to_string(&path).map_err(|source| AppError::Io {
519        context: format!("failed to read {}", path.display()),
520        source,
521    })?;
522    let updated = CliConfig::remove_plugin_from_document(&source, &args.plugin)?;
523    if !confirm(
524        &format!("Remove `{}` from openauth.toml?", args.plugin),
525        args.yes,
526    )? {
527        return Err(AppError::Message("Plugin update aborted.".to_owned()));
528    }
529    fs::write(&path, updated).map_err(|source| AppError::Io {
530        context: format!("failed to write {}", path.display()),
531        source,
532    })?;
533    println!("Removed plugin `{}` from openauth.toml.", args.plugin);
534    println!("OpenAuth does not generate destructive migrations in v1.");
535    Ok(())
536}
537
538fn completions(args: CompletionsArgs) -> Result<(), AppError> {
539    let mut command = Cli::command();
540    let name = command.get_name().to_owned();
541    clap_complete::generate(args.shell, &mut command, name, &mut io::stdout());
542    Ok(())
543}
544
545fn load_config(cwd: &Path) -> Result<CliConfig, AppError> {
546    let path = cwd.join("openauth.toml");
547    CliConfig::load(&path).map_err(AppError::Config)
548}
549
550fn update_env_example(cwd: &Path, config: &CliConfig) -> Result<(), AppError> {
551    let path = cwd.join(".env.example");
552    let mut content = if path.exists() {
553        fs::read_to_string(&path).map_err(|source| AppError::Io {
554            context: format!("failed to read {}", path.display()),
555            source,
556        })?
557    } else {
558        String::new()
559    };
560    append_env_if_missing(
561        &mut content,
562        &config.security.secret_env,
563        generate_secret(32),
564    );
565    append_env_if_missing(
566        &mut content,
567        &config.database.url_env,
568        default_database_url(config),
569    );
570    fs::write(&path, content).map_err(|source| AppError::Io {
571        context: format!("failed to write {}", path.display()),
572        source,
573    })
574}
575
576fn append_env_if_missing(content: &mut String, key: &str, value: impl AsRef<str>) {
577    let prefix = format!("{key}=");
578    if content.lines().any(|line| line.starts_with(&prefix)) {
579        return;
580    }
581    if !content.is_empty() && !content.ends_with('\n') {
582        content.push('\n');
583    }
584    content.push_str(&prefix);
585    content.push_str(value.as_ref());
586    content.push('\n');
587}
588
589fn default_database_url(config: &CliConfig) -> &'static str {
590    match config.database.provider.as_deref() {
591        Some("postgres") | Some("postgresql") | Some("pg") => {
592            "postgres://user:password@localhost:5432/openauth"
593        }
594        Some("mysql") => "mysql://user:password@localhost:3306/openauth",
595        _ => "sqlite://openauth.sqlite",
596    }
597}
598
599fn detect_provider_from_env() -> Option<String> {
600    let url = std::env::var("DATABASE_URL").ok()?;
601    if url.starts_with("postgres://") || url.starts_with("postgresql://") {
602        return Some("postgres".to_owned());
603    }
604    if url.starts_with("mysql://") {
605        return Some("mysql".to_owned());
606    }
607    if url.starts_with("sqlite://") || url.ends_with(".sqlite") || url.ends_with(".db") {
608        return Some("sqlite".to_owned());
609    }
610    None
611}
612
613fn normalize_plugins(plugins: Vec<String>) -> Result<Vec<String>, AppError> {
614    let mut normalized = Vec::new();
615    for plugin in plugins {
616        let plugin = plugin.trim();
617        if plugin.is_empty() {
618            continue;
619        }
620        if !is_official_plugin(plugin) {
621            return Err(AppError::Message(format!(
622                "`{plugin}` is not an official OpenAuth plugin."
623            )));
624        }
625        if !normalized.iter().any(|existing| existing == plugin) {
626            normalized.push(plugin.to_owned());
627        }
628    }
629    Ok(normalized)
630}
631
632fn print_report(report: &DiagnosticReport) {
633    println!("OpenAuth doctor");
634    println!("Rust: {}", report.rust);
635    println!("Cargo: {}", report.cargo);
636    if let Some(root) = &report.workspace_root {
637        println!("Workspace: {root}");
638    }
639    for finding in &report.findings {
640        let label = match finding.severity {
641            Severity::Info => "INFO",
642            Severity::Warn => "WARN",
643            Severity::Error => "ERROR",
644        };
645        println!("[{label}] {}: {}", finding.code, finding.message);
646    }
647}
648
649fn print_plan(planned: &db::PlannedMigration) {
650    let dialect = dialect_from_provider(&planned.provider)
651        .map(dialect_name)
652        .unwrap_or("unknown");
653    println!("OpenAuth schema plan ({dialect})");
654    println!("Tables to create: {}", planned.plan.to_be_created.len());
655    for table in &planned.plan.to_be_created {
656        println!("  - {}", table.table_name);
657    }
658    println!("Columns to add: {}", planned.plan.to_be_added.len());
659    for column in &planned.plan.to_be_added {
660        println!("  - {}.{}", column.table_name, column.column_name);
661    }
662    println!(
663        "Indexes to create: {}",
664        planned.plan.indexes_to_be_created.len()
665    );
666    for index in &planned.plan.indexes_to_be_created {
667        println!("  - {}", index.index_name);
668    }
669    for warning in &planned.plan.warnings {
670        println!("WARNING: {warning:?}");
671    }
672}
673
674fn print_json<T>(value: &T) -> Result<(), AppError>
675where
676    T: Serialize,
677{
678    let rendered = serde_json::to_string_pretty(value)?;
679    println!("{rendered}");
680    Ok(())
681}
682
683fn confirm(message: &str, yes: bool) -> Result<bool, AppError> {
684    if yes {
685        return Ok(true);
686    }
687    Confirm::new(message)
688        .with_default(false)
689        .prompt()
690        .map_err(|error| AppError::Message(format!("prompt failed: {error}")))
691}
692
693fn absolute_cwd(cwd: &Path) -> Result<PathBuf, AppError> {
694    let path = if cwd.is_absolute() {
695        cwd.to_path_buf()
696    } else {
697        std::env::current_dir()
698            .map_err(|source| AppError::Io {
699                context: "failed to read current directory".to_owned(),
700                source,
701            })?
702            .join(cwd)
703    };
704    if path.exists() {
705        Ok(path)
706    } else {
707        Err(AppError::Message(format!(
708            "The directory {} does not exist.",
709            path.display()
710        )))
711    }
712}
713
714fn resolve_project_path(cwd: &Path, path: &Path) -> PathBuf {
715    if path.is_absolute() {
716        path.to_path_buf()
717    } else {
718        cwd.join(path)
719    }
720}
721
722#[derive(Debug, thiserror::Error)]
723enum AppError {
724    #[error("{0}")]
725    Message(String),
726    #[error(transparent)]
727    Config(#[from] ConfigError),
728    #[error(transparent)]
729    Db(#[from] DbCliError),
730    #[error(transparent)]
731    OpenAuth(#[from] openauth_core::error::OpenAuthError),
732    #[error(transparent)]
733    Json(#[from] serde_json::Error),
734    #[error("failed to start async runtime: {0}")]
735    Runtime(std::io::Error),
736    #[error("{context}: {source}")]
737    Io {
738        context: String,
739        source: std::io::Error,
740    },
741    #[error("")]
742    ExitOnly,
743}
744
745#[allow(dead_code)]
746fn _dialect_for_lints(dialect: SqlDialect) -> &'static str {
747    dialect_name(dialect)
748}