zencan_node/
persist.rs

1use core::{
2    cell::RefCell,
3    convert::Infallible,
4    future::Future,
5    pin::{pin, Pin},
6    task::Context,
7};
8
9use crate::object_dict::{find_object, ODEntry};
10use futures::{pending, task::noop_waker_ref};
11
12use defmt_or_log::{debug, warn};
13
14/// Specifies the types of nodes which can be serialized to persistent storage
15#[derive(Debug, Copy, Clone, PartialEq)]
16#[repr(u8)]
17pub enum NodeType {
18    /// A node containing a saved sub-object value
19    ObjectValue = 1,
20    /// An unrecognized node type
21    Unknown,
22}
23
24impl NodeType {
25    /// Create a `NodeType` from an ID byte
26    pub fn from_byte(b: u8) -> Self {
27        match b {
28            1 => Self::ObjectValue,
29            _ => Self::Unknown,
30        }
31    }
32}
33
34async fn write_bytes(bytes: &[u8], reg: &RefCell<u8>) {
35    for b in bytes {
36        *reg.borrow_mut() = *b;
37        pending!()
38    }
39}
40
41async fn serialize_object(obj: &ODEntry<'_>, sub: u8, reg: &RefCell<u8>) {
42    // Unwrap safety: This can only fail if the sub doesn't exist, and we already
43    // checked for that above
44    let data_size = obj.data.read_size(sub).unwrap() as u16;
45    // Serialized node size is the variable length object data, plus node type (u8), index (u16), and sub index (u8)
46    let node_size = data_size + 4;
47
48    write_bytes(&node_size.to_le_bytes(), reg).await;
49    write_bytes(&[NodeType::ObjectValue as u8], reg).await;
50    write_bytes(&obj.index.to_le_bytes(), reg).await;
51    write_bytes(&[sub], reg).await;
52
53    const CHUNK_SIZE: usize = 32;
54    let mut buf = [0u8; CHUNK_SIZE];
55    let mut read_pos = 0;
56    loop {
57        // Note: returned read_len is not checked, on purpose. We already committed above to writing
58        // a certain number of bytes, and we must write them. This is fine for fields which are
59        // shorter than CHUNK_SIZE. This is a problem for fields which are larger than CHUNK_SIZE,
60        // and can lead to "torn reads", where different chunks come from different values because
61        // the value changed between the two chunks. This is only a problem for large fields which
62        // can be modified on a different thread than `Node::process()` is called. Fixing it
63        // requires an object locking mechanism, which may be worth considering in the future.
64        obj.data.read(sub, read_pos, &mut buf).unwrap();
65        let copy_len = data_size as usize - read_pos;
66        read_pos += copy_len;
67        write_bytes(&buf[0..copy_len], reg).await;
68        if read_pos >= data_size as usize {
69            break;
70        }
71    }
72}
73
74async fn serialize_sm(objects: &[ODEntry<'_>], reg: &RefCell<u8>) {
75    for obj in objects {
76        let max_sub = obj.data.max_sub_number();
77
78        for sub in 0..max_sub + 1 {
79            let info = obj.data.sub_info(sub);
80            // On a record, some subs may not be present. Just skip these.
81            if info.is_err() {
82                continue;
83            }
84            let info = info.unwrap();
85            if !info.persist {
86                continue;
87            }
88            serialize_object(obj, sub, reg).await;
89        }
90    }
91}
92
93pub fn serialized_size(objects: &[ODEntry]) -> usize {
94    const OVERHEAD_SIZE: usize = 6;
95    let mut size = 0;
96    for obj in objects {
97        let max_sub = obj.data.max_sub_number();
98        for sub in 0..max_sub + 1 {
99            let info = obj.data.sub_info(sub);
100            // On a record, some subs may not be present. Just skip these.
101            if info.is_err() {
102                continue;
103            }
104            let info = info.unwrap();
105            if !info.persist {
106                continue;
107            }
108            // Unwrap safety: This can only fail if the sub doesn't exist, and we already
109            // checked for that above
110            let data_size = obj.data.read_size(sub).unwrap();
111            // Serialized node size is the variable length object data, plus node type (u8),
112            // index (u16), and sub index (u8), plus a length header (u16)
113            size += data_size + OVERHEAD_SIZE;
114        }
115    }
116
117    size
118}
119
120struct PersistSerializer<'a, 'b, F: Future> {
121    f: Pin<&'a mut F>,
122    reg: &'b RefCell<u8>,
123}
124
125impl<'a, 'b, F: Future> PersistSerializer<'a, 'b, F> {
126    pub fn new(f: Pin<&'a mut F>, reg: &'b RefCell<u8>) -> Self {
127        Self { f, reg }
128    }
129}
130
131impl<F: Future> embedded_io::ErrorType for PersistSerializer<'_, '_, F> {
132    type Error = Infallible;
133}
134
135impl<F: Future> embedded_io::Read for PersistSerializer<'_, '_, F> {
136    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Infallible> {
137        let mut cx = Context::from_waker(noop_waker_ref());
138
139        let mut pos = 0;
140        loop {
141            if pos >= buf.len() {
142                return Ok(pos);
143            }
144
145            match self.f.as_mut().poll(&mut cx) {
146                core::task::Poll::Ready(_) => return Ok(pos),
147                core::task::Poll::Pending => {
148                    buf[pos] = *self.reg.borrow();
149                    pos += 1;
150                }
151            }
152        }
153    }
154}
155
156/// Serialize node data
157pub fn serialize(
158    od: &[ODEntry],
159    callback: &dyn Fn(&mut dyn embedded_io::Read<Error = Infallible>, usize),
160) {
161    let reg = RefCell::new(0);
162    let fut = pin!(serialize_sm(od, &reg));
163    let mut serializer = PersistSerializer::new(fut, &reg);
164    let size = serialized_size(od);
165    callback(&mut serializer, size)
166}
167
168/// Error which can be returned while reading persisted data
169pub enum PersistReadError {
170    /// Not enough bytes were present to construct the node
171    NodeLengthShort,
172}
173
174/// The data for an ObjectValue node
175#[derive(Debug, PartialEq)]
176pub struct ObjectValue<'a> {
177    /// The object index this value belongs to
178    pub index: u16,
179    /// The sub-object index this value belongs to
180    pub sub: u8,
181    /// The raw bytes to be restored to the sub object
182    pub data: &'a [u8],
183}
184
185/// A reference to a single node within a slice of serialized data
186///
187/// Returned by the PersistNodeReader iterator.
188#[derive(Debug, PartialEq)]
189pub enum PersistNodeRef<'a> {
190    /// A saved value for a sub-object
191    ObjectValue(ObjectValue<'a>),
192    /// An unrecognized node type was encountered. Either the serialized data is malformed, or
193    /// perhaps it was written with a future version of code that supports more node types
194    ///
195    /// The bytes of the node are stored in the contained slice, including the node type in the
196    /// first byte
197    Unknown(&'a [u8]),
198}
199
200impl<'a> PersistNodeRef<'a> {
201    /// Create a PersistNodeRef from a slice of bytes
202    pub fn from_slice(data: &'a [u8]) -> Result<Self, PersistReadError> {
203        if data.is_empty() {
204            return Err(PersistReadError::NodeLengthShort);
205        }
206
207        match NodeType::from_byte(data[0]) {
208            NodeType::ObjectValue => {
209                if data.len() < 5 {
210                    return Err(PersistReadError::NodeLengthShort);
211                }
212                Ok(Self::ObjectValue(ObjectValue {
213                    index: u16::from_le_bytes(data[1..3].try_into().unwrap()),
214                    sub: data[3],
215                    data: &data[4..],
216                }))
217            }
218            NodeType::Unknown => Ok(PersistNodeRef::Unknown(data)),
219        }
220    }
221}
222
223/// Read serialized object data from a slice of bytes
224///
225/// PersistNodeReader provides an Iterator of PersistNodeRef objects, representing all of the nodes
226/// stored in the slice
227struct PersistNodeReader<'a> {
228    buf: &'a [u8],
229    pos: usize,
230}
231
232impl<'a> PersistNodeReader<'a> {
233    /// Instantiate a PersistNodeReader from a slice of serialized data
234    pub fn new(data: &'a [u8]) -> Self {
235        Self { buf: data, pos: 0 }
236    }
237}
238
239impl<'a> Iterator for PersistNodeReader<'a> {
240    type Item = PersistNodeRef<'a>;
241
242    fn next(&mut self) -> Option<Self::Item> {
243        if self.buf.len() - self.pos < 2 {
244            return None;
245        }
246        let length = u16::from_le_bytes(self.buf[self.pos..self.pos + 2].try_into().unwrap());
247        self.pos += 2;
248        let node_slice = &self.buf[self.pos..self.pos + length as usize];
249        self.pos += length as usize;
250
251        PersistNodeRef::from_slice(node_slice).ok()
252    }
253}
254
255/// Load values of objects previously persisted in serialized format with limited range
256///
257/// All saved objects where `start_index <= saved object index <= end_index` will be restored to the
258/// object dictionary. Saved objects outside this range will be dropped.
259///
260/// # Arguments
261/// - `od`: The object dictionary where objects will be updated
262/// - `stored_data`: A slice of bytes, as previously provided to the store_objects callback.
263/// - 'start_index
264pub fn restore_stored_objects_ranged(
265    od: &[ODEntry],
266    stored_data: &[u8],
267    start_index: u16,
268    end_index: u16,
269) {
270    let reader = PersistNodeReader::new(stored_data);
271    for item in reader {
272        match item {
273            PersistNodeRef::ObjectValue(restore) => {
274                if restore.index < start_index || restore.index > end_index {
275                    continue;
276                }
277                if let Some(obj) = find_object(od, restore.index) {
278                    if let Ok(_sub_info) = obj.sub_info(restore.sub) {
279                        debug!(
280                            "Restoring 0x{:x}sub{} with {:?}",
281                            restore.index, restore.sub, restore.data
282                        );
283                        if let Err(abort_code) = obj.write(restore.sub, restore.data) {
284                            warn!(
285                                "Error restoring object 0x{:x}sub{}: {:x}",
286                                restore.index, restore.sub, abort_code as u32
287                            );
288                        }
289                    } else {
290                        warn!(
291                            "Saved object 0x{:x}sub{} not found in OD",
292                            restore.index, restore.sub
293                        );
294                    }
295                } else {
296                    warn!("Saved object 0x{:x} not found in OD", restore.index);
297                }
298            }
299            PersistNodeRef::Unknown(id) => warn!("Unknown persisted object read: {}", id[0]),
300        }
301    }
302}
303
304/// Restore all stored objects in stored data to the object dict
305pub fn restore_stored_objects(od: &[ODEntry], stored_data: &[u8]) {
306    restore_stored_objects_ranged(od, stored_data, 0, u16::MAX);
307}
308
309/// Restore only communications objects from the stored data to the object dict
310///
311/// Communications objects are objects 0x1000-0x1fff.
312pub fn restore_stored_comm_objects(od: &[ODEntry], stored_data: &[u8]) {
313    restore_stored_objects_ranged(od, stored_data, 0x1000, 0x1fff);
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::object_dict::{
320        ConstField, NullTermByteField, ODEntry, ProvidesSubObjects, ScalarField, SubObjectAccess,
321    };
322    use zencan_common::objects::{DataType, ObjectCode, SubInfo};
323
324    use crate::persist::serialize;
325
326    #[test]
327    fn test_serialize_deserialize() {
328        #[derive(Default)]
329        struct Object100 {
330            value1: ScalarField<u32>,
331            value2: ScalarField<u16>,
332        }
333
334        impl ProvidesSubObjects for Object100 {
335            fn get_sub_object(&self, sub: u8) -> Option<(SubInfo, &dyn SubObjectAccess)> {
336                match sub {
337                    0 => Some((
338                        SubInfo::MAX_SUB_NUMBER,
339                        const { &ConstField::new(2u8.to_le_bytes()) },
340                    )),
341                    1 => Some((
342                        SubInfo {
343                            size: 4,
344                            data_type: DataType::UInt32,
345                            persist: true,
346                            ..Default::default()
347                        },
348                        &self.value1,
349                    )),
350                    2 => Some((
351                        SubInfo {
352                            size: 4,
353                            data_type: DataType::UInt32,
354                            persist: false,
355                            ..Default::default()
356                        },
357                        &self.value2,
358                    )),
359                    _ => None,
360                }
361            }
362
363            fn object_code(&self) -> ObjectCode {
364                ObjectCode::Record
365            }
366        }
367
368        #[derive(Default)]
369        struct Object200 {
370            string: NullTermByteField<15>,
371        }
372
373        impl ProvidesSubObjects for Object200 {
374            fn get_sub_object(&self, sub: u8) -> Option<(SubInfo, &dyn SubObjectAccess)> {
375                match sub {
376                    0 => Some((
377                        SubInfo::new_visibile_str(self.string.len()).persist(true),
378                        &self.string,
379                    )),
380                    _ => None,
381                }
382            }
383
384            fn object_code(&self) -> ObjectCode {
385                ObjectCode::Var
386            }
387        }
388
389        let inst100 = Box::leak(Box::new(Object100::default()));
390        let inst200 = Box::leak(Box::new(Object200::default()));
391
392        let od = Box::leak(Box::new([
393            ODEntry {
394                index: 0x100,
395                data: inst100,
396            },
397            ODEntry {
398                index: 0x200,
399                data: inst200,
400            },
401        ]));
402        inst100.value1.store(42);
403        inst200.string.set_str("test".as_bytes()).unwrap();
404
405        let data = RefCell::new(Vec::new());
406        serialize(od, &|reader, _size| {
407            const CHUNK_SIZE: usize = 2;
408            let mut buf = [0; CHUNK_SIZE];
409            loop {
410                let n = reader.read(&mut buf).unwrap();
411                data.borrow_mut().extend_from_slice(&buf[..n]);
412                if n < buf.len() {
413                    break;
414                }
415            }
416        });
417
418        let data = data.take();
419        assert_eq!(20, data.len());
420        assert_eq!(data.len(), serialized_size(od));
421
422        let mut deser = PersistNodeReader::new(&data);
423        assert_eq!(
424            deser.next().unwrap(),
425            PersistNodeRef::ObjectValue(ObjectValue {
426                index: 0x100,
427                sub: 1,
428                data: &42u32.to_le_bytes()
429            })
430        );
431        assert_eq!(
432            deser.next().unwrap(),
433            PersistNodeRef::ObjectValue(ObjectValue {
434                index: 0x200,
435                sub: 0,
436                data: "test".as_bytes()
437            })
438        );
439        assert_eq!(deser.next(), None);
440    }
441}