1use anyhow::{Context, Result};
4use std::collections::BTreeSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::cli::{AddArgs, AddFeature};
9use crate::templates::{BackendTemplateContext, BackendTemplateEngine};
10use crate::{
11 ensure_dir, print_info, print_success, print_warning, write_file, TIDEWAY_VERSION,
12};
13
14pub fn run(args: AddArgs) -> Result<()> {
15 let project_dir = PathBuf::from(&args.path);
16 let cargo_path = project_dir.join("Cargo.toml");
17
18 if !cargo_path.exists() {
19 return Err(anyhow::anyhow!(
20 "Cargo.toml not found in {}",
21 project_dir.display()
22 ));
23 }
24
25 let cargo_contents = fs::read_to_string(&cargo_path)
26 .with_context(|| format!("Failed to read {}", cargo_path.display()))?;
27
28 let project_name = project_name_from_cargo(&cargo_contents, &project_dir);
29 let project_name_pascal = to_pascal_case(&project_name);
30
31 update_cargo_toml(&cargo_path, &cargo_contents, args.feature)?;
32 update_env_example(&project_dir, args.feature, &project_name)?;
33
34 if args.feature == AddFeature::Auth {
35 scaffold_auth(&project_dir, &project_name, &project_name_pascal, args.force)?;
36 print_info("Auth scaffold created in src/auth/");
37 if args.wire {
38 wire_auth_in_main(&project_dir, &project_name)?;
39 } else {
40 print_info("Next steps: wire AuthModule + SimpleAuthProvider in main.rs");
41 }
42 }
43
44 if args.feature == AddFeature::Database && args.wire {
45 wire_database_in_main(&project_dir)?;
46 }
47
48 if args.feature == AddFeature::Openapi {
49 ensure_openapi_docs_file(&project_dir)?;
50 if args.wire {
51 wire_openapi_in_main(&project_dir)?;
52 } else {
53 print_info("Next steps: wire OpenAPI in main.rs");
54 }
55 }
56
57 print_success(&format!("Added {}", args.feature));
58 Ok(())
59}
60
61fn update_cargo_toml(path: &Path, contents: &str, feature: AddFeature) -> Result<()> {
62 let mut doc = contents.parse::<toml_edit::DocumentMut>()?;
63
64 let deps = doc["dependencies"].or_insert(toml_edit::Item::Table(toml_edit::Table::new()));
65
66 let tideway_item = deps
67 .as_table_mut()
68 .expect("dependencies should be a table")
69 .entry("tideway");
70
71 let feature_name = feature.to_string();
72
73 match tideway_item {
74 toml_edit::Entry::Vacant(entry) => {
75 let mut table = toml_edit::InlineTable::new();
76 table.get_or_insert("version", TIDEWAY_VERSION);
77 table.get_or_insert("features", array_value(&[feature_name.as_str()]));
78 entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
79 }
80 toml_edit::Entry::Occupied(mut entry) => {
81 if entry.get().is_str() {
82 let version = entry
83 .get()
84 .as_str()
85 .unwrap_or(TIDEWAY_VERSION)
86 .to_string();
87 let mut table = toml_edit::InlineTable::new();
88 table.get_or_insert("version", version);
89 table.get_or_insert("features", array_value(&[feature_name.as_str()]));
90 entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
91 } else {
92 let item = entry.get_mut();
93 let features = item["features"]
94 .or_insert(toml_edit::Item::Value(toml_edit::Value::Array(toml_edit::Array::new())))
95 .as_array_mut()
96 .expect("features should be an array");
97
98 if !features.iter().any(|v| v.as_str() == Some(&feature_name)) {
99 features.push(feature_name);
100 }
101 }
102 }
103 }
104
105 if feature == AddFeature::Database {
106 let deps_table = deps.as_table_mut().expect("dependencies should be a table");
107 deps_table
108 .entry("sea-orm")
109 .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
110 {
111 let mut table = toml_edit::InlineTable::new();
112 table.get_or_insert("version", "1.1");
113 table.get_or_insert(
114 "features",
115 array_value(&["sqlx-postgres", "runtime-tokio-rustls"]),
116 );
117 table
118 },
119 )));
120 }
121
122 if feature == AddFeature::Auth {
123 let deps_table = deps.as_table_mut().expect("dependencies should be a table");
124 deps_table
125 .entry("async-trait")
126 .or_insert(toml_edit::value("0.1"));
127 deps_table
128 .entry("serde")
129 .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
130 {
131 let mut table = toml_edit::InlineTable::new();
132 table.get_or_insert("version", "1.0");
133 table.get_or_insert("features", array_value(&["derive"]));
134 table
135 },
136 )));
137 deps_table
138 .entry("serde_json")
139 .or_insert(toml_edit::value("1.0"));
140 }
141
142 write_file(path, &doc.to_string())
143 .with_context(|| format!("Failed to write {}", path.display()))?;
144 Ok(())
145}
146
147fn update_env_example(project_dir: &Path, feature: AddFeature, project_name: &str) -> Result<()> {
148 let env_path = project_dir.join(".env.example");
149 let mut lines = if env_path.exists() {
150 fs::read_to_string(&env_path)
151 .with_context(|| format!("Failed to read {}", env_path.display()))?
152 .lines()
153 .map(|line| line.to_string())
154 .collect::<Vec<_>>()
155 } else {
156 vec![
157 "# Server".to_string(),
158 "TIDEWAY_HOST=0.0.0.0".to_string(),
159 "TIDEWAY_PORT=8000".to_string(),
160 String::new(),
161 ]
162 };
163
164 let mut existing = BTreeSet::new();
165 for line in &lines {
166 if let Some((key, _)) = line.split_once('=') {
167 existing.insert(key.trim().to_string());
168 }
169 }
170
171 match feature {
172 AddFeature::Database => {
173 if !existing.contains("DATABASE_URL") {
174 lines.push("# Database".to_string());
175 lines.push(format!(
176 "DATABASE_URL=postgres://postgres:postgres@localhost:5432/{}",
177 project_name
178 ));
179 lines.push(String::new());
180 }
181 }
182 AddFeature::Auth => {
183 if !existing.contains("JWT_SECRET") {
184 lines.push("# Auth".to_string());
185 lines.push("JWT_SECRET=your-super-secret-jwt-key-change-in-production".to_string());
186 lines.push(String::new());
187 }
188 }
189 _ => {}
190 }
191
192 write_file(&env_path, &lines.join("\n"))
193 .with_context(|| format!("Failed to write {}", env_path.display()))?;
194 Ok(())
195}
196
197fn scaffold_auth(
198 project_dir: &Path,
199 project_name: &str,
200 project_name_pascal: &str,
201 force: bool,
202) -> Result<()> {
203 let context = BackendTemplateContext {
204 project_name: project_name.to_string(),
205 project_name_pascal: project_name_pascal.to_string(),
206 has_organizations: false,
207 database: "postgres".to_string(),
208 tideway_version: TIDEWAY_VERSION.to_string(),
209 tideway_features: vec!["auth".to_string()],
210 has_tideway_features: true,
211 has_auth_feature: true,
212 has_database_feature: false,
213 has_openapi_feature: false,
214 needs_arc: true,
215 has_config: false,
216 };
217
218 let engine = BackendTemplateEngine::new(context)?;
219 let auth_dir = project_dir.join("src").join("auth");
220
221 write_file_with_force(
222 &auth_dir.join("mod.rs"),
223 &engine.render("starter/src/auth/mod.rs")?,
224 force,
225 )?;
226 write_file_with_force(
227 &auth_dir.join("provider.rs"),
228 &engine.render("starter/src/auth/provider.rs")?,
229 force,
230 )?;
231 write_file_with_force(
232 &auth_dir.join("routes.rs"),
233 &engine.render("starter/src/auth/routes.rs")?,
234 force,
235 )?;
236
237 Ok(())
238}
239
240fn wire_auth_in_main(project_dir: &Path, project_name: &str) -> Result<()> {
241 let main_path = project_dir.join("src").join("main.rs");
242 if !main_path.exists() {
243 print_warning("src/main.rs not found; skipping auto-wiring");
244 return Ok(());
245 }
246
247 let mut contents = fs::read_to_string(&main_path)
248 .with_context(|| format!("Failed to read {}", main_path.display()))?;
249
250 if !contents.contains("mod auth;") {
251 if contents.contains("mod routes;") {
252 contents = contents.replace("mod routes;\n", "mod routes;\nmod auth;\n");
253 } else {
254 contents = format!("mod auth;\n{}", contents);
255 }
256 }
257
258 contents = ensure_use_line(
259 contents,
260 "use axum::Extension;",
261 "use tideway::auth",
262 );
263 contents = ensure_use_line(
264 contents,
265 "use crate::auth::{AuthModule, SimpleAuthProvider};",
266 "use tideway::auth",
267 );
268 contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
269 contents = ensure_use_line(
270 contents,
271 "use tideway::auth::{JwtIssuer, JwtIssuerConfig};",
272 "use tideway::auth",
273 );
274
275 let has_jwt_secret = contents.contains("let jwt_secret");
276 let has_jwt_issuer = contents.contains("let jwt_issuer");
277 let has_auth_provider = contents.contains("auth_provider");
278 let has_auth_module = contents.contains("auth_module");
279
280 if has_jwt_secret && has_jwt_issuer {
281 if !has_auth_provider || !has_auth_module {
282 if let Some(insert_at) = contents.find("let jwt_issuer") {
283 let after = contents[insert_at..]
284 .find(";\n")
285 .map(|idx| insert_at + idx + 2)
286 .unwrap_or(insert_at);
287 let insert = format!(
288 " let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n let auth_module = AuthModule::new(jwt_issuer.clone());\n"
289 );
290 contents.insert_str(after, &insert);
291 }
292 }
293 } else {
294 let block = format!(
295 " let jwt_secret = std::env::var(\"JWT_SECRET\").expect(\"JWT_SECRET is not set\");\n let jwt_issuer = Arc::new(JwtIssuer::new(JwtIssuerConfig::with_secret(\n &jwt_secret,\n \"{}\",\n )).expect(\"Failed to create JWT issuer\"));\n let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n let auth_module = AuthModule::new(jwt_issuer.clone());\n\n",
296 project_name
297 );
298 contents = insert_before_app_builder(contents, &block)?;
299 }
300
301 contents = insert_auth_into_app_builder(contents)?;
302
303 write_file(&main_path, &contents)
304 .with_context(|| format!("Failed to write {}", main_path.display()))?;
305 print_success("Wired auth into src/main.rs");
306 Ok(())
307}
308
309pub fn wire_database_in_main(project_dir: &Path) -> Result<()> {
310 let main_path = project_dir.join("src").join("main.rs");
311 if !main_path.exists() {
312 print_warning("src/main.rs not found; skipping auto-wiring");
313 return Ok(());
314 }
315
316 let mut contents = fs::read_to_string(&main_path)
317 .with_context(|| format!("Failed to read {}", main_path.display()))?;
318
319 if !contents.contains("async fn main") {
320 print_warning("main.rs is not async; skipping database wiring");
321 return Ok(());
322 }
323
324 contents = ensure_use_line(
325 contents,
326 "use tideway::{AppContext, SeaOrmPool};",
327 "use tideway::",
328 );
329 contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
330
331 let has_database_block = contents.contains("DATABASE_URL")
332 || contents.contains("sea_orm::Database::connect")
333 || contents.contains("with_database");
334
335 if !has_database_block {
336 let block = " let database_url = std::env::var(\"DATABASE_URL\").expect(\"DATABASE_URL is not set\");\n let db = sea_orm::Database::connect(&database_url)\n .await\n .expect(\"Failed to connect to database\");\n\n";
337 contents = insert_before_app_builder(contents, block)?;
338 }
339
340 if !contents.contains(".with_database(") {
341 contents = insert_database_into_app_builder(contents)?;
342 }
343
344 write_file(&main_path, &contents)
345 .with_context(|| format!("Failed to write {}", main_path.display()))?;
346 print_success("Wired database into src/main.rs");
347 Ok(())
348}
349
350fn ensure_use_line(mut contents: String, line: &str, anchor: &str) -> String {
351 if contents.contains(line) {
352 return contents;
353 }
354
355 if let Some(pos) = contents.find(anchor) {
356 if let Some(line_end) = contents[pos..].find('\n') {
357 let insert_at = pos + line_end + 1;
358 contents.insert_str(insert_at, &format!("{}\n", line));
359 return contents;
360 }
361 }
362
363 contents = format!("{}\n{}", line, contents);
364 contents
365}
366
367fn insert_before_app_builder(mut contents: String, block: &str) -> Result<String> {
368 if let Some(pos) = contents.find("let app = App::") {
369 contents.insert_str(pos, block);
370 Ok(contents)
371 } else {
372 print_warning("Could not find app builder; skipping auth wiring");
373 Ok(contents)
374 }
375}
376
377fn insert_auth_into_app_builder(mut contents: String) -> Result<String> {
378 if contents.contains("register_module(auth_module)") {
379 return Ok(contents);
380 }
381
382 if let Some(pos) = contents.find("let app = App::") {
383 let line_end = contents[pos..]
384 .find('\n')
385 .map(|idx| pos + idx)
386 .unwrap_or(contents.len());
387 let indent = contents[pos..]
388 .chars()
389 .take_while(|c| c.is_whitespace())
390 .collect::<String>();
391 let insert = format!(
392 "{} .with_global_layer(Extension(auth_provider))\n{} .register_module(auth_module)\n",
393 indent, indent
394 );
395 contents.insert_str(line_end + 1, &insert);
396 Ok(contents)
397 } else {
398 print_warning("Could not find app builder; skipping auth module registration");
399 Ok(contents)
400 }
401}
402
403fn insert_database_into_app_builder(mut contents: String) -> Result<String> {
404 if let Some(pos) = contents.find("let app = App::") {
405 let line_end = contents[pos..]
406 .find('\n')
407 .map(|idx| pos + idx)
408 .unwrap_or(contents.len());
409 let indent = contents[pos..]
410 .chars()
411 .take_while(|c| c.is_whitespace())
412 .collect::<String>();
413 let insert = format!(
414 "{} .with_context(\n{} AppContext::builder()\n{} .with_database(Arc::new(SeaOrmPool::new(db, database_url)))\n{} .build()\n{} )\n",
415 indent, indent, indent, indent, indent
416 );
417 contents.insert_str(line_end + 1, &insert);
418 Ok(contents)
419 } else {
420 print_warning("Could not find app builder; skipping database wiring");
421 Ok(contents)
422 }
423}
424
425fn wire_openapi_in_main(project_dir: &Path) -> Result<()> {
426 let main_path = project_dir.join("src").join("main.rs");
427 if !main_path.exists() {
428 print_warning("src/main.rs not found; skipping auto-wiring");
429 return Ok(());
430 }
431
432 let mut contents = fs::read_to_string(&main_path)
433 .with_context(|| format!("Failed to read {}", main_path.display()))?;
434
435 if contents.contains("openapi::create_openapi_router") || contents.contains("openapi_merge_module") {
436 print_info("OpenAPI already appears wired in main.rs");
437 return Ok(());
438 }
439
440 contents = ensure_use_line(contents, "use tideway::ConfigBuilder;", "use tideway::");
441 if contents.contains("mod config;") {
442 contents = ensure_use_line(contents, "use crate::config::AppConfig;", "use tideway::");
443 }
444 contents = ensure_use_line(contents, "use tideway::openapi;", "use tideway::");
445
446 if !contents.contains("mod openapi_docs;") {
447 if contents.contains("mod routes;") {
448 contents = contents.replace("mod routes;\n", "mod routes;\nmod openapi_docs;\n");
449 } else {
450 contents = format!("mod openapi_docs;\n{}", contents);
451 }
452 }
453
454 let has_config_var = contents.contains("let config = ConfigBuilder::new()")
455 || contents.contains("let config = AppConfig::from_env()");
456 let config_available = contents.contains("ConfigBuilder::new()")
457 || contents.contains("AppConfig::from_env()");
458
459 if !has_config_var && config_available {
460 let config_block = " let config = ConfigBuilder::new()\n .from_env()\n .build()\n .expect(\"Invalid TIDEWAY_* config\");\n\n";
461 contents = insert_before_app_builder(contents, config_block)?;
462 }
463
464 if contents.contains("let config = AppConfig::from_env()") {
465 contents = insert_openapi_into_app_builder(contents, "config.tideway")?;
466 } else {
467 contents = insert_openapi_into_app_builder(contents, "config")?;
468 }
469
470 write_file(&main_path, &contents)
471 .with_context(|| format!("Failed to write {}", main_path.display()))?;
472 print_success("Wired OpenAPI into src/main.rs");
473 Ok(())
474}
475
476fn insert_openapi_into_app_builder(mut contents: String, config_ref: &str) -> Result<String> {
477 if contents.contains("create_openapi_router") {
478 return Ok(contents);
479 }
480
481 if let Some(pos) = contents.find("let app = App::") {
482 if let Some(end_pos) = contents[pos..].find(";\n\n") {
484 let insert_at = pos + end_pos + 3;
485 let block = format!(
486 "\n #[cfg(feature = \"openapi\")]\n if {config_ref}.openapi.enabled {{\n let openapi = tideway::openapi_merge_module!(openapi_docs, ApiDoc);\n let openapi_router = tideway::openapi::create_openapi_router(openapi, &{config_ref}.openapi);\n app = app.merge_router(openapi_router);\n }}\n"
487 );
488 contents.insert_str(insert_at, &block);
489 } else {
490 print_warning("Could not find app builder termination; skipping OpenAPI wiring");
491 }
492 Ok(contents)
493 } else {
494 print_warning("Could not find app builder; skipping OpenAPI wiring");
495 Ok(contents)
496 }
497}
498
499fn ensure_openapi_docs_file(project_dir: &Path) -> Result<()> {
500 let docs_path = project_dir.join("src").join("openapi_docs.rs");
501 if docs_path.exists() {
502 return Ok(());
503 }
504
505 let contents = r#"#[cfg(feature = "openapi")]
506tideway::openapi_doc!(pub(crate) ApiDoc, paths());
507"#;
508
509 if let Some(parent) = docs_path.parent() {
510 ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
511 }
512
513 write_file(&docs_path, &contents)
514 .with_context(|| format!("Failed to write {}", docs_path.display()))?;
515 print_success("Created src/openapi_docs.rs");
516 Ok(())
517}
518
519
520fn write_file_with_force(path: &Path, contents: &str, force: bool) -> Result<()> {
521 if path.exists() && !force {
522 print_warning(&format!(
523 "Skipping {} (use --force to overwrite)",
524 path.display()
525 ));
526 return Ok(());
527 }
528
529 if let Some(parent) = path.parent() {
530 ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
531 }
532
533 write_file(path, contents)
534 .with_context(|| format!("Failed to write {}", path.display()))?;
535 Ok(())
536}
537
538fn project_name_from_cargo(contents: &str, project_dir: &Path) -> String {
539 let doc = contents
540 .parse::<toml_edit::DocumentMut>()
541 .ok()
542 .and_then(|doc| doc["package"]["name"].as_str().map(|s| s.to_string()));
543
544 doc.unwrap_or_else(|| {
545 project_dir
546 .file_name()
547 .and_then(|n| n.to_str())
548 .unwrap_or("my_app")
549 .to_string()
550 })
551 .replace('-', "_")
552}
553
554fn to_pascal_case(s: &str) -> String {
555 s.split('_')
556 .filter(|part| !part.is_empty())
557 .map(|word| {
558 let mut chars = word.chars();
559 match chars.next() {
560 None => String::new(),
561 Some(first) => first.to_uppercase().chain(chars).collect(),
562 }
563 })
564 .collect()
565}
566
567pub fn array_value(values: &[&str]) -> toml_edit::Value {
568 let mut array = toml_edit::Array::new();
569 for value in values {
570 array.push(*value);
571 }
572 toml_edit::Value::Array(array)
573}