1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 hash as usize % self.num_partitions
112 }
113
114 fn touch(&mut self, partition_idx: usize) {
116 self.timestamp += 1;
117 self.access_times[partition_idx] = self.timestamp;
118 }
119
120 fn get_partition_mut(
126 &mut self,
127 partition_idx: usize,
128 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129 self.touch(partition_idx);
130
131 if self.partitions[partition_idx].is_some() {
133 return Ok(self.partitions[partition_idx]
135 .as_mut()
136 .expect("partition is Some: checked on previous line"));
137 }
138
139 if let Some(spill_file) = self.spill_files[partition_idx].take() {
141 let loaded = self.load_partition(&spill_file)?;
142 let bytes = spill_file.bytes_written();
144 let _ = spill_file.delete();
145 self.manager.unregister_spilled_bytes(bytes);
146 self.partitions[partition_idx] = Some(loaded);
147 } else {
148 self.partitions[partition_idx] = Some(HashMap::new());
150 }
151
152 Ok(self.partitions[partition_idx]
154 .as_mut()
155 .expect("partition is Some: set to Some in if/else branches above"))
156 }
157
158 fn load_partition(
160 &self,
161 spill_file: &SpillFile,
162 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163 let mut reader = spill_file.reader()?;
164 let mut adapter = SpillReaderAdapter(&mut reader);
165
166 let num_entries = read_u64(&mut adapter)? as usize;
167 let mut partition = HashMap::with_capacity(num_entries);
168
169 for _ in 0..num_entries {
170 let key_len = read_u64(&mut adapter)? as usize;
172 let mut key_buf = vec![0u8; key_len];
173 adapter.read_exact(&mut key_buf)?;
174 let serialized_key = SerializedKey(key_buf);
175
176 let num_key_columns = read_u64(&mut adapter)? as usize;
178
179 let value = (self.value_deserializer)(&mut adapter)?;
181
182 partition.insert(
183 serialized_key,
184 PartitionEntry {
185 num_key_columns,
186 value,
187 },
188 );
189 }
190
191 Ok(partition)
192 }
193
194 #[must_use]
196 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197 self.partitions[partition_idx].is_some()
198 }
199
200 #[must_use]
202 pub fn partition_size(&self, partition_idx: usize) -> usize {
203 self.partition_sizes[partition_idx]
204 }
205
206 #[must_use]
208 pub fn total_size(&self) -> usize {
209 self.partition_sizes.iter().sum()
210 }
211
212 #[must_use]
214 pub fn in_memory_count(&self) -> usize {
215 self.partitions.iter().filter(|p| p.is_some()).count()
216 }
217
218 #[must_use]
220 pub fn spilled_count(&self) -> usize {
221 self.spill_files.iter().filter(|f| f.is_some()).count()
222 }
223
224 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230 let Some(partition) = self.partitions[partition_idx].take() else {
232 return Ok(0); };
234
235 if partition.is_empty() {
236 return Ok(0);
237 }
238
239 let mut spill_file = self.manager.create_file("partition")?;
241
242 let mut buf = Vec::new();
244 write_u64(&mut buf, partition.len() as u64)?;
245
246 for (key, entry) in &partition {
247 write_u64(&mut buf, key.0.len() as u64)?;
249 buf.extend_from_slice(&key.0);
250
251 write_u64(&mut buf, entry.num_key_columns as u64)?;
253
254 (self.value_serializer)(&entry.value, &mut buf)?;
256 }
257
258 spill_file.write_all(&buf)?;
259 spill_file.finish_write()?;
260
261 let bytes_written = spill_file.bytes_written();
262 self.manager.register_spilled_bytes(bytes_written);
263 self.partition_sizes[partition_idx] = partition.len();
264 self.spill_files[partition_idx] = Some(spill_file);
265
266 Ok(bytes_written as usize)
267 }
268
269 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
277 let largest_idx = self
279 .partitions
280 .iter()
281 .enumerate()
282 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
283 .max_by_key(|(_, size)| *size)
284 .map(|(idx, _)| idx);
285
286 match largest_idx {
287 Some(idx) => self.spill_partition(idx),
288 None => Ok(0),
289 }
290 }
291
292 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
300 let lru_idx = self
302 .partitions
303 .iter()
304 .enumerate()
305 .filter(|(_, p)| p.is_some())
306 .min_by_key(|(idx, _)| self.access_times[*idx])
307 .map(|(idx, _)| idx);
308
309 match lru_idx {
310 Some(idx) => self.spill_partition(idx),
311 None => Ok(0),
312 }
313 }
314
315 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
321 let partition_idx = self.partition_for(&key);
322 let num_key_columns = key.len();
323 let serialized_key = SerializedKey::from_values(&key);
324 let partition = self.get_partition_mut(partition_idx)?;
325
326 let old = partition.insert(
327 serialized_key,
328 PartitionEntry {
329 num_key_columns,
330 value,
331 },
332 );
333
334 if old.is_none() {
335 self.partition_sizes[partition_idx] += 1;
336 }
337
338 Ok(old.map(|e| e.value))
339 }
340
341 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
347 let partition_idx = self.partition_for(key);
348 let serialized_key = SerializedKey::from_values(key);
349 let partition = self.get_partition_mut(partition_idx)?;
350 Ok(partition.get(&serialized_key).map(|e| &e.value))
351 }
352
353 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
359 where
360 F: FnOnce() -> V,
361 {
362 let partition_idx = self.partition_for(&key);
363 let num_key_columns = key.len();
364 let serialized_key = SerializedKey::from_values(&key);
365
366 let was_new;
367 {
368 let partition = self.get_partition_mut(partition_idx)?;
369 was_new = !partition.contains_key(&serialized_key);
370 if was_new {
371 partition.insert(
372 serialized_key.clone(),
373 PartitionEntry {
374 num_key_columns,
375 value: default(),
376 },
377 );
378 }
379 }
380 if was_new {
381 self.partition_sizes[partition_idx] += 1;
382 }
383
384 let partition = self.get_partition_mut(partition_idx)?;
385 Ok(&mut partition
387 .get_mut(&serialized_key)
388 .expect("key exists: just inserted or already present in partition")
389 .value)
390 }
391
392 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
400 let mut result = Vec::with_capacity(self.total_size());
401
402 for partition_idx in 0..self.num_partitions {
403 let partition = self.get_partition_mut(partition_idx)?;
404 for (serialized_key, entry) in partition.drain() {
405 let key = serialized_key.to_values(entry.num_key_columns)?;
406 result.push((key, entry.value));
407 }
408 self.partition_sizes[partition_idx] = 0;
409 }
410
411 for spill_file in &mut self.spill_files {
413 if let Some(file) = spill_file.take() {
414 let bytes = file.bytes_written();
415 let _ = file.delete();
416 self.manager.unregister_spilled_bytes(bytes);
417 }
418 }
419
420 Ok(result)
421 }
422
423 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
431 let mut result = Vec::with_capacity(self.total_size());
432
433 for partition_idx in 0..self.num_partitions {
434 let partition = self.get_partition_mut(partition_idx)?;
435 for (serialized_key, entry) in partition.iter() {
436 let key = serialized_key.to_values(entry.num_key_columns)?;
437 result.push((key, entry.value.clone()));
438 }
439 }
440
441 Ok(result)
442 }
443
444 pub fn cleanup(&mut self) {
446 for file in self.spill_files.iter_mut().flatten() {
447 let bytes = file.bytes_written();
448 self.manager.unregister_spilled_bytes(bytes);
449 }
450
451 self.spill_files.clear();
452 self.partitions.clear();
453 for _ in 0..self.num_partitions {
454 self.spill_files.push(None);
455 self.partitions.push(Some(HashMap::new()));
456 }
457 self.partition_sizes = vec![0; self.num_partitions];
458 }
459}
460
461impl<V> Drop for PartitionedState<V> {
462 fn drop(&mut self) {
463 for file in self.spill_files.iter().flatten() {
465 let bytes = file.bytes_written();
466 self.manager.unregister_spilled_bytes(bytes);
467 }
468 }
469}
470
471fn hash_key(key: &[Value]) -> u64 {
473 use std::hash::{Hash, Hasher};
474 let mut hasher = std::collections::hash_map::DefaultHasher::new();
475
476 for value in key {
477 match value {
478 Value::Null => 0u8.hash(&mut hasher),
479 Value::Bool(b) => {
480 1u8.hash(&mut hasher);
481 b.hash(&mut hasher);
482 }
483 Value::Int64(n) => {
484 2u8.hash(&mut hasher);
485 n.hash(&mut hasher);
486 }
487 Value::Float64(f) => {
488 3u8.hash(&mut hasher);
489 f.to_bits().hash(&mut hasher);
490 }
491 Value::String(s) => {
492 4u8.hash(&mut hasher);
493 s.hash(&mut hasher);
494 }
495 Value::Bytes(b) => {
496 5u8.hash(&mut hasher);
497 b.hash(&mut hasher);
498 }
499 Value::Timestamp(t) => {
500 6u8.hash(&mut hasher);
501 t.hash(&mut hasher);
502 }
503 Value::Date(d) => {
504 10u8.hash(&mut hasher);
505 d.hash(&mut hasher);
506 }
507 Value::Time(t) => {
508 11u8.hash(&mut hasher);
509 t.hash(&mut hasher);
510 }
511 Value::Duration(d) => {
512 12u8.hash(&mut hasher);
513 d.hash(&mut hasher);
514 }
515 Value::ZonedDatetime(zdt) => {
516 14u8.hash(&mut hasher);
517 zdt.hash(&mut hasher);
518 }
519 Value::List(l) => {
520 7u8.hash(&mut hasher);
521 l.len().hash(&mut hasher);
522 }
523 Value::Map(m) => {
524 8u8.hash(&mut hasher);
525 m.len().hash(&mut hasher);
526 }
527 Value::Vector(v) => {
528 9u8.hash(&mut hasher);
529 v.len().hash(&mut hasher);
530 for &f in v.iter().take(4) {
532 f.to_bits().hash(&mut hasher);
533 }
534 }
535 Value::Path { nodes, edges } => {
536 13u8.hash(&mut hasher);
537 nodes.len().hash(&mut hasher);
538 edges.len().hash(&mut hasher);
539 }
540 }
541 }
542
543 hasher.finish()
544}
545
546fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
548 let mut buf = [0u8; 8];
549 reader.read_exact(&mut buf)?;
550 Ok(u64::from_le_bytes(buf))
551}
552
553fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
555 writer.write_all(&value.to_le_bytes())
556}
557
558struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
560
561impl Read for SpillReaderAdapter<'_> {
562 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
563 self.0.read_exact(buf)?;
564 Ok(buf.len())
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use tempfile::TempDir;
572
573 fn create_manager() -> (TempDir, Arc<SpillManager>) {
575 let temp_dir = TempDir::new().unwrap();
576 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
577 (temp_dir, manager)
578 }
579
580 #[allow(clippy::trivially_copy_pass_by_ref)] fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
583 w.write_all(&value.to_le_bytes())
584 }
585
586 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
588 let mut buf = [0u8; 8];
589 r.read_exact(&mut buf)?;
590 Ok(i64::from_le_bytes(buf))
591 }
592
593 fn key(values: &[i64]) -> Vec<Value> {
594 values.iter().map(|&v| Value::Int64(v)).collect()
595 }
596
597 #[test]
598 fn test_partition_for() {
599 let (_temp_dir, manager) = create_manager();
600 let state: PartitionedState<i64> =
601 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
602
603 let k1 = key(&[1, 2, 3]);
605 let p1 = state.partition_for(&k1);
606 let p2 = state.partition_for(&k1);
607 assert_eq!(p1, p2);
608
609 assert!(p1 < 16);
611 }
612
613 #[test]
614 fn test_insert_and_get() {
615 let (_temp_dir, manager) = create_manager();
616 let mut state: PartitionedState<i64> =
617 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
618
619 state.insert(key(&[1]), 100).unwrap();
621 state.insert(key(&[2]), 200).unwrap();
622 state.insert(key(&[3]), 300).unwrap();
623
624 assert_eq!(state.total_size(), 3);
625
626 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
628 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
629 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
630 assert_eq!(state.get(&key(&[4])).unwrap(), None);
631 }
632
633 #[test]
634 fn test_get_or_insert_with() {
635 let (_temp_dir, manager) = create_manager();
636 let mut state: PartitionedState<i64> =
637 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
638
639 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
641 assert_eq!(*v1, 42);
642
643 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
645 assert_eq!(*v2, 42);
646
647 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
649 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
650 }
651
652 #[test]
653 fn test_spill_and_reload() {
654 let (_temp_dir, manager) = create_manager();
655 let mut state: PartitionedState<i64> =
656 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
657
658 for i in 0..20 {
660 state.insert(key(&[i]), i * 10).unwrap();
661 }
662
663 let initial_total = state.total_size();
664 assert!(initial_total > 0);
665
666 let bytes_spilled = state.spill_largest().unwrap();
668 assert!(bytes_spilled > 0);
669 assert!(state.spilled_count() > 0);
670
671 for i in 0..20 {
673 let expected = i * 10;
674 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
675 }
676 }
677
678 #[test]
679 fn test_spill_lru() {
680 let (_temp_dir, manager) = create_manager();
681 let mut state: PartitionedState<i64> =
682 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
683
684 state.insert(key(&[1]), 10).unwrap();
686 state.insert(key(&[2]), 20).unwrap();
687 state.insert(key(&[3]), 30).unwrap();
688
689 state.get(&key(&[3])).unwrap();
691
692 state.spill_lru().unwrap();
694
695 let partition_idx = state.partition_for(&key(&[3]));
697 assert!(state.is_in_memory(partition_idx));
698 }
699
700 #[test]
701 fn test_drain_all() {
702 let (_temp_dir, manager) = create_manager();
703 let mut state: PartitionedState<i64> =
704 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
705
706 for i in 0..10 {
708 state.insert(key(&[i]), i * 10).unwrap();
709 }
710
711 state.spill_largest().unwrap();
713 state.spill_largest().unwrap();
714
715 let entries = state.drain_all().unwrap();
717 assert_eq!(entries.len(), 10);
718
719 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
721 values.sort_unstable();
722 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
723
724 assert_eq!(state.total_size(), 0);
726 }
727
728 #[test]
729 fn test_iter_all() {
730 let (_temp_dir, manager) = create_manager();
731 let mut state: PartitionedState<i64> =
732 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
733
734 for i in 0..5 {
736 state.insert(key(&[i]), i * 10).unwrap();
737 }
738
739 let entries = state.iter_all().unwrap();
741 assert_eq!(entries.len(), 5);
742
743 assert_eq!(state.total_size(), 5);
745
746 let entries2 = state.iter_all().unwrap();
748 assert_eq!(entries2.len(), 5);
749 }
750
751 #[test]
752 fn test_many_groups() {
753 let (_temp_dir, manager) = create_manager();
754 let mut state: PartitionedState<i64> =
755 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
756
757 for i in 0..1000 {
759 state.insert(key(&[i]), i).unwrap();
760 }
761
762 assert_eq!(state.total_size(), 1000);
763
764 for _ in 0..8 {
766 state.spill_largest().unwrap();
767 }
768
769 assert!(state.spilled_count() >= 8);
770
771 for i in 0..1000 {
773 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
774 }
775 }
776
777 #[test]
778 fn test_cleanup() {
779 let (_temp_dir, manager) = create_manager();
780 let mut state: PartitionedState<i64> =
781 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
782
783 for i in 0..20 {
785 state.insert(key(&[i]), i).unwrap();
786 }
787 state.spill_largest().unwrap();
788 state.spill_largest().unwrap();
789
790 let spilled_before = manager.spilled_bytes();
791 assert!(spilled_before > 0);
792
793 state.cleanup();
795
796 assert_eq!(state.total_size(), 0);
797 assert_eq!(state.spilled_count(), 0);
798 }
799
800 #[test]
801 fn test_multi_column_key() {
802 let (_temp_dir, manager) = create_manager();
803 let mut state: PartitionedState<i64> =
804 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
805
806 state
808 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
809 .unwrap();
810 state
811 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
812 .unwrap();
813 state
814 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
815 .unwrap();
816
817 assert_eq!(state.total_size(), 3);
818
819 assert_eq!(
821 state
822 .get(&[Value::String("a".into()), Value::Int64(1)])
823 .unwrap(),
824 Some(&100)
825 );
826 assert_eq!(
827 state
828 .get(&[Value::String("a".into()), Value::Int64(2)])
829 .unwrap(),
830 Some(&200)
831 );
832 assert_eq!(
833 state
834 .get(&[Value::String("b".into()), Value::Int64(1)])
835 .unwrap(),
836 Some(&300)
837 );
838 }
839
840 #[test]
841 fn test_update_existing() {
842 let (_temp_dir, manager) = create_manager();
843 let mut state: PartitionedState<i64> =
844 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
845
846 state.insert(key(&[1]), 100).unwrap();
848 assert_eq!(state.total_size(), 1);
849
850 let old = state.insert(key(&[1]), 200).unwrap();
852 assert_eq!(old, Some(100));
853 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
857 }
858}