1mod noise;
8mod plain;
9
10mod stream_reader;
11mod stream_writer;
12use std::{fmt::Debug, time::Duration};
13
14use stream_reader::StreamReader;
15use stream_writer::StreamWriter;
16use tokio::time::timeout;
17
18use crate::{
19 error::{ClientError, ProtocolError},
20 proto::{ConnectRequest, DisconnectRequest, EspHomeMessage, HelloRequest, PingResponse},
21 API_VERSION,
22};
23
24type StreamPair = (StreamReader, StreamWriter);
25
26#[derive(Debug)]
27pub struct EspHomeClient {
28 streams: StreamPair,
29 handle_ping: bool,
30}
31
32impl EspHomeClient {
33 #[must_use]
35 pub fn builder() -> EspHomeClientBuilder {
36 EspHomeClientBuilder::new()
37 }
38
39 pub async fn try_write<M>(&mut self, message: M) -> Result<(), ClientError>
45 where
46 M: Into<EspHomeMessage> + Debug,
47 {
48 tracing::debug!("Send: {message:?}");
49 let message: EspHomeMessage = message.into();
50 let payload: Vec<u8> = message.into();
51 self.streams.1.write_message(payload).await
52 }
53
54 pub async fn try_read(&mut self) -> Result<EspHomeMessage, ClientError> {
62 loop {
63 let payload = self.streams.0.read_next_message().await?;
64 let message: EspHomeMessage =
65 payload
66 .clone()
67 .try_into()
68 .map_err(|e| ProtocolError::ValidationFailed {
69 reason: format!("Failed to decode EspHomeMessage: {e}"),
70 })?;
71 tracing::debug!("Receive: {message:?}");
72 match message {
73 EspHomeMessage::PingRequest(_) if self.handle_ping => {
74 self.try_write(PingResponse {}).await?;
75 }
76 msg => return Ok(msg),
77 }
78 }
79 }
80
81 pub async fn close(mut self) -> Result<(), ClientError> {
87 self.try_write(DisconnectRequest {}).await?;
88 Ok(())
90 }
91
92 #[must_use]
94 pub fn write_stream(&self) -> EspHomeClientWriteStream {
95 EspHomeClientWriteStream {
96 writer: self.streams.1.clone(),
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
102pub struct EspHomeClientWriteStream {
103 writer: StreamWriter,
104}
105impl EspHomeClientWriteStream {
106 pub async fn try_write<M>(&self, message: M) -> Result<(), ClientError>
112 where
113 M: Into<EspHomeMessage> + Debug,
114 {
115 tracing::debug!("Send: {message:?}");
116 let message: EspHomeMessage = message.into();
117 let payload: Vec<u8> = message.into();
118 self.writer.write_message(payload).await
119 }
120}
121
122#[derive(Debug)]
123pub struct EspHomeClientBuilder {
124 addr: Option<String>,
125 key: Option<String>,
126 password: Option<String>,
127 client_info: String,
128 timeout: Duration,
129 connection_setup: bool,
130 handle_ping: bool,
131}
132
133impl EspHomeClientBuilder {
134 fn new() -> Self {
135 Self {
136 addr: None,
137 key: None,
138 password: None,
139 client_info: format!("{}:{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")),
140 timeout: Duration::from_secs(30),
141 connection_setup: true,
142 handle_ping: true,
143 }
144 }
145
146 #[must_use]
150 pub fn address(mut self, addr: &str) -> Self {
151 self.addr = Some(addr.to_owned());
152 self
153 }
154
155 #[must_use]
160 pub fn key(mut self, key: &str) -> Self {
161 self.key = Some(key.to_owned());
162 self
163 }
164
165 #[must_use]
170 pub fn password(mut self, password: &str) -> Self {
171 self.password = Some(password.to_owned());
172 self
173 }
174
175 #[must_use]
177 pub const fn timeout(mut self, timeout: Duration) -> Self {
178 self.timeout = timeout;
179 self
180 }
181
182 #[must_use]
187 pub fn client_info(mut self, client_info: &str) -> Self {
188 client_info.clone_into(&mut self.client_info);
189 self
190 }
191
192 #[must_use]
200 pub const fn without_connection_setup(mut self) -> Self {
201 self.connection_setup = false;
202 self
203 }
204
205 #[must_use]
210 pub const fn without_ping_handling(mut self) -> Self {
211 self.handle_ping = false;
212 self
213 }
214
215 pub async fn connect(self) -> Result<EspHomeClient, ClientError> {
221 let addr = self.addr.ok_or_else(|| ClientError::Configuration {
222 message: "Address is not set".into(),
223 })?;
224
225 let streams = timeout(self.timeout, async {
226 match self.key {
227 Some(key) => noise::connect(&addr, &key).await,
228 None => plain::connect(&addr).await,
229 }
230 })
231 .await
232 .map_err(|_e| ClientError::Timeout {
233 timeout_ms: self.timeout.as_millis(),
234 })??;
235
236 let mut stream = EspHomeClient {
237 streams,
238 handle_ping: self.handle_ping,
239 };
240 if self.connection_setup {
241 Self::connection_setup(&mut stream, self.client_info, self.password).await?;
242 }
243 Ok(stream)
244 }
245
246 async fn connection_setup(
250 stream: &mut EspHomeClient,
251 client_info: String,
252 password: Option<String>,
253 ) -> Result<(), ClientError> {
254 stream
255 .try_write(HelloRequest {
256 client_info,
257 api_version_major: API_VERSION.0,
258 api_version_minor: API_VERSION.1,
259 })
260 .await?;
261 loop {
262 let response = stream.try_read().await?;
263 match response {
264 EspHomeMessage::HelloResponse(response) => {
265 if response.api_version_major != API_VERSION.0 {
266 return Err(ClientError::ProtocolMismatch {
267 expected: format!("{}.{}", API_VERSION.0, API_VERSION.1),
268 actual: format!(
269 "{}.{}",
270 response.api_version_major, response.api_version_minor
271 ),
272 });
273 }
274 if response.api_version_minor != API_VERSION.1 {
275 tracing::warn!(
276 "API version mismatch: expected {}.{}, got {}.{}, expect breaking changes in messages",
277 API_VERSION.0,
278 API_VERSION.1,
279 response.api_version_major,
280 response.api_version_minor
281 );
282 }
283 break;
284 }
285 _ => {
286 tracing::debug!("Unexpected response during connection setup: {response:?}");
287 }
288 }
289 }
290 stream
291 .try_write(ConnectRequest {
292 password: password.unwrap_or_default(),
293 })
294 .await?;
295 loop {
296 let response = stream.try_read().await?;
297 match response {
298 EspHomeMessage::ConnectResponse(response) => {
299 if response.invalid_password {
300 return Err(ClientError::Authentication {
301 reason: "Invalid password".to_owned(),
302 });
303 }
304 tracing::info!("Connection to ESPHome API established successfully.");
305 break;
306 }
307 _ => {
308 tracing::debug!("Unexpected response during connection setup: {response:?}");
309 }
310 }
311 }
312 Ok(())
313 }
314}