1use base64::prelude::*;
49use futures::sink::SinkExt;
50use log::debug;
51use log::error;
52use log::info;
53use log::trace;
54use noise_protocol::CipherState;
55use noise_protocol::ErrorKind;
56use noise_protocol::HandshakeState;
57use noise_protocol::patterns::noise_nn_psk0;
58use noise_rust_crypto::ChaCha20Poly1305;
59use noise_rust_crypto::Sha256;
60use noise_rust_crypto::X25519;
61use std::sync::Arc;
62use std::sync::atomic::AtomicBool;
63use tokio::io::AsyncWriteExt;
64use tokio::net::TcpStream;
65use tokio::net::tcp::OwnedWriteHalf;
66use tokio::sync::Mutex;
67use tokio::sync::broadcast;
68use tokio::sync::mpsc;
69use tokio::sync::oneshot;
70use tokio_stream::StreamExt;
71use tokio_util::codec::FramedRead;
72use tokio_util::codec::FramedWrite;
73use typed_builder::TypedBuilder;
74
75use crate::frame::FrameCodec;
76use crate::packet_encrypted;
77use crate::packet_plaintext;
78use crate::parser::ProtoMessage;
79use crate::proto::{self, DeviceInfoResponse, DisconnectResponse, HelloResponse, PingResponse};
80
81async fn write_error_and_disconnect(
82 mut writer: FramedWrite<OwnedWriteHalf, FrameCodec>,
83 message: &str,
84) {
85 error!("API Failure: {}. Disconnecting.", message);
86 let packet = [[1].to_vec(), message.as_bytes().to_vec()].concat();
87 writer.send(packet).await.unwrap();
88 writer.flush().await.unwrap();
89 let mut tcp_write = writer.into_inner();
90 if let Err(err) = tcp_write.shutdown().await {
91 error!("failed to shutdown socket: {:?}", err);
92 }
93}
94
95const ERROR_ONLY_ENCRYPTED: &str = "Only key encryption is enabled";
96const ERROR_HANDSHAKE_MAC_FAILURE: &str = "Handshake MAC failure";
97
98#[derive(TypedBuilder, Clone)]
134pub struct EspHomeApi {
135 #[builder(default=Arc::new(AtomicBool::new(false)))]
137 pub(crate) first_message_received: Arc<AtomicBool>,
138
139 #[builder(default=Arc::new(AtomicBool::new(true)))]
140 pub(crate) plaintext_communication: Arc<AtomicBool>,
141
142 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
143 pub(crate) encrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
144 #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
145 pub(crate) decrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
146
147 name: String,
148
149 #[builder(default = None, setter(strip_option(fallback=encryption_key_opt)))]
150 encryption_key: Option<String>,
151
152 #[builder(default = 1)]
153 api_version_major: u32,
154 #[builder(default = 10)]
155 api_version_minor: u32,
156 #[builder(default="Rust: esphome-native-api".to_string())]
157 server_info: String,
158
159 #[builder(default = None, setter(strip_option(fallback=friendly_name_opt)))]
160 friendly_name: Option<String>,
161
162 #[builder(default = None, setter(strip_option(fallback=mac_opt)))]
163 mac: Option<String>,
164
165 #[builder(default = None, setter(strip_option(fallback=model_opt)))]
166 model: Option<String>,
167
168 #[builder(default = None, setter(strip_option(fallback=manufacturer_opt)))]
169 manufacturer: Option<String>,
170 #[builder(default = None, setter(strip_option(fallback=suggested_area_opt)))]
171 suggested_area: Option<String>,
172 #[builder(default = None, setter(strip_option(fallback=bluetooth_mac_address_opt)))]
173 bluetooth_mac_address: Option<String>,
174
175 #[builder(default = None, setter(strip_option(fallback=project_name_opt)))]
176 project_name: Option<String>,
177
178 #[builder(default = None, setter(strip_option(fallback=project_version_opt)))]
179 project_version: Option<String>,
180 #[builder(default = None, setter(strip_option(fallback=compilation_time_opt)))]
181 compilation_time: Option<String>,
182
183 #[builder(default = 0)]
184 legacy_bluetooth_proxy_version: u32,
185 #[builder(default = 0)]
186 bluetooth_proxy_feature_flags: u32,
187 #[builder(default = 0)]
188 legacy_voice_assistant_version: u32,
189 #[builder(default = 0)]
190 voice_assistant_feature_flags: u32,
191}
192
193impl EspHomeApi {
195 pub async fn start(
234 &mut self,
235 tcp_stream: TcpStream,
236 ) -> Result<
237 (
238 mpsc::Sender<ProtoMessage>,
239 broadcast::Receiver<ProtoMessage>,
240 ),
241 Box<dyn std::error::Error>,
242 > {
243 let (answer_messages_tx, mut answer_messages_rx) = mpsc::channel::<ProtoMessage>(16);
245 let (outgoing_messages_tx, outgoing_messages_rx) = broadcast::channel::<ProtoMessage>(16);
246
247 #[allow(deprecated)]
248 let device_info = DeviceInfoResponse {
249 api_encryption_supported: self.encryption_key.is_some(),
250 uses_password: false,
251 name: self.name.clone(),
252 mac_address: self.mac.clone().unwrap_or_default(),
253 esphome_version: proto::VERSION.to_owned(),
254 compilation_time: self.compilation_time.clone().unwrap_or_default(),
255 model: self.model.clone().unwrap_or_default(),
256 has_deep_sleep: false,
257 project_name: self.project_name.clone().unwrap_or_default(),
258 project_version: self.project_version.clone().unwrap_or_default(),
259 webserver_port: 0,
260 legacy_bluetooth_proxy_version: self.legacy_bluetooth_proxy_version,
262 bluetooth_proxy_feature_flags: self.bluetooth_proxy_feature_flags,
263 manufacturer: self.manufacturer.clone().unwrap_or_default(),
264 friendly_name: self.friendly_name.clone().unwrap_or(self.name.clone()),
265 legacy_voice_assistant_version: self.legacy_voice_assistant_version,
266 voice_assistant_feature_flags: self.voice_assistant_feature_flags,
267 suggested_area: self.suggested_area.clone().unwrap_or_default(),
268 bluetooth_mac_address: self.bluetooth_mac_address.clone().unwrap_or_default(),
269 areas: vec![],
270 devices: vec![],
271 area: None,
272 zwave_proxy_feature_flags: 0,
273 zwave_home_id: 0,
274 };
275
276 let hello_response = HelloResponse {
277 api_version_major: self.api_version_major,
278 api_version_minor: self.api_version_minor,
279 server_info: self.server_info.clone(),
280 name: self.name.clone(),
281 };
282
283 let encrypt_cypher_clone = self.encrypt_cypher.clone();
284 let decrypt_cypher_clone = self.decrypt_cypher.clone();
285
286 trace!("Init Connection: Stage 1");
288 let encryption_key = self.encryption_key.clone();
289
290 let mut buf = vec![0; 1];
291 let n = tcp_stream
292 .peek(&mut buf)
293 .await
294 .expect("failed to read data from socket");
295
296 if n == 0 {
297 return Err("No data".into());
298 }
299
300 trace!("TCP Peeked: {:02X?}", &buf[0..n]);
301
302 let preamble = buf[0] as usize;
303
304 let first_message_received = self
305 .first_message_received
306 .load(std::sync::atomic::Ordering::Relaxed);
307
308 if !first_message_received {
309 match preamble {
310 0 => {
311 debug!("Cleartext messaging");
312
313 self.plaintext_communication
314 .store(true, std::sync::atomic::Ordering::Relaxed);
315 }
316 1 => {
317 trace!("Encrypted messaging");
318
319 self.plaintext_communication
320 .store(false, std::sync::atomic::Ordering::Relaxed);
321 }
322 _ => {
323 return Err(format!("Invalid marker byte {}", preamble).into());
324 }
325 }
326 self.first_message_received
327 .store(true, std::sync::atomic::Ordering::Relaxed);
328 }
329
330 let plaintext_communication = self
331 .plaintext_communication
332 .load(std::sync::atomic::Ordering::Relaxed);
333 let encrypted = !plaintext_communication;
334
335 let (tcp_read, tcp_write) = tcp_stream.into_split();
336 let decoder = FrameCodec::new(encrypted);
337 let encoder = FrameCodec::new(encrypted);
338 let mut reader = FramedRead::new(tcp_read, decoder);
339 let mut writer = FramedWrite::new(tcp_write, encoder);
340
341 if plaintext_communication {
342 if self.encryption_key.is_some() {
343 let encoder = FrameCodec::new(true);
344 let writer = FramedWrite::new(writer.into_inner(), encoder);
345 write_error_and_disconnect(writer, ERROR_ONLY_ENCRYPTED).await;
346 return Err(ERROR_ONLY_ENCRYPTED.into());
347 }
348 } else {
349 if self.encryption_key.is_none() {
350 write_error_and_disconnect(writer, "No encrypted communication allowed").await;
351 return Err("No encryption key set, but encrypted communication requested.".into());
352 }
353
354 let frame_noise_hello = reader.next().await.unwrap().unwrap();
355 debug!("Frame 1: {:02X?}", &frame_noise_hello);
356
357 let message_server_hello =
358 packet_encrypted::generate_server_hello_frame(self.name.clone(), self.mac.clone());
359
360 writer.send(message_server_hello.clone()).await.unwrap();
361 writer.flush().await.unwrap();
362
363 let frame_handshake_request = reader.next().await.unwrap().unwrap();
364 debug!("Frame 2: {:02X?}", &frame_handshake_request);
365
366 let mut handshake_state: HandshakeState<X25519, ChaCha20Poly1305, Sha256> =
368 HandshakeState::new(
369 noise_nn_psk0(),
370 false,
371 b"NoiseAPIInit\0\0",
373 None,
374 None,
375 None,
376 None,
377 );
378
379 let noise_psk = BASE64_STANDARD
380 .decode(encryption_key.as_ref().unwrap())
381 .unwrap();
382
383 handshake_state.push_psk(&noise_psk);
384 match handshake_state.read_message_vec(&frame_handshake_request[1..]) {
386 Ok(_) => {}
387 Err(e) => match e.kind() {
388 ErrorKind::Decryption => {
389 write_error_and_disconnect(writer, ERROR_HANDSHAKE_MAC_FAILURE).await;
390 return Err(ERROR_HANDSHAKE_MAC_FAILURE.into());
391 }
392 _ => {
393 debug!("Failed to read message: {}", e);
394 }
395 },
396 }
397
398 let out = handshake_state.write_message_vec(b"").unwrap();
399 {
400 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
401 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
402 let (decrypt_cipher, encrypt_cipher) = handshake_state.get_ciphers();
403 *encrypt_cipher_changer = Some(encrypt_cipher);
404 *decrypt_cipher_changer = Some(decrypt_cipher);
405 }
406
407 let mut message_handshake = vec![0];
408 message_handshake.extend(out);
409
410 debug!("Sending handshake");
411 writer.send(message_handshake.clone()).await.unwrap();
412 writer.flush().await.unwrap();
413 }
414
415 debug!("Initialization done.");
416
417 let (cancellation_write_tx, mut cancellation_write_rx) = oneshot::channel();
419
420 let plaintext_communication = self.plaintext_communication.clone();
422 tokio::spawn(async move {
423 loop {
424 let answer_message: ProtoMessage;
425
426 tokio::select! {
428 biased; cancel_message = &mut cancellation_write_rx => {
430 debug!("Write loop received cancellation signal ({}), exiting.", cancel_message.unwrap());
431 break;
432 }
433 message = answer_messages_rx.recv() => {
434 answer_message = message.unwrap();
435 }
436 };
437
438 debug!("Answer message: {:?}", answer_message);
439
440 if plaintext_communication.load(std::sync::atomic::Ordering::Relaxed) {
441 writer
442 .send(packet_plaintext::message_to_packet(&answer_message).unwrap())
443 .await
444 .unwrap();
445 } else {
448 let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
450 writer
451 .send(
452 packet_encrypted::message_to_packet(
453 &answer_message,
454 &mut *encrypt_cipher_changer.as_mut().unwrap(),
455 )
456 .unwrap(),
457 )
458 .await
459 .unwrap();
460 }
461 writer.flush().await.unwrap();
462
463 if matches!(answer_message, ProtoMessage::DisconnectResponse(_)) {
464 debug!("Disconnecting");
465 let mut tcp_write = writer.into_inner();
466 match tcp_write.shutdown().await {
467 Err(err) => {
468 error!("failed to shutdown socket: {:?}", err);
469 break;
470 }
471 _ => break,
472 }
473 }
474 }
475 });
476
477 let answer_messages_tx_clone = answer_messages_tx.clone();
479 let decrypt_cypher_clone = self.decrypt_cypher.clone();
480 tokio::spawn(async move {
482 loop {
483 let next = reader.next().await;
484 if next.is_none() {
485 info!("Read loop stopped because stream finished");
486 let _ = cancellation_write_tx.send("read loop finished");
488 break;
489 }
490 let frame = next.unwrap().unwrap();
491 trace!("TCP Receive: {:02X?}", &frame);
492
493 let message;
494 if encrypted {
495 let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
496 message = packet_encrypted::packet_to_message(
497 &frame,
498 &mut *decrypt_cipher_changer.as_mut().unwrap(),
499 )
500 .unwrap();
501 } else {
502 message = packet_plaintext::packet_to_message(&frame).unwrap();
503 }
504
505 match &message {
507 ProtoMessage::DisconnectRequest(disconnect_request) => {
508 debug!("DisconnectRequest: {:?}", disconnect_request);
509 let response_message = DisconnectResponse {};
510 answer_messages_tx_clone
511 .send(ProtoMessage::DisconnectResponse(response_message))
512 .await
513 .unwrap();
514 continue;
515 }
516 ProtoMessage::PingRequest(ping_request) => {
517 debug!("PingRequest: {:?}", ping_request);
518 let response_message = PingResponse {};
519 answer_messages_tx_clone
520 .send(ProtoMessage::PingResponse(response_message))
521 .await
522 .unwrap();
523 }
524 ProtoMessage::DeviceInfoRequest(device_info_request) => {
525 debug!("DeviceInfoRequest: {:?}", device_info_request);
526 answer_messages_tx_clone
527 .send(ProtoMessage::DeviceInfoResponse(device_info.clone()))
528 .await
529 .unwrap();
530 }
531 ProtoMessage::HelloRequest(hello_request) => {
532 debug!("HelloRequest: {:?}", hello_request);
533
534 answer_messages_tx_clone
535 .send(ProtoMessage::HelloResponse(hello_response.clone()))
536 .await
537 .unwrap();
538 }
539 ProtoMessage::AuthenticationRequest(_) => {
540 info!("Password Authentication is not supported");
541 }
542 message => {
543 outgoing_messages_tx.send(message.clone()).unwrap();
544 }
545 }
546 }
547 });
548
549 Ok((answer_messages_tx.clone(), outgoing_messages_rx))
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
557
558 #[test]
559 fn test_basic_server_instantiation() {
560 EspHomeApi::builder()
561 .name("test_device".to_string())
562 .build();
563 }
564}