1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
use core::cmp::Ordering;
use core::hash::{Hash, Hasher};
use core::ops::Deref;
use serde::{de::Visitor, Deserialize, Serialize};

/// Represents the current state of firmware and firmware being written on a device.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Status<'a> {
    /// The current version of the firmware.
    #[serde(borrow)]
    pub version: Bytes<'a>,
    /// The max firmware block size to be sent back. The update service must ensure it does not sent larger blocks.
    pub mtu: Option<u32>,
    /// A correlation id which the update service will use when sending commands back. Used mainly when you need to multiplex multiple devices (in a gateway).
    pub correlation_id: Option<u32>,
    /// The status of the firmware being written to a device.
    pub update: Option<UpdateStatus<'a>>,
}

/// The status of the firmware being written to a device.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct UpdateStatus<'a> {
    /// The version of the firmware being written to the device.
    #[serde(borrow)]
    pub version: Bytes<'a>,
    /// The expected next block offset to be written.
    pub offset: u32,
}

impl<'a> Status<'a> {
    /// Create an initial status update where no firmware have been written yet.
    pub fn first(version: &'a [u8], mtu: Option<u32>, correlation_id: Option<u32>) -> Self {
        Self {
            version: Bytes::new(version),
            mtu,
            correlation_id,
            update: None,
        }
    }

    /// Create a status update containing information about the firmware being written in addition to the existing firmware.
    pub fn update(
        version: &'a [u8],
        mtu: Option<u32>,
        offset: u32,
        next_version: &'a [u8],
        correlation_id: Option<u32>,
    ) -> Self {
        Self {
            version: Bytes::new(version),
            mtu,
            correlation_id,
            update: Some(UpdateStatus {
                offset,
                version: Bytes::new(next_version),
            }),
        }
    }
}

/// Represents a command issued from the update service to a device.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Command<'a> {
    /// Instruct the device to wait and send its status update at a later time.
    Wait {
        /// Correlation id matching the id sent in the status update.
        correlation_id: Option<u32>,
        /// The number of seconds the device should wait before sending another status update.
        poll: Option<u32>,
    },
    /// Tell the device that it is up to date and that it can send its status update at a later time.
    Sync {
        /// The version that was used for deciding the device was up to date. The device should check it matches its own version.
        #[serde(borrow)]
        version: Bytes<'a>,
        /// Correlation id matching the id sent in the status update.
        correlation_id: Option<u32>,
        /// The number of seconds the device should wait before sending another status update.
        poll: Option<u32>,
    },
    /// A block of firmware data that should be written to the device at a given offset.
    Write {
        /// The firmware version that this block corresponds to. The device should check that this matches version it has been writing so far.
        #[serde(borrow)]
        version: Bytes<'a>,
        /// Correlation id matching the id sent in the status update.
        correlation_id: Option<u32>,
        /// The offset where this block should be written.
        offset: u32,
        /// The firmware data to write.
        #[serde(borrow)]
        data: Bytes<'a>,
    },
    /// Tell the device that it has now written all of the firmware and that it can commence the swap/update operation.
    Swap {
        /// The version that was used for deciding the device is ready to swap. The device should check it matches the version being written.
        #[serde(borrow)]
        version: Bytes<'a>,
        /// Correlation id matching the id sent in the status update.
        correlation_id: Option<u32>,
        /// The full checksum of the firmware being written. The device should compare this with the checksum of the firmware it has written before swapping.
        #[serde(borrow)]
        checksum: Bytes<'a>,
    },
}

impl<'a> Command<'a> {
    /// Create a new Wait command
    pub fn new_wait(poll: Option<u32>, correlation_id: Option<u32>) -> Self {
        Self::Wait { correlation_id, poll }
    }

    /// Create a new Sync command.
    pub fn new_sync(version: &'a [u8], poll: Option<u32>, correlation_id: Option<u32>) -> Self {
        Self::Sync {
            version: Bytes::new(version),
            correlation_id,
            poll,
        }
    }

    /// Create a new Swap command
    pub fn new_swap(version: &'a [u8], checksum: &'a [u8], correlation_id: Option<u32>) -> Self {
        Self::Swap {
            version: Bytes::new(version),
            correlation_id,
            checksum: Bytes::new(checksum),
        }
    }

