neohub/
lib.rs

1pub mod commands;
2mod live_data;
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{anyhow, ensure, Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use log::debug;
10use rustls::client::danger;
11use rustls::crypto::ring::default_provider;
12use rustls::crypto::{verify_tls12_signature, verify_tls13_signature, WebPkiSupportedAlgorithms};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, Error, SignatureScheme};
15use serde::de::DeserializeOwned;
16use serde::Deserialize;
17use serde::Serialize;
18use serde_json::{json, Value};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21use tokio_tungstenite::tungstenite::Message;
22use tokio_tungstenite::{
23    connect_async_tls_with_config, Connector, MaybeTlsStream, WebSocketStream,
24};
25
26pub use live_data::LiveData;
27
28type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
29
30pub struct Client {
31    url: String,
32    token: String,
33    conn: Option<WsStream>,
34    opts: Opts,
35}
36
37#[non_exhaustive]
38pub struct Opts {
39    pub timeout: Duration,
40}
41
42impl Default for Opts {
43    fn default() -> Self {
44        Self {
45            timeout: Duration::from_secs(15),
46        }
47    }
48}
49
50impl Client {
51    pub fn from_env() -> Result<Self> {
52        Self::new(env_var("NEOHUB_URL")?, env_var("NEOHUB_TOKEN")?)
53    }
54
55    pub fn new(url: impl ToString, token: impl ToString) -> Result<Self> {
56        Self::new_opts(url, token, Opts::default())
57    }
58
59    pub fn new_opts(url: impl ToString, token: impl ToString, opts: Opts) -> Result<Self> {
60        Ok(Client {
61            url: url.to_string(),
62            token: token.to_string(),
63            conn: None,
64            opts,
65        })
66    }
67
68    #[inline]
69    async fn ensure_connected(&mut self) -> Result<&mut WsStream> {
70        if self.conn.is_none() {
71            self.conn = Some(connect(&self.url).await?);
72        }
73        Ok(self.conn.as_mut().expect("we just set it"))
74    }
75
76    pub async fn raw_message(&mut self, msg: &str) -> Result<(String, String)> {
77        timeout(self.opts.timeout, self.raw_message_inner(msg))
78            .await
79            .with_context(|| "timeout sending raw message")?
80    }
81
82    async fn raw_message_inner(&mut self, msg: &str) -> Result<(String, String)> {
83        let middle = serde_json::to_string(&json!({
84            "token": self.token,
85            "COMMANDS": [
86                { "COMMAND": msg, "COMMANDID": 1, }
87            ]
88        }))?;
89        let outer = json!({
90            "message_type": "hm_get_command_queue",
91            "message": middle,
92        });
93        let to_send = serde_json::to_string(&outer)?;
94
95        let conn = self.ensure_connected().await?;
96        debug!("sending: {}", to_send);
97
98        conn.feed(Message::Text(to_send)).await?;
99        conn.flush().await?;
100
101        debug!("receiving");
102        let buf = conn
103            .next()
104            .await
105            .ok_or_else(|| anyhow!("no response received to command"))?
106            .with_context(|| "unpacking websocket message")?
107            .into_data();
108        let resp: CommandResponse =
109            serde_json::from_slice(&buf).with_context(|| "JSON-deserializing response")?;
110        ensure!(
111            resp.message_type == "hm_set_command_response" && resp.command_id == 1,
112            "unexpected response type or id: {:?}",
113            resp
114        );
115        Ok((resp.device_id, resp.response))
116    }
117
118    pub async fn command_void<T: DeserializeOwned>(&mut self, command: &str) -> Result<T> {
119        let (_, resp) = self.raw_message(&serialise_void(command)).await?;
120        serde_json::from_str(&resp).with_context(|| anyhow!("reading {:?}", resp))
121    }
122
123    pub async fn command_str<T: DeserializeOwned>(
124        &mut self,
125        command: &str,
126        arg: &str,
127    ) -> Result<T> {
128        let (_, resp) = self
129            .raw_message(&format!("{{'{}':'{}'}}", command, arg))
130            .await?;
131        serde_json::from_str(&resp).with_context(|| anyhow!("reading {:?}", resp))
132    }
133
134    pub async fn identify(&mut self) -> Result<Identity> {
135        let (device_id, resp) = self
136            .raw_message(&serialise_void("FIRMWARE"))
137            .await
138            .with_context(|| "requesting FIRMWARE version")?;
139        let firmware: Value = serde_json::from_str(&resp)?;
140        Ok(Identity {
141            device_id,
142            firmware_version: firmware
143                .get("firmware version")
144                .and_then(|v| v.as_str())
145                .map(str::to_owned),
146        })
147    }
148
149    pub async fn disconnect(&mut self) -> Result<()> {
150        let conn = match self.conn.as_mut() {
151            None => return Ok(()),
152            Some(conn) => conn,
153        };
154
155        let shutdown_result = timeout(self.opts.timeout, conn.close(None))
156            .await
157            .with_context(|| "timeout disconnecting");
158
159        self.conn = None;
160
161        Ok(shutdown_result??)
162    }
163}
164
165#[inline]
166fn serialise_void(command: &str) -> String {
167    format!("{{'{}':0}}", command)
168}
169
170#[derive(Deserialize, Debug)]
171struct CommandResponse {
172    // we always send a fixed value (1)
173    command_id: i64,
174
175    // mac-address-like string
176    device_id: String,
177
178    // hm_set_command_response
179    message_type: String,
180
181    // json, in a string
182    response: String,
183}
184
185#[derive(Debug, Clone)]
186pub struct Identity {
187    pub device_id: String,
188    pub firmware_version: Option<String>,
189}
190
191#[derive(Deserialize, Serialize, Debug)]
192pub struct Profile {
193    // 1-..
194    #[serde(rename = "PROFILE_ID")]
195    pub profile_id: u16,
196    // 0
197    #[serde(rename = "P_TYPE")]
198    pub p_type: u16,
199    pub info: ProfileInfo,
200    pub name: String,
201}
202
203#[derive(Deserialize, Serialize, Debug)]
204pub struct ProfileInfo {
205    pub monday: ProfileInfoDay,
206    pub tuesday: ProfileInfoDay,
207    pub wednesday: ProfileInfoDay,
208    pub thursday: ProfileInfoDay,
209    pub friday: ProfileInfoDay,
210    pub saturday: ProfileInfoDay,
211    pub sunday: ProfileInfoDay,
212}
213
214type TempSpec = [Value; 4];
215
216#[derive(Deserialize, Serialize, Debug)]
217pub struct ProfileInfoDay {
218    wake: TempSpec,
219    leave: TempSpec,
220    #[serde(rename = "return")]
221    ret: TempSpec,
222    sleep: TempSpec,
223}
224
225#[derive(Debug)]
226struct IgnoreAllCertificateSecurity(WebPkiSupportedAlgorithms);
227
228impl danger::ServerCertVerifier for IgnoreAllCertificateSecurity {
229    fn verify_server_cert(
230        &self,
231        _end_entity: &CertificateDer<'_>,
232        _intermediates: &[CertificateDer<'_>],
233        _server_name: &ServerName<'_>,
234        _ocsp_response: &[u8],
235        _now: UnixTime,
236    ) -> std::result::Result<danger::ServerCertVerified, Error> {
237        Ok(danger::ServerCertVerified::assertion())
238    }
239
240    // copy-paste of WebPkiServerVerifier, the only other implementation of this trait
241    fn verify_tls12_signature(
242        &self,
243        message: &[u8],
244        cert: &CertificateDer<'_>,
245        dss: &DigitallySignedStruct,
246    ) -> std::result::Result<danger::HandshakeSignatureValid, Error> {
247        verify_tls12_signature(message, cert, dss, &self.0)
248    }
249
250    fn verify_tls13_signature(
251        &self,
252        message: &[u8],
253        cert: &CertificateDer<'_>,
254        dss: &DigitallySignedStruct,
255    ) -> std::result::Result<danger::HandshakeSignatureValid, Error> {
256        verify_tls13_signature(message, cert, dss, &self.0)
257    }
258
259    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
260        self.0.supported_schemes()
261    }
262}
263
264async fn connect(url: &str) -> Result<WsStream> {
265    debug!("attempting connection");
266    let connector = Connector::Rustls(Arc::new(
267        rustls::ClientConfig::builder()
268            .dangerous()
269            .with_custom_certificate_verifier(Arc::new(IgnoreAllCertificateSecurity(
270                default_provider().signature_verification_algorithms,
271            )))
272            .with_no_client_auth(),
273    ));
274    let (conn, _) = connect_async_tls_with_config(url, None, true, Some(connector)).await?;
275    debug!("connected");
276    Ok(conn)
277}
278
279fn env_var(key: &'static str) -> Result<String> {
280    std::env::var(key).with_context(|| anyhow!("env var required: {key:?}"))
281}