esphome_native_api/
esphomeapi.rs

1use base64::prelude::*;
2use futures::sink::SinkExt;
3use log::debug;
4use log::error;
5use log::info;
6use log::trace;
7use noise_protocol::CipherState;
8use noise_protocol::ErrorKind;
9use noise_protocol::HandshakeState;
10use noise_protocol::patterns::noise_nn_psk0;
11use noise_rust_crypto::ChaCha20Poly1305;
12use noise_rust_crypto::Sha256;
13use noise_rust_crypto::X25519;
14use std::sync::Arc;
15use std::sync::atomic::AtomicBool;
16use tokio::io::AsyncWriteExt;
17use tokio::net::TcpStream;
18use tokio::net::tcp::OwnedWriteHalf;
19use tokio::sync::Mutex;
20use tokio::sync::broadcast;
21use tokio::sync::mpsc;
22use tokio::sync::oneshot;
23use tokio_stream::StreamExt;
24use tokio_util::codec::FramedRead;
25use tokio_util::codec::FramedWrite;
26use typed_builder::TypedBuilder;
27
28use crate::frame::FrameCodec;
29use crate::packet_encrypted;
30use crate::packet_plaintext;
31use crate::parser::ProtoMessage;
32use crate::proto::{self, DeviceInfoResponse, DisconnectResponse, HelloResponse, PingResponse};
33
34async fn write_error_and_disconnect(
35    mut writer: FramedWrite<OwnedWriteHalf, FrameCodec>,
36    message: &str,
37) {
38    error!("API Failure: {}. Disconnecting.", message);
39    let packet = [[1].to_vec(), message.as_bytes().to_vec()].concat();
40    writer.send(packet).await.unwrap();
41    writer.flush().await.unwrap();
42    let mut tcp_write = writer.into_inner();
43    if let Err(err) = tcp_write.shutdown().await {
44        error!("failed to shutdown socket: {:?}", err);
45    }
46}
47
48const ERROR_ONLY_ENCRYPTED: &str = "Only key encryption is enabled";
49const ERROR_HANDSHAKE_MAC_FAILURE: &str = "Handshake MAC failure";
50
51#[derive(TypedBuilder, Clone)]
52pub struct EspHomeApi {
53    // Private fields
54    #[builder(default=Arc::new(AtomicBool::new(false)))]
55    pub(crate) first_message_received: Arc<AtomicBool>,
56
57    #[builder(default=Arc::new(AtomicBool::new(true)))]
58    pub(crate) plaintext_communication: Arc<AtomicBool>,
59
60    #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
61    pub(crate) encrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
62    #[builder(default=Arc::new(Mutex::new(None)), setter(skip))]
63    pub(crate) decrypt_cypher: Arc<Mutex<Option<CipherState<ChaCha20Poly1305>>>>,
64
65    name: String,
66
67    #[builder(default = None, setter(strip_option(fallback=encryption_key_opt)))]
68    encryption_key: Option<String>,
69
70    #[builder(default = 1)]
71    api_version_major: u32,
72    #[builder(default = 10)]
73    api_version_minor: u32,
74    #[builder(default="Rust: esphome-native-api".to_string())]
75    server_info: String,
76
77    #[builder(default = None, setter(strip_option(fallback=friendly_name_opt)))]
78    friendly_name: Option<String>,
79
80    #[builder(default = None, setter(strip_option(fallback=mac_opt)))]
81    mac: Option<String>,
82
83    #[builder(default = None, setter(strip_option(fallback=model_opt)))]
84    model: Option<String>,
85
86    #[builder(default = None, setter(strip_option(fallback=manufacturer_opt)))]
87    manufacturer: Option<String>,
88    #[builder(default = None, setter(strip_option(fallback=suggested_area_opt)))]
89    suggested_area: Option<String>,
90    #[builder(default = None, setter(strip_option(fallback=bluetooth_mac_address_opt)))]
91    bluetooth_mac_address: Option<String>,
92
93    #[builder(default = None, setter(strip_option(fallback=project_name_opt)))]
94    project_name: Option<String>,
95
96    #[builder(default = None, setter(strip_option(fallback=project_version_opt)))]
97    project_version: Option<String>,
98    #[builder(default = None, setter(strip_option(fallback=compilation_time_opt)))]
99    compilation_time: Option<String>,
100
101    #[builder(default = 0)]
102    legacy_bluetooth_proxy_version: u32,
103    #[builder(default = 0)]
104    bluetooth_proxy_feature_flags: u32,
105    #[builder(default = 0)]
106    legacy_voice_assistant_version: u32,
107    #[builder(default = 0)]
108    voice_assistant_feature_flags: u32,
109}
110
111/// Handles the EspHome Api, with encryption etc.
112impl EspHomeApi {
113    /// Starts the server and returns a broadcast channel for messages, and a
114    /// broadcast receiver for all messages not handled by the abstraction
115    pub async fn start(
116        &mut self,
117        tcp_stream: TcpStream,
118    ) -> Result<
119        (
120            mpsc::Sender<ProtoMessage>,
121            broadcast::Receiver<ProtoMessage>,
122        ),
123        Box<dyn std::error::Error>,
124    > {
125        // Channel for messages
126        let (answer_messages_tx, mut answer_messages_rx) = mpsc::channel::<ProtoMessage>(16);
127        let (outgoing_messages_tx, outgoing_messages_rx) = broadcast::channel::<ProtoMessage>(16);
128
129        #[allow(deprecated)]
130        let device_info = DeviceInfoResponse {
131            api_encryption_supported: self.encryption_key.is_some(),
132            uses_password: false,
133            name: self.name.clone(),
134            mac_address: self.mac.clone().unwrap_or_default(),
135            esphome_version: proto::VERSION.to_owned(),
136            compilation_time: self.compilation_time.clone().unwrap_or_default(),
137            model: self.model.clone().unwrap_or_default(),
138            has_deep_sleep: false,
139            project_name: self.project_name.clone().unwrap_or_default(),
140            project_version: self.project_version.clone().unwrap_or_default(),
141            webserver_port: 0,
142            // See https://github.com/esphome/aioesphomeapi/blob/c1fee2f4eaff84d13ca71996bb272c28b82314fc/aioesphomeapi/model.py#L154
143            legacy_bluetooth_proxy_version: self.legacy_bluetooth_proxy_version,
144            bluetooth_proxy_feature_flags: self.bluetooth_proxy_feature_flags,
145            manufacturer: self.manufacturer.clone().unwrap_or_default(),
146            friendly_name: self.friendly_name.clone().unwrap_or(self.name.clone()),
147            legacy_voice_assistant_version: self.legacy_voice_assistant_version,
148            voice_assistant_feature_flags: self.voice_assistant_feature_flags,
149            suggested_area: self.suggested_area.clone().unwrap_or_default(),
150            bluetooth_mac_address: self.bluetooth_mac_address.clone().unwrap_or_default(),
151            areas: vec![],
152            devices: vec![],
153            area: None,
154            zwave_proxy_feature_flags: 0,
155            zwave_home_id: 0,
156        };
157
158        let hello_response = HelloResponse {
159            api_version_major: self.api_version_major,
160            api_version_minor: self.api_version_minor,
161            server_info: self.server_info.clone(),
162            name: self.name.clone(),
163        };
164
165        let encrypt_cypher_clone = self.encrypt_cypher.clone();
166        let decrypt_cypher_clone = self.decrypt_cypher.clone();
167
168        // Stage 1: Initialization
169        trace!("Init Connection: Stage 1");
170        let encryption_key = self.encryption_key.clone();
171
172        let mut buf = vec![0; 1];
173        let n = tcp_stream
174            .peek(&mut buf)
175            .await
176            .expect("failed to read data from socket");
177
178        if n == 0 {
179            return Err("No data".into());
180        }
181
182        trace!("TCP Peeked: {:02X?}", &buf[0..n]);
183
184        let preamble = buf[0] as usize;
185
186        let first_message_received = self
187            .first_message_received
188            .load(std::sync::atomic::Ordering::Relaxed);
189
190        if !first_message_received {
191            match preamble {
192                0 => {
193                    debug!("Cleartext messaging");
194
195                    self.plaintext_communication
196                        .store(true, std::sync::atomic::Ordering::Relaxed);
197                }
198                1 => {
199                    trace!("Encrypted messaging");
200
201                    self.plaintext_communication
202                        .store(false, std::sync::atomic::Ordering::Relaxed);
203                }
204                _ => {
205                    return Err(format!("Invalid marker byte {}", preamble).into());
206                }
207            }
208            self.first_message_received
209                .store(true, std::sync::atomic::Ordering::Relaxed);
210        }
211
212        let plaintext_communication = self
213            .plaintext_communication
214            .load(std::sync::atomic::Ordering::Relaxed);
215        let encrypted = !plaintext_communication;
216
217        let (tcp_read, tcp_write) = tcp_stream.into_split();
218        let decoder = FrameCodec::new(encrypted);
219        let encoder = FrameCodec::new(encrypted);
220        let mut reader = FramedRead::new(tcp_read, decoder);
221        let mut writer = FramedWrite::new(tcp_write, encoder);
222
223        if plaintext_communication {
224            if self.encryption_key.is_some() {
225                let encoder = FrameCodec::new(true);
226                let writer = FramedWrite::new(writer.into_inner(), encoder);
227                write_error_and_disconnect(writer, ERROR_ONLY_ENCRYPTED).await;
228                return Err(ERROR_ONLY_ENCRYPTED.into());
229            }
230        } else {
231            if self.encryption_key.is_none() {
232                write_error_and_disconnect(writer, "No encrypted communication allowed").await;
233                return Err("No encryption key set, but encrypted communication requested.".into());
234            }
235
236            let frame_noise_hello = reader.next().await.unwrap().unwrap();
237            debug!("Frame 1: {:02X?}", &frame_noise_hello);
238
239            let message_server_hello =
240                packet_encrypted::generate_server_hello_frame(self.name.clone(), self.mac.clone());
241
242            writer.send(message_server_hello.clone()).await.unwrap();
243            writer.flush().await.unwrap();
244
245            let frame_handshake_request = reader.next().await.unwrap().unwrap();
246            debug!("Frame 2: {:02X?}", &frame_handshake_request);
247
248            // Similar to https://github.com/esphome/aioesphomeapi/blob/60bcd1698dd622aeac6f4b5ec448bab0e3467c4f/aioesphomeapi/_frame_helper/noise.py#L248C17-L255
249            let mut handshake_state: HandshakeState<X25519, ChaCha20Poly1305, Sha256> =
250                HandshakeState::new(
251                    noise_nn_psk0(),
252                    false,
253                    // NEXT: This is somehow set from the first api message?
254                    b"NoiseAPIInit\0\0",
255                    None,
256                    None,
257                    None,
258                    None,
259                );
260
261            let noise_psk = BASE64_STANDARD
262                .decode(encryption_key.as_ref().unwrap())
263                .unwrap();
264
265            handshake_state.push_psk(&noise_psk);
266            // Ignore message type byte
267            match handshake_state.read_message_vec(&frame_handshake_request[1..]) {
268                Ok(_) => {}
269                Err(e) => match e.kind() {
270                    ErrorKind::Decryption => {
271                        write_error_and_disconnect(writer, ERROR_HANDSHAKE_MAC_FAILURE).await;
272                        return Err(ERROR_HANDSHAKE_MAC_FAILURE.into());
273                    }
274                    _ => {
275                        debug!("Failed to read message: {}", e);
276                    }
277                },
278            }
279
280            let out = handshake_state.write_message_vec(b"").unwrap();
281            {
282                let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
283                let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
284                let (decrypt_cipher, encrypt_cipher) = handshake_state.get_ciphers();
285                *encrypt_cipher_changer = Some(encrypt_cipher);
286                *decrypt_cipher_changer = Some(decrypt_cipher);
287            }
288
289            let mut message_handshake = vec![0];
290            message_handshake.extend(out);
291
292            debug!("Sending handshake");
293            writer.send(message_handshake.clone()).await.unwrap();
294            writer.flush().await.unwrap();
295        }
296
297        debug!("Initialization done.");
298
299        // Asynchronously wait for an inbound socket.
300        let (cancellation_write_tx, mut cancellation_write_rx) = oneshot::channel();
301
302        // Write Loop
303        let plaintext_communication = self.plaintext_communication.clone();
304        tokio::spawn(async move {
305            loop {
306                let answer_message: ProtoMessage;
307
308                // Wait for any new message
309                tokio::select! {
310                    biased; // Poll cancellation_write_rx first
311                    cancel_message = &mut cancellation_write_rx => {
312                        debug!("Write loop received cancellation signal ({}), exiting.", cancel_message.unwrap());
313                        break;
314                    }
315                    message = answer_messages_rx.recv() => {
316                        answer_message = message.unwrap();
317                    }
318                };
319
320                debug!("Answer message: {:?}", answer_message);
321
322                if plaintext_communication.load(std::sync::atomic::Ordering::Relaxed) {
323                    writer
324                        .send(packet_plaintext::message_to_packet(&answer_message).unwrap())
325                        .await
326                        .unwrap();
327                    // answer_buf =
328                    //     [answer_buf, to_unencrypted_frame(&answer_message).unwrap()].concat();
329                } else {
330                    // Use normal messaging
331                    let mut encrypt_cipher_changer = encrypt_cypher_clone.lock().await;
332                    writer
333                        .send(
334                            packet_encrypted::message_to_packet(
335                                &answer_message,
336                                &mut *encrypt_cipher_changer.as_mut().unwrap(),
337                            )
338                            .unwrap(),
339                        )
340                        .await
341                        .unwrap();
342                }
343                writer.flush().await.unwrap();
344
345                if matches!(answer_message, ProtoMessage::DisconnectResponse(_)) {
346                    debug!("Disconnecting");
347                    let mut tcp_write = writer.into_inner();
348                    match tcp_write.shutdown().await {
349                        Err(err) => {
350                            error!("failed to shutdown socket: {:?}", err);
351                            break;
352                        }
353                        _ => break,
354                    }
355                }
356            }
357        });
358
359        // Clone all necessary data before spawning the task
360        let answer_messages_tx_clone = answer_messages_tx.clone();
361        let decrypt_cypher_clone = self.decrypt_cypher.clone();
362        // Read Loop
363        tokio::spawn(async move {
364            loop {
365                let next = reader.next().await;
366                if next.is_none() {
367                    info!("Read loop stopped because stream finished");
368                    // If sending fails, the write loop is probably already closed
369                    let _ = cancellation_write_tx.send("read loop finished");
370                    break;
371                }
372                let frame = next.unwrap().unwrap();
373                trace!("TCP Receive: {:02X?}", &frame);
374
375                let message;
376                if encrypted {
377                    let mut decrypt_cipher_changer = decrypt_cypher_clone.lock().await;
378                    message = packet_encrypted::packet_to_message(
379                        &frame,
380                        &mut *decrypt_cipher_changer.as_mut().unwrap(),
381                    )
382                    .unwrap();
383                } else {
384                    message = packet_plaintext::packet_to_message(&frame).unwrap();
385                }
386
387                // Authenticated Messages
388                match &message {
389                    ProtoMessage::DisconnectRequest(disconnect_request) => {
390                        debug!("DisconnectRequest: {:?}", disconnect_request);
391                        let response_message = DisconnectResponse {};
392                        answer_messages_tx_clone
393                            .send(ProtoMessage::DisconnectResponse(response_message))
394                            .await
395                            .unwrap();
396                        continue;
397                    }
398                    ProtoMessage::PingRequest(ping_request) => {
399                        debug!("PingRequest: {:?}", ping_request);
400                        let response_message = PingResponse {};
401                        answer_messages_tx_clone
402                            .send(ProtoMessage::PingResponse(response_message))
403                            .await
404                            .unwrap();
405                    }
406                    ProtoMessage::DeviceInfoRequest(device_info_request) => {
407                        debug!("DeviceInfoRequest: {:?}", device_info_request);
408                        answer_messages_tx_clone
409                            .send(ProtoMessage::DeviceInfoResponse(device_info.clone()))
410                            .await
411                            .unwrap();
412                    }
413                    ProtoMessage::HelloRequest(hello_request) => {
414                        debug!("HelloRequest: {:?}", hello_request);
415
416                        answer_messages_tx_clone
417                            .send(ProtoMessage::HelloResponse(hello_response.clone()))
418                            .await
419                            .unwrap();
420                    }
421                    ProtoMessage::AuthenticationRequest(_) => {
422                        info!("Password Authentication is not supported");
423                    }
424                    message => {
425                        outgoing_messages_tx.send(message.clone()).unwrap();
426                    }
427                }
428            }
429        });
430
431        Ok((answer_messages_tx.clone(), outgoing_messages_rx))
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    // Note this useful idiom: importing names from outer (for mod tests) scope.
438    use super::*;
439
440    #[test]
441    fn test_basic_server_instantiation() {
442        EspHomeApi::builder()
443            .name("test_device".to_string())
444            .build();
445    }
446}