1use crate::{
2 error::Error,
3 header::HttpHeader,
4 method::HttpMethod,
5 options::HttpClientOptions,
6 response::{HttpResponse, ResponseBody},
7 status_code::StatusCode,
8};
9#[cfg(feature = "tls")]
10use defmt::debug;
11use defmt::error;
12use embassy_net::{
13 Stack,
14 dns::{self, DnsSocket},
15 tcp::TcpSocket,
16};
17use embassy_net_08 as embassy_net;
18#[cfg(feature = "tls")]
19use embassy_time::Instant;
20use embassy_time::Timer;
21use embassy_time_05 as embassy_time;
22use embedded_io_async::Write as EmbeddedWrite;
23use embedded_io_async_07 as embedded_io_async;
24
25#[cfg(feature = "tls")]
26use embedded_tls::{Aes128GcmSha256, NoVerify, TlsConfig, TlsConnection, TlsContext};
27#[cfg(feature = "tls")]
28use embedded_tls_018 as embedded_tls;
29use heapless::Vec;
30#[cfg(feature = "tls")]
31use rand_chacha::ChaCha8Rng;
32#[cfg(feature = "tls")]
33use rand_chacha_03 as rand_chacha;
34#[cfg(feature = "tls")]
35use rand_core::SeedableRng;
36#[cfg(feature = "tls")]
37use rand_core_06 as rand_core;
38
39const REQUEST_SIZE: usize = 1024;
40const MAX_HEADERS: usize = 16;
41const SMALL_BUFFER_SIZE: usize = 1024;
42const MEDIUM_BUFFER_SIZE: usize = 4096;
43
44pub type DefaultHttpClient<'a> = HttpClient<
46 'a,
47 MEDIUM_BUFFER_SIZE, MEDIUM_BUFFER_SIZE, MEDIUM_BUFFER_SIZE, MEDIUM_BUFFER_SIZE, REQUEST_SIZE, >;
53
54pub type SmallHttpClient<'a> = HttpClient<
56 'a,
57 SMALL_BUFFER_SIZE, SMALL_BUFFER_SIZE, SMALL_BUFFER_SIZE, SMALL_BUFFER_SIZE, REQUEST_SIZE, >;
63
64macro_rules! try_push {
65 ($expr:expr) => {
66 if $expr.is_err() {
67 return Err(Error::InvalidResponse("Request buffer overflow"));
68 }
69 };
70}
71
72pub struct HttpClient<
90 'a,
91 const TCP_RX: usize = MEDIUM_BUFFER_SIZE,
92 const TCP_TX: usize = MEDIUM_BUFFER_SIZE,
93 const TLS_READ: usize = MEDIUM_BUFFER_SIZE,
94 const TLS_WRITE: usize = MEDIUM_BUFFER_SIZE,
95 const RQ: usize = REQUEST_SIZE,
96> {
97 stack: &'a Stack<'a>,
99 options: HttpClientOptions,
101}
102
103impl<
104 'a,
105 const TCP_RX: usize,
106 const TCP_TX: usize,
107 const TLS_READ: usize,
108 const TLS_WRITE: usize,
109 const RQ: usize,
110> HttpClient<'a, TCP_RX, TCP_TX, TLS_READ, TLS_WRITE, RQ>
111{
112 #[must_use]
114 pub fn new(stack: &'a Stack<'a>) -> Self {
115 Self {
116 stack,
117 options: HttpClientOptions::default(),
118 }
119 }
120
121 #[must_use]
123 pub fn with_options(stack: &'a Stack<'a>, options: HttpClientOptions) -> Self {
124 Self { stack, options }
125 }
126
127 pub async fn request<'b>(
183 &self,
184 method: HttpMethod,
185 endpoint: &str,
186 headers: &[HttpHeader<'_>],
187 body: Option<&[u8]>,
188 response_buffer: &'b mut [u8],
189 ) -> Result<(HttpResponse<'b>, usize), Error> {
190 let (scheme, host_port) = if let Some(rest) = endpoint.strip_prefix("http://") {
191 ("http", rest)
192 } else if let Some(rest) = endpoint.strip_prefix("https://") {
193 ("https", rest)
194 } else {
195 return Err(Error::InvalidUrl);
196 };
197
198 let mut url_parts = heapless::Vec::<&str, 8>::new();
199 for part in host_port.splitn(8, '/') {
200 if url_parts.push(part).is_err() {
201 break;
202 }
203 }
204 if url_parts.is_empty() {
205 return Err(Error::InvalidUrl);
206 }
207
208 let host = url_parts[0];
209 let path = &host_port[host.len()..];
210
211 let (host, port) = if let Some(colon_pos) = host.rfind(':') {
212 if let Ok(port) = host[colon_pos + 1..].parse::<u16>() {
213 (&host[..colon_pos], port)
214 } else {
215 (host, if scheme == "https" { 443 } else { 80 })
216 }
217 } else {
218 (host, if scheme == "https" { 443 } else { 80 })
219 };
220
221 let total_read = match scheme {
222 #[cfg(feature = "tls")]
223 "https" => {
224 self.make_https_request(method, (host, port), path, headers, body, response_buffer)
225 .await?
226 }
227 #[cfg(not(feature = "tls"))]
228 "https" => return Err(Error::UnsupportedScheme("https (TLS support not enabled)")),
229 "http" => {
230 self.make_http_request(method, (host, port), path, headers, body, response_buffer)
231 .await?
232 }
233 _ => return Err(Error::UnsupportedScheme(scheme)),
234 };
235
236 let response = Self::parse_http_response_zero_copy(&response_buffer[..total_read])?;
237 Ok((response, total_read))
238 }
239
240 #[cfg(feature = "tls")]
242 async fn make_https_request(
243 &self,
244 method: HttpMethod,
245 host_port: (&str, u16),
246 path: &str,
247 headers: &[HttpHeader<'_>],
248 body: Option<&[u8]>,
249 response_buffer: &mut [u8],
250 ) -> Result<usize, Error> {
251 use embedded_tls_018::UnsecureProvider;
252
253 let (host, port) = host_port;
254 let mut rx_buffer = [0; TCP_RX];
255 let mut tx_buffer = [0; TCP_TX];
256 let mut socket = TcpSocket::new(*self.stack, &mut rx_buffer, &mut tx_buffer);
257 socket.set_timeout(Some(self.options.socket_timeout));
258
259 let dns_socket = DnsSocket::new(*self.stack);
260 let ip_addresses = dns_socket.query(host, dns::DnsQueryType::A).await?;
261
262 if ip_addresses.is_empty() {
263 return Err(Error::IpAddressEmpty);
264 }
265
266 let ip_addr = ip_addresses[0];
267 let remote_endpoint = (ip_addr, port);
268
269 socket
270 .connect(remote_endpoint)
271 .await
272 .map_err(|e: embassy_net::tcp::ConnectError| {
273 socket.abort();
274 Error::from(e)
275 })?;
276
277 let mut read_record_buffer = [0; TLS_READ];
278 let mut write_record_buffer = [0; TLS_WRITE];
279
280 let tls_config = TlsConfig::new().with_server_name(host);
281 let mut tls = TlsConnection::new(socket, &mut read_record_buffer, &mut write_record_buffer);
282 let mut rng = ChaCha8Rng::from_seed(timeseed());
283
284 tls.open(TlsContext::new(
288 &tls_config,
289 UnsecureProvider::new::<Aes128GcmSha256>(rng),
290 ))
291 .await
292 .expect("error establishing TLS connection");
293
294 let http_request = Self::build_http_request(method, host, path, headers, body)?;
295
296 tls.write_all(http_request.as_bytes()).await?;
297
298 if let Some(body_data) = body {
299 tls.write_all(body_data).await?;
300 }
301
302 tls.flush().await?;
303
304 let mut total_read = 0;
305 let mut retries = self.options.max_retries;
306
307 while total_read < response_buffer.len() && retries > 0 {
308 match tls.read(&mut response_buffer[total_read..]).await {
309 Ok(0) => {
310 break;
311 }
312 Ok(n) => {
313 total_read += n;
314 if Self::is_response_complete(&response_buffer[..total_read]) {
315 break;
316 }
317 }
318 Err(e) => {
319 retries -= 1;
320 if retries > 0 {
321 Timer::after(self.options.retry_delay).await;
322 } else {
323 return Err(Error::TlsError(e));
324 }
325 }
326 }
327 }
328
329 if let Err((_, e)) = tls.close().await {
330 debug!("Error closing TLS connection: {:?}", Error::from(e));
331 }
332
333 Timer::after(self.options.socket_close_delay).await;
334
335 if total_read == 0 {
336 return Err(Error::NoResponse);
337 }
338
339 Ok(total_read)
340 }
341
342 async fn make_http_request(
344 &self,
345 method: HttpMethod,
346 host_port: (&str, u16),
347 path: &str,
348 headers: &[HttpHeader<'_>],
349 body: Option<&[u8]>,
350 response_buffer: &mut [u8],
351 ) -> Result<usize, Error> {
352 let (host, port) = host_port;
353 let mut rx_buffer = [0; TCP_RX];
354 let mut tx_buffer = [0; TCP_TX];
355 let mut socket = TcpSocket::new(*self.stack, &mut rx_buffer, &mut tx_buffer);
356 socket.set_timeout(Some(self.options.socket_timeout));
357
358 let dns_socket = DnsSocket::new(*self.stack);
359 let ip_addresses = dns_socket.query(host, dns::DnsQueryType::A).await?;
360
361 if ip_addresses.is_empty() {
362 return Err(Error::IpAddressEmpty);
363 }
364
365 let ip_addr = ip_addresses[0];
366 let remote_endpoint = (ip_addr, port);
367
368 socket
369 .connect(remote_endpoint)
370 .await
371 .map_err(|e: embassy_net::tcp::ConnectError| {
372 socket.abort();
373 Error::from(e)
374 })?;
375
376 let http_request = Self::build_http_request(method, host, path, headers, body)?;
377
378 socket
379 .write_all(http_request.as_bytes())
380 .await
381 .map_err(|e| {
382 socket.abort();
383 Error::from(e)
384 })?;
385
386 if let Some(body_data) = body {
387 socket.write_all(body_data).await.map_err(|e| {
388 socket.abort();
389 Error::from(e)
390 })?;
391 }
392
393 let mut total_read = 0;
394 let mut retries = self.options.max_retries;
395
396 while total_read < response_buffer.len() && retries > 0 {
397 match socket.read(&mut response_buffer[total_read..]).await {
398 Ok(0) => {
399 break;
400 }
401 Ok(n) => {
402 total_read += n;
403 if Self::is_response_complete(&response_buffer[..total_read]) {
404 break;
405 }
406 }
407 Err(e) => {
408 error!("Socket read error: {:?}", defmt::Debug2Format(&e));
409 retries -= 1;
410 if retries > 0 {
411 Timer::after(self.options.retry_delay).await;
412 }
413 }
414 }
415 }
416
417 socket.close();
418 Timer::after(self.options.socket_close_delay).await;
419
420 if total_read == 0 {
421 return Err(Error::NoResponse);
422 }
423
424 Ok(total_read)
425 }
426
427 pub async fn patch<'b>(
442 &self,
443 endpoint: &str,
444 headers: &[HttpHeader<'_>],
445 body: &[u8],
446 response_buffer: &'b mut [u8],
447 ) -> Result<(HttpResponse<'b>, usize), Error> {
448 self.request(
449 HttpMethod::PATCH,
450 endpoint,
451 headers,
452 Some(body),
453 response_buffer,
454 )
455 .await
456 }
457
458 pub async fn head<'b>(
472 &self,
473 endpoint: &str,
474 headers: &[HttpHeader<'_>],
475 response_buffer: &'b mut [u8],
476 ) -> Result<(HttpResponse<'b>, usize), Error> {
477 self.request(HttpMethod::HEAD, endpoint, headers, None, response_buffer)
478 .await
479 }
480
481 pub async fn options<'b>(
495 &self,
496 endpoint: &str,
497 headers: &[HttpHeader<'_>],
498 response_buffer: &'b mut [u8],
499 ) -> Result<(HttpResponse<'b>, usize), Error> {
500 self.request(
501 HttpMethod::OPTIONS,
502 endpoint,
503 headers,
504 None,
505 response_buffer,
506 )
507 .await
508 }
509
510 pub async fn trace<'b>(
524 &self,
525 endpoint: &str,
526 headers: &[HttpHeader<'_>],
527 response_buffer: &'b mut [u8],
528 ) -> Result<(HttpResponse<'b>, usize), Error> {
529 self.request(HttpMethod::TRACE, endpoint, headers, None, response_buffer)
530 .await
531 }
532
533 pub async fn connect<'b>(
547 &self,
548 endpoint: &str,
549 headers: &[HttpHeader<'_>],
550 response_buffer: &'b mut [u8],
551 ) -> Result<(HttpResponse<'b>, usize), Error> {
552 self.request(
553 HttpMethod::CONNECT,
554 endpoint,
555 headers,
556 None,
557 response_buffer,
558 )
559 .await
560 }
561
562 pub async fn get<'b>(
576 &self,
577 endpoint: &str,
578 headers: &[HttpHeader<'_>],
579 response_buffer: &'b mut [u8],
580 ) -> Result<(HttpResponse<'b>, usize), Error> {
581 self.request(HttpMethod::GET, endpoint, headers, None, response_buffer)
582 .await
583 }
584
585 pub async fn post<'b>(
600 &self,
601 endpoint: &str,
602 headers: &[HttpHeader<'_>],
603 body: &[u8],
604 response_buffer: &'b mut [u8],
605 ) -> Result<(HttpResponse<'b>, usize), Error> {
606 self.request(
607 HttpMethod::POST,
608 endpoint,
609 headers,
610 Some(body),
611 response_buffer,
612 )
613 .await
614 }
615
616 pub async fn put<'b>(
631 &self,
632 endpoint: &str,
633 headers: &[HttpHeader<'_>],
634 body: &[u8],
635 response_buffer: &'b mut [u8],
636 ) -> Result<(HttpResponse<'b>, usize), Error> {
637 self.request(
638 HttpMethod::PUT,
639 endpoint,
640 headers,
641 Some(body),
642 response_buffer,
643 )
644 .await
645 }
646
647 pub async fn delete<'b>(
661 &self,
662 endpoint: &str,
663 headers: &[HttpHeader<'_>],
664 response_buffer: &'b mut [u8],
665 ) -> Result<(HttpResponse<'b>, usize), Error> {
666 self.request(HttpMethod::DELETE, endpoint, headers, None, response_buffer)
667 .await
668 }
669
670 fn parse_http_response_zero_copy(data: &[u8]) -> Result<HttpResponse<'_>, Error> {
672 let response_str = core::str::from_utf8(data)
673 .map_err(|_| Error::InvalidResponse("Invalid HTTP response encoding"))?;
674
675 let status_line_end = response_str
676 .find("\r\n")
677 .ok_or(Error::InvalidResponse("Invalid HTTP response format"))?;
678
679 let status_line = &response_str[..status_line_end];
680 let status_code_str = status_line
681 .split_whitespace()
682 .nth(1)
683 .ok_or(Error::InvalidResponse("Invalid HTTP status line"))?;
684
685 let status_code: StatusCode = status_code_str.try_into()?;
686
687 let headers_end = response_str
688 .find("\r\n\r\n")
689 .ok_or(Error::InvalidResponse("Invalid HTTP response format"))?
690 + 4;
691
692 let headers_section = &response_str[status_line_end + 2..headers_end - 4];
693 let mut headers = Vec::<HttpHeader<'_>, MAX_HEADERS>::new();
694
695 for header_line in headers_section.split("\r\n") {
696 if let Some(colon_pos) = header_line.find(':') {
697 let name = header_line[..colon_pos].trim();
698 let value = header_line[colon_pos + 1..].trim();
699
700 let header = HttpHeader::new(name, value);
701 if headers.push(header).is_err() {
702 break;
703 }
704 }
705 }
706
707 let body_data = if headers_end < data.len() {
708 &data[headers_end..]
709 } else {
710 &[]
711 };
712
713 let body = Self::parse_response_body(&headers, body_data);
715
716 Ok(HttpResponse {
717 status_code,
718 headers,
719 body,
720 })
721 }
722
723 fn parse_response_body<'b>(
725 headers: &[HttpHeader<'_>],
726 body_data: &'b [u8],
727 ) -> ResponseBody<'b> {
728 if body_data.is_empty() {
729 return ResponseBody::Empty;
730 }
731
732 if let Some(content_type) = Self::get_content_type(headers) {
734 if Self::is_text_content_type(content_type) {
735 Self::parse_as_text_or_binary(body_data)
736 } else {
737 ResponseBody::Binary(body_data)
738 }
739 } else {
740 Self::parse_as_text_or_binary(body_data)
742 }
743 }
744
745 fn get_content_type<'h>(headers: &'h [HttpHeader<'_>]) -> Option<&'h str> {
747 headers
748 .iter()
749 .find(|h| h.name.eq_ignore_ascii_case("Content-Type"))
750 .map(|h| h.value)
751 }
752
753 fn is_text_content_type(content_type: &str) -> bool {
755 content_type.starts_with("text/")
756 || content_type.starts_with("application/json")
757 || content_type.starts_with("application/xml")
758 || content_type.starts_with("application/x-www-form-urlencoded")
759 }
760
761 fn parse_as_text_or_binary(body_data: &[u8]) -> ResponseBody<'_> {
763 if let Ok(text) = core::str::from_utf8(body_data) {
764 ResponseBody::Text(text)
765 } else {
766 Self::parse_as_binary(body_data)
767 }
768 }
769
770 fn parse_as_binary(body_data: &[u8]) -> ResponseBody<'_> {
772 ResponseBody::Binary(body_data)
773 }
774
775 fn build_http_request(
777 method: HttpMethod,
778 host: &str,
779 path: &str,
780 headers: &[HttpHeader<'_>],
781 body: Option<&[u8]>,
782 ) -> Result<heapless::String<RQ>, Error> {
783 let mut http_request = heapless::String::<RQ>::new();
784
785 try_push!(http_request.push_str(method.as_str()));
786 try_push!(http_request.push_str(" "));
787 try_push!(http_request.push_str(path));
788 try_push!(http_request.push_str(" HTTP/1.1\r\n"));
789 try_push!(http_request.push_str("Host: "));
790 try_push!(http_request.push_str(host));
791 try_push!(http_request.push_str("\r\n"));
792
793 let mut content_length_present = false;
794
795 for header in headers {
796 try_push!(http_request.push_str(header.name));
797 try_push!(http_request.push_str(": "));
798 try_push!(http_request.push_str(header.value));
799 try_push!(http_request.push_str("\r\n"));
800
801 if header.name.eq_ignore_ascii_case("Content-Length") {
802 content_length_present = true;
803 }
804 }
805
806 if !content_length_present && body.is_some() {
808 try_push!(http_request.push_str("Content-Length: "));
809 let mut len_str = heapless::String::<8>::new();
810 if core::fmt::write(
811 &mut len_str,
812 format_args!("{}", body.unwrap_or_default().len()),
813 )
814 .is_err()
815 {
816 return Err(Error::InvalidResponse("Failed to write content length"));
817 }
818 try_push!(http_request.push_str(&len_str));
819 try_push!(http_request.push_str("\r\n"));
820 }
821
822 try_push!(http_request.push_str("Connection: close\r\n"));
823 try_push!(http_request.push_str("\r\n"));
824
825 Ok(http_request)
826 }
827
828 fn is_response_complete(data: &[u8]) -> bool {
830 let response_str = core::str::from_utf8(data).unwrap_or_default();
831
832 if !response_str.contains("\r\n\r\n") {
833 return false;
834 }
835
836 if let Some(content_length_pos) = response_str.find("Content-Length:") {
838 let content_length_end = response_str[content_length_pos..]
839 .find("\r\n")
840 .unwrap_or_default()
841 + content_length_pos;
842 let content_length_str =
843 &response_str[content_length_pos + 15..content_length_end].trim();
844
845 if let Ok(content_length) = content_length_str.parse::<usize>() {
846 let headers_end = response_str.find("\r\n\r\n").unwrap_or_default() + 4;
847 let body_received = data.len().saturating_sub(headers_end);
848 return body_received >= content_length;
849 }
850 }
851
852 true
853 }
854}
855
856#[cfg(feature = "tls")]
857fn timeseed() -> [u8; 32] {
858 let bytes: [u8; 8] = Instant::now().as_ticks().to_be_bytes();
859 let mut result: [u8; 32] = [0; 32];
860 result[..8].copy_from_slice(&bytes);
861 result
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use embassy_net::Stack;
868
869 #[test]
870 fn test_is_response_complete_headers_only() {
871 let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
872 assert!(DefaultHttpClient::is_response_complete(data));
873 }
874
875 #[test]
876 fn test_is_response_complete_with_content_length() {
877 let data = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello";
878 assert!(DefaultHttpClient::is_response_complete(data));
879 }
880
881 #[test]
882 fn test_is_response_complete_incomplete() {
883 let data = b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nshort";
884 assert!(!DefaultHttpClient::is_response_complete(data));
885 }
886
887 #[test]
888 fn test_new_and_with_options() {
889 let fake_stack: *const Stack = core::ptr::NonNull::dangling().as_ptr();
892 let client = DefaultHttpClient::new(unsafe { &*fake_stack });
893 let opts = HttpClientOptions {
894 max_retries: 1,
895 socket_timeout: embassy_time::Duration::from_secs(1),
896 retry_delay: embassy_time::Duration::from_millis(1),
897 socket_close_delay: embassy_time::Duration::from_millis(1),
898 };
899 let client2 = DefaultHttpClient::with_options(unsafe { &*fake_stack }, opts);
900 assert_eq!(client.options.max_retries, 5);
901 assert_eq!(client2.options.max_retries, 1);
902 }
903
904 #[test]
905 fn test_default_http_client_constructors() {
906 let fake_stack: *const Stack = core::ptr::NonNull::dangling().as_ptr();
907 let client_default = DefaultHttpClient::new(unsafe { &*fake_stack });
908 assert_eq!(client_default.options.max_retries, 5);
909
910 let client_custom = DefaultHttpClient::with_options(
911 unsafe { &*fake_stack },
912 HttpClientOptions {
913 max_retries: 3,
914 socket_timeout: embassy_time::Duration::from_secs(2),
915 retry_delay: embassy_time::Duration::from_millis(10),
916 socket_close_delay: embassy_time::Duration::from_millis(5),
917 },
918 );
919 assert_eq!(client_custom.options.max_retries, 3);
920 }
921
922 #[test]
923 fn test_small_http_client_constructors() {
924 let fake_stack: *const Stack = core::ptr::NonNull::dangling().as_ptr();
925 let client_small = SmallHttpClient::new(unsafe { &*fake_stack });
926 assert_eq!(client_small.options.max_retries, 5);
927
928 let client_small_custom = SmallHttpClient::with_options(
929 unsafe { &*fake_stack },
930 HttpClientOptions {
931 max_retries: 2,
932 socket_timeout: embassy_time::Duration::from_secs(1),
933 retry_delay: embassy_time::Duration::from_millis(5),
934 socket_close_delay: embassy_time::Duration::from_millis(2),
935 },
936 );
937 assert_eq!(client_small_custom.options.max_retries, 2);
938 }
939}