hass_rs/
client.rs

1//! Home Assistant client implementation
2
3use crate::types::{
4    Ask, Auth, CallService, Command, HassConfig, HassEntity, HassPanels, HassServices, Response,
5    HassRegistryArea, HassRegistryDevice, HassRegistryEntity,
6    Subscribe, WSEvent,
7};
8use crate::{HassError, HassResult};
9
10use futures_util::{stream::SplitStream, SinkExt, StreamExt};
11use parking_lot::Mutex;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::sync::mpsc::{channel, Receiver, Sender};
18use tokio::sync::oneshot::{channel as oneshot, Sender as OneShotSender};
19use tokio_tungstenite::tungstenite::{Error, Message};
20use tokio_tungstenite::{connect_async, WebSocketStream};
21
22/// HassClient is a library that is meant to simplify the conversation with HomeAssistant Web Socket Server
23/// it provides a number of convenient functions that creates the requests and read the messages from server
24pub struct HassClient {
25    // holds the id of the WS message
26    last_sequence: AtomicU64,
27
28    rx_state: Arc<ReceiverState>,
29
30    /// Client --> Gateway (send "Commands" msg to the Gateway)
31    message_tx: Arc<Sender<Message>>,
32}
33
34#[derive(Default)]
35struct ReceiverState {
36    subscriptions: Mutex<HashMap<u64, Sender<WSEvent>>>,
37    pending_requests: Mutex<HashMap<u64, OneShotSender<Response>>>,
38    untagged_request: Mutex<Option<OneShotSender<Response>>>,
39}
40
41impl ReceiverState {
42    fn get_tx(self: &Arc<Self>, id: u64) -> Option<Sender<WSEvent>> {
43        self.subscriptions.lock().get(&id).map(|tx| tx.clone())
44    }
45
46    fn rm_subscription(self: &Arc<Self>, id: u64) {
47        self.subscriptions.lock().remove(&id);
48    }
49
50    fn take_responder(self: &Arc<Self>, id: u64) -> Option<OneShotSender<Response>> {
51        self.pending_requests.lock().remove(&id)
52    }
53
54    fn take_untagged(self: &Arc<Self>) -> Option<OneShotSender<Response>> {
55        self.untagged_request.lock().take()
56    }
57}
58
59async fn ws_incoming_messages(
60    mut stream: SplitStream<WebSocketStream<impl AsyncRead + AsyncWrite + Unpin>>,
61    rx_state: Arc<ReceiverState>,
62    message_tx: Arc<Sender<Message>>,
63) {
64    while let Some(message) = stream.next().await {
65        log::trace!("incoming: {message:#?}");
66        match check_if_event(message) {
67            Ok(event) => {
68                // Dispatch to subscriber
69                let id = event.id;
70                if let Some(tx) = rx_state.get_tx(id) {
71                    if tx.send(event).await.is_err() {
72                        rx_state.rm_subscription(id);
73                        // TODO: send unsub request here
74                    }
75                }
76            }
77            Err(message) => match message {
78                Ok(Message::Text(data)) => {
79                    let payload: Result<Response, HassError> = serde_json::from_str(data.as_str())
80                        .map_err(|err| HassError::UnableToDeserialize(err));
81
82                    match payload {
83                        Ok(response) => match response.id() {
84                            Some(id) => {
85                                if let Some(tx) = rx_state.take_responder(id) {
86                                    tx.send(response).ok();
87                                } else {
88                                    log::error!("no responder for id={id} {response:#?}");
89                                }
90                            }
91                            None => {
92                                if matches!(&response, Response::AuthRequired(_)) {
93                                    // AuthRequired is always sent unilaterally at connect time.
94                                    // It is never a response to one of our commands, so the
95                                    // simplest way to deal with it is to ignore it.
96                                    log::trace!("Ignoring {response:?}");
97                                    continue;
98                                }
99
100                                if let Some(tx) = rx_state.take_untagged() {
101                                    tx.send(response).ok();
102                                } else {
103                                    log::error!("no untagged responder for {response:#?}");
104                                }
105                            }
106                        },
107                        Err(err) => {
108                            log::error!("Error deserializing response: {err:#} {data}");
109                        }
110                    }
111                }
112                Ok(Message::Ping(data)) => {
113                    if let Err(err) = message_tx.send(Message::Pong(data)).await {
114                        log::error!("Error responding to ping: {err:#}");
115                        break;
116                    }
117                }
118                unexpected => log::error!("Unexpected message: {unexpected:#?}"),
119            },
120        }
121    }
122}
123
124impl HassClient {
125    pub async fn new(url: &str) -> HassResult<Self> {
126        let (wsclient, _) = connect_async(url).await?;
127        let (mut sink, stream) = wsclient.split();
128        let (message_tx, mut message_rx) = channel(20);
129
130        let message_tx = Arc::new(message_tx);
131
132        let rx_state = Arc::new(ReceiverState::default());
133
134        tokio::spawn(async move {
135            while let Some(msg) = message_rx.recv().await {
136                if let Err(err) = sink.send(msg).await {
137                    log::error!("sink error: {err:#}");
138                    break;
139                }
140            }
141        });
142        tokio::spawn(ws_incoming_messages(
143            stream,
144            rx_state.clone(),
145            message_tx.clone(),
146        ));
147
148        let last_sequence = AtomicU64::new(1);
149
150        Ok(Self {
151            last_sequence,
152            rx_state,
153            message_tx,
154        })
155    }
156
157    /// authenticate the session using a long-lived access token
158    ///
159    /// When a client connects to the server, the server sends out auth_required.
160    /// The first message from the client should be an auth message. You can authorize with an access token.
161    /// If the client supplies valid authentication, the authentication phase will complete by the server sending the auth_ok message.
162    /// If the data is incorrect, the server will reply with auth_invalid message and disconnect the session.
163
164    pub async fn auth_with_longlivedtoken(&mut self, token: &str) -> HassResult<()> {
165        let auth_message = Command::AuthInit(Auth {
166            msg_type: "auth".to_owned(),
167            access_token: token.to_owned(),
168        });
169
170        let response = self.command(auth_message, None).await?;
171
172        // Check if the authetication was succefully, should receive {"type": "auth_ok"}
173        match response {
174            Response::AuthOk(_) => Ok(()),
175            Response::AuthInvalid(err) => Err(HassError::AuthenticationFailed(err.message)),
176            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
177        }
178    }
179
180    /// The API supports receiving a ping from the client and returning a pong.
181    /// This serves as a heartbeat to ensure the connection is still alive.
182    pub async fn ping(&mut self) -> HassResult<()> {
183        let id = self.next_seq();
184
185        let ping_req = Command::Ping(Ask {
186            id,
187            msg_type: "ping".to_owned(),
188        });
189
190        let response = self.command(ping_req, Some(id)).await?;
191
192        match response {
193            Response::Pong(_v) => Ok(()),
194            Response::Result(err) => Err(HassError::ResponseError(err)),
195            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
196        }
197    }
198
199    /// This will get the current config of the Home Assistant.
200    ///
201    /// The server will respond with a result message containing the config.
202    pub async fn get_config(&mut self) -> HassResult<HassConfig> {
203        let id = self.next_seq();
204
205        let config_req = Command::GetConfig(Ask {
206            id,
207            msg_type: "get_config".to_owned(),
208        });
209        let response = self.command(config_req, Some(id)).await?;
210
211        match response {
212            Response::Result(data) => {
213                let value = data.result()?;
214                let config: HassConfig = serde_json::from_value(value)?;
215                Ok(config)
216            }
217            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
218        }
219    }
220
221    /// This will get all the current states from Home Assistant.
222    ///
223    /// The server will respond with a result message containing the states.
224
225    pub async fn get_states(&mut self) -> HassResult<Vec<HassEntity>> {
226        let id = self.next_seq();
227
228        let states_req = Command::GetStates(Ask {
229            id,
230            msg_type: "get_states".to_owned(),
231        });
232        let response = self.command(states_req, Some(id)).await?;
233
234        match response {
235            Response::Result(data) => {
236                let value = data.result()?;
237                let states: Vec<HassEntity> = serde_json::from_value(value)?;
238                Ok(states)
239            }
240            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
241        }
242    }
243
244    /// This will get all the services from Home Assistant.
245    ///
246    /// The server will respond with a result message containing the services.
247
248    pub async fn get_services(&mut self) -> HassResult<HassServices> {
249        let id = self.next_seq();
250        let services_req = Command::GetServices(Ask {
251            id,
252            msg_type: "get_services".to_owned(),
253        });
254        let response = self.command(services_req, Some(id)).await?;
255
256        match response {
257            Response::Result(data) => {
258                let value = data.result()?;
259                let services: HassServices = serde_json::from_value(value)?;
260                Ok(services)
261            }
262            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
263        }
264    }
265
266    /// This will get all the registered panels from Home Assistant.
267    ///
268    /// The server will respond with a result message containing the current registered panels.
269
270    pub async fn get_panels(&mut self) -> HassResult<HassPanels> {
271        let id = self.next_seq();
272
273        let services_req = Command::GetPanels(Ask {
274            id,
275            msg_type: "get_panels".to_owned(),
276        });
277        let response = self.command(services_req, Some(id)).await?;
278
279        match response {
280            Response::Result(data) => {
281                let value = data.result()?;
282                let services: HassPanels = serde_json::from_value(value)?;
283                Ok(services)
284            }
285            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
286        }
287    }
288
289    /// This will get the current area registry list from Home Assistant.
290    ///
291    /// The server will respond with a result message containing the area registry list.
292    pub async fn get_area_registry_list(&mut self) -> HassResult<Vec<HassRegistryArea>> {
293        let id = self.next_seq();
294
295        let area_req = Command::GetAreaRegistryList(Ask {
296            id,
297            msg_type: "config/area_registry/list".to_owned(),
298        });
299        let response = self.command(area_req, Some(id)).await?;
300
301        match response {
302            Response::Result(data) => {
303                let value = data.result()?;
304                let areas: Vec<HassRegistryArea> = serde_json::from_value(value)?;
305                Ok(areas)
306            }
307            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
308        }
309    }
310
311    /// This will get the current device registry list from Home Assistant.
312    ///
313    /// The server will respond with a result message containing the device registry list.
314    pub async fn get_device_registry_list(&mut self) -> HassResult<Vec<HassRegistryDevice>> {
315        let id = self.next_seq();
316
317        let device_req = Command::GetDeviceRegistryList(Ask {
318            id,
319            msg_type: "config/device_registry/list".to_owned(),
320        });
321        let response = self.command(device_req, Some(id)).await?;
322
323        match response {
324            Response::Result(data) => {
325                let value = data.result()?;
326                let devices: Vec<HassRegistryDevice> = serde_json::from_value(value)?;
327                Ok(devices)
328            }
329            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
330        }
331    }
332
333    /// This will get the current entity registry list from Home Assistant.
334    ///
335    /// The server will respond with a result message containing the entity registry list.
336    pub async fn get_entity_registry_list(&mut self) -> HassResult<Vec<HassRegistryEntity>> {
337        let id = self.next_seq();
338
339        let entity_req = Command::GetEntityRegistryList(Ask {
340            id,
341            msg_type: "config/entity_registry/list".to_owned(),
342        });
343        let response = self.command(entity_req, Some(id)).await?;
344
345        match response {
346            Response::Result(data) => {
347                let value = data.result()?;
348                let entities: Vec<HassRegistryEntity> = serde_json::from_value(value)?;
349                Ok(entities)
350            }
351            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
352        }
353    }
354
355    ///This will call a service in Home Assistant. Right now there is no return value.
356    ///The client can listen to state_changed events if it is interested in changed entities as a result of a service call.
357    ///
358    /// The server will indicate with a message indicating that the service is done executing.
359    /// <https://developers.home-assistant.io/docs/api/websocket#calling-a-service>
360    /// additional info : <https://developers.home-assistant.io/docs/api/rest> ==> Post `/api/services/<domain>/<service>`
361
362    pub async fn call_service(
363        &mut self,
364        domain: String,
365        service: String,
366        service_data: Option<Value>,
367    ) -> HassResult<()> {
368        let id = self.next_seq();
369
370        let services_req = Command::CallService(CallService {
371            id,
372            msg_type: "call_service".to_owned(),
373            domain,
374            service,
375            service_data,
376        });
377        let response = self.command(services_req, Some(id)).await?;
378
379        match response {
380            Response::Result(data) => {
381                data.result()?;
382                Ok(())
383            }
384            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
385        }
386    }
387
388    /// The command subscribe_event will subscribe your client to the event bus.
389    ///
390    /// Returns a channel that will receive the subscription messages.
391    pub async fn subscribe_event(&mut self, event_name: &str) -> HassResult<Receiver<WSEvent>> {
392        let id = self.next_seq();
393
394        let cmd = Command::SubscribeEvent(Subscribe {
395            id,
396            msg_type: "subscribe_events".to_owned(),
397            event_type: event_name.to_owned(),
398        });
399
400        let response = self.command(cmd, Some(id)).await?;
401
402        match response {
403            Response::Result(v) if v.is_ok() => {
404                let (tx, rx) = channel(20);
405                self.rx_state.subscriptions.lock().insert(v.id, tx);
406                return Ok(rx);
407            }
408            Response::Result(v) => Err(HassError::ResponseError(v)),
409            unknown => Err(HassError::UnknownPayloadReceived(unknown)),
410        }
411    }
412
413    /// send commands and receive responses from the gateway
414    pub(crate) async fn command(&mut self, cmd: Command, id: Option<u64>) -> HassResult<Response> {
415        let cmd_tungstenite = cmd.to_tungstenite_message();
416
417        let (tx, rx) = oneshot();
418
419        match id {
420            Some(id) => {
421                self.rx_state.pending_requests.lock().insert(id, tx);
422            }
423            None => {
424                self.rx_state.untagged_request.lock().replace(tx);
425            }
426        }
427
428        // Send the auth command to gateway
429        self.message_tx
430            .send(cmd_tungstenite)
431            .await
432            .map_err(|err| HassError::SendError(err.to_string()))?;
433
434        rx.await
435            .map_err(|err| HassError::RecvError(err.to_string()))
436    }
437
438    /// get message sequence required by the Websocket server
439    fn next_seq(&self) -> u64 {
440        self.last_sequence.fetch_add(1, Ordering::Relaxed)
441    }
442}
443
444/// convenient function that validates if the message received is an Event
445/// the Events should be processed by used in a separate async task
446fn check_if_event(result: Result<Message, Error>) -> Result<WSEvent, Result<Message, Error>> {
447    match result {
448        Ok(Message::Text(data)) => {
449            let payload: Result<Response, HassError> =
450                serde_json::from_str(data.as_str()).map_err(|err| HassError::from(err));
451
452            if let Ok(Response::Event(event)) = payload {
453                Ok(event)
454            } else {
455                Err(Ok(Message::Text(data)))
456            }
457        }
458        result => Err(result),
459    }
460}