zencan_node/
persist.rs

1use core::{
2    cell::RefCell,
3    convert::Infallible,
4    future::Future,
5    pin::{pin, Pin},
6    task::Context,
7};
8
9use crate::object_dict::{find_object, ODEntry};
10use futures::{pending, task::noop_waker_ref};
11
12use defmt_or_log::{debug, warn};
13
14/// Specifies the types of nodes which can be serialized to persistent storage
15#[derive(Debug, Copy, Clone, PartialEq)]
16#[repr(u8)]
17pub enum NodeType {
18    /// A node containing a saved sub-object value
19    ObjectValue = 1,
20    /// An unrecognized node type
21    Unknown,
22}
23
24impl NodeType {
25    /// Create a `NodeType` from an ID byte
26    pub fn from_byte(b: u8) -> Self {
27        match b {
28            1 => Self::ObjectValue,
29            _ => Self::Unknown,
30        }
31    }
32}
33
34async fn write_bytes(bytes: &[u8], reg: &RefCell<u8>) {
35    for b in bytes {
36        *reg.borrow_mut() = *b;
37        pending!()
38    }
39}
40
41async fn serialize_object(obj: &ODEntry<'_>, sub: u8, reg: &RefCell<u8>) {
42    // Unwrap safety: This can only fail if the sub doesn't exist, and we already
43    // checked for that above
44    let data_size = obj.data.read_size(sub).unwrap() as u16;
45    // Serialized node size is the variable length object data, plus node type (u8), index (u16), and sub index (u8)
46    let node_size = data_size + 4;
47
48    write_bytes(&node_size.to_le_bytes(), reg).await;
49    write_bytes(&[NodeType::ObjectValue as u8], reg).await;
50    write_bytes(&obj.index.to_le_bytes(), reg).await;
51    write_bytes(&[sub], reg).await;
52
53    const CHUNK_SIZE: usize = 32;
54    let mut buf = [0u8; CHUNK_SIZE];
55    let mut read_pos = 0;
56    loop {
57        // Note: returned read_len is not checked, on purpose. We already committed above to writing
58        // a certain number of bytes, and we must write them. This is fine for fields which are
59        // shorter than CHUNK_SIZE. This is a problem for fields which are larger than CHUNK_SIZE,
60        // and can lead to "torn reads", where different chunks come from different values because
61        // the value changed between the two chunks. This is only a problem for large fields which
62        // can be modified on a different thread than `Node::process()` is called. Fixing it
63        // requires an object locking mechanism, which may be worth considering in the future.
64        obj.data.read(sub, read_pos, &mut buf).unwrap();
65        let copy_len = data_size as usize - read_pos;
66        read_pos += copy_len;
67        write_bytes(&buf[0..copy_len], reg).await;
68        if read_pos >= data_size as usize {
69            break;
70        }
71    }
72}
73
74async fn serialize_sm(objects: &[ODEntry<'static>], reg: &RefCell<u8>) {
75    for obj in objects {
76        let max_sub = obj.data.max_sub_number();
77
78        for sub in 0..max_sub + 1 {
79            let info = obj.data.sub_info(sub);
80            // On a record, some subs may not be present. Just skip these.
81            if info.is_err() {
82                continue;
83            }
84            let info = info.unwrap();
85            if !info.persist {
86                continue;
87            }
88            serialize_object(obj, sub, reg).await;
89        }
90    }
91}
92
93pub fn serialized_size(objects: &[ODEntry]) -> usize {
94    const OVERHEAD_SIZE: usize = 6;
95    let mut size = 0;
96    for obj in objects {
97        let max_sub = obj.data.max_sub_number();
98        for sub in 0..max_sub + 1 {
99            let info = obj.data.sub_info(sub);
100            // On a record, some subs may not be present. Just skip these.
101            if info.is_err() {
102                continue;
103            }
104            let info = info.unwrap();
105            if !info.persist {
106                continue;
107            }
108            // Unwrap safety: This can only fail if the sub doesn't exist, and we already
109            // checked for that above
110            let data_size = obj.data.read_size(sub).unwrap();
111            // Serialized node size is the variable length object data, plus node type (u8),
112            // index (u16), and sub index (u8), plus a length header (u16)
113            size += data_size + OVERHEAD_SIZE;
114        }
115    }
116
117    size
118}
119
120struct PersistSerializer<'a, 'b, F: Future> {
121    f: Pin<&'a mut F>,
122    reg: &'b RefCell<u8>,
123}
124
125impl<'a, 'b, F: Future> PersistSerializer<'a, 'b, F> {
126    pub fn new(f: Pin<&'a mut F>, reg: &'b RefCell<u8>) -> Self {
127        Self { f, reg }
128    }
129}
130
131impl<F: Future> embedded_io::ErrorType for PersistSerializer<'_, '_, F> {
132    type Error = Infallible;
133}
134
135impl<F: Future> embedded_io::Read for PersistSerializer<'_, '_, F> {
136    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Infallible> {
137        let mut cx = Context::from_waker(noop_waker_ref());
138
139        let mut pos = 0;
140        loop {
141            if pos >= buf.len() {
142                return Ok(pos);
143            }
144
145            match self.f.as_mut().poll(&mut cx) {
146                core::task::Poll::Ready(_) => return Ok(pos),
147                core::task::Poll::Pending => {
148                    buf[pos] = *self.reg.borrow();
149                    pos += 1;
150                }
151            }
152        }
153    }
154}
155
156/// Serialize node data
157pub fn serialize<F: Fn(&mut dyn embedded_io::Read<Error = Infallible>, usize)>(
158    od: &'static [ODEntry],
159    callback: F,
160) {
161    let reg = RefCell::new(0);
162    let fut = pin!(serialize_sm(od, &reg));
163    let mut serializer = PersistSerializer::new(fut, &reg);
164    let size = serialized_size(od);
165    callback(&mut serializer, size)
166}
167
168/// Error which can be returned while reading persisted data
169pub enum PersistReadError {
170    /// Not enough bytes were present to construct the node
171    NodeLengthShort,
172}
173
174/// The data for an ObjectValue node
175#[derive(Debug, PartialEq)]
176pub struct ObjectValue<'a> {
177    /// The object index this value belongs to
178    pub index: u16,
179    /// The sub-object index this value belongs to
180    pub sub: u8,
181    /// The raw bytes to be restored to the sub object
182    pub data: &'a [u8],
183}
184
185/// A reference to a single node within a slice of serialized data
186///
187/// Returned by the PersistNodeReader iterator.
188#[derive(Debug, PartialEq)]
189pub enum PersistNodeRef<'a> {
190    /// A saved value for a sub-object
191    ObjectValue(ObjectValue<'a>),
192    /// An unrecognized node type was encountered. Either the serialized data is malformed, or
193    /// perhaps it was written with a future version of code that supports more node types
194    ///
195    /// The bytes of the node are stored in the contained slice, including the node type in the
196    /// first byte
197    Unknown(&'a [u8]),
198}
199
200impl<'a> PersistNodeRef<'a> {
201    /// Create a PersistNodeRef from a slice of bytes
202    pub fn from_slice(data: &'a [u8]) -> Result<Self, PersistReadError> {
203        if data.is_empty() {
204            return Err(PersistReadError::NodeLengthShort);
205        }
206
207        match NodeType::from_byte(data[0]) {
208            NodeType::ObjectValue => {
209                if data.len() < 5 {
210                    return Err(PersistReadError::NodeLengthShort);
211                }
212                Ok(Self::ObjectValue(ObjectValue {
213                    index: u16::from_le_bytes(data[1..3].try_into().unwrap()),
214                    sub: data[3],
215                    data: &data[4..],
216                }))
217            }
218            NodeType::Unknown => Ok(PersistNodeRef::Unknown(data)),
219        }
220    }
221}
222
223/// Read serialized object data from a slice of bytes
224///
225/// PersistNodeReader provides an Iterator of PersistNodeRef objects, representing all of the nodes
226/// stored in the slice
227struct PersistNodeReader<'a> {
228    buf: &'a [u8],
229    pos: usize,
230}
231
232impl<'a> PersistNodeReader<'a> {
233    /// Instantiate a PersistNodeReader from a slice of serialized data
234    pub fn new(data: &'a [u8]) -> Self {
235        Self { buf: data, pos: 0 }
236    }
237}
238
239impl<'a> Iterator for PersistNodeReader<'a> {
240    type Item = PersistNodeRef<'a>;
241
242    fn next(&mut self) -> Option<Self::Item> {
243        if self.buf.len() - self.pos < 2 {
244            return None;
245        }
246        let length = u16::from_le_bytes(self.buf[self.pos..self.pos + 2].try_into().unwrap());
247        self.pos += 2;
248        let node_slice = &self.buf[self.pos..self.pos + length as usize];
249        self.pos += length as usize;
250
251        PersistNodeRef::from_slice(node_slice).ok()
252    }
253}
254
255/// Load values of objects previously persisted in serialized format
256///
257/// # Arguments
258/// - `od`: The object dictionary where objects will be updated
259/// - `stored_data`: A slice of bytes, as previously provided to the store_objects callback.
260pub fn restore_stored_objects(od: &[ODEntry], stored_data: &[u8]) {
261    let reader = PersistNodeReader::new(stored_data);
262    for item in reader {
263        match item {
264            PersistNodeRef::ObjectValue(restore) => {
265                if let Some(obj) = find_object(od, restore.index) {
266                    if let Ok(_sub_info) = obj.sub_info(restore.sub) {
267                        debug!(
268                            "Restoring 0x{:x}sub{} with {:?}",
269                            restore.index, restore.sub, restore.data
270                        );
271                        if let Err(abort_code) = obj.write(restore.sub, restore.data) {
272                            warn!(
273                                "Error restoring object 0x{:x}sub{}: {:x}",
274                                restore.index, restore.sub, abort_code as u32
275                            );
276                        }
277                    } else {
278                        warn!(
279                            "Saved object 0x{:x}sub{} not found in OD",
280                            restore.index, restore.sub
281                        );
282                    }
283                } else {
284                    warn!("Saved object 0x{:x} not found in OD", restore.index);
285                }
286            }
287            PersistNodeRef::Unknown(id) => warn!("Unknown persisted object read: {}", id[0]),
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::object_dict::{
296        ConstField, NullTermByteField, ODEntry, ProvidesSubObjects, ScalarField, SubObjectAccess,
297    };
298    use zencan_common::objects::{DataType, ObjectCode, SubInfo};
299
300    use crate::persist::serialize;
301
302    #[test]
303    fn test_serialize_deserialize() {
304        #[derive(Default)]
305        struct Object100 {
306            value1: ScalarField<u32>,
307            value2: ScalarField<u16>,
308        }
309
310        impl ProvidesSubObjects for Object100 {
311            fn get_sub_object(&self, sub: u8) -> Option<(SubInfo, &dyn SubObjectAccess)> {
312                match sub {
313                    0 => Some((
314                        SubInfo::MAX_SUB_NUMBER,
315                        const { &ConstField::new(2u8.to_le_bytes()) },
316                    )),
317                    1 => Some((
318                        SubInfo {
319                            size: 4,
320                            data_type: DataType::UInt32,
321                            persist: true,
322                            ..Default::default()
323                        },
324                        &self.value1,
325                    )),
326                    2 => Some((
327                        SubInfo {
328                            size: 4,
329                            data_type: DataType::UInt32,
330                            persist: false,
331                            ..Default::default()
332                        },
333                        &self.value2,
334                    )),
335                    _ => None,
336                }
337            }
338
339            fn object_code(&self) -> ObjectCode {
340                ObjectCode::Record
341            }
342        }
343
344        #[derive(Default)]
345        struct Object200 {
346            string: NullTermByteField<15>,
347        }
348
349        impl ProvidesSubObjects for Object200 {
350            fn get_sub_object(&self, sub: u8) -> Option<(SubInfo, &dyn SubObjectAccess)> {
351                match sub {
352                    0 => Some((
353                        SubInfo::new_visibile_str(self.string.len()).persist(true),
354                        &self.string,
355                    )),
356                    _ => None,
357                }
358            }
359
360            fn object_code(&self) -> ObjectCode {
361                ObjectCode::Var
362            }
363        }
364
365        let inst100 = Box::leak(Box::new(Object100::default()));
366        let inst200 = Box::leak(Box::new(Object200::default()));
367
368        let od = Box::leak(Box::new([
369            ODEntry {
370                index: 0x100,
371                data: inst100,
372            },
373            ODEntry {
374                index: 0x200,
375                data: inst200,
376            },
377        ]));
378        inst100.value1.store(42);
379        inst200.string.set_str("test".as_bytes()).unwrap();
380
381        let data = RefCell::new(Vec::new());
382        serialize(od, |reader, _size| {
383            const CHUNK_SIZE: usize = 2;
384            let mut buf = [0; CHUNK_SIZE];
385            loop {
386                let n = reader.read(&mut buf).unwrap();
387                data.borrow_mut().extend_from_slice(&buf[..n]);
388                if n < buf.len() {
389                    break;
390                }
391            }
392        });
393
394        let data = data.take();
395        assert_eq!(20, data.len());
396        assert_eq!(data.len(), serialized_size(od));
397
398        let mut deser = PersistNodeReader::new(&data);
399        assert_eq!(
400            deser.next().unwrap(),
401            PersistNodeRef::ObjectValue(ObjectValue {
402                index: 0x100,
403                sub: 1,
404                data: &42u32.to_le_bytes()
405            })
406        );
407        assert_eq!(
408            deser.next().unwrap(),
409            PersistNodeRef::ObjectValue(ObjectValue {
410                index: 0x200,
411                sub: 0,
412                data: "test".as_bytes()
413            })
414        );
415        assert_eq!(deser.next(), None);
416    }
417}