1use std::future::Future;
2use std::io::{BufRead, IsTerminal, Write};
3use std::path::Path;
4
5use owo_colors::OwoColorize;
6
7use crate::api::HaError;
8use crate::config;
9use crate::output;
10
11const SEP: &str = "──────────────────────────────────────";
12
13fn sym_q() -> String {
14 "?".green().bold().to_string()
15}
16
17fn sym_ok() -> String {
18 "✔".green().to_string()
19}
20
21fn sym_fail() -> String {
22 "✖".red().to_string()
23}
24
25fn sym_dim(s: &str) -> String {
26 s.dimmed().to_string()
27}
28
29fn prompt_optional<R: BufRead, W: Write>(
30 r: &mut R,
31 w: &mut W,
32 label: &str,
33 default: &str,
34) -> String {
35 let _ = write!(w, "{} {} [{}]: ", sym_q(), label, sym_dim(default));
36 let _ = w.flush();
37 let mut input = String::new();
38 r.read_line(&mut input).unwrap_or(0);
39 let trimmed = input.trim().to_owned();
40 if trimmed.is_empty() {
41 default.to_owned()
42 } else {
43 trimmed
44 }
45}
46
47fn prompt_required<R: BufRead, W: Write>(
48 r: &mut R,
49 w: &mut W,
50 label: &str,
51 hint: &str,
52) -> Option<String> {
53 loop {
54 let _ = write!(
55 w,
56 "{} {} {}: ",
57 sym_q(),
58 label,
59 sym_dim(&format!("[{hint}]"))
60 );
61 let _ = w.flush();
62 let mut input = String::new();
63 match r.read_line(&mut input) {
64 Ok(0) | Err(_) => return None,
65 Ok(_) => {}
66 }
67 let trimmed = input.trim().to_owned();
68 if !trimmed.is_empty() {
69 return Some(trimmed);
70 }
71 let _ = writeln!(w, " {} {} is required.", sym_fail(), label);
72 }
73}
74
75fn prompt_credential_update<R: BufRead, W: Write>(
76 r: &mut R,
77 w: &mut W,
78 label: &str,
79 current: &str,
80) -> Option<String> {
81 let hint = format!("{} (Enter to keep)", output::mask_credential(current));
82 let _ = write!(w, "{} {} {}: ", sym_q(), label, sym_dim(&hint));
83 let _ = w.flush();
84 let mut input = String::new();
85 match r.read_line(&mut input) {
86 Ok(0) | Err(_) => return None,
87 Ok(_) => {}
88 }
89 let trimmed = input.trim().to_owned();
90 Some(if trimmed.is_empty() {
91 current.to_owned()
92 } else {
93 trimmed
94 })
95}
96
97fn prompt_confirm<R: BufRead, W: Write>(
98 r: &mut R,
99 w: &mut W,
100 label: &str,
101 default_yes: bool,
102) -> bool {
103 let hint = if default_yes { "Y/n" } else { "y/N" };
104 let _ = write!(w, "{} {} [{}]: ", sym_q(), label, sym_dim(hint));
105 let _ = w.flush();
106 let mut input = String::new();
107 r.read_line(&mut input).unwrap_or(0);
108 match input.trim().to_lowercase().as_str() {
109 "y" | "yes" => true,
110 "n" | "no" => false,
111 _ => default_yes,
112 }
113}
114
115fn print_json_schema(config_path: &Path) {
116 let path_str = config_path.to_string_lossy();
117 let schema = serde_json::json!({
118 "configPath": path_str,
119 "pathResolution": config::schema_config_path_description(),
120 "recommendedPermissions": config::recommended_permissions(config_path),
121 "tokenInstructions": {
122 "steps": [
123 "Open Home Assistant in your browser",
124 "Go to Settings → Profile (bottom left)",
125 "Scroll to 'Long-Lived Access Tokens'",
126 "Click 'Create Token', give it a name, copy it"
127 ]
128 },
129 "requiredFields": ["url", "token"],
130 "example": {
131 "configFile": path_str,
132 "format": "[default]\nurl = \"http://homeassistant.local:8123\"\ntoken = \"YOUR_LONG_LIVED_TOKEN\""
133 }
134 });
135 println!(
136 "{}",
137 serde_json::to_string_pretty(&schema).expect("serialize")
138 );
139}
140
141pub async fn run_init<R, W, Fut>(
146 reader: &mut R,
147 writer: &mut W,
148 config_path: &Path,
149 profile_arg: Option<&str>,
150 validate: impl Fn(String, String) -> Fut,
151) -> Result<(), HaError>
152where
153 R: BufRead,
154 W: Write,
155 Fut: Future<Output = Option<String>>,
156{
157 let _ = writeln!(writer, "\nHome Assistant CLI");
158 let _ = writeln!(writer, "{SEP}\n");
159
160 let existing_profiles = config::read_profile_names(config_path);
161 let is_first_setup = existing_profiles.is_empty();
162
163 let (profile_name, is_update) = if let Some(p) = profile_arg {
164 let is_update = existing_profiles.contains(&p.to_owned());
165 (p.to_owned(), is_update)
166 } else if is_first_setup {
167 ("default".to_owned(), false)
168 } else {
169 if existing_profiles.len() == 1 {
170 let p = &existing_profiles[0];
171 let cred = config::read_profile_credentials(config_path, p)
172 .map(|(url, _)| format!(" {}", output::mask_credential(&url)))
173 .unwrap_or_default();
174 let _ = writeln!(writer, " Profile: {}{}\n", p.bold(), sym_dim(&cred));
175 } else {
176 let _ = writeln!(writer, " Profiles:");
177 for p in &existing_profiles {
178 let cred = config::read_profile_credentials(config_path, p)
179 .map(|(url, _)| format!(" {}", output::mask_credential(&url)))
180 .unwrap_or_default();
181 let _ = writeln!(writer, " {}{}", p, sym_dim(&cred));
182 }
183 let _ = writeln!(writer);
184 }
185
186 let action = prompt_optional(reader, writer, "Action [update/add]", "update");
187 let _ = writeln!(writer);
188
189 if action.trim().eq_ignore_ascii_case("add") {
190 let Some(name) = prompt_required(reader, writer, "Profile name", "e.g. prod") else {
191 let _ = writeln!(writer, "\nAborted.");
192 return Ok(());
193 };
194 (name, false)
195 } else if existing_profiles.len() == 1 {
196 (existing_profiles[0].clone(), true)
197 } else {
198 let options = existing_profiles.join("/");
199 let chosen = prompt_optional(
200 reader,
201 writer,
202 &format!("Profile [{}]", options),
203 &existing_profiles[0],
204 );
205 let profile = chosen.trim().to_owned();
206 if !existing_profiles.contains(&profile) {
207 let _ = writeln!(writer, "\n {} Unknown profile '{}'.", sym_fail(), profile);
208 return Ok(());
209 }
210 (profile, true)
211 }
212 };
213
214 let (url, token) = if is_update {
215 let (cur_url, cur_token) = config::read_profile_credentials(config_path, &profile_name)
216 .expect("update mode requires existing credentials");
217 let Some(url) = prompt_credential_update(reader, writer, "URL", &cur_url) else {
218 let _ = writeln!(writer, "\nAborted.");
219 return Ok(());
220 };
221 let Some(token) = prompt_credential_update(reader, writer, "Token", &cur_token) else {
222 let _ = writeln!(writer, "\nAborted.");
223 return Ok(());
224 };
225 (url, token)
226 } else {
227 let Some(url) = prompt_required(
228 reader,
229 writer,
230 "Home Assistant URL",
231 "http://homeassistant.local:8123",
232 ) else {
233 let _ = writeln!(writer, "\nAborted.");
234 return Ok(());
235 };
236 let token_url = format!("{}/profile/security", url.trim_end_matches('/'));
237 let _ = writeln!(
238 writer,
239 " {} {} → Long-Lived Access Tokens → Create Token",
240 sym_dim("→"),
241 sym_dim(&token_url)
242 );
243 let Some(token) = prompt_required(
244 reader,
245 writer,
246 "Long-Lived Access Token",
247 "paste token here",
248 ) else {
249 let _ = writeln!(writer, "\nAborted.");
250 return Ok(());
251 };
252 (url, token)
253 };
254
255 let _ = write!(writer, "\n Verifying credentials...");
256 let _ = writer.flush();
257 let validation = validate(url.clone(), token.clone()).await;
258
259 let save = match validation {
260 Some(name) => {
261 let _ = writeln!(writer, " {} Connected to {}", sym_ok(), name.bold());
262 true
263 }
264 None => {
265 let _ = writeln!(writer, " {} Could not connect.", sym_fail());
266 prompt_confirm(reader, writer, "Save anyway?", false)
267 }
268 };
269
270 if !save {
271 let _ = writeln!(writer, "\nAborted. Config not saved.");
272 let _ = writer.flush();
273 return Ok(());
274 }
275
276 config::write_profile(config_path, &profile_name, &url, &token)?;
277
278 let pfx = if profile_name == "default" {
279 "ha".to_owned()
280 } else {
281 format!("ha --profile {}", profile_name)
282 };
283
284 let _ = writeln!(writer, "\n{SEP}");
285 let _ = writeln!(
286 writer,
287 " {} Configuration saved to {}",
288 sym_ok(),
289 config_path.display()
290 );
291 let _ = writeln!(writer);
292 let _ = writeln!(writer, " {}:", "Next steps".bold());
293 let _ = writeln!(
294 writer,
295 " {} entity list {}",
296 pfx,
297 sym_dim("# list all entities")
298 );
299 let _ = writeln!(
300 writer,
301 " {} service list {}",
302 pfx,
303 sym_dim("# list available services")
304 );
305 let _ = writeln!(
306 writer,
307 " {} event list {}",
308 pfx,
309 sym_dim("# list event types")
310 );
311 let _ = writeln!(
312 writer,
313 " {} completions zsh {}",
314 pfx,
315 sym_dim("# shell completions")
316 );
317 let _ = writeln!(writer);
318 let _ = writer.flush();
319
320 Ok(())
321}
322
323pub async fn init(profile_arg: Option<String>) {
325 let config_path = config::config_path();
326
327 if !std::io::stdout().is_terminal() {
328 print_json_schema(&config_path);
329 return;
330 }
331
332 let stdin = std::io::stdin();
333 let stdout = std::io::stdout();
334 let mut reader = std::io::BufReader::new(stdin.lock());
335 let mut writer = std::io::BufWriter::new(stdout.lock());
336
337 if let Err(e) = run_init(
338 &mut reader,
339 &mut writer,
340 &config_path,
341 profile_arg.as_deref(),
342 |url, token| async move {
343 let client = crate::api::HaClient::new(&url, &token);
344 client.validate().await.ok()
345 },
346 )
347 .await
348 {
349 eprintln!("{} {e}", sym_fail());
350 std::process::exit(crate::output::exit_codes::GENERAL_ERROR);
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use std::io::Cursor;
358 use tempfile::TempDir;
359
360 fn fake_path(dir: &TempDir) -> std::path::PathBuf {
361 dir.path().join("config.toml")
362 }
363
364 #[tokio::test]
365 async fn init_writes_config_on_valid_credentials() {
366 let dir = TempDir::new().unwrap();
367 let path = fake_path(&dir);
368 let input = b"http://ha.local:8123\nmytoken\n";
369 let mut reader = Cursor::new(input.as_ref());
370 let mut writer = Vec::<u8>::new();
371
372 run_init(
373 &mut reader,
374 &mut writer,
375 &path,
376 None,
377 |_url, _token| async { Some("Home Assistant".to_string()) },
378 )
379 .await
380 .unwrap();
381
382 let saved = std::fs::read_to_string(&path).unwrap();
383 assert!(saved.contains("http://ha.local:8123"));
384 assert!(saved.contains("mytoken"));
385 let output = String::from_utf8_lossy(&writer);
386 assert!(output.contains("http://ha.local:8123/profile/security"));
387 assert!(
388 output.contains("Long-Lived Access Tokens"),
389 "should show token creation instructions"
390 );
391 assert!(
392 output.contains("Configuration saved to"),
393 "should show config saved confirmation"
394 );
395 assert!(
396 output.contains("Next steps"),
397 "should show next steps block"
398 );
399 }
400
401 #[tokio::test]
402 async fn init_uses_default_profile_on_first_setup() {
403 let dir = TempDir::new().unwrap();
404 let path = fake_path(&dir);
405 let input = b"http://ha.local:8123\nmytoken\n";
406 let mut reader = Cursor::new(input.as_ref());
407 let mut writer = Vec::<u8>::new();
408
409 run_init(&mut reader, &mut writer, &path, None, |_, _| async {
410 Some("HA".into())
411 })
412 .await
413 .unwrap();
414
415 let saved = std::fs::read_to_string(&path).unwrap();
416 assert!(saved.contains("[default]"));
417 }
418
419 #[tokio::test]
420 async fn init_aborts_when_validation_fails_and_user_declines() {
421 let dir = TempDir::new().unwrap();
422 let path = fake_path(&dir);
423 let input = b"http://ha.local:8123\nbadtoken\nn\n";
424 let mut reader = Cursor::new(input.as_ref());
425 let mut writer = Vec::<u8>::new();
426
427 run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
428 .await
429 .unwrap();
430
431 assert!(!path.exists(), "config must not be written after abort");
432 }
433
434 #[tokio::test]
435 async fn init_saves_when_validation_fails_but_user_forces() {
436 let dir = TempDir::new().unwrap();
437 let path = fake_path(&dir);
438 let input = b"http://ha.local:8123\nbadtoken\ny\n";
439 let mut reader = Cursor::new(input.as_ref());
440 let mut writer = Vec::<u8>::new();
441
442 run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
443 .await
444 .unwrap();
445
446 assert!(path.exists());
447 }
448
449 #[tokio::test]
450 async fn init_with_profile_arg_writes_named_profile() {
451 let dir = TempDir::new().unwrap();
452 let path = fake_path(&dir);
453 let input = b"http://ha.prod:8123\nprodtoken\n";
454 let mut reader = Cursor::new(input.as_ref());
455 let mut writer = Vec::<u8>::new();
456
457 run_init(
458 &mut reader,
459 &mut writer,
460 &path,
461 Some("prod"),
462 |_, _| async { Some("HA".into()) },
463 )
464 .await
465 .unwrap();
466
467 let saved = std::fs::read_to_string(&path).unwrap();
468 assert!(saved.contains("[prod]"));
469 }
470
471 #[tokio::test]
472 async fn init_update_keeps_values_on_enter() {
473 let dir = TempDir::new().unwrap();
474 let path = fake_path(&dir);
475 std::fs::write(
476 &path,
477 "[default]\nurl = \"http://ha.local:8123\"\ntoken = \"existing-token\"\n",
478 )
479 .unwrap();
480
481 let input = b"\n\n\n";
483 let mut reader = Cursor::new(input.as_ref());
484 let mut writer = Vec::<u8>::new();
485
486 run_init(&mut reader, &mut writer, &path, None, |_, _| async {
487 Some("HA".into())
488 })
489 .await
490 .unwrap();
491
492 let saved = std::fs::read_to_string(&path).unwrap();
493 assert!(saved.contains("existing-token"));
494 }
495
496 #[tokio::test]
497 async fn init_outro_includes_profile_flag_for_non_default() {
498 let dir = TempDir::new().unwrap();
499 let path = fake_path(&dir);
500 let input = b"http://ha.local:8123\ntoken\n";
501 let mut reader = Cursor::new(input.as_ref());
502 let mut writer = Vec::<u8>::new();
503
504 run_init(
505 &mut reader,
506 &mut writer,
507 &path,
508 Some("staging"),
509 |_, _| async { Some("HA".into()) },
510 )
511 .await
512 .unwrap();
513
514 let output = String::from_utf8_lossy(&writer);
515 assert!(output.contains("--profile staging"));
516 assert!(output.contains("Next steps"));
517 assert!(output.contains("entity list"));
518 assert!(output.contains("service list"));
519 assert!(output.contains("completions zsh"));
520 }
521
522 #[tokio::test]
523 async fn init_aborts_on_eof() {
524 let dir = TempDir::new().unwrap();
525 let path = fake_path(&dir);
526 let input = b"";
527 let mut reader = Cursor::new(input.as_ref());
528 let mut writer = Vec::<u8>::new();
529
530 run_init(&mut reader, &mut writer, &path, None, |_, _| async {
531 Some("HA".into())
532 })
533 .await
534 .unwrap();
535
536 assert!(!path.exists());
537 let output = String::from_utf8_lossy(&writer);
538 assert!(output.contains("Aborted"));
539 }
540}