1use crate::types::{
4 Ask, Auth, CallService, Command, HassConfig, HassEntity, HassPanels, HassServices, Response,
5 HassRegistryArea, HassRegistryDevice, HassRegistryEntity,
6 Subscribe, WSEvent,
7};
8use crate::{HassError, HassResult};
9
10use futures_util::{stream::SplitStream, SinkExt, StreamExt};
11use parking_lot::Mutex;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::sync::mpsc::{channel, Receiver, Sender};
18use tokio::sync::oneshot::{channel as oneshot, Sender as OneShotSender};
19use tokio_tungstenite::tungstenite::{Error, Message};
20use tokio_tungstenite::{connect_async, WebSocketStream};
21
22pub struct HassClient {
25 last_sequence: AtomicU64,
27
28 rx_state: Arc<ReceiverState>,
29
30 message_tx: Arc<Sender<Message>>,
32}
33
34#[derive(Default)]
35struct ReceiverState {
36 subscriptions: Mutex<HashMap<u64, Sender<WSEvent>>>,
37 pending_requests: Mutex<HashMap<u64, OneShotSender<Response>>>,
38 untagged_request: Mutex<Option<OneShotSender<Response>>>,
39}
40
41impl ReceiverState {
42 fn get_tx(self: &Arc<Self>, id: u64) -> Option<Sender<WSEvent>> {
43 self.subscriptions.lock().get(&id).map(|tx| tx.clone())
44 }
45
46 fn rm_subscription(self: &Arc<Self>, id: u64) {
47 self.subscriptions.lock().remove(&id);
48 }
49
50 fn take_responder(self: &Arc<Self>, id: u64) -> Option<OneShotSender<Response>> {
51 self.pending_requests.lock().remove(&id)
52 }
53
54 fn take_untagged(self: &Arc<Self>) -> Option<OneShotSender<Response>> {
55 self.untagged_request.lock().take()
56 }
57}
58
59async fn ws_incoming_messages(
60 mut stream: SplitStream<WebSocketStream<impl AsyncRead + AsyncWrite + Unpin>>,
61 rx_state: Arc<ReceiverState>,
62 message_tx: Arc<Sender<Message>>,
63) {
64 while let Some(message) = stream.next().await {
65 log::trace!("incoming: {message:#?}");
66 match check_if_event(message) {
67 Ok(event) => {
68 let id = event.id;
70 if let Some(tx) = rx_state.get_tx(id) {
71 if tx.send(event).await.is_err() {
72 rx_state.rm_subscription(id);
73 }
75 }
76 }
77 Err(message) => match message {
78 Ok(Message::Text(data)) => {
79 let payload: Result<Response, HassError> = serde_json::from_str(data.as_str())
80 .map_err(|err| HassError::UnableToDeserialize(err));
81
82 match payload {
83 Ok(response) => match response.id() {
84 Some(id) => {
85 if let Some(tx) = rx_state.take_responder(id) {
86 tx.send(response).ok();
87 } else {
88 log::error!("no responder for id={id} {response:#?}");
89 }
90 }
91 None => {
92 if matches!(&response, Response::AuthRequired(_)) {
93 log::trace!("Ignoring {response:?}");
97 continue;
98 }
99
100 if let Some(tx) = rx_state.take_untagged() {
101 tx.send(response).ok();
102 } else {
103 log::error!("no untagged responder for {response:#?}");
104 }
105 }
106 },
107 Err(err) => {
108 log::error!("Error deserializing response: {err:#} {data}");
109 }
110 }
111 }
112 Ok(Message::Ping(data)) => {
113 if let Err(err) = message_tx.send(Message::Pong(data)).await {
114 log::error!("Error responding to ping: {err:#}");
115 break;
116 }
117 }
118 unexpected => log::error!("Unexpected message: {unexpected:#?}"),
119 },
120 }
121 }
122}
123
124impl HassClient {
125 pub async fn new(url: &str) -> HassResult<Self> {
126 let (wsclient, _) = connect_async(url).await?;
127 let (mut sink, stream) = wsclient.split();
128 let (message_tx, mut message_rx) = channel(20);
129
130 let message_tx = Arc::new(message_tx);
131
132 let rx_state = Arc::new(ReceiverState::default());
133
134 tokio::spawn(async move {
135 while let Some(msg) = message_rx.recv().await {
136 if let Err(err) = sink.send(msg).await {
137 log::error!("sink error: {err:#}");
138 break;
139 }
140 }
141 });
142 tokio::spawn(ws_incoming_messages(
143 stream,
144 rx_state.clone(),
145 message_tx.clone(),
146 ));
147
148 let last_sequence = AtomicU64::new(1);
149
150 Ok(Self {
151 last_sequence,
152 rx_state,
153 message_tx,
154 })
155 }
156
157 pub async fn auth_with_longlivedtoken(&mut self, token: &str) -> HassResult<()> {
165 let auth_message = Command::AuthInit(Auth {
166 msg_type: "auth".to_owned(),
167 access_token: token.to_owned(),
168 });
169
170 let response = self.command(auth_message, None).await?;
171
172 match response {
174 Response::AuthOk(_) => Ok(()),
175 Response::AuthInvalid(err) => Err(HassError::AuthenticationFailed(err.message)),
176 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
177 }
178 }
179
180 pub async fn ping(&mut self) -> HassResult<()> {
183 let id = self.next_seq();
184
185 let ping_req = Command::Ping(Ask {
186 id,
187 msg_type: "ping".to_owned(),
188 });
189
190 let response = self.command(ping_req, Some(id)).await?;
191
192 match response {
193 Response::Pong(_v) => Ok(()),
194 Response::Result(err) => Err(HassError::ResponseError(err)),
195 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
196 }
197 }
198
199 pub async fn get_config(&mut self) -> HassResult<HassConfig> {
203 let id = self.next_seq();
204
205 let config_req = Command::GetConfig(Ask {
206 id,
207 msg_type: "get_config".to_owned(),
208 });
209 let response = self.command(config_req, Some(id)).await?;
210
211 match response {
212 Response::Result(data) => {
213 let value = data.result()?;
214 let config: HassConfig = serde_json::from_value(value)?;
215 Ok(config)
216 }
217 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
218 }
219 }
220
221 pub async fn get_states(&mut self) -> HassResult<Vec<HassEntity>> {
226 let id = self.next_seq();
227
228 let states_req = Command::GetStates(Ask {
229 id,
230 msg_type: "get_states".to_owned(),
231 });
232 let response = self.command(states_req, Some(id)).await?;
233
234 match response {
235 Response::Result(data) => {
236 let value = data.result()?;
237 let states: Vec<HassEntity> = serde_json::from_value(value)?;
238 Ok(states)
239 }
240 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
241 }
242 }
243
244 pub async fn get_services(&mut self) -> HassResult<HassServices> {
249 let id = self.next_seq();
250 let services_req = Command::GetServices(Ask {
251 id,
252 msg_type: "get_services".to_owned(),
253 });
254 let response = self.command(services_req, Some(id)).await?;
255
256 match response {
257 Response::Result(data) => {
258 let value = data.result()?;
259 let services: HassServices = serde_json::from_value(value)?;
260 Ok(services)
261 }
262 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
263 }
264 }
265
266 pub async fn get_panels(&mut self) -> HassResult<HassPanels> {
271 let id = self.next_seq();
272
273 let services_req = Command::GetPanels(Ask {
274 id,
275 msg_type: "get_panels".to_owned(),
276 });
277 let response = self.command(services_req, Some(id)).await?;
278
279 match response {
280 Response::Result(data) => {
281 let value = data.result()?;
282 let services: HassPanels = serde_json::from_value(value)?;
283 Ok(services)
284 }
285 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
286 }
287 }
288
289 pub async fn get_area_registry_list(&mut self) -> HassResult<Vec<HassRegistryArea>> {
293 let id = self.next_seq();
294
295 let area_req = Command::GetAreaRegistryList(Ask {
296 id,
297 msg_type: "config/area_registry/list".to_owned(),
298 });
299 let response = self.command(area_req, Some(id)).await?;
300
301 match response {
302 Response::Result(data) => {
303 let value = data.result()?;
304 let areas: Vec<HassRegistryArea> = serde_json::from_value(value)?;
305 Ok(areas)
306 }
307 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
308 }
309 }
310
311 pub async fn get_device_registry_list(&mut self) -> HassResult<Vec<HassRegistryDevice>> {
315 let id = self.next_seq();
316
317 let device_req = Command::GetDeviceRegistryList(Ask {
318 id,
319 msg_type: "config/device_registry/list".to_owned(),
320 });
321 let response = self.command(device_req, Some(id)).await?;
322
323 match response {
324 Response::Result(data) => {
325 let value = data.result()?;
326 let devices: Vec<HassRegistryDevice> = serde_json::from_value(value)?;
327 Ok(devices)
328 }
329 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
330 }
331 }
332
333 pub async fn get_entity_registry_list(&mut self) -> HassResult<Vec<HassRegistryEntity>> {
337 let id = self.next_seq();
338
339 let entity_req = Command::GetEntityRegistryList(Ask {
340 id,
341 msg_type: "config/entity_registry/list".to_owned(),
342 });
343 let response = self.command(entity_req, Some(id)).await?;
344
345 match response {
346 Response::Result(data) => {
347 let value = data.result()?;
348 let entities: Vec<HassRegistryEntity> = serde_json::from_value(value)?;
349 Ok(entities)
350 }
351 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
352 }
353 }
354
355 pub async fn call_service(
363 &mut self,
364 domain: String,
365 service: String,
366 service_data: Option<Value>,
367 ) -> HassResult<()> {
368 let id = self.next_seq();
369
370 let services_req = Command::CallService(CallService {
371 id,
372 msg_type: "call_service".to_owned(),
373 domain,
374 service,
375 service_data,
376 });
377 let response = self.command(services_req, Some(id)).await?;
378
379 match response {
380 Response::Result(data) => {
381 data.result()?;
382 Ok(())
383 }
384 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
385 }
386 }
387
388 pub async fn subscribe_event(&mut self, event_name: &str) -> HassResult<Receiver<WSEvent>> {
392 let id = self.next_seq();
393
394 let cmd = Command::SubscribeEvent(Subscribe {
395 id,
396 msg_type: "subscribe_events".to_owned(),
397 event_type: event_name.to_owned(),
398 });
399
400 let response = self.command(cmd, Some(id)).await?;
401
402 match response {
403 Response::Result(v) if v.is_ok() => {
404 let (tx, rx) = channel(20);
405 self.rx_state.subscriptions.lock().insert(v.id, tx);
406 return Ok(rx);
407 }
408 Response::Result(v) => Err(HassError::ResponseError(v)),
409 unknown => Err(HassError::UnknownPayloadReceived(unknown)),
410 }
411 }
412
413 pub(crate) async fn command(&mut self, cmd: Command, id: Option<u64>) -> HassResult<Response> {
415 let cmd_tungstenite = cmd.to_tungstenite_message();
416
417 let (tx, rx) = oneshot();
418
419 match id {
420 Some(id) => {
421 self.rx_state.pending_requests.lock().insert(id, tx);
422 }
423 None => {
424 self.rx_state.untagged_request.lock().replace(tx);
425 }
426 }
427
428 self.message_tx
430 .send(cmd_tungstenite)
431 .await
432 .map_err(|err| HassError::SendError(err.to_string()))?;
433
434 rx.await
435 .map_err(|err| HassError::RecvError(err.to_string()))
436 }
437
438 fn next_seq(&self) -> u64 {
440 self.last_sequence.fetch_add(1, Ordering::Relaxed)
441 }
442}
443
444fn check_if_event(result: Result<Message, Error>) -> Result<WSEvent, Result<Message, Error>> {
447 match result {
448 Ok(Message::Text(data)) => {
449 let payload: Result<Response, HassError> =
450 serde_json::from_str(data.as_str()).map_err(|err| HassError::from(err));
451
452 if let Ok(Response::Event(event)) = payload {
453 Ok(event)
454 } else {
455 Err(Ok(Message::Text(data)))
456 }
457 }
458 result => Err(result),
459 }
460}