1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4use tracing::trace;
5
6use crate::{
7    changes::Changes, nearest_node, wchildren, Child, HyperbeeError, KeyValue, KeyValueData, Node,
8    NodePath, SharedNode, Tree, MAX_KEYS,
9};
10
11#[tracing::instrument(skip(changes, path))]
14pub async fn propagate_changes_up_tree(
15    changes: &mut Changes,
16    mut path: NodePath,
17    new_child: Child,
18) {
19    let mut cur_child = new_child;
20    loop {
21        let (node, index) = match path.pop() {
24            None => break,
25            Some(x) => x,
26        };
27        wchildren!(node)[index] = cur_child;
28        cur_child = changes.add_node(node);
29    }
30}
31
32impl Node {
33    #[tracing::instrument(skip(self))]
36    async fn split(&mut self) -> (SharedNode, KeyValue, SharedNode) {
37        let key_median_index = self.keys.len() >> 1;
38        let children_median_index = self.children.len().await >> 1;
39        trace!(
40            "
41    splitting at key index: {key_median_index}
42    splitting at child index: {children_median_index}
43"
44        );
45        let left = Node::new(
46            self.keys.splice(0..key_median_index, vec![]).collect(),
47            self.children
48                .children
49                .write()
50                .await
51                .splice(0..children_median_index, vec![])
52                .collect(),
53            self.blocks.clone(),
54        );
55        let mid_key = self.keys.remove(0);
56        let right = Node::new(
57            self.keys.drain(..).collect(),
58            self.children.children.write().await.drain(..).collect(),
59            self.blocks.clone(),
60        );
61        (
62            Arc::new(RwLock::new(left)),
63            mid_key,
64            Arc::new(RwLock::new(right)),
65        )
66    }
67}
68
69fn cas_always_true(_prev: Option<&KeyValueData>, _next: &KeyValueData) -> bool {
70    true
71}
72impl Tree {
73    #[tracing::instrument(level = "trace", skip(self, cas), ret)]
74    pub async fn put_compare_and_swap(
75        &self,
76        key: &[u8],
77        value: Option<&[u8]>,
78        cas: impl FnOnce(Option<&KeyValueData>, &KeyValueData) -> bool,
79    ) -> Result<(Option<u64>, Option<u64>), HyperbeeError> {
80        let maybe_root = self.get_root(true).await?;
82
83        let seq = self.version().await;
84        let mut changes: Changes = Changes::new(seq, key, value);
85        let mut cur_key = KeyValue::new(seq);
86        let mut children: Vec<Child> = vec![];
87
88        let new_key_data = KeyValueData {
89            seq,
90            key: key.to_vec(),
91            value: value.map(|v| v.to_vec()),
92        };
93
94        let matched = 'new_root: {
95            let root = match maybe_root {
97                None => {
98                    if !cas(None, &new_key_data) {
99                        return Ok((None, None));
100                    }
101                    break 'new_root None;
102                }
103                Some(node) => node,
104            };
105
106            let (matched, mut path) = nearest_node(root, key).await?;
107
108            let old_key_data = if matched.is_some() {
109                let (node, index) = &path[path.len() - 1];
110                Some(node.read().await.get_key_value(*index).await?)
111            } else {
112                None
113            };
114
115            if !cas(old_key_data.as_ref(), &new_key_data) {
116                return Ok((matched, None));
117            }
118
119            loop {
120                let (cur_node, cur_index) = match path.pop() {
121                    None => break 'new_root matched,
122                    Some(cur) => cur,
123                };
124
125                let room_for_more_keys = cur_node.read().await.keys.len() < MAX_KEYS;
128                if matched.is_some() || room_for_more_keys {
129                    trace!("room for more keys or key matched");
130                    let stop = match matched.is_some() {
131                        true => cur_index + 1,
132                        false => cur_index,
133                    };
134
135                    cur_node
136                        .write()
137                        .await
138                        .insert(cur_key, children, cur_index..stop)
139                        .await;
140
141                    let child = changes.add_node(cur_node.clone());
142                    if !path.is_empty() {
143                        trace!("inserted into some child");
144                        propagate_changes_up_tree(&mut changes, path, child).await;
145                        let _ = self.blocks.read().await.add_changes(changes).await?;
146                        return Ok((matched, Some(seq)));
147                    };
148
149                    let _ = self.blocks.read().await.add_changes(changes).await?;
150                    return Ok((matched, Some(seq)));
151                }
152
153                cur_node
155                    .write()
156                    .await
157                    .insert(cur_key, children, cur_index..cur_index)
158                    .await;
159
160                let (left, mid_key, right) = cur_node.write().await.split().await;
161
162                children = vec![
163                    changes.add_node(left.clone()),
164                    changes.add_node(right.clone()),
165                ];
166
167                cur_key = mid_key;
168            }
169        };
170
171        trace!(
172            "creating a new root with key = [{:#?}] and # children = [{}]",
173            &cur_key,
174            children.len(),
175        );
176        let new_root = Arc::new(RwLock::new(Node::new(
177            vec![cur_key.clone()],
178            children,
179            self.blocks.clone(),
180        )));
181
182        changes.add_node(new_root);
185        let _ = self.blocks.read().await.add_changes(changes).await?;
186
187        Ok((matched, Some(seq)))
188    }
189
190    #[tracing::instrument(level = "trace", skip(self), ret)]
198    pub async fn put(
199        &self,
200        key: &[u8],
201        value: Option<&[u8]>,
202    ) -> Result<(Option<u64>, u64), HyperbeeError> {
203        let (old, new) = self
204            .put_compare_and_swap(key, value, cas_always_true)
205            .await?;
206        return Ok((
207            old,
208            new.expect("with cas_always_true this should never be none"),
209        ));
210    }
211}
212
213#[cfg(test)]
214mod test {
215    use crate::{
216        test::{check_tree, i32_key_vec, Rand},
217        Hyperbee, Tree,
218    };
219
220    #[tokio::test]
221    async fn test_cas() -> Result<(), Box<dyn std::error::Error>> {
222        let hb = Hyperbee::from_ram().await?;
223        let k = b"foo";
224        let res = hb.put_compare_and_swap(k, None, |_old, _new| false).await?;
225        assert_eq!(res, (None, None));
226
227        let res = hb.put_compare_and_swap(k, None, |_old, _new| true).await?;
228        assert_eq!(res, (None, Some(1)));
229
230        let res = hb.put_compare_and_swap(k, None, |_old, _new| false).await?;
231        assert_eq!(res, (Some(1), None));
232
233        let res = hb.put_compare_and_swap(k, None, |_old, _new| true).await?;
234        assert_eq!(res, (Some(1), Some(2)));
235        Ok(())
236    }
237
238    #[tokio::test]
239    async fn test_old_seq() -> Result<(), Box<dyn std::error::Error>> {
240        let hb = Tree::from_ram().await?;
241        let (None, first_seq) = hb.put(b"a", None).await? else {
242            panic!("should be None")
243        };
244        assert_eq!(first_seq, hb.version().await - 1);
245
246        let (Some(old_seq), _second_seq) = hb.put(b"a", None).await? else {
247            panic!("should be Some")
248        };
249        assert_eq!(first_seq, old_seq);
250        Ok(())
251    }
252
253    #[tokio::test]
254    async fn basic_put() -> Result<(), Box<dyn std::error::Error>> {
255        let hb = Tree::from_ram().await?;
256        for i in 0..4 {
257            let key = vec![i];
258            let val = vec![i];
259            hb.put(&key, Some(&val)).await?;
260            for j in 0..(i + 1) {
261                let key = vec![j];
262                let val = Some(key.clone());
263                let res = hb.get(&key).await?.unwrap();
264                assert_eq!(res.1, val);
265            }
266        }
267        Ok(())
268    }
269
270    #[tokio::test]
271    async fn basic_put_with_replace() -> Result<(), Box<dyn std::error::Error>> {
272        let hb = Tree::from_ram().await?;
273        for i in 0..4 {
274            let key = vec![i];
275            let val = vec![i];
276            hb.put(&key.clone(), Some(&val)).await?;
278            let val = vec![i + 1_u8];
280            hb.put(&key, Some(&val)).await?;
281            for j in 0..(i + 1) {
282                let key = vec![j];
283                let val = Some(vec![j + 1]);
284                let res = hb.get(&key).await?.unwrap();
285                assert_eq!(res.1, val);
286            }
287        }
288        Ok(())
289    }
290
291    #[cfg(feature = "debug")]
292    #[tokio::test]
293    async fn print_put() -> Result<(), Box<dyn std::error::Error>> {
294        let hb = Tree::from_ram().await?;
295        for i in 0..3 {
296            let is = i.to_string();
297            let key = is.clone().as_bytes().to_vec();
298            let val: Option<&[u8]> = Some(&key);
299            hb.put(&key, val).await?;
300        }
301        let tree = hb.print().await?;
302        assert_eq!(
303            tree,
304            "0
3051
3062
307"
308        );
309        Ok(())
310    }
311
312    #[tokio::test]
313    async fn multi_put() -> Result<(), Box<dyn std::error::Error>> {
314        let mut hb = Tree::from_ram().await?;
315        for i in 0..100 {
316            let is = i.to_string();
317            let key = is.clone().as_bytes().to_vec();
318            let val = Some(key.clone());
319            hb.put(&key, val.as_deref()).await?;
320            hb = check_tree(hb).await?;
321
322            for j in 0..(i + 1) {
323                let js = j.to_string();
324                let key = js.clone().as_bytes().to_vec();
325                let val = Some(key.clone());
326                let res = hb.get(&key).await?.unwrap();
327                assert_eq!(res.1, val);
328            }
329        }
330        Ok(())
331    }
332
333    #[tokio::test]
334    async fn shuffled_put() -> Result<(), Box<dyn std::error::Error>> {
335        let rand = Rand::default();
336        let mut hb = Tree::from_ram().await?;
337
338        let keys: Vec<Vec<u8>> = (0..100).map(i32_key_vec).collect();
339        let keys = rand.shuffle(keys);
340        let mut used: Vec<Vec<u8>> = vec![];
341
342        for k in keys {
343            used.push(k.clone());
344
345            let val: Option<&[u8]> = Some(&k);
346            hb.put(&k, val).await?;
347
348            for kj in used.iter() {
349                let val = Some(kj.clone());
350                let res = hb.get(kj).await?.unwrap();
351                assert_eq!(res.1, val);
352            }
353
354            hb = check_tree(hb).await?;
355        }
356        Ok(())
357    }
358}