1use base64::prelude::*;
49use futures::sink::SinkExt;
50use log::debug;
51use log::error;
52use log::info;
53use log::trace;
54use noise_protocol::CipherState;
55use noise_protocol::ErrorKind;
56use noise_protocol::HandshakeState;
57use noise_protocol::patterns::noise_nn_psk0;
58use noise_rust_crypto::ChaCha20Poly1305;
59use noise_rust_crypto::Sha256;
60use noise_rust_crypto::X25519;
61use std::sync::Arc;
62use std::sync::atomic::AtomicBool;
63use tokio::io::AsyncWriteExt;
64use tokio::net::TcpStream;
65use tokio::net::tcp::OwnedWriteHalf;
66use tokio::sync::Mutex;
67use tokio::sync::broadcast;
68use tokio::sync::mpsc;
69use tokio::sync::oneshot;
70use tokio_stream::StreamExt;
71use tokio_util::codec::FramedRead;
72use tokio_util::codec::FramedWrite;
73use typed_builder::TypedBuilder;
74
75use crate::frame::FrameCodec;
76use crate::packet_encrypted;
77use crate::packet_plaintext;
78use crate::parser::ProtoMessage;
79use crate::proto::{
80 self, AuthenticationResponse, DeviceInfoResponse, DisconnectResponse, HelloResponse,
81 PingResponse,
82};
83
84async fn write_error_and_disconnect(
85 mut writer: FramedWrite<OwnedWriteHalf, FrameCodec>,
86 message: &str,
87) {
88 error!("API Failure: {}. Disconnecting.", message);
89 let packet = [[1].to_vec(), message.as_bytes().to_vec()].concat();
90 writer.send(packet).await.unwrap();
91 writer.flush().await.unwrap();
92 let mut tcp_write = writer.into_inner();
93 if let Err(err) = tcp_write.shutdown().await {
94 error!("failed to shutdown socket: {:?}", err);
95 }
96}
97
98const ERROR_ONLY_ENCRYPTED: &str = "Only key encryption is enabled";
99const ERROR_HANDSHAKE_MAC_FAILURE: &str = "Handshake MAC failure";
100
101#[derive(TypedBuilder, Clone)]
137pub struct EspHomeApi {
138 #[builder(default=Arc::new(AtomicBool::new(false)))]
140 pub(crate) first_message_received: Arc<AtomicBool>,
141
142 #[builder(default=Arc::new(AtomicBool::new(true)))]
143 pub(crate) plaintext_communication: Arc<AtomicBool>,
144
145 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
146 pub(crate) encrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
147 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
148 pub(crate) decrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
149
150 name: String,
151
152 #[builder(default = None, setter(strip_option(fallback=encryption_key_opt)))]
153 encryption_key: Option<String>,
154
155 #[builder(default = 1)]
156 api_version_major: u32,
157 #[builder(default = 10)]
158 api_version_minor: u32,
159 #[builder(default="Rust: esphome-native-api".to_string())]
160 server_info: String,
161
162 #[builder(default = None, setter(strip_option(fallback=friendly_name_opt)))]
163 friendly_name: Option<String>,
164
165 #[builder(default = None, setter(strip_option(fallback=mac_opt)))]
166 mac: Option<String>,
167
168 #[builder(default = None, setter(strip_option(fallback=model_opt)))]
169 model: Option<String>,
170
171 #[builder(default = None, setter(strip_option(fallback=manufacturer_opt)))]
172 manufacturer: Option<String>,
173 #[builder(default = None, setter(strip_option(fallback=suggested_area_opt)))]
174 suggested_area: Option<String>,
175 #[builder(default = None, setter(strip_option(fallback=bluetooth_mac_address_opt)))]
176 bluetooth_mac_address: Option<String>,
177
178 #[builder(default = None, setter(strip_option(fallback=project_name_opt)))]
179 project_name: Option<String>,
180
181 #[builder(default = None, setter(strip_option(fallback=project_version_opt)))]
182 project_version: Option<String>,
183 #[builder(default = None, setter(strip_option(fallback=compilation_time_opt)))]
184 compilation_time: Option<String>,
185
186 #[builder(default = 0)]
187 legacy_bluetooth_proxy_version: u32,
188 #[builder(default = 0)]
189 bluetooth_proxy_feature_flags: u32,
190 #[builder(default = 0)]
191 legacy_voice_assistant_version: u32,
192 #[builder(default = 0)]
193 voice_assistant_feature_flags: u32,
194}
195
196impl EspHomeApi {
198 pub async fn start(
237 &mut self,
238 tcp_stream: TcpStream,
239 ) -> Result<
240 (
241 mpsc::Sender<ProtoMessage>,
242 broadcast::Receiver<ProtoMessage>,
243 ),
244 Box<dyn std::error::Error>,
245 > {
246 let (answer_messages_tx, mut answer_messages_rx) = mpsc::channel::<ProtoMessage>(16);
248 let (outgoing_messages_tx, outgoing_messages_rx) = broadcast::channel::<ProtoMessage>(16);
249
250 #[allow(deprecated)]
251 let device_info = DeviceInfoResponse {
252 api_encryption_supported: self.encryption_key.is_some(),
253 uses_password: false,
254 name: self.name.clone(),
255 mac_address: self.mac.clone().unwrap_or_default(),
256 esphome_version: proto::VERSION.to_owned(),
257 compilation_time: self.compilation_time.clone().unwrap_or_default(),
258 model: self.model.clone().unwrap_or_default(),
259 has_deep_sleep: false,
260 project_name: self.project_name.clone().unwrap_or_default(),
261 project_version: self.project_version.clone().unwrap_or_default(),
262 webserver_port: 0,
263 legacy_bluetooth_proxy_version: self.legacy_bluetooth_proxy_version,
265 bluetooth_proxy_feature_flags: self.bluetooth_proxy_feature_flags,
266 manufacturer: self.manufacturer.clone().unwrap_or_default(),
267 friendly_name: self.friendly_name.clone().unwrap_or(self.name.clone()),
268 legacy_voice_assistant_version: self.legacy_voice_assistant_version,
269 voice_assistant_feature_flags: self.voice_assistant_feature_flags,
270 suggested_area: self.suggested_area.clone().unwrap_or_default(),
271 bluetooth_mac_address: self.bluetooth_mac_address.clone().unwrap_or_default(),
272 areas: vec![],
273 devices: vec![],
274 area: None,
275 zwave_proxy_feature_flags: 0,
276 zwave_home_id: 0,
277 };
278
279 let hello_response = HelloResponse {
280 api_version_major: self.api_version_major,
281 api_version_minor: self.api_version_minor,
282 server_info: self.server_info.clone(),
283 name: self.name.clone(),
284 };
285
286 let encrypt_cypher_clone = self.encrypt_cypher.clone();
287 let decrypt_cypher_clone = self.decrypt_cypher.clone();
288
289 trace!("Init Connection: Stage 1");
291 let encryption_key = self.encryption_key.clone();
292
293 let mut buf = vec![0; 1];
294 let n = tcp_stream
295 .peek(&mut buf)
296 .await
297 .expect("failed to read data from socket");
298
299 if n == 0 {
300 return Err("No data".into());
301 }
302
303 trace!("TCP Peeked: {:02X?}", &buf[0..n]);
304
305 let preamble = buf[0] as usize;
306
307 let first_message_received = self
308 .first_message_received
309 .load(std::sync::atomic::Ordering::Relaxed);
310
311 if !first_message_received {
312 match preamble {
313 0 => {
314 debug!("Cleartext messaging");
315
316 self.plaintext_communication
317 .store(true, std::sync::atomic::Ordering::Relaxed);
318 }
319 1 => {
320 trace!("Encrypted messaging");
321
322 self.plaintext_communication
323 .store(false, std::sync::atomic::Ordering::Relaxed);
324 }
325 _ => {
326 return Err(format!("Invalid marker byte {}", preamble).into());
327 }
328 }
329 self.first_message_received
330 .store(true, std::sync::atomic::Ordering::Relaxed);
331 }
332
333 let plaintext_communication = self
334 .plaintext_communication
335 .load(std::sync::atomic::Ordering::Relaxed);
336 let encrypted = !plaintext_communication;
337
338 let (tcp_read, tcp_write) = tcp_stream.into_split();
339 let decoder = FrameCodec::new(encrypted);
340 let encoder = FrameCodec::new(encrypted);
341 let mut reader = FramedRead::new(tcp_read, decoder);
342 let mut writer = FramedWrite::new(tcp_write, encoder);
343
344 if plaintext_communication {
345 if self.encryption_key.is_some() {
346 let encoder = FrameCodec::new(true);
347 let writer = FramedWrite::new(writer.into_inner(), encoder);
348 write_error_and_disconnect(writer, ERROR_ONLY_ENCRYPTED).await;
349 return Err(ERROR_ONLY_ENCRYPTED.into());
350 }
351 } else {
352 if self.encryption_key.is_none() {
353 write_error_and_disconnect(writer, "No encrypted communication allowed").await;
354 return Err("No encryption key set, but encrypted communication requested.".into());
355 }
356
357 let frame_noise_hello = reader.next().await.unwrap().unwrap();
358 debug!("Frame 1: {:02X?}", &frame_noise_hello);
359
360 let message_server_hello =
361 packet_encrypted::generate_server_hello_frame(self.name.clone(), self.mac.clone());
362
363 writer.send(message_server_hello.clone()).await.unwrap();
364 writer.flush().await.unwrap();
365
366 let frame_handshake_request = reader.next().await.unwrap().unwrap();
367 debug!("Frame 2: {:02X?}", &frame_handshake_request);
368
369 let mut handshake_state: HandshakeState<X25519, ChaCha20Poly1305, Sha256> =
371 HandshakeState::new(
372 noise_nn_psk0(),
373 false,
374 b"NoiseAPIInit\0\0",
376 None,
377 None,
378 None,
379 None,
380 );
381
382 let noise_psk = BASE64_STANDARD
383 .decode(encryption_key.as_ref().unwrap())
384 .unwrap();
385
386 handshake_state.push_psk(&noise_psk);
387 match handshake_state.read_message_vec(&frame_handshake_request[1..]) {
389 Ok(_) => {}
390 Err(e) => match e.kind() {
391 ErrorKind::Decryption => {
392 write_error_and_disconnect(writer, ERROR_HANDSHAKE_MAC_FAILURE).await;
393 return Err(ERROR_HANDSHAKE_MAC_FAILURE.into());
394 }
395 _ => {
396 debug!("Failed to read message: {}", e);
397 }
398 },
399 }
400
401 let out = handshake_state.write_message_vec(b"").unwrap();
402 {
403 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
404 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
405 let (decrypt_cipher, encrypt_cipher) = handshake_state.get_ciphers();
406 *encrypt_cipher_changer = Some(encrypt_cipher);
407 *decrypt_cipher_changer = Some(decrypt_cipher);
408 }
409
410 let mut message_handshake = vec![0];
411 message_handshake.extend(out);
412
413 debug!("Sending handshake");
414 writer.send(message_handshake.clone()).await.unwrap();
415 writer.flush().await.unwrap();
416 }
417
418 debug!("Initialization done.");
419
420 let (cancellation_write_tx, mut cancellation_write_rx) = oneshot::channel();
422
423 let plaintext_communication = self.plaintext_communication.clone();
425 tokio::spawn(async move {
426 loop {
427 let answer_message: ProtoMessage;
428
429 tokio::select! {
431 biased; cancel_message = &mut cancellation_write_rx => {
433 debug!("Write loop received cancellation signal ({}), exiting.", cancel_message.unwrap());
434 break;
435 }
436 message = answer_messages_rx.recv() => {
437 answer_message = message.unwrap();
438 }
439 };
440
441 debug!("Answer message: {:?}", answer_message);
442
443 if plaintext_communication.load(std::sync::atomic::Ordering::Relaxed) {
444 writer
445 .send(packet_plaintext::message_to_packet(&answer_message).unwrap())
446 .await
447 .unwrap();
448 } else {
451 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
453 writer
454 .send(
455 packet_encrypted::message_to_packet(
456 &answer_message,
457 &mut *encrypt_cipher_changer.as_mut().unwrap(),
458 )
459 .unwrap(),
460 )
461 .await
462 .unwrap();
463 }
464 writer.flush().await.unwrap();
465
466 if matches!(answer_message, ProtoMessage::DisconnectResponse(_)) {
467 debug!("Disconnecting");
468 let mut tcp_write = writer.into_inner();
469 match tcp_write.shutdown().await {
470 Err(err) => {
471 error!("failed to shutdown socket: {:?}", err);
472 break;
473 }
474 _ => break,
475 }
476 }
477 }
478 });
479
480 let answer_messages_tx_clone = answer_messages_tx.clone();
482 let decrypt_cypher_clone = self.decrypt_cypher.clone();
483 tokio::spawn(async move {
485 loop {
486 let next = reader.next().await;
487 if next.is_none() {
488 info!("Read loop stopped because stream finished");
489 let _ = cancellation_write_tx.send("read loop finished");
491 break;
492 }
493 let frame = next.unwrap().unwrap();
494 trace!("TCP Receive: {:02X?}", &frame);
495
496 let message;
497 if encrypted {
498 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
499 message = packet_encrypted::packet_to_message(
500 &frame,
501 &mut *decrypt_cipher_changer.as_mut().unwrap(),
502 )
503 .unwrap();
504 } else {
505 message = packet_plaintext::packet_to_message(&frame).unwrap();
506 }
507
508 match &message {
510 ProtoMessage::DisconnectRequest(disconnect_request) => {
511 debug!("DisconnectRequest: {:?}", disconnect_request);
512 let response_message = DisconnectResponse {};
513 answer_messages_tx_clone
514 .send(ProtoMessage::DisconnectResponse(response_message))
515 .await
516 .unwrap();
517 continue;
518 }
519 ProtoMessage::PingRequest(ping_request) => {
520 debug!("PingRequest: {:?}", ping_request);
521 let response_message = PingResponse {};
522 answer_messages_tx_clone
523 .send(ProtoMessage::PingResponse(response_message))
524 .await
525 .unwrap();
526 }
527 ProtoMessage::DeviceInfoRequest(device_info_request) => {
528 debug!("DeviceInfoRequest: {:?}", device_info_request);
529 answer_messages_tx_clone
530 .send(ProtoMessage::DeviceInfoResponse(device_info.clone()))
531 .await
532 .unwrap();
533 }
534 ProtoMessage::HelloRequest(hello_request) => {
535 debug!("HelloRequest: {:?}", hello_request);
536
537 answer_messages_tx_clone
538 .send(ProtoMessage::HelloResponse(hello_response.clone()))
539 .await
540 .unwrap();
541 }
542 ProtoMessage::AuthenticationRequest(authentication_request) => {
543 debug!("AuthenticationRequest: {:?}", authentication_request);
544
545 if authentication_request.password != "" {
546 info!("Password Authentication is not supported");
547 } else {
548 let response_message = AuthenticationResponse {
549 invalid_password: false,
550 };
551 answer_messages_tx_clone
552 .send(ProtoMessage::AuthenticationResponse(response_message))
553 .await
554 .unwrap();
555 }
556 }
557 message => {
558 outgoing_messages_tx.send(message.clone()).unwrap();
559 }
560 }
561 }
562 });
563
564 Ok((answer_messages_tx.clone(), outgoing_messages_rx))
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
572
573 #[test]
574 fn test_basic_server_instantiation() {
575 EspHomeApi::builder()
576 .name("test_device".to_string())
577 .build();
578 }
579}