1use futures_util::{SinkExt, StreamExt};
12use tokio_tungstenite::tungstenite::Message;
13
14use crate::api::HaError;
15
16pub(crate) fn derive_ws_url(base_url: &str) -> Result<String, HaError> {
20 let base = base_url.trim_end_matches('/');
21 let (scheme, rest) = if let Some(rest) = base.strip_prefix("https://") {
22 ("wss://", rest)
23 } else if let Some(rest) = base.strip_prefix("http://") {
24 ("ws://", rest)
25 } else {
26 return Err(HaError::InvalidInput(format!(
27 "URL must start with http:// or https://: {base_url}"
28 )));
29 };
30 Ok(format!("{scheme}{rest}/api/websocket"))
31}
32
33type WsStream =
34 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
35
36pub struct HaWs {
42 stream: WsStream,
43 next_id: u64,
44}
45
46impl HaWs {
47 pub async fn connect(base_url: &str, token: &str) -> Result<Self, HaError> {
49 let ws_url = derive_ws_url(base_url)?;
50 let (stream, _response) = tokio_tungstenite::connect_async(&ws_url)
51 .await
52 .map_err(|e| HaError::Connection(format!("{ws_url}: {e}")))?;
53 let mut client = Self { stream, next_id: 1 };
54 client.authenticate(token).await?;
55 Ok(client)
56 }
57
58 async fn authenticate(&mut self, token: &str) -> Result<(), HaError> {
59 let msg = self.recv_json().await?;
60 match msg.get("type").and_then(|v| v.as_str()) {
61 Some("auth_required") => {}
62 Some(other) => {
63 return Err(HaError::Other(format!(
64 "expected auth_required, got {other}"
65 )));
66 }
67 None => return Err(HaError::Other("missing type on first message".into())),
68 }
69
70 self.send_json(&serde_json::json!({
71 "type": "auth",
72 "access_token": token,
73 }))
74 .await?;
75
76 let msg = self.recv_json().await?;
77 match msg.get("type").and_then(|v| v.as_str()) {
78 Some("auth_ok") => Ok(()),
79 Some("auth_invalid") => {
80 let m = msg
81 .get("message")
82 .and_then(|v| v.as_str())
83 .unwrap_or("invalid token");
84 Err(HaError::Auth(m.to_owned()))
85 }
86 _ => Err(HaError::Other(format!("unexpected auth response: {msg}"))),
87 }
88 }
89
90 pub async fn call(
97 &mut self,
98 msg_type: &str,
99 extra: serde_json::Value,
100 ) -> Result<serde_json::Value, HaError> {
101 let id = self.next_id;
102 self.next_id += 1;
103
104 let mut cmd = serde_json::json!({ "id": id, "type": msg_type });
105 if let serde_json::Value::Object(extras) = extra
106 && let serde_json::Value::Object(ref mut obj) = cmd
107 {
108 for (k, v) in extras {
109 obj.insert(k, v);
110 }
111 }
112 self.send_json(&cmd).await?;
113
114 loop {
115 let msg = self.recv_json().await?;
116 let is_result = msg.get("type").and_then(|v| v.as_str()) == Some("result");
117 let matches_id = msg.get("id").and_then(|v| v.as_u64()) == Some(id);
118 if !(is_result && matches_id) {
119 continue;
120 }
121 if msg.get("success").and_then(|v| v.as_bool()) == Some(true) {
122 return Ok(msg
123 .get("result")
124 .cloned()
125 .unwrap_or(serde_json::Value::Null));
126 }
127 let err = msg.get("error").cloned().unwrap_or(serde_json::Value::Null);
128 let code = err
129 .get("code")
130 .and_then(|v| v.as_str())
131 .unwrap_or("unknown")
132 .to_owned();
133 let message = err
134 .get("message")
135 .and_then(|v| v.as_str())
136 .unwrap_or("")
137 .to_owned();
138 return Err(match code.as_str() {
139 "not_found" => HaError::NotFound(message),
140 "unauthorized" => HaError::Auth(message),
141 _ => HaError::Api {
142 status: 0,
143 message: format!("{code}: {message}"),
144 },
145 });
146 }
147 }
148
149 pub async fn close(mut self) {
151 let _ = self.stream.close(None).await;
152 }
153
154 async fn send_json(&mut self, value: &serde_json::Value) -> Result<(), HaError> {
155 let text = value.to_string();
156 self.stream
157 .send(Message::Text(text))
158 .await
159 .map_err(|e| HaError::Connection(format!("send failed: {e}")))
160 }
161
162 async fn recv_json(&mut self) -> Result<serde_json::Value, HaError> {
163 loop {
164 let msg = self
165 .stream
166 .next()
167 .await
168 .ok_or_else(|| HaError::Connection("connection closed".into()))?
169 .map_err(|e| HaError::Connection(format!("recv failed: {e}")))?;
170 match msg {
171 Message::Text(text) => {
172 return serde_json::from_str(&text)
173 .map_err(|e| HaError::Other(format!("invalid JSON from server: {e}")));
174 }
175 Message::Binary(_) => {
176 return Err(HaError::Other("unexpected binary frame".into()));
177 }
178 Message::Close(_) => {
179 return Err(HaError::Connection("server closed connection".into()));
180 }
181 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
182 }
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn derive_ws_url_http_to_ws() {
193 assert_eq!(
194 derive_ws_url("http://ha.local:8123").unwrap(),
195 "ws://ha.local:8123/api/websocket"
196 );
197 }
198
199 #[test]
200 fn derive_ws_url_https_to_wss() {
201 assert_eq!(
202 derive_ws_url("https://ha.example.com").unwrap(),
203 "wss://ha.example.com/api/websocket"
204 );
205 }
206
207 #[test]
208 fn derive_ws_url_strips_trailing_slash() {
209 assert_eq!(
210 derive_ws_url("http://ha.local:8123/").unwrap(),
211 "ws://ha.local:8123/api/websocket"
212 );
213 }
214
215 #[test]
216 fn derive_ws_url_preserves_base_path() {
217 assert_eq!(
218 derive_ws_url("https://example.com/ha").unwrap(),
219 "wss://example.com/ha/api/websocket"
220 );
221 }
222
223 #[test]
224 fn derive_ws_url_rejects_invalid_scheme() {
225 assert!(matches!(
226 derive_ws_url("ftp://ha.local").unwrap_err(),
227 HaError::InvalidInput(_)
228 ));
229 }
230
231 async fn spawn_mock_server<F, Fut>(handler: F) -> (String, tokio::task::JoinHandle<()>)
234 where
235 F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
236 + Send
237 + 'static,
238 Fut: std::future::Future<Output = ()> + Send + 'static,
239 {
240 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
241 let port = listener.local_addr().unwrap().port();
242 let base_url = format!("http://127.0.0.1:{port}");
243 let handle = tokio::spawn(async move {
244 if let Ok((stream, _)) = listener.accept().await
245 && let Ok(ws) = tokio_tungstenite::accept_async(stream).await
246 {
247 handler(ws).await;
248 }
249 });
250 (base_url, handle)
251 }
252
253 async fn recv_text(
254 ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
255 ) -> serde_json::Value {
256 let msg = ws.next().await.unwrap().unwrap();
257 let text = match msg {
258 Message::Text(t) => t.to_string(),
259 other => panic!("expected text frame, got {other:?}"),
260 };
261 serde_json::from_str(&text).unwrap()
262 }
263
264 async fn send_text(
265 ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
266 v: serde_json::Value,
267 ) {
268 ws.send(Message::Text(v.to_string())).await.unwrap();
269 }
270
271 #[tokio::test]
272 async fn connect_completes_auth_handshake() {
273 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
274 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
275 let auth = recv_text(&mut ws).await;
276 assert_eq!(auth["type"], "auth");
277 assert_eq!(auth["access_token"], "tok");
278 send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
279 })
280 .await;
281
282 let client = HaWs::connect(&base_url, "tok").await.unwrap();
283 client.close().await;
284 handle.await.unwrap();
285 }
286
287 #[tokio::test]
288 async fn connect_auth_invalid_maps_to_auth_error() {
289 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
290 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
291 let _ = recv_text(&mut ws).await;
292 send_text(
293 &mut ws,
294 serde_json::json!({"type": "auth_invalid", "message": "Invalid access token"}),
295 )
296 .await;
297 })
298 .await;
299
300 let result = HaWs::connect(&base_url, "tok").await;
301 match result {
302 Err(HaError::Auth(_)) => {}
303 Err(e) => panic!("expected Auth error, got {e:?}"),
304 Ok(_) => panic!("expected Auth error, got Ok"),
305 }
306 handle.await.unwrap();
307 }
308
309 #[tokio::test]
310 async fn call_returns_result_payload() {
311 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
312 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
313 let _ = recv_text(&mut ws).await;
314 send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
315
316 let cmd = recv_text(&mut ws).await;
317 assert_eq!(cmd["type"], "config/entity_registry/list");
318 let id = cmd["id"].as_u64().unwrap();
319 send_text(
320 &mut ws,
321 serde_json::json!({
322 "id": id,
323 "type": "result",
324 "success": true,
325 "result": [{"entity_id": "light.x"}]
326 }),
327 )
328 .await;
329 })
330 .await;
331
332 let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
333 let result = client
334 .call("config/entity_registry/list", serde_json::json!({}))
335 .await
336 .unwrap();
337 assert_eq!(result[0]["entity_id"], "light.x");
338 client.close().await;
339 handle.await.unwrap();
340 }
341
342 #[tokio::test]
343 async fn call_not_found_error_maps_to_not_found() {
344 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
345 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
346 let _ = recv_text(&mut ws).await;
347 send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
348
349 let cmd = recv_text(&mut ws).await;
350 let id = cmd["id"].as_u64().unwrap();
351 send_text(
352 &mut ws,
353 serde_json::json!({
354 "id": id,
355 "type": "result",
356 "success": false,
357 "error": {"code": "not_found", "message": "Entity not found"}
358 }),
359 )
360 .await;
361 })
362 .await;
363
364 let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
365 let err = client
366 .call(
367 "config/entity_registry/remove",
368 serde_json::json!({"entity_id": "light.missing"}),
369 )
370 .await
371 .unwrap_err();
372 assert!(matches!(err, HaError::NotFound(_)));
373 client.close().await;
374 handle.await.unwrap();
375 }
376
377 #[tokio::test]
378 async fn call_merges_extra_fields() {
379 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
380 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
381 let _ = recv_text(&mut ws).await;
382 send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
383
384 let cmd = recv_text(&mut ws).await;
385 assert_eq!(cmd["type"], "config/entity_registry/remove");
386 assert_eq!(cmd["entity_id"], "light.kitchen");
387 let id = cmd["id"].as_u64().unwrap();
388 send_text(
389 &mut ws,
390 serde_json::json!({
391 "id": id,
392 "type": "result",
393 "success": true,
394 "result": null
395 }),
396 )
397 .await;
398 })
399 .await;
400
401 let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
402 client
403 .call(
404 "config/entity_registry/remove",
405 serde_json::json!({"entity_id": "light.kitchen"}),
406 )
407 .await
408 .unwrap();
409 client.close().await;
410 handle.await.unwrap();
411 }
412
413 #[tokio::test]
414 async fn call_ignores_interleaved_unrelated_messages() {
415 let (base_url, handle) = spawn_mock_server(|mut ws| async move {
416 send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
417 let _ = recv_text(&mut ws).await;
418 send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
419
420 let cmd = recv_text(&mut ws).await;
421 let id = cmd["id"].as_u64().unwrap();
422 send_text(&mut ws, serde_json::json!({"type": "event", "event": {}})).await;
424 send_text(
425 &mut ws,
426 serde_json::json!({
427 "id": 9999,
428 "type": "result",
429 "success": true,
430 "result": "wrong"
431 }),
432 )
433 .await;
434 send_text(
435 &mut ws,
436 serde_json::json!({
437 "id": id,
438 "type": "result",
439 "success": true,
440 "result": "correct"
441 }),
442 )
443 .await;
444 })
445 .await;
446
447 let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
448 let result = client
449 .call("config/entity_registry/list", serde_json::json!({}))
450 .await
451 .unwrap();
452 assert_eq!(result, "correct");
453 client.close().await;
454 handle.await.unwrap();
455 }
456}