1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 #[allow(clippy::cast_possible_truncation)]
113 {
114 hash as usize % self.num_partitions
115 }
116 }
117
118 fn touch(&mut self, partition_idx: usize) {
120 self.timestamp += 1;
121 self.access_times[partition_idx] = self.timestamp;
122 }
123
124 fn get_partition_mut(
130 &mut self,
131 partition_idx: usize,
132 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
133 self.touch(partition_idx);
134
135 if self.partitions[partition_idx].is_some() {
137 return Ok(self.partitions[partition_idx]
139 .as_mut()
140 .expect("partition is Some: checked on previous line"));
141 }
142
143 if let Some(spill_file) = self.spill_files[partition_idx].take() {
145 let loaded = self.load_partition(&spill_file)?;
146 let bytes = spill_file.bytes_written();
148 let _ = spill_file.delete();
149 self.manager.unregister_spilled_bytes(bytes);
150 self.partitions[partition_idx] = Some(loaded);
151 } else {
152 self.partitions[partition_idx] = Some(HashMap::new());
154 }
155
156 Ok(self.partitions[partition_idx]
158 .as_mut()
159 .expect("partition is Some: set to Some in if/else branches above"))
160 }
161
162 fn load_partition(
164 &self,
165 spill_file: &SpillFile,
166 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
167 let mut reader = spill_file.reader()?;
168 let mut adapter = SpillReaderAdapter(&mut reader);
169
170 #[allow(clippy::cast_possible_truncation)]
172 let num_entries = read_u64(&mut adapter)? as usize;
173 let mut partition = HashMap::with_capacity(num_entries);
174
175 for _ in 0..num_entries {
176 #[allow(clippy::cast_possible_truncation)]
179 let key_len = read_u64(&mut adapter)? as usize;
180 let mut key_buf = vec![0u8; key_len];
181 adapter.read_exact(&mut key_buf)?;
182 let serialized_key = SerializedKey(key_buf);
183
184 #[allow(clippy::cast_possible_truncation)]
187 let num_key_columns = read_u64(&mut adapter)? as usize;
188
189 let value = (self.value_deserializer)(&mut adapter)?;
191
192 partition.insert(
193 serialized_key,
194 PartitionEntry {
195 num_key_columns,
196 value,
197 },
198 );
199 }
200
201 Ok(partition)
202 }
203
204 #[must_use]
206 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
207 self.partitions[partition_idx].is_some()
208 }
209
210 #[must_use]
212 pub fn partition_size(&self, partition_idx: usize) -> usize {
213 self.partition_sizes[partition_idx]
214 }
215
216 #[must_use]
218 pub fn total_size(&self) -> usize {
219 self.partition_sizes.iter().sum()
220 }
221
222 #[must_use]
224 pub fn in_memory_count(&self) -> usize {
225 self.partitions.iter().filter(|p| p.is_some()).count()
226 }
227
228 #[must_use]
230 pub fn spilled_count(&self) -> usize {
231 self.spill_files.iter().filter(|f| f.is_some()).count()
232 }
233
234 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
240 let Some(partition) = self.partitions[partition_idx].take() else {
242 return Ok(0); };
244
245 if partition.is_empty() {
246 return Ok(0);
247 }
248
249 let mut spill_file = self.manager.create_file("partition")?;
251
252 let mut buf = Vec::new();
254 write_u64(&mut buf, partition.len() as u64)?;
255
256 for (key, entry) in &partition {
257 write_u64(&mut buf, key.0.len() as u64)?;
259 buf.extend_from_slice(&key.0);
260
261 write_u64(&mut buf, entry.num_key_columns as u64)?;
263
264 (self.value_serializer)(&entry.value, &mut buf)?;
266 }
267
268 spill_file.write_all(&buf)?;
269 spill_file.finish_write()?;
270
271 let bytes_written = spill_file.bytes_written();
272 self.manager.register_spilled_bytes(bytes_written);
273 self.partition_sizes[partition_idx] = partition.len();
274 self.spill_files[partition_idx] = Some(spill_file);
275
276 #[allow(clippy::cast_possible_truncation)]
278 Ok(bytes_written as usize)
279 }
280
281 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
289 let largest_idx = self
291 .partitions
292 .iter()
293 .enumerate()
294 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
295 .max_by_key(|(_, size)| *size)
296 .map(|(idx, _)| idx);
297
298 match largest_idx {
299 Some(idx) => self.spill_partition(idx),
300 None => Ok(0),
301 }
302 }
303
304 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
312 let lru_idx = self
314 .partitions
315 .iter()
316 .enumerate()
317 .filter(|(_, p)| p.is_some())
318 .min_by_key(|(idx, _)| self.access_times[*idx])
319 .map(|(idx, _)| idx);
320
321 match lru_idx {
322 Some(idx) => self.spill_partition(idx),
323 None => Ok(0),
324 }
325 }
326
327 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
333 let partition_idx = self.partition_for(&key);
334 let num_key_columns = key.len();
335 let serialized_key = SerializedKey::from_values(&key);
336 let partition = self.get_partition_mut(partition_idx)?;
337
338 let old = partition.insert(
339 serialized_key,
340 PartitionEntry {
341 num_key_columns,
342 value,
343 },
344 );
345
346 if old.is_none() {
347 self.partition_sizes[partition_idx] += 1;
348 }
349
350 Ok(old.map(|e| e.value))
351 }
352
353 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
359 let partition_idx = self.partition_for(key);
360 let serialized_key = SerializedKey::from_values(key);
361 let partition = self.get_partition_mut(partition_idx)?;
362 Ok(partition.get(&serialized_key).map(|e| &e.value))
363 }
364
365 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
375 where
376 F: FnOnce() -> V,
377 {
378 let partition_idx = self.partition_for(&key);
379 let num_key_columns = key.len();
380 let serialized_key = SerializedKey::from_values(&key);
381
382 let was_new;
383 {
384 let partition = self.get_partition_mut(partition_idx)?;
385 was_new = !partition.contains_key(&serialized_key);
386 if was_new {
387 partition.insert(
388 serialized_key.clone(),
389 PartitionEntry {
390 num_key_columns,
391 value: default(),
392 },
393 );
394 }
395 }
396 if was_new {
397 self.partition_sizes[partition_idx] += 1;
398 }
399
400 let partition = self.get_partition_mut(partition_idx)?;
401 Ok(&mut partition
403 .get_mut(&serialized_key)
404 .expect("key exists: just inserted or already present in partition")
405 .value)
406 }
407
408 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
416 let mut result = Vec::with_capacity(self.total_size());
417
418 for partition_idx in 0..self.num_partitions {
419 let partition = self.get_partition_mut(partition_idx)?;
420 for (serialized_key, entry) in partition.drain() {
421 let key = serialized_key.to_values(entry.num_key_columns)?;
422 result.push((key, entry.value));
423 }
424 self.partition_sizes[partition_idx] = 0;
425 }
426
427 for spill_file in &mut self.spill_files {
429 if let Some(file) = spill_file.take() {
430 let bytes = file.bytes_written();
431 let _ = file.delete();
432 self.manager.unregister_spilled_bytes(bytes);
433 }
434 }
435
436 Ok(result)
437 }
438
439 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
447 let mut result = Vec::with_capacity(self.total_size());
448
449 for partition_idx in 0..self.num_partitions {
450 let partition = self.get_partition_mut(partition_idx)?;
451 for (serialized_key, entry) in partition.iter() {
452 let key = serialized_key.to_values(entry.num_key_columns)?;
453 result.push((key, entry.value.clone()));
454 }
455 }
456
457 Ok(result)
458 }
459
460 pub fn cleanup(&mut self) {
462 for file in self.spill_files.iter_mut().flatten() {
463 let bytes = file.bytes_written();
464 self.manager.unregister_spilled_bytes(bytes);
465 }
466
467 self.spill_files.clear();
468 self.partitions.clear();
469 for _ in 0..self.num_partitions {
470 self.spill_files.push(None);
471 self.partitions.push(Some(HashMap::new()));
472 }
473 self.partition_sizes = vec![0; self.num_partitions];
474 }
475}
476
477impl<V> Drop for PartitionedState<V> {
478 fn drop(&mut self) {
479 for file in self.spill_files.iter().flatten() {
481 let bytes = file.bytes_written();
482 self.manager.unregister_spilled_bytes(bytes);
483 }
484 }
485}
486
487fn hash_key(key: &[Value]) -> u64 {
489 use std::hash::{Hash, Hasher};
490 let mut hasher = std::collections::hash_map::DefaultHasher::new();
491
492 for value in key {
493 match value {
494 Value::Null => 0u8.hash(&mut hasher),
495 Value::Bool(b) => {
496 1u8.hash(&mut hasher);
497 b.hash(&mut hasher);
498 }
499 Value::Int64(n) => {
500 2u8.hash(&mut hasher);
501 n.hash(&mut hasher);
502 }
503 Value::Float64(f) => {
504 3u8.hash(&mut hasher);
505 f.to_bits().hash(&mut hasher);
506 }
507 Value::String(s) => {
508 4u8.hash(&mut hasher);
509 s.hash(&mut hasher);
510 }
511 Value::Bytes(b) => {
512 5u8.hash(&mut hasher);
513 b.hash(&mut hasher);
514 }
515 Value::Timestamp(t) => {
516 6u8.hash(&mut hasher);
517 t.hash(&mut hasher);
518 }
519 Value::Date(d) => {
520 10u8.hash(&mut hasher);
521 d.hash(&mut hasher);
522 }
523 Value::Time(t) => {
524 11u8.hash(&mut hasher);
525 t.hash(&mut hasher);
526 }
527 Value::Duration(d) => {
528 12u8.hash(&mut hasher);
529 d.hash(&mut hasher);
530 }
531 Value::ZonedDatetime(zdt) => {
532 14u8.hash(&mut hasher);
533 zdt.hash(&mut hasher);
534 }
535 Value::List(l) => {
536 7u8.hash(&mut hasher);
537 l.len().hash(&mut hasher);
538 }
539 Value::Map(m) => {
540 8u8.hash(&mut hasher);
541 m.len().hash(&mut hasher);
542 }
543 Value::Vector(v) => {
544 9u8.hash(&mut hasher);
545 v.len().hash(&mut hasher);
546 for &f in v.iter().take(4) {
548 f.to_bits().hash(&mut hasher);
549 }
550 }
551 Value::Path { nodes, edges } => {
552 13u8.hash(&mut hasher);
553 nodes.len().hash(&mut hasher);
554 edges.len().hash(&mut hasher);
555 }
556 Value::GCounter(counts) => {
557 15u8.hash(&mut hasher);
558 counts.len().hash(&mut hasher);
559 }
560 Value::OnCounter { pos, neg } => {
561 16u8.hash(&mut hasher);
562 pos.len().hash(&mut hasher);
563 neg.len().hash(&mut hasher);
564 }
565 _ => {
566 255u8.hash(&mut hasher);
567 }
568 }
569 }
570
571 hasher.finish()
572}
573
574fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
576 let mut buf = [0u8; 8];
577 reader.read_exact(&mut buf)?;
578 Ok(u64::from_le_bytes(buf))
579}
580
581fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
583 writer.write_all(&value.to_le_bytes())
584}
585
586struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
588
589impl Read for SpillReaderAdapter<'_> {
590 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
591 self.0.read_exact(buf)?;
592 Ok(buf.len())
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use tempfile::TempDir;
600
601 fn create_manager() -> (TempDir, Arc<SpillManager>) {
603 let temp_dir = TempDir::new().unwrap();
604 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
605 (temp_dir, manager)
606 }
607
608 #[allow(clippy::trivially_copy_pass_by_ref)] fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
611 w.write_all(&value.to_le_bytes())
612 }
613
614 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
616 let mut buf = [0u8; 8];
617 r.read_exact(&mut buf)?;
618 Ok(i64::from_le_bytes(buf))
619 }
620
621 fn key(values: &[i64]) -> Vec<Value> {
622 values.iter().map(|&v| Value::Int64(v)).collect()
623 }
624
625 #[test]
626 fn test_partition_for() {
627 let (_temp_dir, manager) = create_manager();
628 let state: PartitionedState<i64> =
629 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
630
631 let k1 = key(&[1, 2, 3]);
633 let p1 = state.partition_for(&k1);
634 let p2 = state.partition_for(&k1);
635 assert_eq!(p1, p2);
636
637 assert!(p1 < 16);
639 }
640
641 #[test]
642 fn test_insert_and_get() {
643 let (_temp_dir, manager) = create_manager();
644 let mut state: PartitionedState<i64> =
645 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
646
647 state.insert(key(&[1]), 100).unwrap();
649 state.insert(key(&[2]), 200).unwrap();
650 state.insert(key(&[3]), 300).unwrap();
651
652 assert_eq!(state.total_size(), 3);
653
654 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
656 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
657 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
658 assert_eq!(state.get(&key(&[4])).unwrap(), None);
659 }
660
661 #[test]
662 fn test_get_or_insert_with() {
663 let (_temp_dir, manager) = create_manager();
664 let mut state: PartitionedState<i64> =
665 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
666
667 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
669 assert_eq!(*v1, 42);
670
671 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
673 assert_eq!(*v2, 42);
674
675 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
677 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
678 }
679
680 #[test]
681 fn test_spill_and_reload() {
682 let (_temp_dir, manager) = create_manager();
683 let mut state: PartitionedState<i64> =
684 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
685
686 for i in 0..20 {
688 state.insert(key(&[i]), i * 10).unwrap();
689 }
690
691 let initial_total = state.total_size();
692 assert!(initial_total > 0);
693
694 let bytes_spilled = state.spill_largest().unwrap();
696 assert!(bytes_spilled > 0);
697 assert!(state.spilled_count() > 0);
698
699 for i in 0..20 {
701 let expected = i * 10;
702 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
703 }
704 }
705
706 #[test]
707 fn test_spill_lru() {
708 let (_temp_dir, manager) = create_manager();
709 let mut state: PartitionedState<i64> =
710 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
711
712 state.insert(key(&[1]), 10).unwrap();
714 state.insert(key(&[2]), 20).unwrap();
715 state.insert(key(&[3]), 30).unwrap();
716
717 state.get(&key(&[3])).unwrap();
719
720 state.spill_lru().unwrap();
722
723 let partition_idx = state.partition_for(&key(&[3]));
725 assert!(state.is_in_memory(partition_idx));
726 }
727
728 #[test]
729 fn test_drain_all() {
730 let (_temp_dir, manager) = create_manager();
731 let mut state: PartitionedState<i64> =
732 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
733
734 for i in 0..10 {
736 state.insert(key(&[i]), i * 10).unwrap();
737 }
738
739 state.spill_largest().unwrap();
741 state.spill_largest().unwrap();
742
743 let entries = state.drain_all().unwrap();
745 assert_eq!(entries.len(), 10);
746
747 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
749 values.sort_unstable();
750 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
751
752 assert_eq!(state.total_size(), 0);
754 }
755
756 #[test]
757 fn test_iter_all() {
758 let (_temp_dir, manager) = create_manager();
759 let mut state: PartitionedState<i64> =
760 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
761
762 for i in 0..5 {
764 state.insert(key(&[i]), i * 10).unwrap();
765 }
766
767 let entries = state.iter_all().unwrap();
769 assert_eq!(entries.len(), 5);
770
771 assert_eq!(state.total_size(), 5);
773
774 let entries2 = state.iter_all().unwrap();
776 assert_eq!(entries2.len(), 5);
777 }
778
779 #[test]
780 fn test_many_groups() {
781 let (_temp_dir, manager) = create_manager();
782 let mut state: PartitionedState<i64> =
783 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
784
785 for i in 0..1000 {
787 state.insert(key(&[i]), i).unwrap();
788 }
789
790 assert_eq!(state.total_size(), 1000);
791
792 for _ in 0..8 {
794 state.spill_largest().unwrap();
795 }
796
797 assert!(state.spilled_count() >= 8);
798
799 for i in 0..1000 {
801 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
802 }
803 }
804
805 #[test]
806 fn test_cleanup() {
807 let (_temp_dir, manager) = create_manager();
808 let mut state: PartitionedState<i64> =
809 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
810
811 for i in 0..20 {
813 state.insert(key(&[i]), i).unwrap();
814 }
815 state.spill_largest().unwrap();
816 state.spill_largest().unwrap();
817
818 let spilled_before = manager.spilled_bytes();
819 assert!(spilled_before > 0);
820
821 state.cleanup();
823
824 assert_eq!(state.total_size(), 0);
825 assert_eq!(state.spilled_count(), 0);
826 }
827
828 #[test]
829 fn test_multi_column_key() {
830 let (_temp_dir, manager) = create_manager();
831 let mut state: PartitionedState<i64> =
832 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
833
834 state
836 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
837 .unwrap();
838 state
839 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
840 .unwrap();
841 state
842 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
843 .unwrap();
844
845 assert_eq!(state.total_size(), 3);
846
847 assert_eq!(
849 state
850 .get(&[Value::String("a".into()), Value::Int64(1)])
851 .unwrap(),
852 Some(&100)
853 );
854 assert_eq!(
855 state
856 .get(&[Value::String("a".into()), Value::Int64(2)])
857 .unwrap(),
858 Some(&200)
859 );
860 assert_eq!(
861 state
862 .get(&[Value::String("b".into()), Value::Int64(1)])
863 .unwrap(),
864 Some(&300)
865 );
866 }
867
868 #[test]
869 fn test_update_existing() {
870 let (_temp_dir, manager) = create_manager();
871 let mut state: PartitionedState<i64> =
872 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
873
874 state.insert(key(&[1]), 100).unwrap();
876 assert_eq!(state.total_size(), 1);
877
878 let old = state.insert(key(&[1]), 200).unwrap();
880 assert_eq!(old, Some(100));
881 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
885 }
886}