use std::io::{IsTerminal, Write};
use crate::api::HaError;
use crate::api::websocket::HaWs;
use crate::output::{self, OutputConfig, exit_codes};
pub async fn entity_list(
out: &OutputConfig,
base_url: &str,
token: &str,
integration: Option<&str>,
domain: Option<&str>,
) -> Result<(), HaError> {
let mut ws = HaWs::connect(base_url, token).await?;
let raw = ws
.call("config/entity_registry/list", serde_json::json!({}))
.await?;
ws.close().await;
let mut entries: Vec<serde_json::Value> = match raw {
serde_json::Value::Array(a) => a,
_ => Vec::new(),
};
if let Some(platform) = integration {
entries.retain(|e| e.get("platform").and_then(|v| v.as_str()) == Some(platform));
}
if let Some(d) = domain {
let prefix = format!("{d}.");
entries.retain(|e| {
e.get("entity_id")
.and_then(|v| v.as_str())
.is_some_and(|id| id.starts_with(&prefix))
});
}
entries.sort_by(|a, b| {
let ka = a.get("entity_id").and_then(|v| v.as_str()).unwrap_or("");
let kb = b.get("entity_id").and_then(|v| v.as_str()).unwrap_or("");
ka.cmp(kb)
});
if out.is_json() {
out.print_data(
&serde_json::to_string_pretty(&serde_json::json!({
"ok": true,
"data": entries,
}))
.expect("serialize"),
);
} else {
let rows: Vec<Vec<String>> = entries
.iter()
.map(|e| {
let entity_id = e
.get("entity_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_owned();
let name = e
.get("name")
.and_then(|v| v.as_str())
.or_else(|| e.get("original_name").and_then(|v| v.as_str()))
.unwrap_or("")
.to_owned();
let platform = e
.get("platform")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_owned();
let disabled_by = e
.get("disabled_by")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_owned();
vec![
output::colored_entity_id(&entity_id),
name,
platform,
disabled_by,
]
})
.collect();
out.print_data(&output::table(
&["ENTITY", "NAME", "INTEGRATION", "DISABLED_BY"],
&rows,
));
}
Ok(())
}
pub async fn entity_remove(
out: &OutputConfig,
base_url: &str,
token: &str,
entity_ids: &[String],
dry_run: bool,
yes: bool,
) -> Result<(), HaError> {
if entity_ids.is_empty() {
return Err(HaError::InvalidInput(
"at least one entity_id is required".into(),
));
}
if dry_run {
let data: Vec<serde_json::Value> = entity_ids
.iter()
.map(|id| serde_json::json!({"entity_id": id, "status": "dry_run"}))
.collect();
if out.is_json() {
out.print_data(
&serde_json::to_string_pretty(&serde_json::json!({
"ok": true,
"data": data,
}))
.expect("serialize"),
);
} else {
out.print_message(&format!(
"[dry-run] would remove {} entit{}:",
entity_ids.len(),
if entity_ids.len() == 1 { "y" } else { "ies" }
));
for id in entity_ids {
out.print_data(&format!(" {id}"));
}
}
return Ok(());
}
let auto_confirm = yes || out.is_json() || !std::io::stdin().is_terminal();
if !auto_confirm {
eprintln!(
"About to remove {} entit{} from the Home Assistant registry:",
entity_ids.len(),
if entity_ids.len() == 1 { "y" } else { "ies" }
);
for id in entity_ids {
eprintln!(" {id}");
}
eprint!("Proceed? [y/N] ");
let _ = std::io::stderr().flush();
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.map_err(|e| HaError::Other(format!("failed to read stdin: {e}")))?;
let answer = input.trim().to_ascii_lowercase();
if answer != "y" && answer != "yes" {
return Err(HaError::InvalidInput("aborted by user".into()));
}
}
let mut ws = HaWs::connect(base_url, token).await?;
let mut results = Vec::with_capacity(entity_ids.len());
let mut failed = 0usize;
for id in entity_ids {
let outcome = ws
.call(
"config/entity_registry/remove",
serde_json::json!({"entity_id": id}),
)
.await;
match outcome {
Ok(_) => results.push(serde_json::json!({
"entity_id": id,
"status": "removed",
})),
Err(HaError::NotFound(msg)) => {
failed += 1;
results.push(serde_json::json!({
"entity_id": id,
"status": "not_found",
"error": msg,
}));
}
Err(e) => {
failed += 1;
results.push(serde_json::json!({
"entity_id": id,
"status": "error",
"error": e.to_string(),
}));
}
}
}
ws.close().await;
let any_failed = failed > 0;
if out.is_json() {
out.print_data(
&serde_json::to_string_pretty(&serde_json::json!({
"ok": !any_failed,
"data": results,
}))
.expect("serialize"),
);
} else {
for r in &results {
let id = r.get("entity_id").and_then(|v| v.as_str()).unwrap_or("");
let status = r.get("status").and_then(|v| v.as_str()).unwrap_or("");
let err = r.get("error").and_then(|v| v.as_str()).unwrap_or("");
if err.is_empty() {
out.print_data(&format!("{id}: {status}"));
} else {
out.print_data(&format!("{id}: {status} ({err})"));
}
}
out.print_message(&format!(
"{} removed, {} failed",
entity_ids.len() - failed,
failed
));
}
if any_failed {
std::process::exit(exit_codes::PARTIAL_FAILURE);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::output::OutputFormat;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
fn json_out() -> OutputConfig {
OutputConfig::new(Some(OutputFormat::Json), false)
}
async fn spawn_mock<F, Fut>(handler: F) -> (String, tokio::task::JoinHandle<()>)
where
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let base_url = format!("http://127.0.0.1:{port}");
let handle = tokio::spawn(async move {
if let Ok((stream, _)) = listener.accept().await
&& let Ok(ws) = tokio_tungstenite::accept_async(stream).await
{
handler(ws).await;
}
});
(base_url, handle)
}
async fn do_auth(ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) {
ws.send(Message::Text(
serde_json::json!({"type": "auth_required"}).to_string(),
))
.await
.unwrap();
let _ = ws.next().await.unwrap().unwrap();
ws.send(Message::Text(
serde_json::json!({"type": "auth_ok"}).to_string(),
))
.await
.unwrap();
}
async fn recv_cmd(
ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
) -> serde_json::Value {
let msg = ws.next().await.unwrap().unwrap();
match msg {
Message::Text(t) => serde_json::from_str(&t).unwrap(),
other => panic!("expected text frame, got {other:?}"),
}
}
async fn send_result(
ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
id: u64,
result: serde_json::Value,
) {
ws.send(Message::Text(
serde_json::json!({
"id": id,
"type": "result",
"success": true,
"result": result,
})
.to_string(),
))
.await
.unwrap();
}
async fn send_error(
ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
id: u64,
code: &str,
message: &str,
) {
ws.send(Message::Text(
serde_json::json!({
"id": id,
"type": "result",
"success": false,
"error": {"code": code, "message": message},
})
.to_string(),
))
.await
.unwrap();
}
#[tokio::test]
async fn entity_list_calls_registry_endpoint() {
let (base, handle) = spawn_mock(|mut ws| async move {
do_auth(&mut ws).await;
let cmd = recv_cmd(&mut ws).await;
assert_eq!(cmd["type"], "config/entity_registry/list");
let id = cmd["id"].as_u64().unwrap();
send_result(
&mut ws,
id,
serde_json::json!([
{"entity_id": "light.a", "platform": "hue", "name": "A"},
{"entity_id": "switch.b", "platform": "zha"},
{"entity_id": "light.c", "platform": "hue"},
]),
)
.await;
})
.await;
entity_list(&json_out(), &base, "tok", None, None)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn entity_list_filters_by_domain_and_integration() {
let (base, handle) = spawn_mock(|mut ws| async move {
do_auth(&mut ws).await;
let cmd = recv_cmd(&mut ws).await;
let id = cmd["id"].as_u64().unwrap();
send_result(
&mut ws,
id,
serde_json::json!([
{"entity_id": "light.a", "platform": "hue"},
{"entity_id": "switch.b", "platform": "hue"},
{"entity_id": "light.c", "platform": "zha"},
]),
)
.await;
})
.await;
entity_list(&json_out(), &base, "tok", Some("hue"), Some("light"))
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn entity_remove_dry_run_makes_no_network_calls() {
let unused_url = "http://127.0.0.1:1";
let ids = vec!["light.a".to_string(), "light.b".to_string()];
entity_remove(&json_out(), unused_url, "tok", &ids, true, true)
.await
.unwrap();
}
#[tokio::test]
async fn entity_remove_empty_list_errors() {
let err = entity_remove(&json_out(), "http://example.com", "tok", &[], false, true)
.await
.unwrap_err();
assert!(matches!(err, HaError::InvalidInput(_)));
}
#[tokio::test]
async fn entity_remove_sends_one_call_per_id() {
let (base, handle) = spawn_mock(|mut ws| async move {
do_auth(&mut ws).await;
for expected in ["light.a", "light.b"] {
let cmd = recv_cmd(&mut ws).await;
assert_eq!(cmd["type"], "config/entity_registry/remove");
assert_eq!(cmd["entity_id"], expected);
let id = cmd["id"].as_u64().unwrap();
send_result(&mut ws, id, serde_json::Value::Null).await;
}
})
.await;
let ids = vec!["light.a".to_string(), "light.b".to_string()];
entity_remove(&json_out(), &base, "tok", &ids, false, true)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn entity_remove_reports_not_found_per_entity() {
let (base, handle) = spawn_mock(|mut ws| async move {
do_auth(&mut ws).await;
let cmd = recv_cmd(&mut ws).await;
let id = cmd["id"].as_u64().unwrap();
send_error(&mut ws, id, "not_found", "Entity not found").await;
let _ = ws.next().await;
})
.await;
drop(base);
handle.abort();
}
}