1use std::ffi::OsString;
2use std::fs;
3use std::io;
4use std::path::{Path, PathBuf};
5
6use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
7use clap_complete::Shell;
8use inquire::Confirm;
9use openauth_core::db::sql::SqlDialect;
10use serde::Serialize;
11
12use crate::config::{CliConfig, ConfigError};
13use crate::db::{self, DbCliError};
14use crate::diagnostics::{doctor, DiagnosticReport, Severity};
15use crate::plugins::{is_official_plugin, official_plugins, rust_snippet};
16use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
17use crate::secret::{assess_secret, generate_secret, SecretSeverity};
18use crate::workspace;
19
20#[derive(Debug, Parser)]
21#[command(name = "openauth", version, about = "Command-line tools for OpenAuth.")]
22pub struct Cli {
23 #[arg(long, global = true, default_value = ".")]
24 cwd: PathBuf,
25 #[command(subcommand)]
26 command: Commands,
27}
28
29#[derive(Debug, Subcommand)]
30enum Commands {
31 Init(InitArgs),
32 Doctor(DiagnosticArgs),
33 Info(InfoArgs),
34 Secret(SecretArgs),
35 Db(DbArgs),
36 Generate(GenerateArgs),
37 Migrate(MigrateArgs),
38 Schema(SchemaArgs),
39 Plugins(PluginsArgs),
40 Completions(CompletionsArgs),
41}
42
43#[derive(Debug, clap::Args)]
44struct InitArgs {
45 #[arg(long)]
46 framework: Option<String>,
47 #[arg(long)]
48 adapter: Option<String>,
49 #[arg(long)]
50 database: Option<String>,
51 #[arg(long)]
52 base_url: Option<String>,
53 #[arg(long, value_delimiter = ',')]
54 plugins: Vec<String>,
55 #[arg(short = 'y', long)]
56 yes: bool,
57 #[arg(long)]
58 force: bool,
59}
60
61#[derive(Debug, clap::Args)]
62struct DiagnosticArgs {
63 #[arg(long)]
64 production: bool,
65 #[arg(long)]
66 json: bool,
67 #[arg(long)]
68 strict: bool,
69}
70
71#[derive(Debug, clap::Args)]
72struct InfoArgs {
73 #[arg(long)]
74 json: bool,
75}
76
77#[derive(Debug, clap::Args)]
78struct SecretArgs {
79 #[arg(long, default_value_t = 32)]
80 bytes: usize,
81 #[arg(long)]
82 check: Option<String>,
83 #[arg(long)]
84 check_env: Option<String>,
85}
86
87#[derive(Debug, clap::Args)]
88struct DbArgs {
89 #[command(subcommand)]
90 command: DbCommands,
91}
92
93#[derive(Debug, Subcommand)]
94enum DbCommands {
95 Status(StatusArgs),
96 Generate(GenerateArgs),
97 Migrate(MigrateArgs),
98}
99
100#[derive(Debug, clap::Args)]
101struct StatusArgs {
102 #[arg(long)]
103 json: bool,
104 #[arg(long)]
105 check: bool,
106}
107
108#[derive(Debug, clap::Args)]
109struct GenerateArgs {
110 #[arg(long)]
111 output: Option<PathBuf>,
112 #[arg(long)]
113 from_empty: bool,
114 #[arg(long)]
115 force: bool,
116}
117
118#[derive(Debug, clap::Args)]
119struct MigrateArgs {
120 #[arg(long)]
121 dry_run: bool,
122 #[arg(short = 'y', long)]
123 yes: bool,
124}
125
126#[derive(Debug, clap::Args)]
127struct SchemaArgs {
128 #[command(subcommand)]
129 command: SchemaCommands,
130}
131
132#[derive(Debug, Subcommand)]
133enum SchemaCommands {
134 Print(SchemaPrintArgs),
135}
136
137#[derive(Debug, clap::Args)]
138struct SchemaPrintArgs {
139 #[arg(long, value_enum, default_value_t = SchemaFormat::Sql)]
140 format: SchemaFormat,
141 #[arg(long, default_value = "sqlite")]
142 dialect: String,
143}
144
145#[derive(Debug, Clone, Copy, ValueEnum)]
146enum SchemaFormat {
147 Sql,
148 Json,
149}
150
151#[derive(Debug, clap::Args)]
152struct PluginsArgs {
153 #[command(subcommand)]
154 command: PluginsCommands,
155}
156
157#[derive(Debug, Subcommand)]
158enum PluginsCommands {
159 List(PluginListArgs),
160 Add(PluginChangeArgs),
161 Remove(PluginChangeArgs),
162}
163
164#[derive(Debug, clap::Args)]
165struct PluginListArgs {
166 #[arg(long)]
167 json: bool,
168}
169
170#[derive(Debug, clap::Args)]
171struct PluginChangeArgs {
172 plugin: String,
173 #[arg(short = 'y', long)]
174 yes: bool,
175}
176
177#[derive(Debug, clap::Args)]
178struct CompletionsArgs {
179 shell: Shell,
180}
181
182pub fn run() -> i32 {
183 run_from(std::env::args_os())
184}
185
186pub fn run_cargo() -> i32 {
187 let mut args = std::env::args_os().collect::<Vec<_>>();
188 if args
189 .get(1)
190 .and_then(|arg| arg.to_str())
191 .is_some_and(is_cargo_subcommand_name)
192 {
193 args.remove(1);
194 }
195 run_from(args)
196}
197
198fn is_cargo_subcommand_name(value: &str) -> bool {
199 matches!(
200 value,
201 "openauth" | "open-auth" | "better-auth" | "betterauth"
202 )
203}
204
205pub fn run_from<I, T>(args: I) -> i32
206where
207 I: IntoIterator<Item = T>,
208 T: Into<OsString> + Clone,
209{
210 match Cli::try_parse_from(args) {
211 Ok(cli) => match execute(cli) {
212 Ok(()) => 0,
213 Err(error) => {
214 eprintln!("{error}");
215 1
216 }
217 },
218 Err(error) => {
219 let _ = error.print();
220 error.exit_code()
221 }
222 }
223}
224
225fn execute(cli: Cli) -> Result<(), AppError> {
226 let runtime = tokio::runtime::Runtime::new().map_err(AppError::Runtime)?;
227 runtime.block_on(async move { execute_async(cli).await })
228}
229
230async fn execute_async(cli: Cli) -> Result<(), AppError> {
231 let cwd = absolute_cwd(&cli.cwd)?;
232 match cli.command {
233 Commands::Init(args) => init(&cwd, args),
234 Commands::Doctor(args) => doctor_command(&cwd, args).await,
235 Commands::Info(args) => info_command(&cwd, args).await,
236 Commands::Secret(args) => secret_command(args),
237 Commands::Db(args) => match args.command {
238 DbCommands::Status(args) => db_status(&cwd, args).await,
239 DbCommands::Generate(args) => db_generate(&cwd, args).await,
240 DbCommands::Migrate(args) => db_migrate(&cwd, args).await,
241 },
242 Commands::Generate(args) => db_generate(&cwd, args).await,
243 Commands::Migrate(args) => db_migrate(&cwd, args).await,
244 Commands::Schema(args) => match args.command {
245 SchemaCommands::Print(args) => schema_print(&cwd, args),
246 },
247 Commands::Plugins(args) => match args.command {
248 PluginsCommands::List(args) => plugins_list(args),
249 PluginsCommands::Add(args) => plugin_add(&cwd, args).await,
250 PluginsCommands::Remove(args) => plugin_remove(&cwd, args),
251 },
252 Commands::Completions(args) => completions(args),
253 }
254}
255
256fn init(cwd: &Path, args: InitArgs) -> Result<(), AppError> {
257 let config_path = cwd.join("openauth.toml");
258 if config_path.exists() && !args.force {
259 return Err(AppError::Message(format!(
260 "{} already exists. Use --force to overwrite it.",
261 config_path.display()
262 )));
263 }
264
265 let detected = workspace::inspect(cwd).ok();
266 let framework = args
267 .framework
268 .or_else(|| {
269 detected
270 .as_ref()
271 .and_then(|info| info.detected_frameworks.first())
272 .map(|item| item.name.clone())
273 })
274 .unwrap_or_else(|| "axum".to_owned());
275 let database = args.database.or_else(detect_provider_from_env).or_else(|| {
276 detected.as_ref().and_then(|info| {
277 if info
278 .detected_databases
279 .iter()
280 .any(|item| item.name == "sqlx")
281 {
282 Some("sqlite".to_owned())
283 } else {
284 None
285 }
286 })
287 });
288
289 let config = CliConfig {
290 project: crate::config::ProjectConfig {
291 framework: Some(framework.clone()),
292 base_url: args
293 .base_url
294 .unwrap_or_else(|| "http://localhost:3000/api/auth".to_owned()),
295 ..crate::config::ProjectConfig::default()
296 },
297 database: crate::config::DatabaseConfig {
298 adapter: args.adapter.unwrap_or_else(|| "sqlx".to_owned()),
299 provider: database.or(Some("sqlite".to_owned())),
300 ..crate::config::DatabaseConfig::default()
301 },
302 plugins: crate::config::PluginsConfig {
303 enabled: normalize_plugins(args.plugins)?,
304 },
305 ..CliConfig::default()
306 };
307
308 if config_path.exists() && !confirm("Overwrite existing openauth.toml?", args.yes)? {
309 return Err(AppError::Message("Initialization aborted.".to_owned()));
310 }
311 config.write(&config_path)?;
312 update_env_example(cwd, &config)?;
313 println!("Created openauth.toml");
314 println!("Updated .env.example");
315 if framework == "axum" {
316 println!();
317 println!("Axum integration snippet:");
318 println!("let app = openauth_axum::router(auth)?;");
319 }
320 Ok(())
321}
322
323async fn doctor_command(cwd: &Path, args: DiagnosticArgs) -> Result<(), AppError> {
324 let config = load_config(cwd)?;
325 let report = doctor(cwd, &config, args.production).await;
326 if args.json {
327 print_json(&report)?;
328 } else {
329 print_report(&report);
330 }
331 if report.has_errors() || (args.strict && report.has_warnings()) {
332 return Err(AppError::ExitOnly);
333 }
334 Ok(())
335}
336
337async fn info_command(cwd: &Path, args: InfoArgs) -> Result<(), AppError> {
338 let config = load_config(cwd)?;
339 let report = doctor(cwd, &config, false).await;
340 if args.json {
341 print_json(&report)?;
342 } else {
343 println!("OpenAuth info");
344 println!("Rust: {}", report.rust);
345 println!("Cargo: {}", report.cargo);
346 if let Some(root) = report.workspace_root {
347 println!("Workspace: {root}");
348 }
349 println!(
350 "Framework: {}",
351 config.project.framework.unwrap_or_default()
352 );
353 println!("Adapter: {}", config.database.adapter);
354 println!(
355 "Database provider: {}",
356 config.database.provider.unwrap_or_default()
357 );
358 println!("Plugins: {}", config.plugins.enabled.join(", "));
359 }
360 Ok(())
361}
362
363fn secret_command(args: SecretArgs) -> Result<(), AppError> {
364 let value = match (args.check, args.check_env) {
365 (Some(value), None) => Some(value),
366 (None, Some(env)) => Some(std::env::var(&env).unwrap_or_default()),
367 (Some(_), Some(_)) => {
368 return Err(AppError::Message(
369 "Use only one of --check or --check-env.".to_owned(),
370 ))
371 }
372 (None, None) => None,
373 };
374 let Some(secret) = value else {
375 println!("{}", generate_secret(args.bytes));
376 return Ok(());
377 };
378
379 let assessment = assess_secret(&secret, true);
380 match assessment.severity {
381 SecretSeverity::Ok => {
382 println!("{}", assessment.message);
383 Ok(())
384 }
385 SecretSeverity::Warning => {
386 eprintln!("{}", assessment.message);
387 Ok(())
388 }
389 SecretSeverity::Error => Err(AppError::Message(assessment.message)),
390 }
391}
392
393async fn db_status(cwd: &Path, args: StatusArgs) -> Result<(), AppError> {
394 let config = load_config(cwd)?;
395 let planned = db::plan(&config, false).await?;
396 let summary = planned.summary();
397 if args.json {
398 print_json(&summary)?;
399 } else {
400 print_plan(&planned);
401 }
402 if args.check && !planned.plan.is_empty() {
403 return Err(AppError::ExitOnly);
404 }
405 Ok(())
406}
407
408async fn db_generate(cwd: &Path, args: GenerateArgs) -> Result<(), AppError> {
409 let config = load_config(cwd)?;
410 let planned = db::plan(&config, args.from_empty).await?;
411 if planned.plan.is_empty() {
412 println!("Schema is already up to date.");
413 return Ok(());
414 }
415 let output = args
416 .output
417 .as_ref()
418 .map(|path| resolve_project_path(cwd, path))
419 .unwrap_or_else(|| cwd.join(&config.database.migrations_dir));
420 let path = db::write_migration(&config, &planned, Some(&output), args.force)?;
421 println!("Generated migration: {}", path.display());
422 Ok(())
423}
424
425async fn db_migrate(cwd: &Path, args: MigrateArgs) -> Result<(), AppError> {
426 let config = load_config(cwd)?;
427 let planned = db::plan(&config, false).await?;
428 if planned.plan.is_empty() {
429 println!("No migrations needed.");
430 return Ok(());
431 }
432 print_plan(&planned);
433 if args.dry_run {
434 println!("Dry run complete; no changes were applied.");
435 return Ok(());
436 }
437 if !confirm("Apply these migrations?", args.yes)? {
438 println!("Migration cancelled.");
439 return Ok(());
440 }
441 db::migrate(&config).await?;
442 println!("Migration completed successfully.");
443 Ok(())
444}
445
446fn schema_print(cwd: &Path, args: SchemaPrintArgs) -> Result<(), AppError> {
447 let config = load_config(cwd)?;
448 let schema = target_schema(&config)?;
449 match args.format {
450 SchemaFormat::Json => print_json(&schema)?,
451 SchemaFormat::Sql => {
452 let dialect = dialect_from_provider(&args.dialect).ok_or_else(|| {
453 AppError::Message(format!("unsupported dialect `{}`", args.dialect))
454 })?;
455 let plan = full_schema_plan(dialect, &schema)?;
456 println!("{}", plan.compile());
457 }
458 }
459 Ok(())
460}
461
462fn plugins_list(args: PluginListArgs) -> Result<(), AppError> {
463 let plugins = official_plugins();
464 if args.json {
465 print_json(&plugins)?;
466 } else {
467 for plugin in plugins {
468 let schema = if plugin.schema { "schema" } else { "no schema" };
469 println!("{} ({schema})", plugin.id);
470 }
471 }
472 Ok(())
473}
474
475async fn plugin_add(cwd: &Path, args: PluginChangeArgs) -> Result<(), AppError> {
476 if !is_official_plugin(&args.plugin) {
477 return Err(AppError::Message(format!(
478 "`{}` is not an official OpenAuth plugin.",
479 args.plugin
480 )));
481 }
482 let path = cwd.join("openauth.toml");
483 let source = fs::read_to_string(&path).map_err(|source| AppError::Io {
484 context: format!("failed to read {}", path.display()),
485 source,
486 })?;
487 let updated = CliConfig::add_plugin_to_document(&source, &args.plugin)?;
488 if !confirm(
489 &format!("Add `{}` to openauth.toml?", args.plugin),
490 args.yes,
491 )? {
492 return Err(AppError::Message("Plugin update aborted.".to_owned()));
493 }
494 fs::write(&path, updated).map_err(|source| AppError::Io {
495 context: format!("failed to write {}", path.display()),
496 source,
497 })?;
498 println!("Added plugin `{}` to openauth.toml.", args.plugin);
499 if let Some(snippet) = rust_snippet(&args.plugin) {
500 println!("Rust snippet: {snippet}");
501 }
502 let config = load_config(cwd)?;
503 match db::plan(&config, false).await {
504 Ok(plan) if !plan.plan.is_empty() => {
505 println!("This plugin changes the database schema.");
506 println!("Run `openauth db generate` or `openauth db migrate`.");
507 }
508 Ok(_) => {}
509 Err(error) => {
510 println!("Database impact could not be checked: {error}");
511 }
512 }
513 Ok(())
514}
515
516fn plugin_remove(cwd: &Path, args: PluginChangeArgs) -> Result<(), AppError> {
517 let path = cwd.join("openauth.toml");
518 let source = fs::read_to_string(&path).map_err(|source| AppError::Io {
519 context: format!("failed to read {}", path.display()),
520 source,
521 })?;
522 let updated = CliConfig::remove_plugin_from_document(&source, &args.plugin)?;
523 if !confirm(
524 &format!("Remove `{}` from openauth.toml?", args.plugin),
525 args.yes,
526 )? {
527 return Err(AppError::Message("Plugin update aborted.".to_owned()));
528 }
529 fs::write(&path, updated).map_err(|source| AppError::Io {
530 context: format!("failed to write {}", path.display()),
531 source,
532 })?;
533 println!("Removed plugin `{}` from openauth.toml.", args.plugin);
534 println!("OpenAuth does not generate destructive migrations in v1.");
535 Ok(())
536}
537
538fn completions(args: CompletionsArgs) -> Result<(), AppError> {
539 let mut command = Cli::command();
540 let name = command.get_name().to_owned();
541 clap_complete::generate(args.shell, &mut command, name, &mut io::stdout());
542 Ok(())
543}
544
545fn load_config(cwd: &Path) -> Result<CliConfig, AppError> {
546 let path = cwd.join("openauth.toml");
547 CliConfig::load(&path).map_err(AppError::Config)
548}
549
550fn update_env_example(cwd: &Path, config: &CliConfig) -> Result<(), AppError> {
551 let path = cwd.join(".env.example");
552 let mut content = if path.exists() {
553 fs::read_to_string(&path).map_err(|source| AppError::Io {
554 context: format!("failed to read {}", path.display()),
555 source,
556 })?
557 } else {
558 String::new()
559 };
560 append_env_if_missing(
561 &mut content,
562 &config.security.secret_env,
563 generate_secret(32),
564 );
565 append_env_if_missing(
566 &mut content,
567 &config.database.url_env,
568 default_database_url(config),
569 );
570 fs::write(&path, content).map_err(|source| AppError::Io {
571 context: format!("failed to write {}", path.display()),
572 source,
573 })
574}
575
576fn append_env_if_missing(content: &mut String, key: &str, value: impl AsRef<str>) {
577 let prefix = format!("{key}=");
578 if content.lines().any(|line| line.starts_with(&prefix)) {
579 return;
580 }
581 if !content.is_empty() && !content.ends_with('\n') {
582 content.push('\n');
583 }
584 content.push_str(&prefix);
585 content.push_str(value.as_ref());
586 content.push('\n');
587}
588
589fn default_database_url(config: &CliConfig) -> &'static str {
590 match config.database.provider.as_deref() {
591 Some("postgres") | Some("postgresql") | Some("pg") => {
592 "postgres://user:password@localhost:5432/openauth"
593 }
594 Some("mysql") => "mysql://user:password@localhost:3306/openauth",
595 _ => "sqlite://openauth.sqlite",
596 }
597}
598
599fn detect_provider_from_env() -> Option<String> {
600 let url = std::env::var("DATABASE_URL").ok()?;
601 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
602 return Some("postgres".to_owned());
603 }
604 if url.starts_with("mysql://") {
605 return Some("mysql".to_owned());
606 }
607 if url.starts_with("sqlite://") || url.ends_with(".sqlite") || url.ends_with(".db") {
608 return Some("sqlite".to_owned());
609 }
610 None
611}
612
613fn normalize_plugins(plugins: Vec<String>) -> Result<Vec<String>, AppError> {
614 let mut normalized = Vec::new();
615 for plugin in plugins {
616 let plugin = plugin.trim();
617 if plugin.is_empty() {
618 continue;
619 }
620 if !is_official_plugin(plugin) {
621 return Err(AppError::Message(format!(
622 "`{plugin}` is not an official OpenAuth plugin."
623 )));
624 }
625 if !normalized.iter().any(|existing| existing == plugin) {
626 normalized.push(plugin.to_owned());
627 }
628 }
629 Ok(normalized)
630}
631
632fn print_report(report: &DiagnosticReport) {
633 println!("OpenAuth doctor");
634 println!("Rust: {}", report.rust);
635 println!("Cargo: {}", report.cargo);
636 if let Some(root) = &report.workspace_root {
637 println!("Workspace: {root}");
638 }
639 for finding in &report.findings {
640 let label = match finding.severity {
641 Severity::Info => "INFO",
642 Severity::Warn => "WARN",
643 Severity::Error => "ERROR",
644 };
645 println!("[{label}] {}: {}", finding.code, finding.message);
646 }
647}
648
649fn print_plan(planned: &db::PlannedMigration) {
650 let dialect = dialect_from_provider(&planned.provider)
651 .map(dialect_name)
652 .unwrap_or("unknown");
653 println!("OpenAuth schema plan ({dialect})");
654 println!("Tables to create: {}", planned.plan.to_be_created.len());
655 for table in &planned.plan.to_be_created {
656 println!(" - {}", table.table_name);
657 }
658 println!("Columns to add: {}", planned.plan.to_be_added.len());
659 for column in &planned.plan.to_be_added {
660 println!(" - {}.{}", column.table_name, column.column_name);
661 }
662 println!(
663 "Indexes to create: {}",
664 planned.plan.indexes_to_be_created.len()
665 );
666 for index in &planned.plan.indexes_to_be_created {
667 println!(" - {}", index.index_name);
668 }
669 for warning in &planned.plan.warnings {
670 println!("WARNING: {warning:?}");
671 }
672}
673
674fn print_json<T>(value: &T) -> Result<(), AppError>
675where
676 T: Serialize,
677{
678 let rendered = serde_json::to_string_pretty(value)?;
679 println!("{rendered}");
680 Ok(())
681}
682
683fn confirm(message: &str, yes: bool) -> Result<bool, AppError> {
684 if yes {
685 return Ok(true);
686 }
687 Confirm::new(message)
688 .with_default(false)
689 .prompt()
690 .map_err(|error| AppError::Message(format!("prompt failed: {error}")))
691}
692
693fn absolute_cwd(cwd: &Path) -> Result<PathBuf, AppError> {
694 let path = if cwd.is_absolute() {
695 cwd.to_path_buf()
696 } else {
697 std::env::current_dir()
698 .map_err(|source| AppError::Io {
699 context: "failed to read current directory".to_owned(),
700 source,
701 })?
702 .join(cwd)
703 };
704 if path.exists() {
705 Ok(path)
706 } else {
707 Err(AppError::Message(format!(
708 "The directory {} does not exist.",
709 path.display()
710 )))
711 }
712}
713
714fn resolve_project_path(cwd: &Path, path: &Path) -> PathBuf {
715 if path.is_absolute() {
716 path.to_path_buf()
717 } else {
718 cwd.join(path)
719 }
720}
721
722#[derive(Debug, thiserror::Error)]
723enum AppError {
724 #[error("{0}")]
725 Message(String),
726 #[error(transparent)]
727 Config(#[from] ConfigError),
728 #[error(transparent)]
729 Db(#[from] DbCliError),
730 #[error(transparent)]
731 OpenAuth(#[from] openauth_core::error::OpenAuthError),
732 #[error(transparent)]
733 Json(#[from] serde_json::Error),
734 #[error("failed to start async runtime: {0}")]
735 Runtime(std::io::Error),
736 #[error("{context}: {source}")]
737 Io {
738 context: String,
739 source: std::io::Error,
740 },
741 #[error("")]
742 ExitOnly,
743}
744
745#[allow(dead_code)]
746fn _dialect_for_lints(dialect: SqlDialect) -> &'static str {
747 dialect_name(dialect)
748}