1use base64::prelude::*;
2use futures::sink::SinkExt;
3use log::debug;
4use log::error;
5use log::info;
6use log::trace;
7use noise_protocol::CipherState;
8use noise_protocol::ErrorKind;
9use noise_protocol::HandshakeState;
10use noise_protocol::patterns::noise_nn_psk0;
11use noise_rust_crypto::ChaCha20Poly1305;
12use noise_rust_crypto::Sha256;
13use noise_rust_crypto::X25519;
14use std::sync::Arc;
15use std::sync::atomic::AtomicBool;
16use tokio::io::AsyncWriteExt;
17use tokio::net::TcpStream;
18use tokio::net::tcp::OwnedWriteHalf;
19use tokio::sync::Mutex;
20use tokio::sync::broadcast;
21use tokio::sync::mpsc;
22use tokio::sync::oneshot;
23use tokio_stream::StreamExt;
24use tokio_util::codec::FramedRead;
25use tokio_util::codec::FramedWrite;
26use typed_builder::TypedBuilder;
27
28use crate::frame::FrameCodec;
29use crate::packet_encrypted;
30use crate::packet_plaintext;
31use crate::parser::ProtoMessage;
32use crate::proto::{self, DeviceInfoResponse, DisconnectResponse, HelloResponse, PingResponse};
33
34async fn write_error_and_disconnect(
35 mut writer: FramedWrite<OwnedWriteHalf, FrameCodec>,
36 message: &str,
37) {
38 error!("API Failure: {}. Disconnecting.", message);
39 let packet = [[1].to_vec(), message.as_bytes().to_vec()].concat();
40 writer.send(packet).await.unwrap();
41 writer.flush().await.unwrap();
42 let mut tcp_write = writer.into_inner();
43 if let Err(err) = tcp_write.shutdown().await {
44 error!("failed to shutdown socket: {:?}", err);
45 }
46}
47
48const ERROR_ONLY_ENCRYPTED: &str = "Only key encryption is enabled";
49const ERROR_HANDSHAKE_MAC_FAILURE: &str = "Handshake MAC failure";
50
51#[derive(TypedBuilder, Clone)]
52pub struct EspHomeApi {
53 #[builder(default=Arc::new(AtomicBool::new(false)))]
55 pub(crate) first_message_received: Arc<AtomicBool>,
56
57 #[builder(default=Arc::new(AtomicBool::new(true)))]
58 pub(crate) plaintext_communication: Arc<AtomicBool>,
59
60 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
61 pub(crate) encrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
62 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
63 pub(crate) decrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
64
65 name: String,
66
67 #[builder(default = None, setter(strip_option(fallback=encryption_key_opt)))]
68 encryption_key: Option<String>,
69
70 #[builder(default = 1)]
71 api_version_major: u32,
72 #[builder(default = 10)]
73 api_version_minor: u32,
74 #[builder(default="Rust: esphome-native-api".to_string())]
75 server_info: String,
76
77 #[builder(default = None, setter(strip_option(fallback=friendly_name_opt)))]
78 friendly_name: Option<String>,
79
80 #[builder(default = None, setter(strip_option(fallback=mac_opt)))]
81 mac: Option<String>,
82
83 #[builder(default = None, setter(strip_option(fallback=model_opt)))]
84 model: Option<String>,
85
86 #[builder(default = None, setter(strip_option(fallback=manufacturer_opt)))]
87 manufacturer: Option<String>,
88 #[builder(default = None, setter(strip_option(fallback=suggested_area_opt)))]
89 suggested_area: Option<String>,
90 #[builder(default = None, setter(strip_option(fallback=bluetooth_mac_address_opt)))]
91 bluetooth_mac_address: Option<String>,
92
93 #[builder(default = None, setter(strip_option(fallback=project_name_opt)))]
94 project_name: Option<String>,
95
96 #[builder(default = None, setter(strip_option(fallback=project_version_opt)))]
97 project_version: Option<String>,
98 #[builder(default = None, setter(strip_option(fallback=compilation_time_opt)))]
99 compilation_time: Option<String>,
100
101 #[builder(default = 0)]
102 legacy_bluetooth_proxy_version: u32,
103 #[builder(default = 0)]
104 bluetooth_proxy_feature_flags: u32,
105 #[builder(default = 0)]
106 legacy_voice_assistant_version: u32,
107 #[builder(default = 0)]
108 voice_assistant_feature_flags: u32,
109}
110
111impl EspHomeApi {
113 pub async fn start(
116 &mut self,
117 tcp_stream: TcpStream,
118 ) -> Result<
119 (
120 mpsc::Sender<ProtoMessage>,
121 broadcast::Receiver<ProtoMessage>,
122 ),
123 Box<dyn std::error::Error>,
124 > {
125 let (answer_messages_tx, mut answer_messages_rx) = mpsc::channel::<ProtoMessage>(16);
127 let (outgoing_messages_tx, outgoing_messages_rx) = broadcast::channel::<ProtoMessage>(16);
128
129 #[allow(deprecated)]
130 let device_info = DeviceInfoResponse {
131 api_encryption_supported: self.encryption_key.is_some(),
132 uses_password: false,
133 name: self.name.clone(),
134 mac_address: self.mac.clone().unwrap_or_default(),
135 esphome_version: proto::VERSION.to_owned(),
136 compilation_time: self.compilation_time.clone().unwrap_or_default(),
137 model: self.model.clone().unwrap_or_default(),
138 has_deep_sleep: false,
139 project_name: self.project_name.clone().unwrap_or_default(),
140 project_version: self.project_version.clone().unwrap_or_default(),
141 webserver_port: 0,
142 legacy_bluetooth_proxy_version: self.legacy_bluetooth_proxy_version,
144 bluetooth_proxy_feature_flags: self.bluetooth_proxy_feature_flags,
145 manufacturer: self.manufacturer.clone().unwrap_or_default(),
146 friendly_name: self.friendly_name.clone().unwrap_or(self.name.clone()),
147 legacy_voice_assistant_version: self.legacy_voice_assistant_version,
148 voice_assistant_feature_flags: self.voice_assistant_feature_flags,
149 suggested_area: self.suggested_area.clone().unwrap_or_default(),
150 bluetooth_mac_address: self.bluetooth_mac_address.clone().unwrap_or_default(),
151 areas: vec![],
152 devices: vec![],
153 area: None,
154 zwave_proxy_feature_flags: 0,
155 zwave_home_id: 0,
156 };
157
158 let hello_response = HelloResponse {
159 api_version_major: self.api_version_major,
160 api_version_minor: self.api_version_minor,
161 server_info: self.server_info.clone(),
162 name: self.name.clone(),
163 };
164
165 let encrypt_cypher_clone = self.encrypt_cypher.clone();
166 let decrypt_cypher_clone = self.decrypt_cypher.clone();
167
168 trace!("Init Connection: Stage 1");
170 let encryption_key = self.encryption_key.clone();
171
172 let mut buf = vec![0; 1];
173 let n = tcp_stream
174 .peek(&mut buf)
175 .await
176 .expect("failed to read data from socket");
177
178 if n == 0 {
179 return Err("No data".into());
180 }
181
182 trace!("TCP Peeked: {:02X?}", &buf[0..n]);
183
184 let preamble = buf[0] as usize;
185
186 let first_message_received = self
187 .first_message_received
188 .load(std::sync::atomic::Ordering::Relaxed);
189
190 if !first_message_received {
191 match preamble {
192 0 => {
193 debug!("Cleartext messaging");
194
195 self.plaintext_communication
196 .store(true, std::sync::atomic::Ordering::Relaxed);
197 }
198 1 => {
199 trace!("Encrypted messaging");
200
201 self.plaintext_communication
202 .store(false, std::sync::atomic::Ordering::Relaxed);
203 }
204 _ => {
205 return Err(format!("Invalid marker byte {}", preamble).into());
206 }
207 }
208 self.first_message_received
209 .store(true, std::sync::atomic::Ordering::Relaxed);
210 }
211
212 let plaintext_communication = self
213 .plaintext_communication
214 .load(std::sync::atomic::Ordering::Relaxed);
215 let encrypted = !plaintext_communication;
216
217 let (tcp_read, tcp_write) = tcp_stream.into_split();
218 let decoder = FrameCodec::new(encrypted);
219 let encoder = FrameCodec::new(encrypted);
220 let mut reader = FramedRead::new(tcp_read, decoder);
221 let mut writer = FramedWrite::new(tcp_write, encoder);
222
223 if plaintext_communication {
224 if self.encryption_key.is_some() {
225 let encoder = FrameCodec::new(true);
226 let writer = FramedWrite::new(writer.into_inner(), encoder);
227 write_error_and_disconnect(writer, ERROR_ONLY_ENCRYPTED).await;
228 return Err(ERROR_ONLY_ENCRYPTED.into());
229 }
230 } else {
231 if self.encryption_key.is_none() {
232 write_error_and_disconnect(writer, "No encrypted communication allowed").await;
233 return Err("No encryption key set, but encrypted communication requested.".into());
234 }
235
236 let frame_noise_hello = reader.next().await.unwrap().unwrap();
237 debug!("Frame 1: {:02X?}", &frame_noise_hello);
238
239 let message_server_hello =
240 packet_encrypted::generate_server_hello_frame(self.name.clone(), self.mac.clone());
241
242 writer.send(message_server_hello.clone()).await.unwrap();
243 writer.flush().await.unwrap();
244
245 let frame_handshake_request = reader.next().await.unwrap().unwrap();
246 debug!("Frame 2: {:02X?}", &frame_handshake_request);
247
248 let mut handshake_state: HandshakeState<X25519, ChaCha20Poly1305, Sha256> =
250 HandshakeState::new(
251 noise_nn_psk0(),
252 false,
253 b"NoiseAPIInit\0\0",
255 None,
256 None,
257 None,
258 None,
259 );
260
261 let noise_psk = BASE64_STANDARD
262 .decode(encryption_key.as_ref().unwrap())
263 .unwrap();
264
265 handshake_state.push_psk(&noise_psk);
266 match handshake_state.read_message_vec(&frame_handshake_request[1..]) {
268 Ok(_) => {}
269 Err(e) => match e.kind() {
270 ErrorKind::Decryption => {
271 write_error_and_disconnect(writer, ERROR_HANDSHAKE_MAC_FAILURE).await;
272 return Err(ERROR_HANDSHAKE_MAC_FAILURE.into());
273 }
274 _ => {
275 debug!("Failed to read message: {}", e);
276 }
277 },
278 }
279
280 let out = handshake_state.write_message_vec(b"").unwrap();
281 {
282 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
283 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
284 let (decrypt_cipher, encrypt_cipher) = handshake_state.get_ciphers();
285 *encrypt_cipher_changer = Some(encrypt_cipher);
286 *decrypt_cipher_changer = Some(decrypt_cipher);
287 }
288
289 let mut message_handshake = vec![0];
290 message_handshake.extend(out);
291
292 debug!("Sending handshake");
293 writer.send(message_handshake.clone()).await.unwrap();
294 writer.flush().await.unwrap();
295 }
296
297 debug!("Initialization done.");
298
299 let (cancellation_write_tx, mut cancellation_write_rx) = oneshot::channel();
301
302 let plaintext_communication = self.plaintext_communication.clone();
304 tokio::spawn(async move {
305 loop {
306 let answer_message: ProtoMessage;
307
308 tokio::select! {
310 biased; cancel_message = &mut cancellation_write_rx => {
312 debug!("Write loop received cancellation signal ({}), exiting.", cancel_message.unwrap());
313 break;
314 }
315 message = answer_messages_rx.recv() => {
316 answer_message = message.unwrap();
317 }
318 };
319
320 debug!("Answer message: {:?}", answer_message);
321
322 if plaintext_communication.load(std::sync::atomic::Ordering::Relaxed) {
323 writer
324 .send(packet_plaintext::message_to_packet(&answer_message).unwrap())
325 .await
326 .unwrap();
327 } else {
330 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
332 writer
333 .send(
334 packet_encrypted::message_to_packet(
335 &answer_message,
336 &mut *encrypt_cipher_changer.as_mut().unwrap(),
337 )
338 .unwrap(),
339 )
340 .await
341 .unwrap();
342 }
343 writer.flush().await.unwrap();
344
345 if matches!(answer_message, ProtoMessage::DisconnectResponse(_)) {
346 debug!("Disconnecting");
347 let mut tcp_write = writer.into_inner();
348 match tcp_write.shutdown().await {
349 Err(err) => {
350 error!("failed to shutdown socket: {:?}", err);
351 break;
352 }
353 _ => break,
354 }
355 }
356 }
357 });
358
359 let answer_messages_tx_clone = answer_messages_tx.clone();
361 let decrypt_cypher_clone = self.decrypt_cypher.clone();
362 tokio::spawn(async move {
364 loop {
365 let next = reader.next().await;
366 if next.is_none() {
367 info!("Read loop stopped because stream finished");
368 let _ = cancellation_write_tx.send("read loop finished");
370 break;
371 }
372 let frame = next.unwrap().unwrap();
373 trace!("TCP Receive: {:02X?}", &frame);
374
375 let message;
376 if encrypted {
377 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
378 message = packet_encrypted::packet_to_message(
379 &frame,
380 &mut *decrypt_cipher_changer.as_mut().unwrap(),
381 )
382 .unwrap();
383 } else {
384 message = packet_plaintext::packet_to_message(&frame).unwrap();
385 }
386
387 match &message {
389 ProtoMessage::DisconnectRequest(disconnect_request) => {
390 debug!("DisconnectRequest: {:?}", disconnect_request);
391 let response_message = DisconnectResponse {};
392 answer_messages_tx_clone
393 .send(ProtoMessage::DisconnectResponse(response_message))
394 .await
395 .unwrap();
396 continue;
397 }
398 ProtoMessage::PingRequest(ping_request) => {
399 debug!("PingRequest: {:?}", ping_request);
400 let response_message = PingResponse {};
401 answer_messages_tx_clone
402 .send(ProtoMessage::PingResponse(response_message))
403 .await
404 .unwrap();
405 }
406 ProtoMessage::DeviceInfoRequest(device_info_request) => {
407 debug!("DeviceInfoRequest: {:?}", device_info_request);
408 answer_messages_tx_clone
409 .send(ProtoMessage::DeviceInfoResponse(device_info.clone()))
410 .await
411 .unwrap();
412 }
413 ProtoMessage::HelloRequest(hello_request) => {
414 debug!("HelloRequest: {:?}", hello_request);
415
416 answer_messages_tx_clone
417 .send(ProtoMessage::HelloResponse(hello_response.clone()))
418 .await
419 .unwrap();
420 }
421 ProtoMessage::AuthenticationRequest(_) => {
422 info!("Password Authentication is not supported");
423 }
424 message => {
425 outgoing_messages_tx.send(message.clone()).unwrap();
426 }
427 }
428 }
429 });
430
431 Ok((answer_messages_tx.clone(), outgoing_messages_rx))
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
439
440 #[test]
441 fn test_basic_server_instantiation() {
442 EspHomeApi::builder()
443 .name("test_device".to_string())
444 .build();
445 }
446}