esphome_native_api/
esphomeapi.rs

1use crate::frame::FrameCodec;
2use crate::packet_encrypted;
3use crate::packet_plaintext;
4use crate::parser::ProtoMessage;
5use crate::proto::version_2025_12_1::DeviceInfoResponse;
6use crate::proto::version_2025_12_1::DisconnectResponse;
7use crate::proto::version_2025_12_1::HelloResponse;
8use crate::proto::version_2025_12_1::PingResponse;
9use base64::prelude::*;
10use futures::sink::SinkExt;
11use log::debug;
12use log::error;
13use log::info;
14use log::trace;
15use noise_protocol::CipherState;
16use noise_protocol::ErrorKind;
17use noise_protocol::HandshakeState;
18use noise_protocol::patterns::noise_nn_psk0;
19use noise_rust_crypto::ChaCha20Poly1305;
20use noise_rust_crypto::Sha256;
21use noise_rust_crypto::X25519;
22use std::sync::Arc;
23use std::sync::atomic::AtomicBool;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26use tokio::net::tcp::OwnedWriteHalf;
27use tokio::sync::Mutex;
28use tokio::sync::broadcast;
29use tokio::sync::mpsc;
30use tokio::sync::oneshot;
31use tokio_stream::StreamExt;
32use tokio_util::codec::FramedRead;
33use tokio_util::codec::FramedWrite;
34use typed_builder::TypedBuilder;
35
36async fn write_error_and_disconnect(
37    mut writer: FramedWrite<OwnedWriteHalf, FrameCodec>,
38    message: &str,
39) {
40    error!("API Failure: {}. Disconnecting.", message);
41    let packet = [[1].to_vec(), message.as_bytes().to_vec()].concat();
42    writer.send(packet).await.unwrap();
43    writer.flush().await.unwrap();
44    let mut tcp_write = writer.into_inner();
45    if let Err(err) = tcp_write.shutdown().await {
46        error!("failed to shutdown socket: {:?}", err);
47    }
48}
49
50const ERROR_ONLY_ENCRYPTED: &str = "Only key encryption is enabled";
51const ERROR_HANDSHAKE_MAC_FAILURE: &str = "Handshake MAC failure";
52
53#[derive(TypedBuilder, Clone)]
54pub struct EspHomeApi {
55    // Private fields
56    #[builder(default=Arc::new(AtomicBool::new(false)))]
57    pub(crate) first_message_received: Arc<AtomicBool>,
58
59    #[builder(default=Arc::new(AtomicBool::new(true)))]
60    pub(crate) plaintext_communication: Arc<AtomicBool>,
61
62    #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
63    pub(crate) encrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
64    #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
65    pub(crate) decrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
66
67    name: String,
68
69    #[builder(default = None, setter(strip_option(fallback=encryption_key_opt)))]
70    encryption_key: Option<String>,
71
72    #[builder(default = 1)]
73    api_version_major: u32,
74    #[builder(default = 10)]
75    api_version_minor: u32,
76    #[builder(default="Rust: esphome-native-api".to_string())]
77    server_info: String,
78
79    #[builder(default = None, setter(strip_option(fallback=friendly_name_opt)))]
80    friendly_name: Option<String>,
81
82    #[builder(default = None, setter(strip_option(fallback=mac_opt)))]
83    mac: Option<String>,
84
85    #[builder(default = None, setter(strip_option(fallback=model_opt)))]
86    model: Option<String>,
87
88    #[builder(default = None, setter(strip_option(fallback=manufacturer_opt)))]
89    manufacturer: Option<String>,
90    #[builder(default = None, setter(strip_option(fallback=suggested_area_opt)))]
91    suggested_area: Option<String>,
92    #[builder(default = None, setter(strip_option(fallback=bluetooth_mac_address_opt)))]
93    bluetooth_mac_address: Option<String>,
94
95    #[builder(default = None, setter(strip_option(fallback=project_name_opt)))]
96    project_name: Option<String>,
97
98    #[builder(default = None, setter(strip_option(fallback=project_version_opt)))]
99    project_version: Option<String>,
100    #[builder(default = None, setter(strip_option(fallback=compilation_time_opt)))]
101    compilation_time: Option<String>,
102
103    #[builder(default = 0)]
104    legacy_bluetooth_proxy_version: u32,
105    #[builder(default = 0)]
106    bluetooth_proxy_feature_flags: u32,
107    #[builder(default = 0)]
108    legacy_voice_assistant_version: u32,
109    #[builder(default = 0)]
110    voice_assistant_feature_flags: u32,
111
112    #[builder(default = "2025.4.0".to_string())]
113    esphome_version: String,
114}
115
116/// Handles the EspHome Api, with encryption etc.
117impl EspHomeApi {
118    /// Starts the server and returns a broadcast channel for messages, and a
119    /// broadcast receiver for all messages not handled by the abstraction
120    pub async fn start(
121        &mut self,
122        tcp_stream: TcpStream,
123    ) -> Result<
124        (
125            mpsc::Sender<ProtoMessage>,
126            broadcast::Receiver<ProtoMessage>,
127        ),
128        Box<dyn std::error::Error>,
129    > {
130        // Channel for messages
131        let (answer_messages_tx, mut answer_messages_rx) = mpsc::channel::<ProtoMessage>(16);
132        let (outgoing_messages_tx, outgoing_messages_rx) = broadcast::channel::<ProtoMessage>(16);
133
134        #[allow(deprecated)]
135        let device_info = DeviceInfoResponse {
136            api_encryption_supported: self.encryption_key.is_some(),
137            uses_password: false,
138            name: self.name.clone(),
139            mac_address: self.mac.clone().unwrap_or_default(),
140            esphome_version: self.esphome_version.clone(),
141            compilation_time: self.compilation_time.clone().unwrap_or_default(),
142            model: self.model.clone().unwrap_or_default(),
143            has_deep_sleep: false,
144            project_name: self.project_name.clone().unwrap_or_default(),
145            project_version: self.project_version.clone().unwrap_or_default(),
146            webserver_port: 0,
147            // See https://github.com/esphome/aioesphomeapi/blob/c1fee2f4eaff84d13ca71996bb272c28b82314fc/aioesphomeapi/model.py#L154
148            legacy_bluetooth_proxy_version: self.legacy_bluetooth_proxy_version,
149            bluetooth_proxy_feature_flags: self.bluetooth_proxy_feature_flags,
150            manufacturer: self.manufacturer.clone().unwrap_or_default(),
151            friendly_name: self.friendly_name.clone().unwrap_or(self.name.clone()),
152            legacy_voice_assistant_version: self.legacy_voice_assistant_version,
153            voice_assistant_feature_flags: self.voice_assistant_feature_flags,
154            suggested_area: self.suggested_area.clone().unwrap_or_default(),
155            bluetooth_mac_address: self.bluetooth_mac_address.clone().unwrap_or_default(),
156            areas: vec![],
157            devices: vec![],
158            area: None,
159            zwave_proxy_feature_flags: 0,
160            zwave_home_id: 0,
161        };
162
163        let hello_response = HelloResponse {
164            api_version_major: self.api_version_major,
165            api_version_minor: self.api_version_minor,
166            server_info: self.server_info.clone(),
167            name: self.name.clone(),
168        };
169
170        let encrypt_cypher_clone = self.encrypt_cypher.clone();
171        let decrypt_cypher_clone = self.decrypt_cypher.clone();
172
173        // Stage 1: Initialization
174        trace!("Init Connection: Stage 1");
175        let encryption_key = self.encryption_key.clone();
176
177        let mut buf = vec![0; 1];
178        let n = tcp_stream
179            .peek(&mut buf)
180            .await
181            .expect("failed to read data from socket");
182
183        if n == 0 {
184            return Err("No data".into());
185        }
186
187        trace!("TCP Peeked: {:02X?}", &buf[0..n]);
188
189        let preamble = buf[0] as usize;
190
191        let first_message_received = self
192            .first_message_received
193            .load(std::sync::atomic::Ordering::Relaxed);
194
195        if !first_message_received {
196            match preamble {
197                0 => {
198                    debug!("Cleartext messaging");
199
200                    self.plaintext_communication
201                        .store(true, std::sync::atomic::Ordering::Relaxed);
202                }
203                1 => {
204                    trace!("Encrypted messaging");
205
206                    self.plaintext_communication
207                        .store(false, std::sync::atomic::Ordering::Relaxed);
208                }
209                _ => {
210                    return Err(format!("Invalid marker byte {}", preamble).into());
211                }
212            }
213            self.first_message_received
214                .store(true, std::sync::atomic::Ordering::Relaxed);
215        }
216
217        let plaintext_communication = self
218            .plaintext_communication
219            .load(std::sync::atomic::Ordering::Relaxed);
220        let encrypted = !plaintext_communication;
221
222        let (tcp_read, tcp_write) = tcp_stream.into_split();
223        let decoder = FrameCodec::new(encrypted);
224        let encoder = FrameCodec::new(encrypted);
225        let mut reader = FramedRead::new(tcp_read, decoder);
226        let mut writer = FramedWrite::new(tcp_write, encoder);
227
228        if plaintext_communication {
229            if self.encryption_key.is_some() {
230                let encoder = FrameCodec::new(true);
231                let writer = FramedWrite::new(writer.into_inner(), encoder);
232                write_error_and_disconnect(writer, ERROR_ONLY_ENCRYPTED).await;
233                return Err(ERROR_ONLY_ENCRYPTED.into());
234            }
235        } else {
236            if self.encryption_key.is_none() {
237                write_error_and_disconnect(writer, "No encrypted communication allowed").await;
238                return Err("No encryption key set, but encrypted communication requested.".into());
239            }
240
241            let frame_noise_hello = reader.next().await.unwrap().unwrap();
242            trace!("Frame 1: {:02X?}", &frame_noise_hello);
243
244            let message_server_hello =
245                packet_encrypted::generate_server_hello_frame(self.name.clone(), self.mac.clone());
246
247            writer.send(message_server_hello.clone()).await.unwrap();
248            writer.flush().await.unwrap();
249
250            let frame_handshake_request = reader.next().await.unwrap().unwrap();
251            info!("Frame 2: {:02X?}", &frame_handshake_request);
252
253            // Similar to https://github.com/esphome/aioesphomeapi/blob/60bcd1698dd622aeac6f4b5ec448bab0e3467c4f/aioesphomeapi/_frame_helper/noise.py#L248C17-L255
254            let mut handshake_state: HandshakeState<X25519, ChaCha20Poly1305, Sha256> =
255                HandshakeState::new(
256                    noise_nn_psk0(),
257                    false,
258                    // NEXT: This is somehow set from the first api message?
259                    b"NoiseAPIInit\0\0",
260                    None,
261                    None,
262                    None,
263                    None,
264                );
265
266            let noise_psk = BASE64_STANDARD
267                .decode(encryption_key.as_ref().unwrap())
268                .unwrap();
269
270            handshake_state.push_psk(&noise_psk);
271            // Ignore message type byte
272            match handshake_state.read_message_vec(&frame_handshake_request[1..]) {
273                Ok(_) => {}
274                Err(e) => match e.kind() {
275                    ErrorKind::Decryption => {
276                        write_error_and_disconnect(writer, ERROR_HANDSHAKE_MAC_FAILURE).await;
277                        return Err(ERROR_HANDSHAKE_MAC_FAILURE.into());
278                    }
279                    _ => {
280                        debug!("Failed to read message: {}", e);
281                    }
282                },
283            }
284
285            let out = handshake_state.write_message_vec(b"").unwrap();
286            {
287                let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
288                let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
289                let (decrypt_cipher, encrypt_cipher) = handshake_state.get_ciphers();
290                *encrypt_cipher_changer = Some(encrypt_cipher);
291                *decrypt_cipher_changer = Some(decrypt_cipher);
292            }
293
294            let mut message_handshake = vec![0];
295            message_handshake.extend(out);
296
297            debug!("Sending handshake");
298            writer.send(message_handshake.clone()).await.unwrap();
299            writer.flush().await.unwrap();
300        }
301
302        debug!("Initialization done.");
303
304        // Asynchronously wait for an inbound socket.
305        let (cancellation_write_tx, mut cancellation_write_rx) = oneshot::channel();
306
307        // Write Loop
308        let plaintext_communication = self.plaintext_communication.clone();
309        tokio::spawn(async move {
310            loop {
311                let answer_message: ProtoMessage;
312
313                // Wait for any new message
314                tokio::select! {
315                    biased; // Poll cancellation_write_rx first
316                    cancel_message = &mut cancellation_write_rx => {
317                        debug!("Write loop received cancellation signal ({}), exiting.", cancel_message.unwrap());
318                        break;
319                    }
320                    message = answer_messages_rx.recv() => {
321                        answer_message = message.unwrap();
322                    }
323                };
324
325                debug!("Answer message: {:?}", answer_message);
326
327                if plaintext_communication.load(std::sync::atomic::Ordering::Relaxed) {
328                    writer
329                        .send(packet_plaintext::message_to_packet(&answer_message).unwrap())
330                        .await
331                        .unwrap();
332                    // answer_buf =
333                    //     [answer_buf, to_unencrypted_frame(&answer_message).unwrap()].concat();
334                } else {
335                    // Use normal messaging
336                    let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
337                    writer
338                        .send(
339                            packet_encrypted::message_to_packet(
340                                &answer_message,
341                                &mut *encrypt_cipher_changer.as_mut().unwrap(),
342                            )
343                            .unwrap(),
344                        )
345                        .await
346                        .unwrap();
347                }
348                writer.flush().await.unwrap();
349
350                if matches!(answer_message, ProtoMessage::DisconnectResponse(_)) {
351                    debug!("Disconnecting");
352                    let mut tcp_write = writer.into_inner();
353                    match tcp_write.shutdown().await {
354                        Err(err) => {
355                            error!("failed to shutdown socket: {:?}", err);
356                            break;
357                        }
358                        _ => break,
359                    }
360                }
361            }
362        });
363
364        // Clone all necessary data before spawning the task
365        let answer_messages_tx_clone = answer_messages_tx.clone();
366        let decrypt_cypher_clone = self.decrypt_cypher.clone();
367        // Read Loop
368        tokio::spawn(async move {
369            loop {
370                let next = reader.next().await;
371                if next.is_none() {
372                    info!("Read loop stopped because stream finished");
373                    // If sending fails, the write loop is probably already closed
374                    let _ = cancellation_write_tx.send("read loop finished");
375                    break;
376                }
377                let frame = next.unwrap().unwrap();
378                trace!("TCP Receive: {:02X?}", &frame);
379
380                let message;
381                if encrypted {
382                    let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
383                    message = packet_encrypted::packet_to_message(
384                        &frame,
385                        &mut *decrypt_cipher_changer.as_mut().unwrap(),
386                    )
387                    .unwrap();
388                } else {
389                    message = packet_plaintext::packet_to_message(&frame).unwrap();
390                }
391
392                // Authenticated Messages
393                match &message {
394                    ProtoMessage::DisconnectRequest(disconnect_request) => {
395                        debug!("DisconnectRequest: {:?}", disconnect_request);
396                        let response_message = DisconnectResponse {};
397                        answer_messages_tx_clone
398                            .send(ProtoMessage::DisconnectResponse(response_message))
399                            .await
400                            .unwrap();
401                        continue;
402                    }
403                    ProtoMessage::PingRequest(ping_request) => {
404                        debug!("PingRequest: {:?}", ping_request);
405                        let response_message = PingResponse {};
406                        answer_messages_tx_clone
407                            .send(ProtoMessage::PingResponse(response_message))
408                            .await
409                            .unwrap();
410                    }
411                    ProtoMessage::DeviceInfoRequest(device_info_request) => {
412                        debug!("DeviceInfoRequest: {:?}", device_info_request);
413                        answer_messages_tx_clone
414                            .send(ProtoMessage::DeviceInfoResponse(device_info.clone()))
415                            .await
416                            .unwrap();
417                    }
418                    ProtoMessage::HelloRequest(hello_request) => {
419                        debug!("HelloRequest: {:?}", hello_request);
420
421                        answer_messages_tx_clone
422                            .send(ProtoMessage::HelloResponse(hello_response.clone()))
423                            .await
424                            .unwrap();
425                    }
426                    ProtoMessage::AuthenticationRequest(_) => {
427                        info!("Password Authentication is not supported");
428                    }
429                    message => {
430                        outgoing_messages_tx.send(message.clone()).unwrap();
431                    }
432                }
433            }
434        });
435
436        Ok((answer_messages_tx.clone(), outgoing_messages_rx))
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    // Note this useful idiom: importing names from outer (for mod tests) scope.
443    use super::*;
444
445    #[test]
446    fn test_basic_server_instantiation() {
447        EspHomeApi::builder()
448            .name("test_device".to_string())
449            .build();
450    }
451}