    /// Create a new Write command.
    pub fn new_write(version: &'a [u8], offset: u32, data: &'a [u8], correlation_id: Option<u32>) -> Self {
        Self::Write {
            version: Bytes::new(version),
            correlation_id,
            offset,
            data: Bytes::new(data),
        }
    }
}

/// Represents a serde serializeable byte slice.
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Bytes<'a> {
    data: &'a [u8],
}

impl<'a> Bytes<'a> {
    fn new(data: &'a [u8]) -> Self {
        Self { data }
    }
}

impl<'a> Serialize for Bytes<'a> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_bytes(self.data)
    }
}

impl<'a, 'de: 'a> Deserialize<'de> for Bytes<'a> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserializer.deserialize_bytes(BytesVisitor)
    }
}

impl<'a> AsRef<[u8]> for Bytes<'a> {
    fn as_ref(&self) -> &[u8] {
        self.data
    }
}

impl<'a> Deref for Bytes<'a> {
    type Target = [u8];

    fn deref(&self) -> &Self::Target {
        self.data
    }
}

impl<'a> Default for Bytes<'a> {
    fn default() -> Self {
        Bytes::new(&[])
    }
}

impl<'a, Rhs> PartialEq<Rhs> for Bytes<'a>
where
    Rhs: ?Sized + AsRef<[u8]>,
{
    fn eq(&self, other: &Rhs) -> bool {
        self.as_ref().eq(other.as_ref())
    }
}

impl<'a, Rhs> PartialOrd<Rhs> for Bytes<'a>
where
    Rhs: ?Sized + AsRef<[u8]>,
{
    fn partial_cmp(&self, other: &Rhs) -> Option<Ordering> {
        self.as_ref().partial_cmp(other.as_ref())
    }
}

impl<'a> Hash for Bytes<'a> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.data.hash(state);
    }
}

struct BytesVisitor;

impl<'de> Visitor<'de> for BytesVisitor {
    type Value = Bytes<'de>;

    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
        formatter.write_str("a byte slice")
    }

    fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        Ok(Bytes::new(v))
    }
}

#[cfg(test)]
mod tests {
    extern crate std;
    use super::*;
    use std::println;
    use std::vec::Vec;

    #[test]
    fn deserialize_ref() {
        let s = Command::new_write(b"1234", 0, &[1, 2, 3, 4], None);
        let out = serde_cbor::to_vec(&s).unwrap();

        let s: Command = serde_cbor::from_slice(&out).unwrap();
        println!("Out: {:?}", s);
    }

    #[test]
    fn serialized_status_size() {
        // 1 byte version, 4 byte payload, 4 byte checksum
        let version = &[1];
        let mtu = Some(4);
        let cid = None;
        let offset = 0;
        let next_version = &[2];

        let s = Status::first(version, mtu, cid);
        let first = encode(&s);

        let s = Status::update(version, mtu, offset, next_version, cid);
        let update = encode(&s);
        println!("Serialized size:\n FIRST:\t{}\nUPDATE:\t{}", first.len(), update.len(),);
    }

    #[test]
    fn serialized_command_size() {
        // 1 byte version, 4 byte payload, 4 byte checksum
        let version = &[1];
        let payload = &[1, 2, 3, 4];
        let checksum = &[1, 2, 3, 4];

        let s = Command::new_write(version, 0, payload, None);
        let write = encode(&s);

        let s = Command::new_wait(Some(1), None);
        let wait = encode(&s);

        let s = Command::new_sync(version, Some(1), None);
        let sync = encode(&s);

        let s = Command::new_swap(version, checksum, None);
        let swap = encode(&s);
        println!(
            "Serialized size:\n WRITE:\t{}\nWAIT:\t{}\nSYNC:\t{}\nSWAP:\t{}",
            write.len(),
            wait.len(),
            sync.len(),
            swap.len()
        );
    }

    fn encode<T>(value: &T) -> Vec<u8>
    where
        T: serde::Serialize,
    {
        serde_cbor::ser::to_vec_packed(value).unwrap()
    }
}