Skip to main content

homeassistant_cli/commands/
init.rs

1use std::future::Future;
2use std::io::{BufRead, IsTerminal, Write};
3use std::path::Path;
4
5use owo_colors::OwoColorize;
6
7use crate::api::HaError;
8use crate::config;
9use crate::output;
10
11const SEP: &str = "──────────────────────────────────────";
12
13fn sym_q() -> String {
14    "?".green().bold().to_string()
15}
16
17fn sym_ok() -> String {
18    "✔".green().to_string()
19}
20
21fn sym_fail() -> String {
22    "✖".red().to_string()
23}
24
25fn sym_dim(s: &str) -> String {
26    s.dimmed().to_string()
27}
28
29fn prompt_optional<R: BufRead, W: Write>(
30    r: &mut R,
31    w: &mut W,
32    label: &str,
33    default: &str,
34) -> String {
35    let _ = write!(w, "{} {}  [{}]: ", sym_q(), label, sym_dim(default));
36    let _ = w.flush();
37    let mut input = String::new();
38    r.read_line(&mut input).unwrap_or(0);
39    let trimmed = input.trim().to_owned();
40    if trimmed.is_empty() {
41        default.to_owned()
42    } else {
43        trimmed
44    }
45}
46
47fn prompt_required<R: BufRead, W: Write>(
48    r: &mut R,
49    w: &mut W,
50    label: &str,
51    hint: &str,
52) -> Option<String> {
53    loop {
54        let _ = write!(
55            w,
56            "{} {}  {}: ",
57            sym_q(),
58            label,
59            sym_dim(&format!("[{hint}]"))
60        );
61        let _ = w.flush();
62        let mut input = String::new();
63        match r.read_line(&mut input) {
64            Ok(0) | Err(_) => return None,
65            Ok(_) => {}
66        }
67        let trimmed = input.trim().to_owned();
68        if !trimmed.is_empty() {
69            return Some(trimmed);
70        }
71        let _ = writeln!(w, "  {} {} is required.", sym_fail(), label);
72    }
73}
74
75fn prompt_credential_update<R: BufRead, W: Write>(
76    r: &mut R,
77    w: &mut W,
78    label: &str,
79    current: &str,
80) -> Option<String> {
81    let hint = format!("{} (Enter to keep)", output::mask_credential(current));
82    let _ = write!(w, "{} {}  {}: ", sym_q(), label, sym_dim(&hint));
83    let _ = w.flush();
84    let mut input = String::new();
85    match r.read_line(&mut input) {
86        Ok(0) | Err(_) => return None,
87        Ok(_) => {}
88    }
89    let trimmed = input.trim().to_owned();
90    Some(if trimmed.is_empty() {
91        current.to_owned()
92    } else {
93        trimmed
94    })
95}
96
97fn prompt_confirm<R: BufRead, W: Write>(
98    r: &mut R,
99    w: &mut W,
100    label: &str,
101    default_yes: bool,
102) -> bool {
103    let hint = if default_yes { "Y/n" } else { "y/N" };
104    let _ = write!(w, "{} {}  [{}]: ", sym_q(), label, sym_dim(hint));
105    let _ = w.flush();
106    let mut input = String::new();
107    r.read_line(&mut input).unwrap_or(0);
108    match input.trim().to_lowercase().as_str() {
109        "y" | "yes" => true,
110        "n" | "no" => false,
111        _ => default_yes,
112    }
113}
114
115fn print_json_schema(config_path: &Path) {
116    let path_str = config_path.to_string_lossy();
117    let schema = serde_json::json!({
118        "configPath": path_str,
119        "pathResolution": config::schema_config_path_description(),
120        "recommendedPermissions": config::recommended_permissions(config_path),
121        "tokenInstructions": {
122            "steps": [
123                "Open Home Assistant in your browser",
124                "Go to Settings → Profile (bottom left)",
125                "Scroll to 'Long-Lived Access Tokens'",
126                "Click 'Create Token', give it a name, copy it"
127            ]
128        },
129        "requiredFields": ["url", "token"],
130        "example": {
131            "configFile": path_str,
132            "format": "[default]\nurl = \"http://homeassistant.local:8123\"\ntoken = \"YOUR_LONG_LIVED_TOKEN\""
133        }
134    });
135    println!(
136        "{}",
137        serde_json::to_string_pretty(&schema).expect("serialize")
138    );
139}
140
141/// Interactive init flow with injectable IO and async validator for testing.
142///
143/// `validate` receives (url, token) and returns `Some(display_name)` on success
144/// or `None` on auth failure.
145pub async fn run_init<R, W, Fut>(
146    reader: &mut R,
147    writer: &mut W,
148    config_path: &Path,
149    profile_arg: Option<&str>,
150    validate: impl Fn(String, String) -> Fut,
151) -> Result<(), HaError>
152where
153    R: BufRead,
154    W: Write,
155    Fut: Future<Output = Option<String>>,
156{
157    let _ = writeln!(writer, "\nHome Assistant CLI");
158    let _ = writeln!(writer, "{SEP}\n");
159
160    let existing_profiles = config::read_profile_names(config_path);
161    let is_first_setup = existing_profiles.is_empty();
162
163    let (profile_name, is_update) = if let Some(p) = profile_arg {
164        let is_update = existing_profiles.contains(&p.to_owned());
165        (p.to_owned(), is_update)
166    } else if is_first_setup {
167        ("default".to_owned(), false)
168    } else {
169        if existing_profiles.len() == 1 {
170            let p = &existing_profiles[0];
171            let cred = config::read_profile_credentials(config_path, p)
172                .map(|(url, _)| format!("  {}", output::mask_credential(&url)))
173                .unwrap_or_default();
174            let _ = writeln!(writer, "  Profile: {}{}\n", p.bold(), sym_dim(&cred));
175        } else {
176            let _ = writeln!(writer, "  Profiles:");
177            for p in &existing_profiles {
178                let cred = config::read_profile_credentials(config_path, p)
179                    .map(|(url, _)| format!("  {}", output::mask_credential(&url)))
180                    .unwrap_or_default();
181                let _ = writeln!(writer, "    {}{}", p, sym_dim(&cred));
182            }
183            let _ = writeln!(writer);
184        }
185
186        let action = prompt_optional(reader, writer, "Action  [update/add]", "update");
187        let _ = writeln!(writer);
188
189        if action.trim().eq_ignore_ascii_case("add") {
190            let Some(name) = prompt_required(reader, writer, "Profile name", "e.g. prod") else {
191                let _ = writeln!(writer, "\nAborted.");
192                return Ok(());
193            };
194            (name, false)
195        } else if existing_profiles.len() == 1 {
196            (existing_profiles[0].clone(), true)
197        } else {
198            let options = existing_profiles.join("/");
199            let chosen = prompt_optional(
200                reader,
201                writer,
202                &format!("Profile  [{}]", options),
203                &existing_profiles[0],
204            );
205            let profile = chosen.trim().to_owned();
206            if !existing_profiles.contains(&profile) {
207                let _ = writeln!(writer, "\n  {} Unknown profile '{}'.", sym_fail(), profile);
208                return Ok(());
209            }
210            (profile, true)
211        }
212    };
213
214    let (url, token) = if is_update {
215        let (cur_url, cur_token) = config::read_profile_credentials(config_path, &profile_name)
216            .expect("update mode requires existing credentials");
217        let Some(url) = prompt_credential_update(reader, writer, "URL", &cur_url) else {
218            let _ = writeln!(writer, "\nAborted.");
219            return Ok(());
220        };
221        let Some(token) = prompt_credential_update(reader, writer, "Token", &cur_token) else {
222            let _ = writeln!(writer, "\nAborted.");
223            return Ok(());
224        };
225        (url, token)
226    } else {
227        let Some(url) = prompt_required(
228            reader,
229            writer,
230            "Home Assistant URL",
231            "http://homeassistant.local:8123",
232        ) else {
233            let _ = writeln!(writer, "\nAborted.");
234            return Ok(());
235        };
236        let token_url = format!("{}/profile/security", url.trim_end_matches('/'));
237        let _ = writeln!(
238            writer,
239            "  {} {} → Long-Lived Access Tokens → Create Token",
240            sym_dim("→"),
241            sym_dim(&token_url)
242        );
243        let Some(token) = prompt_required(
244            reader,
245            writer,
246            "Long-Lived Access Token",
247            "paste token here",
248        ) else {
249            let _ = writeln!(writer, "\nAborted.");
250            return Ok(());
251        };
252        (url, token)
253    };
254
255    let _ = write!(writer, "\n  Verifying credentials...");
256    let _ = writer.flush();
257    let validation = validate(url.clone(), token.clone()).await;
258
259    let save = match validation {
260        Some(name) => {
261            let _ = writeln!(writer, " {} Connected to {}", sym_ok(), name.bold());
262            true
263        }
264        None => {
265            let _ = writeln!(writer, " {} Could not connect.", sym_fail());
266            prompt_confirm(reader, writer, "Save anyway?", false)
267        }
268    };
269
270    if !save {
271        let _ = writeln!(writer, "\nAborted. Config not saved.");
272        let _ = writer.flush();
273        return Ok(());
274    }
275
276    config::write_profile(config_path, &profile_name, &url, &token)?;
277
278    let pfx = if profile_name == "default" {
279        "ha".to_owned()
280    } else {
281        format!("ha --profile {}", profile_name)
282    };
283
284    let _ = writeln!(writer, "\n{SEP}");
285    let _ = writeln!(
286        writer,
287        "  {} Configuration saved to {}",
288        sym_ok(),
289        config_path.display()
290    );
291    let _ = writeln!(writer);
292    let _ = writeln!(writer, "  {}:", "Next steps".bold());
293    let _ = writeln!(
294        writer,
295        "    {} entity list          {}",
296        pfx,
297        sym_dim("# list all entities")
298    );
299    let _ = writeln!(
300        writer,
301        "    {} service list         {}",
302        pfx,
303        sym_dim("# list available services")
304    );
305    let _ = writeln!(
306        writer,
307        "    {} event list           {}",
308        pfx,
309        sym_dim("# list event types")
310    );
311    let _ = writeln!(
312        writer,
313        "    {} completions zsh      {}",
314        pfx,
315        sym_dim("# shell completions")
316    );
317    let _ = writeln!(writer);
318    let _ = writer.flush();
319
320    Ok(())
321}
322
323/// Entry point from main — uses real stdin/stdout and live API validation.
324pub async fn init(profile_arg: Option<String>) {
325    let config_path = config::config_path();
326
327    if !std::io::stdout().is_terminal() {
328        print_json_schema(&config_path);
329        return;
330    }
331
332    let stdin = std::io::stdin();
333    let stdout = std::io::stdout();
334    let mut reader = std::io::BufReader::new(stdin.lock());
335    let mut writer = std::io::BufWriter::new(stdout.lock());
336
337    if let Err(e) = run_init(
338        &mut reader,
339        &mut writer,
340        &config_path,
341        profile_arg.as_deref(),
342        |url, token| async move {
343            let client = crate::api::HaClient::new(&url, &token);
344            client.validate().await.ok()
345        },
346    )
347    .await
348    {
349        eprintln!("{} {e}", sym_fail());
350        std::process::exit(crate::output::exit_codes::GENERAL_ERROR);
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use std::io::Cursor;
358    use tempfile::TempDir;
359
360    fn fake_path(dir: &TempDir) -> std::path::PathBuf {
361        dir.path().join("config.toml")
362    }
363
364    #[tokio::test]
365    async fn init_writes_config_on_valid_credentials() {
366        let dir = TempDir::new().unwrap();
367        let path = fake_path(&dir);
368        let input = b"http://ha.local:8123\nmytoken\n";
369        let mut reader = Cursor::new(input.as_ref());
370        let mut writer = Vec::<u8>::new();
371
372        run_init(
373            &mut reader,
374            &mut writer,
375            &path,
376            None,
377            |_url, _token| async { Some("Home Assistant".to_string()) },
378        )
379        .await
380        .unwrap();
381
382        let saved = std::fs::read_to_string(&path).unwrap();
383        assert!(saved.contains("http://ha.local:8123"));
384        assert!(saved.contains("mytoken"));
385        let output = String::from_utf8_lossy(&writer);
386        assert!(output.contains("http://ha.local:8123/profile/security"));
387        assert!(
388            output.contains("Long-Lived Access Tokens"),
389            "should show token creation instructions"
390        );
391        assert!(
392            output.contains("Configuration saved to"),
393            "should show config saved confirmation"
394        );
395        assert!(
396            output.contains("Next steps"),
397            "should show next steps block"
398        );
399    }
400
401    #[tokio::test]
402    async fn init_uses_default_profile_on_first_setup() {
403        let dir = TempDir::new().unwrap();
404        let path = fake_path(&dir);
405        let input = b"http://ha.local:8123\nmytoken\n";
406        let mut reader = Cursor::new(input.as_ref());
407        let mut writer = Vec::<u8>::new();
408
409        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
410            Some("HA".into())
411        })
412        .await
413        .unwrap();
414
415        let saved = std::fs::read_to_string(&path).unwrap();
416        assert!(saved.contains("[default]"));
417    }
418
419    #[tokio::test]
420    async fn init_aborts_when_validation_fails_and_user_declines() {
421        let dir = TempDir::new().unwrap();
422        let path = fake_path(&dir);
423        let input = b"http://ha.local:8123\nbadtoken\nn\n";
424        let mut reader = Cursor::new(input.as_ref());
425        let mut writer = Vec::<u8>::new();
426
427        run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
428            .await
429            .unwrap();
430
431        assert!(!path.exists(), "config must not be written after abort");
432    }
433
434    #[tokio::test]
435    async fn init_saves_when_validation_fails_but_user_forces() {
436        let dir = TempDir::new().unwrap();
437        let path = fake_path(&dir);
438        let input = b"http://ha.local:8123\nbadtoken\ny\n";
439        let mut reader = Cursor::new(input.as_ref());
440        let mut writer = Vec::<u8>::new();
441
442        run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
443            .await
444            .unwrap();
445
446        assert!(path.exists());
447    }
448
449    #[tokio::test]
450    async fn init_with_profile_arg_writes_named_profile() {
451        let dir = TempDir::new().unwrap();
452        let path = fake_path(&dir);
453        let input = b"http://ha.prod:8123\nprodtoken\n";
454        let mut reader = Cursor::new(input.as_ref());
455        let mut writer = Vec::<u8>::new();
456
457        run_init(
458            &mut reader,
459            &mut writer,
460            &path,
461            Some("prod"),
462            |_, _| async { Some("HA".into()) },
463        )
464        .await
465        .unwrap();
466
467        let saved = std::fs::read_to_string(&path).unwrap();
468        assert!(saved.contains("[prod]"));
469    }
470
471    #[tokio::test]
472    async fn init_update_keeps_values_on_enter() {
473        let dir = TempDir::new().unwrap();
474        let path = fake_path(&dir);
475        std::fs::write(
476            &path,
477            "[default]\nurl = \"http://ha.local:8123\"\ntoken = \"existing-token\"\n",
478        )
479        .unwrap();
480
481        // \n accepts "update" default at action prompt, then Enter to keep both fields
482        let input = b"\n\n\n";
483        let mut reader = Cursor::new(input.as_ref());
484        let mut writer = Vec::<u8>::new();
485
486        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
487            Some("HA".into())
488        })
489        .await
490        .unwrap();
491
492        let saved = std::fs::read_to_string(&path).unwrap();
493        assert!(saved.contains("existing-token"));
494    }
495
496    #[tokio::test]
497    async fn init_outro_includes_profile_flag_for_non_default() {
498        let dir = TempDir::new().unwrap();
499        let path = fake_path(&dir);
500        let input = b"http://ha.local:8123\ntoken\n";
501        let mut reader = Cursor::new(input.as_ref());
502        let mut writer = Vec::<u8>::new();
503
504        run_init(
505            &mut reader,
506            &mut writer,
507            &path,
508            Some("staging"),
509            |_, _| async { Some("HA".into()) },
510        )
511        .await
512        .unwrap();
513
514        let output = String::from_utf8_lossy(&writer);
515        assert!(output.contains("--profile staging"));
516        assert!(output.contains("Next steps"));
517        assert!(output.contains("entity list"));
518        assert!(output.contains("service list"));
519        assert!(output.contains("completions zsh"));
520    }
521
522    #[tokio::test]
523    async fn init_aborts_on_eof() {
524        let dir = TempDir::new().unwrap();
525        let path = fake_path(&dir);
526        let input = b"";
527        let mut reader = Cursor::new(input.as_ref());
528        let mut writer = Vec::<u8>::new();
529
530        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
531            Some("HA".into())
532        })
533        .await
534        .unwrap();
535
536        assert!(!path.exists());
537        let output = String::from_utf8_lossy(&writer);
538        assert!(output.contains("Aborted"));
539    }
540}