state_tree/
tree.rs

1pub const DELAY_ADDITIONAL_OFFSET: usize = 2;
2
3/// State Tree structure.
4
5//on attributes, see https://github.com/rkyv/rkyv/blob/main/rkyv/examples/json_like_schema.rs
6#[derive(Clone, PartialEq, Eq)]
7pub enum StateTree {
8    Delay {
9        readidx: u64,
10        writeidx: u64,
11        data: Vec<u64>, //assume we are using only mono f64 data
12    },
13    Mem {
14        data: Vec<u64>, //assume we are using only mono f64 data
15    },
16    Feed {
17        data: Vec<u64>, //assume we are using generic data, might be tuple of float
18    },
19    FnCall(Vec<StateTree>),
20}
21
22impl std::fmt::Display for StateTree {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            StateTree::Delay {
26                readidx,
27                writeidx,
28                data,
29            } => write!(
30                f,
31                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
32                readidx,
33                writeidx,
34                data.iter().take(10).collect::<Vec<&u64>>()
35            ),
36            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
37            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
38            StateTree::FnCall(children) => {
39                let children_str: Vec<String> = children.iter().map(|c| format!("{c}")).collect();
40                write!(f, "FnCall([{}])", children_str.join(", "))
41            }
42        }
43    }
44}
45impl std::fmt::Debug for StateTree {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            StateTree::Delay {
49                readidx,
50                writeidx,
51                data,
52            } => write!(
53                f,
54                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
55                readidx,
56                writeidx,
57                data.iter().take(10).collect::<Vec<&u64>>()
58            ),
59            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
60            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
61            StateTree::FnCall(children) => {
62                let children_str: Vec<String> = children.iter().map(|c| format!("{c:?}")).collect();
63                write!(f, "FnCall([{}])", children_str.join(", "))
64            }
65        }
66    }
67}
68impl<T: SizedType> From<StateTreeSkeleton<T>> for StateTree {
69    //create empty StateTree from StateTreeSkeleton
70    fn from(skeleton: StateTreeSkeleton<T>) -> Self {
71        match skeleton {
72            StateTreeSkeleton::Delay { len } => StateTree::Delay {
73                readidx: 0,
74                writeidx: 0,
75                data: vec![0; len as usize],
76            },
77            StateTreeSkeleton::Mem(t) => StateTree::Mem {
78                data: vec![0; t.word_size() as usize],
79            },
80            StateTreeSkeleton::Feed(t) => StateTree::Feed {
81                data: vec![0; t.word_size() as usize],
82            },
83            StateTreeSkeleton::FnCall(children_layout) => StateTree::FnCall(
84                children_layout
85                    .into_iter()
86                    .map(|child_layout| StateTree::from(*child_layout))
87                    .collect(),
88            ),
89        }
90    }
91}
92
93impl StateTree {
94    /// パスを指定して、イミュータブルなノードへの参照を取得する
95    pub fn get_node(&self, path: &[usize]) -> Option<&StateTree> {
96        let mut current = self;
97        for &index in path {
98            if let StateTree::FnCall(children) = current {
99                current = children.get(index)?;
100            } else {
101                // パスが深すぎるか、FnCallではないノードを指している
102                return None;
103            }
104        }
105        Some(current)
106    }
107
108    /// パスを指定して、ミュータブルなノードへの参照を取得する
109    pub fn get_node_mut(&mut self, path: &[usize]) -> Option<&mut StateTree> {
110        let mut current = self;
111        for &index in path {
112            if let StateTree::FnCall(children) = current {
113                current = children.get_mut(index)?;
114            } else {
115                // パスが深すぎるか、FnCallではないノードを指している
116                return None;
117            }
118        }
119        Some(current)
120    }
121
122    /// StateTree から StateTreeSkeleton への変換(データを除いた構造のみ)
123    pub fn to_skeleton(&self) -> StateTreeSkeleton<u64> {
124        match self {
125            StateTree::Delay { data, .. } => StateTreeSkeleton::Delay {
126                len: data.len() as u64,
127            },
128            StateTree::Mem { data } => StateTreeSkeleton::Mem(data.len() as u64),
129            StateTree::Feed { data } => StateTreeSkeleton::Feed(data.len() as u64),
130            StateTree::FnCall(children) => StateTreeSkeleton::FnCall(
131                children
132                    .iter()
133                    .map(|child| Box::new(child.to_skeleton()))
134                    .collect(),
135            ),
136        }
137    }
138}
139
140pub fn serialize_tree_untagged(tree: StateTree) -> Vec<u64> {
141    match tree {
142        StateTree::Delay {
143            readidx,
144            writeidx,
145            data,
146        } => itertools::concat([vec![readidx, writeidx], data]),
147        StateTree::Mem { data } | StateTree::Feed { data } => data,
148        StateTree::FnCall(state_trees) => {
149            itertools::concat(state_trees.into_iter().map(serialize_tree_untagged))
150        }
151    }
152}
153
154pub trait SizedType: std::fmt::Debug {
155    fn word_size(&self) -> u64;
156}
157
158impl SizedType for u64 {
159    fn word_size(&self) -> u64 {
160        *self
161    }
162}
163
164impl SizedType for usize {
165    fn word_size(&self) -> u64 {
166        *self as u64
167    }
168}
169
170/// This data represents just a memory layout on a flat array, do not own actual data.
171#[derive(Debug, Clone)]
172pub enum StateTreeSkeleton<T: SizedType> {
173    Delay {
174        len: u64, //assume we are using only mono f64 data
175    },
176    Mem(T),
177    Feed(T),
178    FnCall(Vec<Box<StateTreeSkeleton<T>>>),
179}
180impl<T: SizedType> StateTreeSkeleton<T> {
181    pub fn total_size(&self) -> u64 {
182        match self {
183            StateTreeSkeleton::Delay { len } => DELAY_ADDITIONAL_OFFSET as u64 + *len,
184            StateTreeSkeleton::Mem(t) | StateTreeSkeleton::Feed(t) => t.word_size(),
185            StateTreeSkeleton::FnCall(children_layout) => children_layout
186                .iter()
187                .map(|child_layout| child_layout.total_size())
188                .sum(),
189        }
190    }
191
192    /// Convert a path (position in the tree) to an address (offset) in a flat array.
193    /// 
194    /// # Arguments
195    /// * `path` - Path in the tree. Empty means root, [0] is the first child, [0, 1] is the second child of the first child.
196    /// 
197    /// # Returns
198    /// Returns the start address of the node pointed to by the path and the size of that node.
199    /// Returns None if the path is invalid.
200    pub fn path_to_address(&self, path: &[usize]) -> Option<(usize, usize)> {
201        if path.is_empty() {
202            // Root node case
203            return Some((0, self.total_size() as usize));
204        }
205
206        match self {
207            StateTreeSkeleton::FnCall(children) => {
208                let child_idx = path[0];
209                if child_idx >= children.len() {
210                    return None;
211                }
212
213                // Calculate offset to the child node
214                let offset: u64 = children
215                    .iter()
216                    .take(child_idx)
217                    .map(|child| child.total_size())
218                    .sum();
219
220                // Recursively resolve the path within the child node
221                let (child_offset, size) = children[child_idx].path_to_address(&path[1..])?;
222                Some((offset as usize + child_offset, size))
223            }
224            // Error if path remains on a leaf node
225            _ => None,
226        }
227    }
228}
229impl<T: SizedType> PartialEq for StateTreeSkeleton<T> {
230    fn eq(&self, other: &Self) -> bool {
231        match (self, other) {
232            (Self::Delay { len: l_len }, Self::Delay { len: r_len }) => l_len == r_len,
233            (Self::Mem(l0), Self::Mem(r0)) => l0.word_size() == r0.word_size(),
234            (Self::Feed(l0), Self::Feed(r0)) => l0.word_size() == r0.word_size(),
235            (Self::FnCall(l0), Self::FnCall(r0)) => l0 == r0,
236            _ => false,
237        }
238    }
239}
240
241fn deserialize_tree_untagged_rec<T: SizedType>(
242    data: &[u64],
243    data_layout: &StateTreeSkeleton<T>,
244) -> Option<(StateTree, usize)> {
245    match data_layout {
246        StateTreeSkeleton::Delay { len } => {
247            let readidx = data.first().copied()?;
248            let writeidx = data.get(1).copied()?;
249            let d = data
250                .get(DELAY_ADDITIONAL_OFFSET..DELAY_ADDITIONAL_OFFSET + (*len as usize))?
251                .to_vec();
252            Some((
253                StateTree::Delay {
254                    readidx,
255                    writeidx,
256                    data: d,
257                },
258                DELAY_ADDITIONAL_OFFSET + (*len as usize),
259            ))
260        }
261        StateTreeSkeleton::Mem(t) => {
262            let size = t.word_size() as usize;
263            let data = data.get(0..size)?.to_vec();
264            Some((StateTree::Mem { data }, size))
265        }
266        StateTreeSkeleton::Feed(t) => {
267            let size = t.word_size() as usize;
268            let data = data.get(0..size)?.to_vec();
269            Some((StateTree::Feed { data }, size))
270        }
271        StateTreeSkeleton::FnCall(children_layout) => {
272            let (children, used) =
273                children_layout
274                    .iter()
275                    .try_fold((vec![], 0), |(v, last_used), child_layout| {
276                        let (child, used) =
277                            deserialize_tree_untagged_rec(&data[last_used..], child_layout)?;
278
279                        Some(([v, vec![child]].concat(), last_used + used))
280                    })?;
281
282            Some((StateTree::FnCall(children), used))
283        }
284    }
285}
286
287pub fn deserialize_tree_untagged<T: SizedType>(
288    data: &[u64],
289    data_layout: &StateTreeSkeleton<T>,
290) -> Option<StateTree> {
291    log::trace!("Deserializing  with layout: {data_layout:?}");
292    if let Some((tree, used)) = deserialize_tree_untagged_rec(data, data_layout) {
293        if used == data.len() { Some(tree) } else { None }
294    } else {
295        None
296    }
297}