Skip to main content

homeassistant_cli/commands/
init.rs

1use std::future::Future;
2use std::io::{BufRead, IsTerminal, Write};
3use std::path::Path;
4
5use owo_colors::OwoColorize;
6
7use crate::api::HaError;
8use crate::config;
9use crate::output;
10
11const SEP: &str = "──────────────────────────────────────";
12
13fn sym_q() -> String {
14    "?".green().bold().to_string()
15}
16
17fn sym_ok() -> String {
18    "✔".green().to_string()
19}
20
21fn sym_fail() -> String {
22    "✖".red().to_string()
23}
24
25fn sym_dim(s: &str) -> String {
26    s.dimmed().to_string()
27}
28
29fn prompt_optional<R: BufRead, W: Write>(
30    r: &mut R,
31    w: &mut W,
32    label: &str,
33    default: &str,
34) -> String {
35    let _ = write!(w, "{} {}  [{}]: ", sym_q(), label, sym_dim(default));
36    let _ = w.flush();
37    let mut input = String::new();
38    r.read_line(&mut input).unwrap_or(0);
39    let trimmed = input.trim().to_owned();
40    if trimmed.is_empty() {
41        default.to_owned()
42    } else {
43        trimmed
44    }
45}
46
47fn prompt_required<R: BufRead, W: Write>(
48    r: &mut R,
49    w: &mut W,
50    label: &str,
51    hint: &str,
52) -> Option<String> {
53    loop {
54        let _ = write!(
55            w,
56            "{} {}  {}: ",
57            sym_q(),
58            label,
59            sym_dim(&format!("[{hint}]"))
60        );
61        let _ = w.flush();
62        let mut input = String::new();
63        match r.read_line(&mut input) {
64            Ok(0) | Err(_) => return None,
65            Ok(_) => {}
66        }
67        let trimmed = input.trim().to_owned();
68        if !trimmed.is_empty() {
69            return Some(trimmed);
70        }
71        let _ = writeln!(w, "  {} {} is required.", sym_fail(), label);
72    }
73}
74
75fn prompt_credential_update<R: BufRead, W: Write>(
76    r: &mut R,
77    w: &mut W,
78    label: &str,
79    current: &str,
80) -> Option<String> {
81    let hint = format!("{} (Enter to keep)", output::mask_credential(current));
82    let _ = write!(w, "{} {}  {}: ", sym_q(), label, sym_dim(&hint));
83    let _ = w.flush();
84    let mut input = String::new();
85    match r.read_line(&mut input) {
86        Ok(0) | Err(_) => return None,
87        Ok(_) => {}
88    }
89    let trimmed = input.trim().to_owned();
90    Some(if trimmed.is_empty() {
91        current.to_owned()
92    } else {
93        trimmed
94    })
95}
96
97fn prompt_confirm<R: BufRead, W: Write>(
98    r: &mut R,
99    w: &mut W,
100    label: &str,
101    default_yes: bool,
102) -> bool {
103    let hint = if default_yes { "Y/n" } else { "y/N" };
104    let _ = write!(w, "{} {}  [{}]: ", sym_q(), label, sym_dim(hint));
105    let _ = w.flush();
106    let mut input = String::new();
107    r.read_line(&mut input).unwrap_or(0);
108    match input.trim().to_lowercase().as_str() {
109        "y" | "yes" => true,
110        "n" | "no" => false,
111        _ => default_yes,
112    }
113}
114
115fn print_json_schema(config_path: &Path) {
116    let path_str = config_path.to_string_lossy();
117    let schema = serde_json::json!({
118        "configPath": path_str,
119        "pathResolution": config::schema_config_path_description(),
120        "recommendedPermissions": config::recommended_permissions(config_path),
121        "tokenInstructions": {
122            "steps": [
123                "Open Home Assistant in your browser",
124                "Go to Settings → Profile (bottom left)",
125                "Scroll to 'Long-Lived Access Tokens'",
126                "Click 'Create Token', give it a name, copy it"
127            ]
128        },
129        "requiredFields": ["url", "token"],
130        "example": {
131            "configFile": path_str,
132            "format": "[default]\nurl = \"http://homeassistant.local:8123\"\ntoken = \"YOUR_LONG_LIVED_TOKEN\""
133        }
134    });
135    println!(
136        "{}",
137        serde_json::to_string_pretty(&schema).expect("serialize")
138    );
139}
140
141/// Interactive init flow with injectable IO and async validator for testing.
142///
143/// `validate` receives (url, token) and returns `Some(display_name)` on success
144/// or `None` on auth failure.
145pub async fn run_init<R, W, Fut>(
146    reader: &mut R,
147    writer: &mut W,
148    config_path: &Path,
149    profile_arg: Option<&str>,
150    validate: impl Fn(String, String) -> Fut,
151) -> Result<(), HaError>
152where
153    R: BufRead,
154    W: Write,
155    Fut: Future<Output = Option<String>>,
156{
157    let _ = writeln!(writer, "\nHome Assistant CLI");
158    let _ = writeln!(writer, "{SEP}\n");
159
160    let existing_profiles = config::read_profile_names(config_path);
161    let is_first_setup = existing_profiles.is_empty();
162
163    let (profile_name, is_update) = if let Some(p) = profile_arg {
164        let is_update = existing_profiles.contains(&p.to_owned());
165        (p.to_owned(), is_update)
166    } else if is_first_setup {
167        ("default".to_owned(), false)
168    } else {
169        if existing_profiles.len() == 1 {
170            let p = &existing_profiles[0];
171            let cred = config::read_profile_credentials(config_path, p)
172                .map(|(url, _)| format!("  {}", output::mask_credential(&url)))
173                .unwrap_or_default();
174            let _ = writeln!(writer, "  Profile: {}{}\n", p.bold(), sym_dim(&cred));
175        } else {
176            let _ = writeln!(writer, "  Profiles:");
177            for p in &existing_profiles {
178                let cred = config::read_profile_credentials(config_path, p)
179                    .map(|(url, _)| format!("  {}", output::mask_credential(&url)))
180                    .unwrap_or_default();
181                let _ = writeln!(writer, "    {}{}", p, sym_dim(&cred));
182            }
183            let _ = writeln!(writer);
184        }
185
186        let action = prompt_optional(reader, writer, "Action  [update/add]", "update");
187        let _ = writeln!(writer);
188
189        if action.trim().eq_ignore_ascii_case("add") {
190            let Some(name) = prompt_required(reader, writer, "Profile name", "e.g. prod") else {
191                let _ = writeln!(writer, "\nAborted.");
192                return Ok(());
193            };
194            (name, false)
195        } else if existing_profiles.len() == 1 {
196            (existing_profiles[0].clone(), true)
197        } else {
198            let options = existing_profiles.join("/");
199            let chosen = prompt_optional(
200                reader,
201                writer,
202                &format!("Profile  [{}]", options),
203                &existing_profiles[0],
204            );
205            let profile = chosen.trim().to_owned();
206            if !existing_profiles.contains(&profile) {
207                let _ = writeln!(writer, "\n  {} Unknown profile '{}'.", sym_fail(), profile);
208                return Ok(());
209            }
210            (profile, true)
211        }
212    };
213
214    let (url, token) = if is_update {
215        let (cur_url, cur_token) = config::read_profile_credentials(config_path, &profile_name)
216            .expect("update mode requires existing credentials");
217        let Some(url) = prompt_credential_update(reader, writer, "URL", &cur_url) else {
218            let _ = writeln!(writer, "\nAborted.");
219            return Ok(());
220        };
221        let Some(token) = prompt_credential_update(reader, writer, "Token", &cur_token) else {
222            let _ = writeln!(writer, "\nAborted.");
223            return Ok(());
224        };
225        (url, token)
226    } else {
227        let Some(url) = prompt_required(
228            reader,
229            writer,
230            "Home Assistant URL",
231            "http://homeassistant.local:8123",
232        ) else {
233            let _ = writeln!(writer, "\nAborted.");
234            return Ok(());
235        };
236        let token_url = format!("{}/profile/security", url.trim_end_matches('/'));
237        let _ = writeln!(
238            writer,
239            "  {} Create a token at: {}",
240            sym_ok(),
241            sym_dim(&token_url)
242        );
243        let Some(token) = prompt_required(
244            reader,
245            writer,
246            "Long-Lived Access Token",
247            "paste token here",
248        ) else {
249            let _ = writeln!(writer, "\nAborted.");
250            return Ok(());
251        };
252        (url, token)
253    };
254
255    let _ = write!(writer, "\n  Verifying credentials...");
256    let _ = writer.flush();
257    let validation = validate(url.clone(), token.clone()).await;
258
259    let save = match validation {
260        Some(name) => {
261            let _ = writeln!(writer, " {} Connected to {}", sym_ok(), name.bold());
262            true
263        }
264        None => {
265            let _ = writeln!(writer, " {} Could not connect.", sym_fail());
266            prompt_confirm(reader, writer, "Save anyway?", false)
267        }
268    };
269
270    if !save {
271        let _ = writeln!(writer, "\nAborted. Config not saved.");
272        let _ = writer.flush();
273        return Ok(());
274    }
275
276    config::write_profile(config_path, &profile_name, &url, &token)?;
277
278    let run_cmd = if profile_name == "default" {
279        "ha entity list".to_owned()
280    } else {
281        format!("ha --profile {} entity list", profile_name)
282    };
283
284    let _ = writeln!(writer, "\n{SEP}");
285    let _ = writeln!(
286        writer,
287        "  {} Config saved to {}",
288        sym_ok(),
289        sym_dim(&config_path.display().to_string())
290    );
291    let _ = writeln!(writer, "  Run: {}", run_cmd.bold());
292    let _ = writer.flush();
293
294    Ok(())
295}
296
297/// Entry point from main — uses real stdin/stdout and live API validation.
298pub async fn init(profile_arg: Option<String>) {
299    let config_path = config::config_path();
300
301    if !std::io::stdout().is_terminal() {
302        print_json_schema(&config_path);
303        return;
304    }
305
306    let stdin = std::io::stdin();
307    let stdout = std::io::stdout();
308    let mut reader = std::io::BufReader::new(stdin.lock());
309    let mut writer = std::io::BufWriter::new(stdout.lock());
310
311    if let Err(e) = run_init(
312        &mut reader,
313        &mut writer,
314        &config_path,
315        profile_arg.as_deref(),
316        |url, token| async move {
317            let client = crate::api::HaClient::new(&url, &token);
318            client.validate().await.ok()
319        },
320    )
321    .await
322    {
323        eprintln!("{} {e}", sym_fail());
324        std::process::exit(crate::output::exit_codes::GENERAL_ERROR);
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use std::io::Cursor;
332    use tempfile::TempDir;
333
334    fn fake_path(dir: &TempDir) -> std::path::PathBuf {
335        dir.path().join("config.toml")
336    }
337
338    #[tokio::test]
339    async fn init_writes_config_on_valid_credentials() {
340        let dir = TempDir::new().unwrap();
341        let path = fake_path(&dir);
342        let input = b"http://ha.local:8123\nmytoken\n";
343        let mut reader = Cursor::new(input.as_ref());
344        let mut writer = Vec::<u8>::new();
345
346        run_init(
347            &mut reader,
348            &mut writer,
349            &path,
350            None,
351            |_url, _token| async { Some("Home Assistant".to_string()) },
352        )
353        .await
354        .unwrap();
355
356        let saved = std::fs::read_to_string(&path).unwrap();
357        assert!(saved.contains("http://ha.local:8123"));
358        assert!(saved.contains("mytoken"));
359        let output = String::from_utf8_lossy(&writer);
360        assert!(output.contains("http://ha.local:8123/profile/security"));
361    }
362
363    #[tokio::test]
364    async fn init_uses_default_profile_on_first_setup() {
365        let dir = TempDir::new().unwrap();
366        let path = fake_path(&dir);
367        let input = b"http://ha.local:8123\nmytoken\n";
368        let mut reader = Cursor::new(input.as_ref());
369        let mut writer = Vec::<u8>::new();
370
371        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
372            Some("HA".into())
373        })
374        .await
375        .unwrap();
376
377        let saved = std::fs::read_to_string(&path).unwrap();
378        assert!(saved.contains("[default]"));
379    }
380
381    #[tokio::test]
382    async fn init_aborts_when_validation_fails_and_user_declines() {
383        let dir = TempDir::new().unwrap();
384        let path = fake_path(&dir);
385        let input = b"http://ha.local:8123\nbadtoken\nn\n";
386        let mut reader = Cursor::new(input.as_ref());
387        let mut writer = Vec::<u8>::new();
388
389        run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
390            .await
391            .unwrap();
392
393        assert!(!path.exists(), "config must not be written after abort");
394    }
395
396    #[tokio::test]
397    async fn init_saves_when_validation_fails_but_user_forces() {
398        let dir = TempDir::new().unwrap();
399        let path = fake_path(&dir);
400        let input = b"http://ha.local:8123\nbadtoken\ny\n";
401        let mut reader = Cursor::new(input.as_ref());
402        let mut writer = Vec::<u8>::new();
403
404        run_init(&mut reader, &mut writer, &path, None, |_, _| async { None })
405            .await
406            .unwrap();
407
408        assert!(path.exists());
409    }
410
411    #[tokio::test]
412    async fn init_with_profile_arg_writes_named_profile() {
413        let dir = TempDir::new().unwrap();
414        let path = fake_path(&dir);
415        let input = b"http://ha.prod:8123\nprodtoken\n";
416        let mut reader = Cursor::new(input.as_ref());
417        let mut writer = Vec::<u8>::new();
418
419        run_init(
420            &mut reader,
421            &mut writer,
422            &path,
423            Some("prod"),
424            |_, _| async { Some("HA".into()) },
425        )
426        .await
427        .unwrap();
428
429        let saved = std::fs::read_to_string(&path).unwrap();
430        assert!(saved.contains("[prod]"));
431    }
432
433    #[tokio::test]
434    async fn init_update_keeps_values_on_enter() {
435        let dir = TempDir::new().unwrap();
436        let path = fake_path(&dir);
437        std::fs::write(
438            &path,
439            "[default]\nurl = \"http://ha.local:8123\"\ntoken = \"existing-token\"\n",
440        )
441        .unwrap();
442
443        // \n accepts "update" default at action prompt, then Enter to keep both fields
444        let input = b"\n\n\n";
445        let mut reader = Cursor::new(input.as_ref());
446        let mut writer = Vec::<u8>::new();
447
448        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
449            Some("HA".into())
450        })
451        .await
452        .unwrap();
453
454        let saved = std::fs::read_to_string(&path).unwrap();
455        assert!(saved.contains("existing-token"));
456    }
457
458    #[tokio::test]
459    async fn init_outro_includes_profile_flag_for_non_default() {
460        let dir = TempDir::new().unwrap();
461        let path = fake_path(&dir);
462        let input = b"http://ha.local:8123\ntoken\n";
463        let mut reader = Cursor::new(input.as_ref());
464        let mut writer = Vec::<u8>::new();
465
466        run_init(
467            &mut reader,
468            &mut writer,
469            &path,
470            Some("staging"),
471            |_, _| async { Some("HA".into()) },
472        )
473        .await
474        .unwrap();
475
476        let output = String::from_utf8_lossy(&writer);
477        assert!(output.contains("--profile staging"));
478    }
479
480    #[tokio::test]
481    async fn init_aborts_on_eof() {
482        let dir = TempDir::new().unwrap();
483        let path = fake_path(&dir);
484        let input = b"";
485        let mut reader = Cursor::new(input.as_ref());
486        let mut writer = Vec::<u8>::new();
487
488        run_init(&mut reader, &mut writer, &path, None, |_, _| async {
489            Some("HA".into())
490        })
491        .await
492        .unwrap();
493
494        assert!(!path.exists());
495        let output = String::from_utf8_lossy(&writer);
496        assert!(output.contains("Aborted"));
497    }
498}