Skip to main content

homeassistant_cli/api/
mod.rs

1pub mod entities;
2pub mod events;
3pub mod services;
4pub mod types;
5
6pub use types::*;
7
8use std::fmt;
9
10#[derive(Debug)]
11pub enum HaError {
12    /// 401/403 from HA API.
13    Auth(String),
14    /// 404 — entity, service, or resource not found.
15    NotFound(String),
16    /// Missing or invalid config/input.
17    InvalidInput(String),
18    /// Could not reach Home Assistant.
19    Connection(String),
20    /// Non-2xx response.
21    Api { status: u16, message: String },
22    /// Network/TLS error from reqwest.
23    Http(reqwest::Error),
24    /// Any other error.
25    Other(String),
26}
27
28impl HaError {
29    /// Machine-readable error code for JSON error envelopes.
30    pub fn error_code(&self) -> &str {
31        match self {
32            HaError::Auth(_) => "HA_AUTH_ERROR",
33            HaError::NotFound(_) => "HA_NOT_FOUND",
34            HaError::InvalidInput(_) => "HA_INVALID_INPUT",
35            HaError::Connection(_) => "HA_CONNECTION_ERROR",
36            HaError::Api { .. } => "HA_API_ERROR",
37            HaError::Http(_) => "HA_HTTP_ERROR",
38            HaError::Other(_) => "HA_ERROR",
39        }
40    }
41}
42
43impl fmt::Display for HaError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            HaError::Auth(msg) => write!(
47                f,
48                "Authentication failed: {msg}\nCheck your token or run `ha init`."
49            ),
50            HaError::NotFound(msg) => write!(f, "Not found: {msg}"),
51            HaError::InvalidInput(msg) => write!(f, "{msg}"),
52            HaError::Connection(url) => write!(
53                f,
54                "Could not connect to Home Assistant at {url}\nCheck that HA is running and the URL is correct."
55            ),
56            HaError::Api { status, message } => write!(f, "API error {status}: {message}"),
57            HaError::Http(e) => write!(f, "HTTP error: {e}"),
58            HaError::Other(msg) => write!(f, "{msg}"),
59        }
60    }
61}
62
63impl std::error::Error for HaError {
64    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
65        match self {
66            HaError::Http(e) => Some(e),
67            _ => None,
68        }
69    }
70}
71
72impl From<reqwest::Error> for HaError {
73    fn from(e: reqwest::Error) -> Self {
74        if e.is_connect() || e.is_timeout() {
75            HaError::Connection(
76                e.url()
77                    .map(|u| u.to_string())
78                    .unwrap_or_else(|| "unknown".into()),
79            )
80        } else {
81            HaError::Http(e)
82        }
83    }
84}
85
86/// HTTP client for the Home Assistant REST API.
87pub struct HaClient {
88    pub base_url: String,
89    token: String,
90    pub(crate) client: reqwest::Client,
91}
92
93impl HaClient {
94    pub fn new(base_url: impl Into<String>, token: impl Into<String>) -> Self {
95        Self {
96            base_url: base_url.into().trim_end_matches('/').to_owned(),
97            token: token.into(),
98            client: reqwest::Client::builder()
99                .timeout(std::time::Duration::from_secs(30))
100                .build()
101                .expect("build reqwest client"),
102        }
103    }
104
105    pub fn token(&self) -> &str {
106        &self.token
107    }
108
109    /// Returns a GET request builder pre-configured with Bearer auth.
110    pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
111        self.client
112            .get(format!("{}{}", self.base_url, path))
113            .bearer_auth(&self.token)
114    }
115
116    /// Returns a POST request builder pre-configured with Bearer auth.
117    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
118        self.client
119            .post(format!("{}{}", self.base_url, path))
120            .bearer_auth(&self.token)
121    }
122
123    /// Validate the connection by calling GET /api/
124    pub async fn validate(&self) -> Result<String, HaError> {
125        let resp = self.get("/api/").send().await?;
126        match resp.status().as_u16() {
127            200 => {
128                let body: serde_json::Value = resp.json().await?;
129                Ok(body["message"]
130                    .as_str()
131                    .unwrap_or("API running.")
132                    .to_owned())
133            }
134            401 | 403 => Err(HaError::Auth("Invalid token".into())),
135            status => Err(HaError::Api {
136                status,
137                message: resp.text().await.unwrap_or_default(),
138            }),
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use std::error::Error;
147
148    #[test]
149    fn error_code_returns_expected_strings() {
150        assert_eq!(HaError::Auth("x".into()).error_code(), "HA_AUTH_ERROR");
151        assert_eq!(HaError::NotFound("x".into()).error_code(), "HA_NOT_FOUND");
152        assert_eq!(
153            HaError::InvalidInput("x".into()).error_code(),
154            "HA_INVALID_INPUT"
155        );
156        assert_eq!(
157            HaError::Connection("x".into()).error_code(),
158            "HA_CONNECTION_ERROR"
159        );
160        assert_eq!(
161            HaError::Api {
162                status: 500,
163                message: "x".into()
164            }
165            .error_code(),
166            "HA_API_ERROR"
167        );
168        assert_eq!(HaError::Other("x".into()).error_code(), "HA_ERROR");
169    }
170
171    #[test]
172    fn auth_error_display_includes_guidance() {
173        let err = HaError::Auth("401 Unauthorized".into());
174        let msg = err.to_string();
175        assert!(msg.contains("Authentication failed"));
176        assert!(msg.contains("ha init") || msg.contains("HA_TOKEN"));
177    }
178
179    #[test]
180    fn not_found_display_includes_entity() {
181        let err = HaError::NotFound("light.missing".into());
182        assert!(err.to_string().contains("light.missing"));
183    }
184
185    #[test]
186    fn connection_error_mentions_url() {
187        let err = HaError::Connection("http://ha.local:8123".into());
188        assert!(err.to_string().contains("http://ha.local:8123"));
189    }
190
191    #[test]
192    fn http_error_source_is_reqwest() {
193        let rt = tokio::runtime::Runtime::new().unwrap();
194        let reqwest_err = rt.block_on(async {
195            reqwest::Client::new()
196                .get("http://127.0.0.1:1")
197                .send()
198                .await
199                .unwrap_err()
200        });
201        let api_err = HaError::Http(reqwest_err);
202        assert!(api_err.source().is_some());
203    }
204
205    #[test]
206    fn ha_client_new_trims_trailing_slash() {
207        let client = HaClient::new("http://ha.local:8123/", "token");
208        assert_eq!(client.base_url, "http://ha.local:8123");
209    }
210}