Skip to main content

homeassistant_cli/api/
websocket.rs

1//! WebSocket client for the Home Assistant WebSocket API.
2//!
3//! Unlike REST, the registry endpoints (`config/entity_registry/*`,
4//! `config/device_registry/*`, `config/area_registry/*`) are only reachable
5//! over the WebSocket API. This module provides a minimal id-multiplexed
6//! request/response client that authenticates once and exchanges JSON
7//! messages with Home Assistant.
8//!
9//! Protocol reference: <https://developers.home-assistant.io/docs/api/websocket>
10
11use futures_util::{SinkExt, StreamExt};
12use tokio_tungstenite::tungstenite::Message;
13
14use crate::api::HaError;
15
16/// Derive the WebSocket URL from a REST base URL.
17/// `http://host/` → `ws://host/api/websocket`, `https://host/` → `wss://host/api/websocket`.
18/// Preserves any base path (e.g. `https://ha.example.com/ha` for reverse-proxied installs).
19pub(crate) fn derive_ws_url(base_url: &str) -> Result<String, HaError> {
20    let base = base_url.trim_end_matches('/');
21    let (scheme, rest) = if let Some(rest) = base.strip_prefix("https://") {
22        ("wss://", rest)
23    } else if let Some(rest) = base.strip_prefix("http://") {
24        ("ws://", rest)
25    } else {
26        return Err(HaError::InvalidInput(format!(
27            "URL must start with http:// or https://: {base_url}"
28        )));
29    };
30    Ok(format!("{scheme}{rest}/api/websocket"))
31}
32
33type WsStream =
34    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
35
36/// Authenticated Home Assistant WebSocket client.
37///
38/// The connection is opened and authenticated in [`HaWs::connect`]; afterwards
39/// [`HaWs::call`] sends a command and returns the matching `result` payload.
40/// Ids are assigned monotonically per client.
41pub struct HaWs {
42    stream: WsStream,
43    next_id: u64,
44}
45
46impl HaWs {
47    /// Open a WebSocket connection and complete the auth handshake.
48    pub async fn connect(base_url: &str, token: &str) -> Result<Self, HaError> {
49        let ws_url = derive_ws_url(base_url)?;
50        let (stream, _response) = tokio_tungstenite::connect_async(&ws_url)
51            .await
52            .map_err(|e| HaError::Connection(format!("{ws_url}: {e}")))?;
53        let mut client = Self { stream, next_id: 1 };
54        client.authenticate(token).await?;
55        Ok(client)
56    }
57
58    async fn authenticate(&mut self, token: &str) -> Result<(), HaError> {
59        let msg = self.recv_json().await?;
60        match msg.get("type").and_then(|v| v.as_str()) {
61            Some("auth_required") => {}
62            Some(other) => {
63                return Err(HaError::Other(format!(
64                    "expected auth_required, got {other}"
65                )));
66            }
67            None => return Err(HaError::Other("missing type on first message".into())),
68        }
69
70        self.send_json(&serde_json::json!({
71            "type": "auth",
72            "access_token": token,
73        }))
74        .await?;
75
76        let msg = self.recv_json().await?;
77        match msg.get("type").and_then(|v| v.as_str()) {
78            Some("auth_ok") => Ok(()),
79            Some("auth_invalid") => {
80                let m = msg
81                    .get("message")
82                    .and_then(|v| v.as_str())
83                    .unwrap_or("invalid token");
84                Err(HaError::Auth(m.to_owned()))
85            }
86            _ => Err(HaError::Other(format!("unexpected auth response: {msg}"))),
87        }
88    }
89
90    /// Send a command and return its `result` payload.
91    ///
92    /// `extra` is merged into the command envelope alongside `id` and `type`
93    /// (e.g. `{"entity_id": "light.x"}` for a `config/entity_registry/remove`).
94    /// HA error codes map to [`HaError`]: `not_found` → `NotFound`, everything
95    /// else → `Api { status: 0, message: "<code>: <message>" }`.
96    pub async fn call(
97        &mut self,
98        msg_type: &str,
99        extra: serde_json::Value,
100    ) -> Result<serde_json::Value, HaError> {
101        let id = self.next_id;
102        self.next_id += 1;
103
104        let mut cmd = serde_json::json!({ "id": id, "type": msg_type });
105        if let serde_json::Value::Object(extras) = extra
106            && let serde_json::Value::Object(ref mut obj) = cmd
107        {
108            for (k, v) in extras {
109                obj.insert(k, v);
110            }
111        }
112        self.send_json(&cmd).await?;
113
114        loop {
115            let msg = self.recv_json().await?;
116            let is_result = msg.get("type").and_then(|v| v.as_str()) == Some("result");
117            let matches_id = msg.get("id").and_then(|v| v.as_u64()) == Some(id);
118            if !(is_result && matches_id) {
119                continue;
120            }
121            if msg.get("success").and_then(|v| v.as_bool()) == Some(true) {
122                return Ok(msg
123                    .get("result")
124                    .cloned()
125                    .unwrap_or(serde_json::Value::Null));
126            }
127            let err = msg.get("error").cloned().unwrap_or(serde_json::Value::Null);
128            let code = err
129                .get("code")
130                .and_then(|v| v.as_str())
131                .unwrap_or("unknown")
132                .to_owned();
133            let message = err
134                .get("message")
135                .and_then(|v| v.as_str())
136                .unwrap_or("")
137                .to_owned();
138            return Err(match code.as_str() {
139                "not_found" => HaError::NotFound(message),
140                "unauthorized" => HaError::Auth(message),
141                _ => HaError::Api {
142                    status: 0,
143                    message: format!("{code}: {message}"),
144                },
145            });
146        }
147    }
148
149    /// Close the WebSocket cleanly. Errors on close are ignored.
150    pub async fn close(mut self) {
151        let _ = self.stream.close(None).await;
152    }
153
154    async fn send_json(&mut self, value: &serde_json::Value) -> Result<(), HaError> {
155        let text = value.to_string();
156        self.stream
157            .send(Message::Text(text))
158            .await
159            .map_err(|e| HaError::Connection(format!("send failed: {e}")))
160    }
161
162    async fn recv_json(&mut self) -> Result<serde_json::Value, HaError> {
163        loop {
164            let msg = self
165                .stream
166                .next()
167                .await
168                .ok_or_else(|| HaError::Connection("connection closed".into()))?
169                .map_err(|e| HaError::Connection(format!("recv failed: {e}")))?;
170            match msg {
171                Message::Text(text) => {
172                    return serde_json::from_str(&text)
173                        .map_err(|e| HaError::Other(format!("invalid JSON from server: {e}")));
174                }
175                Message::Binary(_) => {
176                    return Err(HaError::Other("unexpected binary frame".into()));
177                }
178                Message::Close(_) => {
179                    return Err(HaError::Connection("server closed connection".into()));
180                }
181                Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
182            }
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn derive_ws_url_http_to_ws() {
193        assert_eq!(
194            derive_ws_url("http://ha.local:8123").unwrap(),
195            "ws://ha.local:8123/api/websocket"
196        );
197    }
198
199    #[test]
200    fn derive_ws_url_https_to_wss() {
201        assert_eq!(
202            derive_ws_url("https://ha.example.com").unwrap(),
203            "wss://ha.example.com/api/websocket"
204        );
205    }
206
207    #[test]
208    fn derive_ws_url_strips_trailing_slash() {
209        assert_eq!(
210            derive_ws_url("http://ha.local:8123/").unwrap(),
211            "ws://ha.local:8123/api/websocket"
212        );
213    }
214
215    #[test]
216    fn derive_ws_url_preserves_base_path() {
217        assert_eq!(
218            derive_ws_url("https://example.com/ha").unwrap(),
219            "wss://example.com/ha/api/websocket"
220        );
221    }
222
223    #[test]
224    fn derive_ws_url_rejects_invalid_scheme() {
225        assert!(matches!(
226            derive_ws_url("ftp://ha.local").unwrap_err(),
227            HaError::InvalidInput(_)
228        ));
229    }
230
231    /// Spawn a tiny WebSocket server that runs `handler` against exactly one
232    /// client connection, then returns the HTTP base URL the client should use.
233    async fn spawn_mock_server<F, Fut>(handler: F) -> (String, tokio::task::JoinHandle<()>)
234    where
235        F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
236            + Send
237            + 'static,
238        Fut: std::future::Future<Output = ()> + Send + 'static,
239    {
240        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
241        let port = listener.local_addr().unwrap().port();
242        let base_url = format!("http://127.0.0.1:{port}");
243        let handle = tokio::spawn(async move {
244            if let Ok((stream, _)) = listener.accept().await
245                && let Ok(ws) = tokio_tungstenite::accept_async(stream).await
246            {
247                handler(ws).await;
248            }
249        });
250        (base_url, handle)
251    }
252
253    async fn recv_text(
254        ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
255    ) -> serde_json::Value {
256        let msg = ws.next().await.unwrap().unwrap();
257        let text = match msg {
258            Message::Text(t) => t.to_string(),
259            other => panic!("expected text frame, got {other:?}"),
260        };
261        serde_json::from_str(&text).unwrap()
262    }
263
264    async fn send_text(
265        ws: &mut tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
266        v: serde_json::Value,
267    ) {
268        ws.send(Message::Text(v.to_string())).await.unwrap();
269    }
270
271    #[tokio::test]
272    async fn connect_completes_auth_handshake() {
273        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
274            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
275            let auth = recv_text(&mut ws).await;
276            assert_eq!(auth["type"], "auth");
277            assert_eq!(auth["access_token"], "tok");
278            send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
279        })
280        .await;
281
282        let client = HaWs::connect(&base_url, "tok").await.unwrap();
283        client.close().await;
284        handle.await.unwrap();
285    }
286
287    #[tokio::test]
288    async fn connect_auth_invalid_maps_to_auth_error() {
289        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
290            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
291            let _ = recv_text(&mut ws).await;
292            send_text(
293                &mut ws,
294                serde_json::json!({"type": "auth_invalid", "message": "Invalid access token"}),
295            )
296            .await;
297        })
298        .await;
299
300        let result = HaWs::connect(&base_url, "tok").await;
301        match result {
302            Err(HaError::Auth(_)) => {}
303            Err(e) => panic!("expected Auth error, got {e:?}"),
304            Ok(_) => panic!("expected Auth error, got Ok"),
305        }
306        handle.await.unwrap();
307    }
308
309    #[tokio::test]
310    async fn call_returns_result_payload() {
311        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
312            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
313            let _ = recv_text(&mut ws).await;
314            send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
315
316            let cmd = recv_text(&mut ws).await;
317            assert_eq!(cmd["type"], "config/entity_registry/list");
318            let id = cmd["id"].as_u64().unwrap();
319            send_text(
320                &mut ws,
321                serde_json::json!({
322                    "id": id,
323                    "type": "result",
324                    "success": true,
325                    "result": [{"entity_id": "light.x"}]
326                }),
327            )
328            .await;
329        })
330        .await;
331
332        let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
333        let result = client
334            .call("config/entity_registry/list", serde_json::json!({}))
335            .await
336            .unwrap();
337        assert_eq!(result[0]["entity_id"], "light.x");
338        client.close().await;
339        handle.await.unwrap();
340    }
341
342    #[tokio::test]
343    async fn call_not_found_error_maps_to_not_found() {
344        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
345            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
346            let _ = recv_text(&mut ws).await;
347            send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
348
349            let cmd = recv_text(&mut ws).await;
350            let id = cmd["id"].as_u64().unwrap();
351            send_text(
352                &mut ws,
353                serde_json::json!({
354                    "id": id,
355                    "type": "result",
356                    "success": false,
357                    "error": {"code": "not_found", "message": "Entity not found"}
358                }),
359            )
360            .await;
361        })
362        .await;
363
364        let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
365        let err = client
366            .call(
367                "config/entity_registry/remove",
368                serde_json::json!({"entity_id": "light.missing"}),
369            )
370            .await
371            .unwrap_err();
372        assert!(matches!(err, HaError::NotFound(_)));
373        client.close().await;
374        handle.await.unwrap();
375    }
376
377    #[tokio::test]
378    async fn call_merges_extra_fields() {
379        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
380            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
381            let _ = recv_text(&mut ws).await;
382            send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
383
384            let cmd = recv_text(&mut ws).await;
385            assert_eq!(cmd["type"], "config/entity_registry/remove");
386            assert_eq!(cmd["entity_id"], "light.kitchen");
387            let id = cmd["id"].as_u64().unwrap();
388            send_text(
389                &mut ws,
390                serde_json::json!({
391                    "id": id,
392                    "type": "result",
393                    "success": true,
394                    "result": null
395                }),
396            )
397            .await;
398        })
399        .await;
400
401        let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
402        client
403            .call(
404                "config/entity_registry/remove",
405                serde_json::json!({"entity_id": "light.kitchen"}),
406            )
407            .await
408            .unwrap();
409        client.close().await;
410        handle.await.unwrap();
411    }
412
413    #[tokio::test]
414    async fn call_ignores_interleaved_unrelated_messages() {
415        let (base_url, handle) = spawn_mock_server(|mut ws| async move {
416            send_text(&mut ws, serde_json::json!({"type": "auth_required"})).await;
417            let _ = recv_text(&mut ws).await;
418            send_text(&mut ws, serde_json::json!({"type": "auth_ok"})).await;
419
420            let cmd = recv_text(&mut ws).await;
421            let id = cmd["id"].as_u64().unwrap();
422            // Send a spurious event, then a result with a mismatched id, then the real result.
423            send_text(&mut ws, serde_json::json!({"type": "event", "event": {}})).await;
424            send_text(
425                &mut ws,
426                serde_json::json!({
427                    "id": 9999,
428                    "type": "result",
429                    "success": true,
430                    "result": "wrong"
431                }),
432            )
433            .await;
434            send_text(
435                &mut ws,
436                serde_json::json!({
437                    "id": id,
438                    "type": "result",
439                    "success": true,
440                    "result": "correct"
441                }),
442            )
443            .await;
444        })
445        .await;
446
447        let mut client = HaWs::connect(&base_url, "tok").await.unwrap();
448        let result = client
449            .call("config/entity_registry/list", serde_json::json!({}))
450            .await
451            .unwrap();
452        assert_eq!(result, "correct");
453        client.close().await;
454        handle.await.unwrap();
455    }
456}