homeassistant_cli/api/
mod.rs1pub mod entities;
2pub mod events;
3pub mod services;
4pub mod types;
5pub mod websocket;
6
7pub use types::*;
8
9use std::fmt;
10
11#[derive(Debug)]
12pub enum HaError {
13 Auth(String),
15 NotFound(String),
17 InvalidInput(String),
19 Connection(String),
21 Api { status: u16, message: String },
23 Http(reqwest::Error),
25 Other(String),
27}
28
29impl HaError {
30 pub fn error_code(&self) -> &str {
32 match self {
33 HaError::Auth(_) => "HA_AUTH_ERROR",
34 HaError::NotFound(_) => "HA_NOT_FOUND",
35 HaError::InvalidInput(_) => "HA_INVALID_INPUT",
36 HaError::Connection(_) => "HA_CONNECTION_ERROR",
37 HaError::Api { .. } => "HA_API_ERROR",
38 HaError::Http(_) => "HA_HTTP_ERROR",
39 HaError::Other(_) => "HA_ERROR",
40 }
41 }
42}
43
44impl fmt::Display for HaError {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 HaError::Auth(msg) => write!(
48 f,
49 "Authentication failed: {msg}\nCheck your token or run `ha init`."
50 ),
51 HaError::NotFound(msg) => write!(f, "Not found: {msg}"),
52 HaError::InvalidInput(msg) => write!(f, "{msg}"),
53 HaError::Connection(url) => write!(
54 f,
55 "Could not connect to Home Assistant at {url}\nCheck that HA is running and the URL is correct."
56 ),
57 HaError::Api { status, message } => write!(f, "API error {status}: {message}"),
58 HaError::Http(e) => write!(f, "HTTP error: {e}"),
59 HaError::Other(msg) => write!(f, "{msg}"),
60 }
61 }
62}
63
64impl std::error::Error for HaError {
65 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66 match self {
67 HaError::Http(e) => Some(e),
68 _ => None,
69 }
70 }
71}
72
73impl From<reqwest::Error> for HaError {
74 fn from(e: reqwest::Error) -> Self {
75 if e.is_connect() || e.is_timeout() {
76 HaError::Connection(
77 e.url()
78 .map(|u| u.to_string())
79 .unwrap_or_else(|| "unknown".into()),
80 )
81 } else {
82 HaError::Http(e)
83 }
84 }
85}
86
87pub struct HaClient {
89 pub base_url: String,
90 token: String,
91 pub(crate) client: reqwest::Client,
92}
93
94impl HaClient {
95 pub fn new(base_url: impl Into<String>, token: impl Into<String>) -> Self {
96 Self {
97 base_url: base_url.into().trim_end_matches('/').to_owned(),
98 token: token.into(),
99 client: reqwest::Client::builder()
100 .timeout(std::time::Duration::from_secs(30))
101 .build()
102 .expect("build reqwest client"),
103 }
104 }
105
106 pub fn token(&self) -> &str {
107 &self.token
108 }
109
110 pub fn get(&self, path: &str) -> reqwest::RequestBuilder {
112 self.client
113 .get(format!("{}{}", self.base_url, path))
114 .bearer_auth(&self.token)
115 }
116
117 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
119 self.client
120 .post(format!("{}{}", self.base_url, path))
121 .bearer_auth(&self.token)
122 }
123
124 pub async fn validate(&self) -> Result<String, HaError> {
126 let resp = self.get("/api/").send().await?;
127 match resp.status().as_u16() {
128 200 => {
129 let body: serde_json::Value = resp.json().await?;
130 Ok(body["message"]
131 .as_str()
132 .unwrap_or("API running.")
133 .to_owned())
134 }
135 401 | 403 => Err(HaError::Auth("Invalid token".into())),
136 status => Err(HaError::Api {
137 status,
138 message: resp.text().await.unwrap_or_default(),
139 }),
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::error::Error;
148
149 #[test]
150 fn error_code_returns_expected_strings() {
151 assert_eq!(HaError::Auth("x".into()).error_code(), "HA_AUTH_ERROR");
152 assert_eq!(HaError::NotFound("x".into()).error_code(), "HA_NOT_FOUND");
153 assert_eq!(
154 HaError::InvalidInput("x".into()).error_code(),
155 "HA_INVALID_INPUT"
156 );
157 assert_eq!(
158 HaError::Connection("x".into()).error_code(),
159 "HA_CONNECTION_ERROR"
160 );
161 assert_eq!(
162 HaError::Api {
163 status: 500,
164 message: "x".into()
165 }
166 .error_code(),
167 "HA_API_ERROR"
168 );
169 assert_eq!(HaError::Other("x".into()).error_code(), "HA_ERROR");
170 }
171
172 #[test]
173 fn auth_error_display_includes_guidance() {
174 let err = HaError::Auth("401 Unauthorized".into());
175 let msg = err.to_string();
176 assert!(msg.contains("Authentication failed"));
177 assert!(msg.contains("ha init") || msg.contains("HA_TOKEN"));
178 }
179
180 #[test]
181 fn not_found_display_includes_entity() {
182 let err = HaError::NotFound("light.missing".into());
183 assert!(err.to_string().contains("light.missing"));
184 }
185
186 #[test]
187 fn connection_error_mentions_url() {
188 let err = HaError::Connection("http://ha.local:8123".into());
189 assert!(err.to_string().contains("http://ha.local:8123"));
190 }
191
192 #[test]
193 fn http_error_source_is_reqwest() {
194 let rt = tokio::runtime::Runtime::new().unwrap();
195 let reqwest_err = rt.block_on(async {
196 reqwest::Client::new()
197 .get("http://127.0.0.1:1")
198 .send()
199 .await
200 .unwrap_err()
201 });
202 let api_err = HaError::Http(reqwest_err);
203 assert!(api_err.source().is_some());
204 }
205
206 #[test]
207 fn ha_client_new_trims_trailing_slash() {
208 let client = HaClient::new("http://ha.local:8123/", "token");
209 assert_eq!(client.base_url, "http://ha.local:8123");
210 }
211}