1use std::{
2 ops::Bound,
3 sync::{
4 Arc,
5 atomic::{AtomicU64, Ordering},
6 },
7};
8
9use async_broadcast::{Receiver, broadcast};
10use derive_more::Debug;
11use xxhash_rust::xxh3::Xxh3Default;
12
13use super::{ByteValue, KeyValue};
14use crate::{
15 error::{DbError, DbResult},
16 replication::{ChangeEvent, ReplicationEvent, ReplicationEventHandler, Subscribers},
17 types::{ArmourError, attribute::EntityAttribute, num_ops::g4bits},
18 utils::{CheckSumVec, CollectionInfo, GroupVal, HashPoints},
19};
20
21#[derive(Debug)]
22pub(crate) struct InnerFields {
23 pub(crate) info: CollectionInfo,
24 pub(crate) hashpoints: HashPoints,
26 pub(crate) seq: AtomicU64,
28 #[debug("{}", self.replication_handler.is_some())]
29 pub(crate) replication_handler: ReplicationEventHandler<ByteValue, ByteValue>,
30}
31
32impl InnerFields {
33 pub(crate) fn invalidate_hash(&self, group_id: u32) {
34 self.hashpoints.insert(
35 group_id,
36 GroupVal {
37 hash: 0,
38 changed: true,
39 },
40 );
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct RawTree {
46 pub name: String,
48 pub attributes: &'static EntityAttribute,
49 #[debug(skip)]
50 pub(crate) seq_tree: super::Tree,
51 #[debug(skip)]
52 pub(crate) tree: super::Tree,
53 #[debug(skip)]
54 pub(crate) removed: super::Tree,
55 pub(crate) inner: Arc<InnerFields>,
56 pub(crate) subscribers: Arc<Subscribers<ByteValue, ByteValue>>,
57}
58
59impl RawTree {
60 #[instrument(level = "info", skip(self), fields(name = self.name, ret))]
61 pub fn checksum(&self) -> u32 {
62 let start = std::time::Instant::now();
63 let checksum = self.tree.checksum().expect("cannot calculate checksum");
64 if checksum != 0 {
65 debug!("checksum: {checksum:#X}");
66 }
67
68 histogram!("armour_sled_rawtree_checksum_duration", "name" => self.name.clone())
69 .record(start.elapsed().as_secs_f64());
70 counter!("armour_sled_rawtree_checksum_total", "name" => self.name.clone()).increment(1);
71
72 checksum
73 }
74
75 pub fn is_empty(&self) -> bool {
76 self.tree.is_empty().expect("cannot check if tree is empty")
77 }
78
79 #[instrument(level = "trace", skip(self), fields(name = self.name), ret)]
80 pub fn count(&self) -> u64 {
81 self.inner.seq.load(Ordering::Relaxed)
82 }
83
84 #[instrument(level = "info", skip(self), fields(name = self.name))]
85 pub fn subscribe(&self) -> Receiver<ChangeEvent<ByteValue, ByteValue>> {
86 let (_, r) = self.subscribers.get_or_init(|| {
97 let (mut s, r) = broadcast(1024);
98 let r = r.deactivate();
99 s.set_await_active(false);
100 s.set_overflow(true);
101 (s, r)
102 });
103 r.activate_cloned()
104 }
105
106 #[instrument(level = "info", skip(self), fields(name = self.name, ret))]
107 pub fn hashpoints(&self) -> CheckSumVec {
108 let start = std::time::Instant::now();
109 let res = self
110 .inner
111 .hashpoints
112 .iter()
113 .map(|entry| {
114 let key = *entry.key();
115 (key, entry.value().hash)
116 })
117 .collect();
118
119 histogram!("armour_sled_rawtree_hashpoints_duration", "name" => self.name.clone())
120 .record(start.elapsed().as_secs_f64());
121 counter!("armour_sled_rawtree_hashpoints_total", "name" => self.name.clone()).increment(1);
122
123 res
124 }
125
126 #[instrument(level = "trace", skip(self), fields(name = self.name))]
127 pub fn scan_group(&self, group: u32) -> impl DoubleEndedIterator<Item = KeyValue> + use<> {
128 let start_bytes = group.to_be_bytes();
132
133 let group_bits = self.attributes.group_bits;
138 let bits_sub = u32::BITS - group_bits;
142 let bits_pow_of_two = 2u32.pow(bits_sub);
146 let end = group + bits_pow_of_two;
150 let end_bytes = end.to_be_bytes();
151
152 let start_bound = Bound::Included(start_bytes.as_ref());
153 let end_bound = Bound::Excluded(end_bytes.as_ref());
154
155 let res = self
156 .tree
157 .range::<&[u8], _>((start_bound, end_bound))
158 .filter_map(|item| match item {
159 Ok((key, value)) => {
160 Some((key, value))
162 }
163 Err(e) => {
164 error!(%e);
165 None
166 }
167 });
168
169 counter!("armour_sled_rawtree_scan_group_total", "name" => self.name.clone()).increment(1);
170 res
171 }
172
173 #[instrument(level = "info", skip(self), fields(name = self.name, ret))]
174 pub fn recalcucate_hash(&self) -> u64 {
175 let start = std::time::Instant::now();
176 let hash = self
177 .inner
178 .hashpoints
179 .iter()
180 .map(|item| {
181 let group_val = item.value();
182
183 if group_val.changed {
184 let group = *item.key();
185 drop(item);
186 let mut hash_val = Xxh3Default::new();
187
188 for (key, value) in self.scan_group(group) {
189 hash_val.update(&key);
190 hash_val.update(&value);
191 }
192 let hash = hash_val.digest();
193 self.inner.hashpoints.insert(
195 group,
196 GroupVal {
197 hash,
198 changed: false,
199 },
200 );
201 hash
202 } else {
203 item.value().hash
204 }
205 })
206 .fold(Xxh3Default::new(), |mut hasher, item| {
207 hasher.update(&item.to_le_bytes());
208 hasher
209 });
210
211 histogram!("armour_sled_rawtree_recalcucate_hash_duration", "name" => self.name.clone())
212 .record(start.elapsed().as_secs_f64());
213 counter!("armour_sled_rawtree_recalcucate_hash_total", "name" => self.name.clone())
214 .increment(1);
215
216 hash.digest()
217 }
218
219 #[instrument(skip(self), fields(name = self.name))]
220 pub fn tree_info(&self) -> CollectionInfo {
221 let seq = self.inner.seq.load(Ordering::SeqCst);
222
223 if seq != 0 {
224 debug!(seq, "close seq");
225 }
226
227 let typ_hash = self.attributes.ty.h();
228 let version = self.attributes.version;
229
230 CollectionInfo {
231 seq,
232 typ_hash,
233 version,
234 }
235 }
236
237 #[instrument(level = "trace", skip(self), fields(name = self.name, ret))]
239 pub fn get(&self, id: &[u8]) -> DbResult<Option<Vec<u8>>> {
240 let start = std::time::Instant::now();
241 let res = self
242 .tree
243 .get(id)
244 .map(|item| {
245 item.map(|item| {
246 let len = id.len() + item.len();
247 let mut v = vec![0; len];
248 v[..id.len()].copy_from_slice(id);
249 v[id.len()..].copy_from_slice(&item);
250 v
251 })
252 })
253 .map_err(DbError::from);
254
255 histogram!("armour_sled_rawtree_get_duration", "name" => self.name.clone())
256 .record(start.elapsed().as_secs_f64());
257 counter!("armour_sled_rawtree_get_total", "name" => self.name.clone()).increment(1);
258
259 res
260 }
261
262 #[instrument(level = "trace", skip(self), fields(name = self.name))]
263 pub fn range(
264 &self,
265 range: (Bound<&[u8]>, Bound<&[u8]>),
266 ) -> impl DoubleEndedIterator<Item = KeyValue> {
267 let iter = self.tree.range::<&[u8], _>(range);
268 let iter = iter.filter_map(|item| match item {
269 Ok((key, value)) => Some((key, value)),
270 Err(e) => {
271 error!(%e);
272 None
273 }
274 });
275
276 counter!("armour_sled_rawtree_range_total", "name" => self.name.clone()).increment(1);
277
278 iter
279 }
280
281 #[instrument(level = "trace", skip(self), fields(name = self.name))]
282 fn invalidate_hash(&self, id: &[u8]) {
283 let start = std::time::Instant::now();
284 let mut bytes = [0; 4];
285 bytes.copy_from_slice(&id[..4]);
286 let group = u32::from_be_bytes(bytes);
287 let group = g4bits(group, self.attributes.group_bits);
288 self.inner.invalidate_hash(group);
289
290 histogram!("armour_sled_rawtree_invalidate_hash_duration", "name" => self.name.clone())
291 .record(start.elapsed().as_secs_f64());
292 counter!("armour_sled_rawtree_invalidate_hash_total", "name" => self.name.clone())
293 .increment(1);
294 }
295
296 #[instrument(level = "debug", skip(self), fields(name = self.name), err, ret)]
297 pub fn next_id(&self) -> DbResult<u64> {
298 let start = std::time::Instant::now();
299 let next_id_key = format!("next_id-{}", self.name);
300
301 let key_ref = next_id_key.as_bytes();
302 let mut current = self.seq_tree.get(key_ref)?;
303
304 let res = (|| loop {
305 let tmp = current.as_ref().map(AsRef::as_ref);
306
307 let mut id = 1u64;
308
309 let next = match tmp {
310 Some(bytes) => {
311 let bytes = bytes
312 .try_into()
313 .map_err(|err| DbError::Armour(ArmourError::from(err)))?;
314 let old = u64::from_le_bytes(bytes);
315 id = old + 1;
316 id.to_le_bytes().to_vec()
317 }
318 None => id.to_le_bytes().to_vec(),
319 };
320
321 match self.seq_tree.compare_and_swap(key_ref, tmp, Some(next))? {
322 Ok(_) => return Ok(id),
323 Err(sled::CompareAndSwapError { current: cur, .. }) => {
324 current = cur;
325 }
326 }
327 })();
328
329 histogram!("armour_sled_rawtree_next_id_duration", "name" => self.name.clone())
330 .record(start.elapsed().as_secs_f64());
331 counter!("armour_sled_rawtree_next_id_total", "name" => self.name.clone()).increment(1);
332
333 res
334 }
335
336 #[instrument(level = "debug", skip(self, event), fields(name = self.name, event = event.variant()), err)]
337 pub fn apply_event(&self, event: ChangeEvent<ByteValue, ByteValue>) -> DbResult<()> {
338 let start = std::time::Instant::now();
339 match &event {
340 ChangeEvent::Upsert((key, val)) => {
341 let old = self.tree.insert(key, val.clone())?;
342 if old.is_none() {
343 self.inner.seq.fetch_add(1, Ordering::Relaxed);
344 }
345 if let Some(f) = self.inner.replication_handler.as_ref() {
346 f(ReplicationEvent::Upsert {
347 key,
348 val,
349 old_val: old.as_ref(),
350 });
351 }
352 }
353 ChangeEvent::Delete(key) => {
354 match self.tree.remove(key)? {
355 Some(val) => {
356 self.inner.seq.fetch_sub(1, Ordering::AcqRel);
357 if let Some(f) = self.inner.replication_handler.as_ref() {
358 f(ReplicationEvent::Delete { key, val: &val });
359 }
360 }
361 _ => {
362 error!(?key, "delete not found");
363 }
364 };
365 }
366 ChangeEvent::ChangeId(old, new) => {
367 let val = self.tree.remove(old)?.ok_or(DbError::NotFound)?;
368 if self.tree.insert(new, val.clone())?.is_some() {
369 error!(?new, "change_id already exists");
370 };
371 if let Some(f) = self.inner.replication_handler.as_ref() {
372 f(ReplicationEvent::IdChange {
373 old_key: old,
374 new_key: new,
375 val: &val,
376 });
377 }
378 self.invalidate_hash(new);
379 }
380 }
381 self.invalidate_hash(event.key());
382 if let Some((sender, _)) = self.subscribers.get() {
383 match sender.try_broadcast(event) {
384 Ok(None) => {}
385 Ok(Some(item)) => {
386 error!(?item, "broadcast is full");
387 }
388 Err(e) => {
389 error!(%e);
390 }
391 }
392 }
393
394 histogram!("armour_sled_rawtree_apply_event_duration", "name" => self.name.clone())
395 .record(start.elapsed().as_secs_f64());
396 counter!("armour_sled_rawtree_apply_event_total", "name" => self.name.clone()).increment(1);
397
398 Ok(())
399 }
400
401 #[instrument(level = "info", skip(self, iter), fields(name = self.name), err)]
403 pub fn apply_batch(
404 &self,
405 iter: impl Iterator<Item = (ByteValue, Option<ByteValue>)>,
406 ) -> DbResult<()> {
407 let start = std::time::Instant::now();
408 let mut batch = sled::Batch::default();
409
410 for (key, val) in iter {
412 self.invalidate_hash(&key);
413
414 match &val {
415 Some(val) => {
416 batch.insert(key.clone(), val.clone());
417
418 if let Some(f) = self.inner.replication_handler.as_ref() {
419 let old = self.tree.get(&key)?;
420 f(ReplicationEvent::Upsert {
421 key: &key,
422 val,
423 old_val: old.as_ref(),
424 });
425 }
426 }
427 None => {
428 batch.remove(key.clone());
429 if let Some(f) = self.inner.replication_handler.as_ref() {
430 match self.tree.get(&key)? {
431 Some(item) => {
432 f(ReplicationEvent::Delete {
433 key: &key,
434 val: &item,
435 });
436 }
437 _ => {
438 error!(?key, "delete not found");
439 }
440 }
441 }
442 }
443 }
444 if let Some((sender, _)) = self.subscribers.get() {
445 let msg = ChangeEvent::from_kv(key, val);
446 match sender.try_broadcast(msg) {
447 Ok(None) => {}
448 Ok(Some(item)) => {
449 error!(?item, "broadcast is full");
450 }
451 Err(e) => {
452 error!(%e);
453 }
454 }
455 }
456 }
457
458 self.tree.apply_batch(batch)?;
459
460 histogram!("armour_sled_rawtree_apply_batch_duration", "name" => self.name.clone())
461 .record(start.elapsed().as_secs_f64());
462 counter!("armour_sled_rawtree_apply_batch_total", "name" => self.name.clone()).increment(1);
463
464 Ok(())
465 }
466
467 #[doc(hidden)]
468 pub fn inner(&self) -> &super::Tree {
469 &self.tree
470 }
471}