1pub mod commands;
2mod live_data;
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{anyhow, ensure, Context, Result};
8use futures_util::{SinkExt, StreamExt};
9use log::debug;
10use rustls::client::danger;
11use rustls::crypto::ring::default_provider;
12use rustls::crypto::{verify_tls12_signature, verify_tls13_signature, WebPkiSupportedAlgorithms};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, Error, SignatureScheme};
15use serde::de::DeserializeOwned;
16use serde::Deserialize;
17use serde::Serialize;
18use serde_json::{json, Value};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21use tokio_tungstenite::tungstenite::Message;
22use tokio_tungstenite::{
23 connect_async_tls_with_config, Connector, MaybeTlsStream, WebSocketStream,
24};
25
26pub use live_data::LiveData;
27
28type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
29
30pub struct Client {
31 url: String,
32 token: String,
33 conn: Option<WsStream>,
34 opts: Opts,
35}
36
37#[non_exhaustive]
38pub struct Opts {
39 pub timeout: Duration,
40}
41
42impl Default for Opts {
43 fn default() -> Self {
44 Self {
45 timeout: Duration::from_secs(15),
46 }
47 }
48}
49
50impl Client {
51 pub fn from_env() -> Result<Self> {
52 Self::new(env_var("NEOHUB_URL")?, env_var("NEOHUB_TOKEN")?)
53 }
54
55 pub fn new(url: impl ToString, token: impl ToString) -> Result<Self> {
56 Self::new_opts(url, token, Opts::default())
57 }
58
59 pub fn new_opts(url: impl ToString, token: impl ToString, opts: Opts) -> Result<Self> {
60 Ok(Client {
61 url: url.to_string(),
62 token: token.to_string(),
63 conn: None,
64 opts,
65 })
66 }
67
68 #[inline]
69 async fn ensure_connected(&mut self) -> Result<&mut WsStream> {
70 if self.conn.is_none() {
71 self.conn = Some(connect(&self.url).await?);
72 }
73 Ok(self.conn.as_mut().expect("we just set it"))
74 }
75
76 pub async fn raw_message(&mut self, msg: &str) -> Result<(String, String)> {
77 timeout(self.opts.timeout, self.raw_message_inner(msg))
78 .await
79 .with_context(|| "timeout sending raw message")?
80 }
81
82 async fn raw_message_inner(&mut self, msg: &str) -> Result<(String, String)> {
83 let middle = serde_json::to_string(&json!({
84 "token": self.token,
85 "COMMANDS": [
86 { "COMMAND": msg, "COMMANDID": 1, }
87 ]
88 }))?;
89 let outer = json!({
90 "message_type": "hm_get_command_queue",
91 "message": middle,
92 });
93 let to_send = serde_json::to_string(&outer)?;
94
95 let conn = self.ensure_connected().await?;
96 debug!("sending: {}", to_send);
97
98 conn.feed(Message::Text(to_send)).await?;
99 conn.flush().await?;
100
101 debug!("receiving");
102 let buf = conn
103 .next()
104 .await
105 .ok_or_else(|| anyhow!("no response received to command"))?
106 .with_context(|| "unpacking websocket message")?
107 .into_data();
108 let resp: CommandResponse =
109 serde_json::from_slice(&buf).with_context(|| "JSON-deserializing response")?;
110 ensure!(
111 resp.message_type == "hm_set_command_response" && resp.command_id == 1,
112 "unexpected response type or id: {:?}",
113 resp
114 );
115 Ok((resp.device_id, resp.response))
116 }
117
118 pub async fn command_void<T: DeserializeOwned>(&mut self, command: &str) -> Result<T> {
119 let (_, resp) = self.raw_message(&serialise_void(command)).await?;
120 serde_json::from_str(&resp).with_context(|| anyhow!("reading {:?}", resp))
121 }
122
123 pub async fn command_str<T: DeserializeOwned>(
124 &mut self,
125 command: &str,
126 arg: &str,
127 ) -> Result<T> {
128 let (_, resp) = self
129 .raw_message(&format!("{{'{}':'{}'}}", command, arg))
130 .await?;
131 serde_json::from_str(&resp).with_context(|| anyhow!("reading {:?}", resp))
132 }
133
134 pub async fn identify(&mut self) -> Result<Identity> {
135 let (device_id, resp) = self
136 .raw_message(&serialise_void("FIRMWARE"))
137 .await
138 .with_context(|| "requesting FIRMWARE version")?;
139 let firmware: Value = serde_json::from_str(&resp)?;
140 Ok(Identity {
141 device_id,
142 firmware_version: firmware
143 .get("firmware version")
144 .and_then(|v| v.as_str())
145 .map(str::to_owned),
146 })
147 }
148
149 pub async fn disconnect(&mut self) -> Result<()> {
150 let conn = match self.conn.as_mut() {
151 None => return Ok(()),
152 Some(conn) => conn,
153 };
154
155 let shutdown_result = timeout(self.opts.timeout, conn.close(None))
156 .await
157 .with_context(|| "timeout disconnecting");
158
159 self.conn = None;
160
161 Ok(shutdown_result??)
162 }
163}
164
165#[inline]
166fn serialise_void(command: &str) -> String {
167 format!("{{'{}':0}}", command)
168}
169
170#[derive(Deserialize, Debug)]
171struct CommandResponse {
172 command_id: i64,
174
175 device_id: String,
177
178 message_type: String,
180
181 response: String,
183}
184
185#[derive(Debug, Clone)]
186pub struct Identity {
187 pub device_id: String,
188 pub firmware_version: Option<String>,
189}
190
191#[derive(Deserialize, Serialize, Debug)]
192pub struct Profile {
193 #[serde(rename = "PROFILE_ID")]
195 pub profile_id: u16,
196 #[serde(rename = "P_TYPE")]
198 pub p_type: u16,
199 pub info: ProfileInfo,
200 pub name: String,
201}
202
203#[derive(Deserialize, Serialize, Debug)]
204pub struct ProfileInfo {
205 pub monday: ProfileInfoDay,
206 pub tuesday: ProfileInfoDay,
207 pub wednesday: ProfileInfoDay,
208 pub thursday: ProfileInfoDay,
209 pub friday: ProfileInfoDay,
210 pub saturday: ProfileInfoDay,
211 pub sunday: ProfileInfoDay,
212}
213
214type TempSpec = [Value; 4];
215
216#[derive(Deserialize, Serialize, Debug)]
217pub struct ProfileInfoDay {
218 wake: TempSpec,
219 leave: TempSpec,
220 #[serde(rename = "return")]
221 ret: TempSpec,
222 sleep: TempSpec,
223}
224
225#[derive(Debug)]
226struct IgnoreAllCertificateSecurity(WebPkiSupportedAlgorithms);
227
228impl danger::ServerCertVerifier for IgnoreAllCertificateSecurity {
229 fn verify_server_cert(
230 &self,
231 _end_entity: &CertificateDer<'_>,
232 _intermediates: &[CertificateDer<'_>],
233 _server_name: &ServerName<'_>,
234 _ocsp_response: &[u8],
235 _now: UnixTime,
236 ) -> std::result::Result<danger::ServerCertVerified, Error> {
237 Ok(danger::ServerCertVerified::assertion())
238 }
239
240 fn verify_tls12_signature(
242 &self,
243 message: &[u8],
244 cert: &CertificateDer<'_>,
245 dss: &DigitallySignedStruct,
246 ) -> std::result::Result<danger::HandshakeSignatureValid, Error> {
247 verify_tls12_signature(message, cert, dss, &self.0)
248 }
249
250 fn verify_tls13_signature(
251 &self,
252 message: &[u8],
253 cert: &CertificateDer<'_>,
254 dss: &DigitallySignedStruct,
255 ) -> std::result::Result<danger::HandshakeSignatureValid, Error> {
256 verify_tls13_signature(message, cert, dss, &self.0)
257 }
258
259 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
260 self.0.supported_schemes()
261 }
262}
263
264async fn connect(url: &str) -> Result<WsStream> {
265 debug!("attempting connection");
266 let connector = Connector::Rustls(Arc::new(
267 rustls::ClientConfig::builder()
268 .dangerous()
269 .with_custom_certificate_verifier(Arc::new(IgnoreAllCertificateSecurity(
270 default_provider().signature_verification_algorithms,
271 )))
272 .with_no_client_auth(),
273 ));
274 let (conn, _) = connect_async_tls_with_config(url, None, true, Some(connector)).await?;
275 debug!("connected");
276 Ok(conn)
277}
278
279fn env_var(key: &'static str) -> Result<String> {
280 std::env::var(key).with_context(|| anyhow!("env var required: {key:?}"))
281}