Skip to main content

homeassistant_cli/api/
mod.rs

1pub mod entities;
2pub mod events;
3pub mod services;
4pub mod types;
5pub mod websocket;
6
7pub use types::*;
8
9use std::fmt;
10
11#[derive(Debug)]
12pub enum HaError {
13    /// 401/403 from HA API.
14    Auth(String),
15    /// 404 — entity, service, or resource not found.
16    NotFound(String),
17    /// Missing or invalid config/input.
18    InvalidInput(String),
19    /// Could not reach Home Assistant.
20    Connection(String),
21    /// Non-2xx response.
22    Api { status: u16, message: String },
23    /// Network/TLS error from reqwest.
24    Http(reqwest::Error),
25    /// Any other error.
26    Other(String),
27}
28
29impl HaError {
30    /// Machine-readable error code for JSON error envelopes.
31    pub fn error_code(&self) -> &str {
32        match self {
33            HaError::Auth(_) => "HA_AUTH_ERROR",
34            HaError::NotFound(_) => "HA_NOT_FOUND",
35            HaError::InvalidInput(_) => "HA_INVALID_INPUT",
36            HaError::Connection(_) => "HA_CONNECTION_ERROR",
37            HaError::Api { .. } => "HA_API_ERROR",
38            HaError::Http(_) => "HA_HTTP_ERROR",
39            HaError::Other(_) => "HA_ERROR",
40        }
41    }
42}
43
44impl fmt::Display for HaError {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            HaError::Auth(msg) => write!(
48                f,
49                "Authentication failed: {msg}\nCheck your token or run `ha init`."
50            ),
51            HaError::NotFound(msg) => write!(f, "Not found: {msg}"),
52            HaError::InvalidInput(msg) => write!(f, "{msg}"),
53            HaError::Connection(url) => write!(
54                f,
55                "Could not connect to Home Assistant at {url}\nCheck that HA is running and the URL is correct."
56            ),
57            HaError::Api { status, message } => write!(f, "API error {status}: {message}"),
58            HaError::Http(e) => write!(f, "HTTP error: {e}"),
59            HaError::Other(msg) => write!(f, "{msg}"),
60        }
61    }
62}
63
64impl std::error::Error for HaError {
65    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66        match self {
67            HaError::Http(e) => Some(e),
68            _ => None,
69        }
70    }
71}
72
73impl From<reqwest::Error> for HaError {
74    fn from(e: reqwest::Error) -> Self {
75        if e.is_connect() || e.is_timeout() {
76            HaError::Connection(
77                e.url()
78                    .map(|u| u.to_string())
79                    .unwrap_or_else(|| "unknown".into()),
80            )
81        } else {
82            HaError::Http(e)
83        }
84    }
85}
86
87/// HTTP client for the Home Assistant REST API.
88pub struct HaClient {
89    pub base_url: String,
90    token: String,
91    pub(crate) client: reqwest::Client,
92}
93
94impl HaClient {
95    pub fn new(base_url: impl Into<String>, token: impl Into<String>) -> Self {
96        Self {
97            base_url: base_url.into().trim_end_matches('/').to_owned(),
98            token: token.into(),
99            client: reqwest::Client::builder()
100                .timeout(std::time::Duration::from_secs(30))
101                .build()
102                .expect("build reqwest client"),
103        }
104    }
105
106    pub fn token(&self) -> &str {
107        &self.token
108    }
109
110    /// Returns a GET request builder pre-configured with Bearer auth.
111    pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
112        self.client
113            .get(format!("{}{}", self.base_url, path))
114            .bearer_auth(&self.token)
115    }
116
117    /// Returns a POST request builder pre-configured with Bearer auth.
118    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
119        self.client
120            .post(format!("{}{}", self.base_url, path))
121            .bearer_auth(&self.token)
122    }
123
124    /// Validate the connection by calling GET /api/
125    pub async fn validate(&self) -> Result<String, HaError> {
126        let resp = self.get("/api/").send().await?;
127        match resp.status().as_u16() {
128            200 => {
129                let body: serde_json::Value = resp.json().await?;
130                Ok(body["message"]
131                    .as_str()
132                    .unwrap_or("API running.")
133                    .to_owned())
134            }
135            401 | 403 => Err(HaError::Auth("Invalid token".into())),
136            status => Err(HaError::Api {
137                status,
138                message: resp.text().await.unwrap_or_default(),
139            }),
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use std::error::Error;
148
149    #[test]
150    fn error_code_returns_expected_strings() {
151        assert_eq!(HaError::Auth("x".into()).error_code(), "HA_AUTH_ERROR");
152        assert_eq!(HaError::NotFound("x".into()).error_code(), "HA_NOT_FOUND");
153        assert_eq!(
154            HaError::InvalidInput("x".into()).error_code(),
155            "HA_INVALID_INPUT"
156        );
157        assert_eq!(
158            HaError::Connection("x".into()).error_code(),
159            "HA_CONNECTION_ERROR"
160        );
161        assert_eq!(
162            HaError::Api {
163                status: 500,
164                message: "x".into()
165            }
166            .error_code(),
167            "HA_API_ERROR"
168        );
169        assert_eq!(HaError::Other("x".into()).error_code(), "HA_ERROR");
170    }
171
172    #[test]
173    fn auth_error_display_includes_guidance() {
174        let err = HaError::Auth("401 Unauthorized".into());
175        let msg = err.to_string();
176        assert!(msg.contains("Authentication failed"));
177        assert!(msg.contains("ha init") || msg.contains("HA_TOKEN"));
178    }
179
180    #[test]
181    fn not_found_display_includes_entity() {
182        let err = HaError::NotFound("light.missing".into());
183        assert!(err.to_string().contains("light.missing"));
184    }
185
186    #[test]
187    fn connection_error_mentions_url() {
188        let err = HaError::Connection("http://ha.local:8123".into());
189        assert!(err.to_string().contains("http://ha.local:8123"));
190    }
191
192    #[test]
193    fn http_error_source_is_reqwest() {
194        let rt = tokio::runtime::Runtime::new().unwrap();
195        let reqwest_err = rt.block_on(async {
196            reqwest::Client::new()
197                .get("http://127.0.0.1:1")
198                .send()
199                .await
200                .unwrap_err()
201        });
202        let api_err = HaError::Http(reqwest_err);
203        assert!(api_err.source().is_some());
204    }
205
206    #[test]
207    fn ha_client_new_trims_trailing_slash() {
208        let client = HaClient::new("http://ha.local:8123/", "token");
209        assert_eq!(client.base_url, "http://ha.local:8123");
210    }
211}