1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 hash as usize % self.num_partitions
112 }
113
114 fn touch(&mut self, partition_idx: usize) {
116 self.timestamp += 1;
117 self.access_times[partition_idx] = self.timestamp;
118 }
119
120 fn get_partition_mut(
126 &mut self,
127 partition_idx: usize,
128 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129 self.touch(partition_idx);
130
131 if self.partitions[partition_idx].is_some() {
133 return Ok(self.partitions[partition_idx]
135 .as_mut()
136 .expect("partition is Some: checked on previous line"));
137 }
138
139 if let Some(spill_file) = self.spill_files[partition_idx].take() {
141 let loaded = self.load_partition(&spill_file)?;
142 let bytes = spill_file.bytes_written();
144 let _ = spill_file.delete();
145 self.manager.unregister_spilled_bytes(bytes);
146 self.partitions[partition_idx] = Some(loaded);
147 } else {
148 self.partitions[partition_idx] = Some(HashMap::new());
150 }
151
152 Ok(self.partitions[partition_idx]
154 .as_mut()
155 .expect("partition is Some: set to Some in if/else branches above"))
156 }
157
158 fn load_partition(
160 &self,
161 spill_file: &SpillFile,
162 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163 let mut reader = spill_file.reader()?;
164 let mut adapter = SpillReaderAdapter(&mut reader);
165
166 let num_entries = read_u64(&mut adapter)? as usize;
167 let mut partition = HashMap::with_capacity(num_entries);
168
169 for _ in 0..num_entries {
170 let key_len = read_u64(&mut adapter)? as usize;
172 let mut key_buf = vec![0u8; key_len];
173 adapter.read_exact(&mut key_buf)?;
174 let serialized_key = SerializedKey(key_buf);
175
176 let num_key_columns = read_u64(&mut adapter)? as usize;
178
179 let value = (self.value_deserializer)(&mut adapter)?;
181
182 partition.insert(
183 serialized_key,
184 PartitionEntry {
185 num_key_columns,
186 value,
187 },
188 );
189 }
190
191 Ok(partition)
192 }
193
194 #[must_use]
196 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197 self.partitions[partition_idx].is_some()
198 }
199
200 #[must_use]
202 pub fn partition_size(&self, partition_idx: usize) -> usize {
203 self.partition_sizes[partition_idx]
204 }
205
206 #[must_use]
208 pub fn total_size(&self) -> usize {
209 self.partition_sizes.iter().sum()
210 }
211
212 #[must_use]
214 pub fn in_memory_count(&self) -> usize {
215 self.partitions.iter().filter(|p| p.is_some()).count()
216 }
217
218 #[must_use]
220 pub fn spilled_count(&self) -> usize {
221 self.spill_files.iter().filter(|f| f.is_some()).count()
222 }
223
224 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230 let Some(partition) = self.partitions[partition_idx].take() else {
232 return Ok(0); };
234
235 if partition.is_empty() {
236 return Ok(0);
237 }
238
239 let mut spill_file = self.manager.create_file("partition")?;
241
242 let mut buf = Vec::new();
244 write_u64(&mut buf, partition.len() as u64)?;
245
246 for (key, entry) in &partition {
247 write_u64(&mut buf, key.0.len() as u64)?;
249 buf.extend_from_slice(&key.0);
250
251 write_u64(&mut buf, entry.num_key_columns as u64)?;
253
254 (self.value_serializer)(&entry.value, &mut buf)?;
256 }
257
258 spill_file.write_all(&buf)?;
259 spill_file.finish_write()?;
260
261 let bytes_written = spill_file.bytes_written();
262 self.manager.register_spilled_bytes(bytes_written);
263 self.partition_sizes[partition_idx] = partition.len();
264 self.spill_files[partition_idx] = Some(spill_file);
265
266 Ok(bytes_written as usize)
267 }
268
269 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
277 let largest_idx = self
279 .partitions
280 .iter()
281 .enumerate()
282 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
283 .max_by_key(|(_, size)| *size)
284 .map(|(idx, _)| idx);
285
286 match largest_idx {
287 Some(idx) => self.spill_partition(idx),
288 None => Ok(0),
289 }
290 }
291
292 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
300 let lru_idx = self
302 .partitions
303 .iter()
304 .enumerate()
305 .filter(|(_, p)| p.is_some())
306 .min_by_key(|(idx, _)| self.access_times[*idx])
307 .map(|(idx, _)| idx);
308
309 match lru_idx {
310 Some(idx) => self.spill_partition(idx),
311 None => Ok(0),
312 }
313 }
314
315 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
321 let partition_idx = self.partition_for(&key);
322 let num_key_columns = key.len();
323 let serialized_key = SerializedKey::from_values(&key);
324 let partition = self.get_partition_mut(partition_idx)?;
325
326 let old = partition.insert(
327 serialized_key,
328 PartitionEntry {
329 num_key_columns,
330 value,
331 },
332 );
333
334 if old.is_none() {
335 self.partition_sizes[partition_idx] += 1;
336 }
337
338 Ok(old.map(|e| e.value))
339 }
340
341 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
347 let partition_idx = self.partition_for(key);
348 let serialized_key = SerializedKey::from_values(key);
349 let partition = self.get_partition_mut(partition_idx)?;
350 Ok(partition.get(&serialized_key).map(|e| &e.value))
351 }
352
353 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
359 where
360 F: FnOnce() -> V,
361 {
362 let partition_idx = self.partition_for(&key);
363 let num_key_columns = key.len();
364 let serialized_key = SerializedKey::from_values(&key);
365
366 let was_new;
367 {
368 let partition = self.get_partition_mut(partition_idx)?;
369 was_new = !partition.contains_key(&serialized_key);
370 if was_new {
371 partition.insert(
372 serialized_key.clone(),
373 PartitionEntry {
374 num_key_columns,
375 value: default(),
376 },
377 );
378 }
379 }
380 if was_new {
381 self.partition_sizes[partition_idx] += 1;
382 }
383
384 let partition = self.get_partition_mut(partition_idx)?;
385 Ok(&mut partition
387 .get_mut(&serialized_key)
388 .expect("key exists: just inserted or already present in partition")
389 .value)
390 }
391
392 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
400 let mut result = Vec::with_capacity(self.total_size());
401
402 for partition_idx in 0..self.num_partitions {
403 let partition = self.get_partition_mut(partition_idx)?;
404 for (serialized_key, entry) in partition.drain() {
405 let key = serialized_key.to_values(entry.num_key_columns)?;
406 result.push((key, entry.value));
407 }
408 self.partition_sizes[partition_idx] = 0;
409 }
410
411 for spill_file in &mut self.spill_files {
413 if let Some(file) = spill_file.take() {
414 let bytes = file.bytes_written();
415 let _ = file.delete();
416 self.manager.unregister_spilled_bytes(bytes);
417 }
418 }
419
420 Ok(result)
421 }
422
423 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
431 let mut result = Vec::with_capacity(self.total_size());
432
433 for partition_idx in 0..self.num_partitions {
434 let partition = self.get_partition_mut(partition_idx)?;
435 for (serialized_key, entry) in partition.iter() {
436 let key = serialized_key.to_values(entry.num_key_columns)?;
437 result.push((key, entry.value.clone()));
438 }
439 }
440
441 Ok(result)
442 }
443
444 pub fn cleanup(&mut self) {
446 for file in self.spill_files.iter_mut().flatten() {
447 let bytes = file.bytes_written();
448 self.manager.unregister_spilled_bytes(bytes);
449 }
450
451 self.spill_files.clear();
452 self.partitions.clear();
453 for _ in 0..self.num_partitions {
454 self.spill_files.push(None);
455 self.partitions.push(Some(HashMap::new()));
456 }
457 self.partition_sizes = vec![0; self.num_partitions];
458 }
459}
460
461impl<V> Drop for PartitionedState<V> {
462 fn drop(&mut self) {
463 for file in self.spill_files.iter().flatten() {
465 let bytes = file.bytes_written();
466 self.manager.unregister_spilled_bytes(bytes);
467 }
468 }
469}
470
471fn hash_key(key: &[Value]) -> u64 {
473 use std::hash::{Hash, Hasher};
474 let mut hasher = std::collections::hash_map::DefaultHasher::new();
475
476 for value in key {
477 match value {
478 Value::Null => 0u8.hash(&mut hasher),
479 Value::Bool(b) => {
480 1u8.hash(&mut hasher);
481 b.hash(&mut hasher);
482 }
483 Value::Int64(n) => {
484 2u8.hash(&mut hasher);
485 n.hash(&mut hasher);
486 }
487 Value::Float64(f) => {
488 3u8.hash(&mut hasher);
489 f.to_bits().hash(&mut hasher);
490 }
491 Value::String(s) => {
492 4u8.hash(&mut hasher);
493 s.hash(&mut hasher);
494 }
495 Value::Bytes(b) => {
496 5u8.hash(&mut hasher);
497 b.hash(&mut hasher);
498 }
499 Value::Timestamp(t) => {
500 6u8.hash(&mut hasher);
501 t.hash(&mut hasher);
502 }
503 Value::List(l) => {
504 7u8.hash(&mut hasher);
505 l.len().hash(&mut hasher);
506 }
507 Value::Map(m) => {
508 8u8.hash(&mut hasher);
509 m.len().hash(&mut hasher);
510 }
511 Value::Vector(v) => {
512 9u8.hash(&mut hasher);
513 v.len().hash(&mut hasher);
514 for &f in v.iter().take(4) {
516 f.to_bits().hash(&mut hasher);
517 }
518 }
519 }
520 }
521
522 hasher.finish()
523}
524
525fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
527 let mut buf = [0u8; 8];
528 reader.read_exact(&mut buf)?;
529 Ok(u64::from_le_bytes(buf))
530}
531
532fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
534 writer.write_all(&value.to_le_bytes())
535}
536
537struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
539
540impl<'a> Read for SpillReaderAdapter<'a> {
541 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
542 self.0.read_exact(buf)?;
543 Ok(buf.len())
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use tempfile::TempDir;
551
552 fn create_manager() -> (TempDir, Arc<SpillManager>) {
554 let temp_dir = TempDir::new().unwrap();
555 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
556 (temp_dir, manager)
557 }
558
559 #[allow(clippy::trivially_copy_pass_by_ref)] fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
562 w.write_all(&value.to_le_bytes())
563 }
564
565 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
567 let mut buf = [0u8; 8];
568 r.read_exact(&mut buf)?;
569 Ok(i64::from_le_bytes(buf))
570 }
571
572 fn key(values: &[i64]) -> Vec<Value> {
573 values.iter().map(|&v| Value::Int64(v)).collect()
574 }
575
576 #[test]
577 fn test_partition_for() {
578 let (_temp_dir, manager) = create_manager();
579 let state: PartitionedState<i64> =
580 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
581
582 let k1 = key(&[1, 2, 3]);
584 let p1 = state.partition_for(&k1);
585 let p2 = state.partition_for(&k1);
586 assert_eq!(p1, p2);
587
588 assert!(p1 < 16);
590 }
591
592 #[test]
593 fn test_insert_and_get() {
594 let (_temp_dir, manager) = create_manager();
595 let mut state: PartitionedState<i64> =
596 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
597
598 state.insert(key(&[1]), 100).unwrap();
600 state.insert(key(&[2]), 200).unwrap();
601 state.insert(key(&[3]), 300).unwrap();
602
603 assert_eq!(state.total_size(), 3);
604
605 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
607 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
608 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
609 assert_eq!(state.get(&key(&[4])).unwrap(), None);
610 }
611
612 #[test]
613 fn test_get_or_insert_with() {
614 let (_temp_dir, manager) = create_manager();
615 let mut state: PartitionedState<i64> =
616 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
617
618 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
620 assert_eq!(*v1, 42);
621
622 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
624 assert_eq!(*v2, 42);
625
626 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
628 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
629 }
630
631 #[test]
632 fn test_spill_and_reload() {
633 let (_temp_dir, manager) = create_manager();
634 let mut state: PartitionedState<i64> =
635 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
636
637 for i in 0..20 {
639 state.insert(key(&[i]), i * 10).unwrap();
640 }
641
642 let initial_total = state.total_size();
643 assert!(initial_total > 0);
644
645 let bytes_spilled = state.spill_largest().unwrap();
647 assert!(bytes_spilled > 0);
648 assert!(state.spilled_count() > 0);
649
650 for i in 0..20 {
652 let expected = i * 10;
653 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
654 }
655 }
656
657 #[test]
658 fn test_spill_lru() {
659 let (_temp_dir, manager) = create_manager();
660 let mut state: PartitionedState<i64> =
661 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
662
663 state.insert(key(&[1]), 10).unwrap();
665 state.insert(key(&[2]), 20).unwrap();
666 state.insert(key(&[3]), 30).unwrap();
667
668 state.get(&key(&[3])).unwrap();
670
671 state.spill_lru().unwrap();
673
674 let partition_idx = state.partition_for(&key(&[3]));
676 assert!(state.is_in_memory(partition_idx));
677 }
678
679 #[test]
680 fn test_drain_all() {
681 let (_temp_dir, manager) = create_manager();
682 let mut state: PartitionedState<i64> =
683 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
684
685 for i in 0..10 {
687 state.insert(key(&[i]), i * 10).unwrap();
688 }
689
690 state.spill_largest().unwrap();
692 state.spill_largest().unwrap();
693
694 let entries = state.drain_all().unwrap();
696 assert_eq!(entries.len(), 10);
697
698 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
700 values.sort_unstable();
701 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
702
703 assert_eq!(state.total_size(), 0);
705 }
706
707 #[test]
708 fn test_iter_all() {
709 let (_temp_dir, manager) = create_manager();
710 let mut state: PartitionedState<i64> =
711 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
712
713 for i in 0..5 {
715 state.insert(key(&[i]), i * 10).unwrap();
716 }
717
718 let entries = state.iter_all().unwrap();
720 assert_eq!(entries.len(), 5);
721
722 assert_eq!(state.total_size(), 5);
724
725 let entries2 = state.iter_all().unwrap();
727 assert_eq!(entries2.len(), 5);
728 }
729
730 #[test]
731 fn test_many_groups() {
732 let (_temp_dir, manager) = create_manager();
733 let mut state: PartitionedState<i64> =
734 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
735
736 for i in 0..1000 {
738 state.insert(key(&[i]), i).unwrap();
739 }
740
741 assert_eq!(state.total_size(), 1000);
742
743 for _ in 0..8 {
745 state.spill_largest().unwrap();
746 }
747
748 assert!(state.spilled_count() >= 8);
749
750 for i in 0..1000 {
752 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
753 }
754 }
755
756 #[test]
757 fn test_cleanup() {
758 let (_temp_dir, manager) = create_manager();
759 let mut state: PartitionedState<i64> =
760 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
761
762 for i in 0..20 {
764 state.insert(key(&[i]), i).unwrap();
765 }
766 state.spill_largest().unwrap();
767 state.spill_largest().unwrap();
768
769 let spilled_before = manager.spilled_bytes();
770 assert!(spilled_before > 0);
771
772 state.cleanup();
774
775 assert_eq!(state.total_size(), 0);
776 assert_eq!(state.spilled_count(), 0);
777 }
778
779 #[test]
780 fn test_multi_column_key() {
781 let (_temp_dir, manager) = create_manager();
782 let mut state: PartitionedState<i64> =
783 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
784
785 state
787 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
788 .unwrap();
789 state
790 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
791 .unwrap();
792 state
793 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
794 .unwrap();
795
796 assert_eq!(state.total_size(), 3);
797
798 assert_eq!(
800 state
801 .get(&[Value::String("a".into()), Value::Int64(1)])
802 .unwrap(),
803 Some(&100)
804 );
805 assert_eq!(
806 state
807 .get(&[Value::String("a".into()), Value::Int64(2)])
808 .unwrap(),
809 Some(&200)
810 );
811 assert_eq!(
812 state
813 .get(&[Value::String("b".into()), Value::Int64(1)])
814 .unwrap(),
815 Some(&300)
816 );
817 }
818
819 #[test]
820 fn test_update_existing() {
821 let (_temp_dir, manager) = create_manager();
822 let mut state: PartitionedState<i64> =
823 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
824
825 state.insert(key(&[1]), 100).unwrap();
827 assert_eq!(state.total_size(), 1);
828
829 let old = state.insert(key(&[1]), 200).unwrap();
831 assert_eq!(old, Some(100));
832 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
836 }
837}