1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::{Command, Status, UpdateService};
use core::future::Future;
use embedded_nal_async::{SocketAddr, TcpConnect};
use rand_core::{CryptoRng, RngCore};
use reqwless::{
    client::{Error as HttpError, HttpClient},
    request::{ContentType, Request, Status as ResponseStatus},
};
use serde::Serialize;

#[cfg(feature = "tls")]
use embedded_tls::*;

/// An update service implementation for the Drogue Cloud update service.
pub struct DrogueHttp<'a, T, RNG, const MTU: usize>
where
    T: TcpConnect + 'a,
    RNG: RngCore + CryptoRng + 'a,
{
    client: T,
    rng: RNG,
    addr: SocketAddr,
    host: &'a str,
    username: &'a str,
    password: &'a str,
    buf: [u8; MTU],
}

impl<'a, T, RNG, const MTU: usize> DrogueHttp<'a, T, RNG, MTU>
where
    T: TcpConnect + 'a,
    RNG: RngCore + CryptoRng + 'a,
{
    /// Construct a new Drogue update service
    pub fn new(client: T, rng: RNG, addr: SocketAddr, host: &'a str, username: &'a str, password: &'a str) -> Self {
        Self {
            client,
            rng,
            addr,
            host,
            username,
            password,
            buf: [0; MTU],
        }
    }
}

/// An error returned from the update service.
#[derive(Debug)]
pub enum Error<N, H, S, T> {
    /// Error from the underlying network
    Network(N),
    /// Error from HTTP client
    Http(H),
    /// Error from TLS
    Tls(T),
    /// Error in encoding or decoding of the payload
    Codec(S),
    /// Error in the firmware update protocol
    Protocol,
}

impl<'a, T, RNG, const MTU: usize> UpdateService for DrogueHttp<'a, T, RNG, MTU>
where
    T: TcpConnect + 'a,
    RNG: RngCore + CryptoRng + 'a,
{
    #[cfg(feature = "tls")]
    type Error = Error<T::Error, HttpError, serde_cbor::Error, TlsError>;

    #[cfg(not(feature = "tls"))]
    type Error = Error<T::Error, HttpError, serde_cbor::Error, ()>;

    type RequestFuture<'m> = impl Future<Output = Result<Command<'m>, Self::Error>> + 'm where Self: 'm;
    fn request<'m>(&'m mut self, status: &'m Status<'m>) -> Self::RequestFuture<'m> {
        async move {
            #[allow(unused_mut)]
            let mut connection = self.client.connect(self.addr).await.map_err(Error::Network)?;

            #[cfg(feature = "tls")]
            let mut tls_buffer = [0; 6000];

            #[cfg(feature = "tls")]
            let mut connection = {
                let mut connection: TlsConnection<'_, _, Aes128GcmSha256> =
                    TlsConnection::new(connection, &mut tls_buffer);
                connection
                    .open::<_, NoClock, 1>(TlsContext::new(
                        &TlsConfig::new().with_server_name(self.host),
                        &mut self.rng,
                    ))
                    .await
                    .map_err(Error::Tls)?;
                connection
            };
            let mut client = HttpClient::new(&mut connection, self.host);

            let mut payload = [0; 64];
            let writer = serde_cbor::ser::SliceWrite::new(&mut payload[..]);
            let mut ser = serde_cbor::Serializer::new(writer).packed_format();
            status.serialize(&mut ser).map_err(Error::Codec)?;
            let writer = ser.into_inner();
            let size = writer.bytes_written();
            debug!("Status payload is {} bytes", size);

            let request = Request::post()
                .path("/v1/dfu?ct=30")
                .payload(&payload[..size])
                .basic_auth(self.username, self.password)
                .content_type(ContentType::ApplicationCbor)
                .build();

            let mut rx_buf = [0; MTU];
            let response = client.request(request, &mut rx_buf).await.map_err(Error::Http)?;

            if response.status == ResponseStatus::Ok
                || response.status == ResponseStatus::Accepted
                || response.status == ResponseStatus::Created
            {
                if let Some(payload) = response.payload {
                    self.buf[..payload.len()].copy_from_slice(payload);
                    let command: Command<'m> =
                        serde_cbor::de::from_mut_slice(&mut self.buf[..payload.len()]).map_err(Error::Codec)?;
                    Ok(command)
                } else {
                    Ok(Command::new_wait(Some(10), None))
                }
            } else {
                Err(Error::Protocol)
            }
        }
    }
}