1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 hash as usize % self.num_partitions
112 }
113
114 fn touch(&mut self, partition_idx: usize) {
116 self.timestamp += 1;
117 self.access_times[partition_idx] = self.timestamp;
118 }
119
120 fn get_partition_mut(
126 &mut self,
127 partition_idx: usize,
128 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129 self.touch(partition_idx);
130
131 if self.partitions[partition_idx].is_some() {
133 return Ok(self.partitions[partition_idx]
135 .as_mut()
136 .expect("partition is Some: checked on previous line"));
137 }
138
139 if let Some(spill_file) = self.spill_files[partition_idx].take() {
141 let loaded = self.load_partition(&spill_file)?;
142 let bytes = spill_file.bytes_written();
144 let _ = spill_file.delete();
145 self.manager.unregister_spilled_bytes(bytes);
146 self.partitions[partition_idx] = Some(loaded);
147 } else {
148 self.partitions[partition_idx] = Some(HashMap::new());
150 }
151
152 Ok(self.partitions[partition_idx]
154 .as_mut()
155 .expect("partition is Some: set to Some in if/else branches above"))
156 }
157
158 fn load_partition(
160 &self,
161 spill_file: &SpillFile,
162 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163 let mut reader = spill_file.reader()?;
164 let mut adapter = SpillReaderAdapter(&mut reader);
165
166 let num_entries = read_u64(&mut adapter)? as usize;
167 let mut partition = HashMap::with_capacity(num_entries);
168
169 for _ in 0..num_entries {
170 let key_len = read_u64(&mut adapter)? as usize;
172 let mut key_buf = vec![0u8; key_len];
173 adapter.read_exact(&mut key_buf)?;
174 let serialized_key = SerializedKey(key_buf);
175
176 let num_key_columns = read_u64(&mut adapter)? as usize;
178
179 let value = (self.value_deserializer)(&mut adapter)?;
181
182 partition.insert(
183 serialized_key,
184 PartitionEntry {
185 num_key_columns,
186 value,
187 },
188 );
189 }
190
191 Ok(partition)
192 }
193
194 #[must_use]
196 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197 self.partitions[partition_idx].is_some()
198 }
199
200 #[must_use]
202 pub fn partition_size(&self, partition_idx: usize) -> usize {
203 self.partition_sizes[partition_idx]
204 }
205
206 #[must_use]
208 pub fn total_size(&self) -> usize {
209 self.partition_sizes.iter().sum()
210 }
211
212 #[must_use]
214 pub fn in_memory_count(&self) -> usize {
215 self.partitions.iter().filter(|p| p.is_some()).count()
216 }
217
218 #[must_use]
220 pub fn spilled_count(&self) -> usize {
221 self.spill_files.iter().filter(|f| f.is_some()).count()
222 }
223
224 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230 let partition = match self.partitions[partition_idx].take() {
232 Some(p) => p,
233 None => return Ok(0), };
235
236 if partition.is_empty() {
237 return Ok(0);
238 }
239
240 let mut spill_file = self.manager.create_file("partition")?;
242
243 let mut buf = Vec::new();
245 write_u64(&mut buf, partition.len() as u64)?;
246
247 for (key, entry) in &partition {
248 write_u64(&mut buf, key.0.len() as u64)?;
250 buf.extend_from_slice(&key.0);
251
252 write_u64(&mut buf, entry.num_key_columns as u64)?;
254
255 (self.value_serializer)(&entry.value, &mut buf)?;
257 }
258
259 spill_file.write_all(&buf)?;
260 spill_file.finish_write()?;
261
262 let bytes_written = spill_file.bytes_written();
263 self.manager.register_spilled_bytes(bytes_written);
264 self.partition_sizes[partition_idx] = partition.len();
265 self.spill_files[partition_idx] = Some(spill_file);
266
267 Ok(bytes_written as usize)
268 }
269
270 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
278 let largest_idx = self
280 .partitions
281 .iter()
282 .enumerate()
283 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
284 .max_by_key(|(_, size)| *size)
285 .map(|(idx, _)| idx);
286
287 match largest_idx {
288 Some(idx) => self.spill_partition(idx),
289 None => Ok(0),
290 }
291 }
292
293 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
301 let lru_idx = self
303 .partitions
304 .iter()
305 .enumerate()
306 .filter(|(_, p)| p.is_some())
307 .min_by_key(|(idx, _)| self.access_times[*idx])
308 .map(|(idx, _)| idx);
309
310 match lru_idx {
311 Some(idx) => self.spill_partition(idx),
312 None => Ok(0),
313 }
314 }
315
316 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
322 let partition_idx = self.partition_for(&key);
323 let num_key_columns = key.len();
324 let serialized_key = SerializedKey::from_values(&key);
325 let partition = self.get_partition_mut(partition_idx)?;
326
327 let old = partition.insert(
328 serialized_key,
329 PartitionEntry {
330 num_key_columns,
331 value,
332 },
333 );
334
335 if old.is_none() {
336 self.partition_sizes[partition_idx] += 1;
337 }
338
339 Ok(old.map(|e| e.value))
340 }
341
342 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
348 let partition_idx = self.partition_for(key);
349 let serialized_key = SerializedKey::from_values(key);
350 let partition = self.get_partition_mut(partition_idx)?;
351 Ok(partition.get(&serialized_key).map(|e| &e.value))
352 }
353
354 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
360 where
361 F: FnOnce() -> V,
362 {
363 let partition_idx = self.partition_for(&key);
364 let num_key_columns = key.len();
365 let serialized_key = SerializedKey::from_values(&key);
366
367 let was_new;
368 {
369 let partition = self.get_partition_mut(partition_idx)?;
370 was_new = !partition.contains_key(&serialized_key);
371 if was_new {
372 partition.insert(
373 serialized_key.clone(),
374 PartitionEntry {
375 num_key_columns,
376 value: default(),
377 },
378 );
379 }
380 }
381 if was_new {
382 self.partition_sizes[partition_idx] += 1;
383 }
384
385 let partition = self.get_partition_mut(partition_idx)?;
386 Ok(&mut partition
388 .get_mut(&serialized_key)
389 .expect("key exists: just inserted or already present in partition")
390 .value)
391 }
392
393 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
401 let mut result = Vec::with_capacity(self.total_size());
402
403 for partition_idx in 0..self.num_partitions {
404 let partition = self.get_partition_mut(partition_idx)?;
405 for (serialized_key, entry) in partition.drain() {
406 let key = serialized_key.to_values(entry.num_key_columns)?;
407 result.push((key, entry.value));
408 }
409 self.partition_sizes[partition_idx] = 0;
410 }
411
412 for spill_file in self.spill_files.iter_mut() {
414 if let Some(file) = spill_file.take() {
415 let bytes = file.bytes_written();
416 let _ = file.delete();
417 self.manager.unregister_spilled_bytes(bytes);
418 }
419 }
420
421 Ok(result)
422 }
423
424 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
432 let mut result = Vec::with_capacity(self.total_size());
433
434 for partition_idx in 0..self.num_partitions {
435 let partition = self.get_partition_mut(partition_idx)?;
436 for (serialized_key, entry) in partition.iter() {
437 let key = serialized_key.to_values(entry.num_key_columns)?;
438 result.push((key, entry.value.clone()));
439 }
440 }
441
442 Ok(result)
443 }
444
445 pub fn cleanup(&mut self) {
447 for file in self.spill_files.iter_mut().flatten() {
448 let bytes = file.bytes_written();
449 self.manager.unregister_spilled_bytes(bytes);
450 }
451
452 self.spill_files.clear();
453 self.partitions.clear();
454 for _ in 0..self.num_partitions {
455 self.spill_files.push(None);
456 self.partitions.push(Some(HashMap::new()));
457 }
458 self.partition_sizes = vec![0; self.num_partitions];
459 }
460}
461
462impl<V> Drop for PartitionedState<V> {
463 fn drop(&mut self) {
464 for file in self.spill_files.iter().flatten() {
466 let bytes = file.bytes_written();
467 self.manager.unregister_spilled_bytes(bytes);
468 }
469 }
470}
471
472fn hash_key(key: &[Value]) -> u64 {
474 use std::hash::{Hash, Hasher};
475 let mut hasher = std::collections::hash_map::DefaultHasher::new();
476
477 for value in key {
478 match value {
479 Value::Null => 0u8.hash(&mut hasher),
480 Value::Bool(b) => {
481 1u8.hash(&mut hasher);
482 b.hash(&mut hasher);
483 }
484 Value::Int64(n) => {
485 2u8.hash(&mut hasher);
486 n.hash(&mut hasher);
487 }
488 Value::Float64(f) => {
489 3u8.hash(&mut hasher);
490 f.to_bits().hash(&mut hasher);
491 }
492 Value::String(s) => {
493 4u8.hash(&mut hasher);
494 s.hash(&mut hasher);
495 }
496 Value::Bytes(b) => {
497 5u8.hash(&mut hasher);
498 b.hash(&mut hasher);
499 }
500 Value::Timestamp(t) => {
501 6u8.hash(&mut hasher);
502 t.hash(&mut hasher);
503 }
504 Value::List(l) => {
505 7u8.hash(&mut hasher);
506 l.len().hash(&mut hasher);
507 }
508 Value::Map(m) => {
509 8u8.hash(&mut hasher);
510 m.len().hash(&mut hasher);
511 }
512 }
513 }
514
515 hasher.finish()
516}
517
518fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
520 let mut buf = [0u8; 8];
521 reader.read_exact(&mut buf)?;
522 Ok(u64::from_le_bytes(buf))
523}
524
525fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
527 writer.write_all(&value.to_le_bytes())
528}
529
530struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
532
533impl<'a> Read for SpillReaderAdapter<'a> {
534 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
535 self.0.read_exact(buf)?;
536 Ok(buf.len())
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use tempfile::TempDir;
544
545 fn create_manager() -> (TempDir, Arc<SpillManager>) {
547 let temp_dir = TempDir::new().unwrap();
548 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
549 (temp_dir, manager)
550 }
551
552 fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
554 w.write_all(&value.to_le_bytes())
555 }
556
557 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
559 let mut buf = [0u8; 8];
560 r.read_exact(&mut buf)?;
561 Ok(i64::from_le_bytes(buf))
562 }
563
564 fn key(values: &[i64]) -> Vec<Value> {
565 values.iter().map(|&v| Value::Int64(v)).collect()
566 }
567
568 #[test]
569 fn test_partition_for() {
570 let (_temp_dir, manager) = create_manager();
571 let state: PartitionedState<i64> =
572 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
573
574 let k1 = key(&[1, 2, 3]);
576 let p1 = state.partition_for(&k1);
577 let p2 = state.partition_for(&k1);
578 assert_eq!(p1, p2);
579
580 assert!(p1 < 16);
582 }
583
584 #[test]
585 fn test_insert_and_get() {
586 let (_temp_dir, manager) = create_manager();
587 let mut state: PartitionedState<i64> =
588 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
589
590 state.insert(key(&[1]), 100).unwrap();
592 state.insert(key(&[2]), 200).unwrap();
593 state.insert(key(&[3]), 300).unwrap();
594
595 assert_eq!(state.total_size(), 3);
596
597 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
599 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
600 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
601 assert_eq!(state.get(&key(&[4])).unwrap(), None);
602 }
603
604 #[test]
605 fn test_get_or_insert_with() {
606 let (_temp_dir, manager) = create_manager();
607 let mut state: PartitionedState<i64> =
608 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
609
610 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
612 assert_eq!(*v1, 42);
613
614 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
616 assert_eq!(*v2, 42);
617
618 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
620 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
621 }
622
623 #[test]
624 fn test_spill_and_reload() {
625 let (_temp_dir, manager) = create_manager();
626 let mut state: PartitionedState<i64> =
627 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
628
629 for i in 0..20 {
631 state.insert(key(&[i]), i * 10).unwrap();
632 }
633
634 let initial_total = state.total_size();
635 assert!(initial_total > 0);
636
637 let bytes_spilled = state.spill_largest().unwrap();
639 assert!(bytes_spilled > 0);
640 assert!(state.spilled_count() > 0);
641
642 for i in 0..20 {
644 let expected = i * 10;
645 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
646 }
647 }
648
649 #[test]
650 fn test_spill_lru() {
651 let (_temp_dir, manager) = create_manager();
652 let mut state: PartitionedState<i64> =
653 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
654
655 state.insert(key(&[1]), 10).unwrap();
657 state.insert(key(&[2]), 20).unwrap();
658 state.insert(key(&[3]), 30).unwrap();
659
660 state.get(&key(&[3])).unwrap();
662
663 state.spill_lru().unwrap();
665
666 let partition_idx = state.partition_for(&key(&[3]));
668 assert!(state.is_in_memory(partition_idx));
669 }
670
671 #[test]
672 fn test_drain_all() {
673 let (_temp_dir, manager) = create_manager();
674 let mut state: PartitionedState<i64> =
675 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
676
677 for i in 0..10 {
679 state.insert(key(&[i]), i * 10).unwrap();
680 }
681
682 state.spill_largest().unwrap();
684 state.spill_largest().unwrap();
685
686 let entries = state.drain_all().unwrap();
688 assert_eq!(entries.len(), 10);
689
690 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
692 values.sort();
693 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
694
695 assert_eq!(state.total_size(), 0);
697 }
698
699 #[test]
700 fn test_iter_all() {
701 let (_temp_dir, manager) = create_manager();
702 let mut state: PartitionedState<i64> =
703 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
704
705 for i in 0..5 {
707 state.insert(key(&[i]), i * 10).unwrap();
708 }
709
710 let entries = state.iter_all().unwrap();
712 assert_eq!(entries.len(), 5);
713
714 assert_eq!(state.total_size(), 5);
716
717 let entries2 = state.iter_all().unwrap();
719 assert_eq!(entries2.len(), 5);
720 }
721
722 #[test]
723 fn test_many_groups() {
724 let (_temp_dir, manager) = create_manager();
725 let mut state: PartitionedState<i64> =
726 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
727
728 for i in 0..1000 {
730 state.insert(key(&[i]), i).unwrap();
731 }
732
733 assert_eq!(state.total_size(), 1000);
734
735 for _ in 0..8 {
737 state.spill_largest().unwrap();
738 }
739
740 assert!(state.spilled_count() >= 8);
741
742 for i in 0..1000 {
744 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
745 }
746 }
747
748 #[test]
749 fn test_cleanup() {
750 let (_temp_dir, manager) = create_manager();
751 let mut state: PartitionedState<i64> =
752 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
753
754 for i in 0..20 {
756 state.insert(key(&[i]), i).unwrap();
757 }
758 state.spill_largest().unwrap();
759 state.spill_largest().unwrap();
760
761 let spilled_before = manager.spilled_bytes();
762 assert!(spilled_before > 0);
763
764 state.cleanup();
766
767 assert_eq!(state.total_size(), 0);
768 assert_eq!(state.spilled_count(), 0);
769 }
770
771 #[test]
772 fn test_multi_column_key() {
773 let (_temp_dir, manager) = create_manager();
774 let mut state: PartitionedState<i64> =
775 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
776
777 state
779 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
780 .unwrap();
781 state
782 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
783 .unwrap();
784 state
785 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
786 .unwrap();
787
788 assert_eq!(state.total_size(), 3);
789
790 assert_eq!(
792 state
793 .get(&[Value::String("a".into()), Value::Int64(1)])
794 .unwrap(),
795 Some(&100)
796 );
797 assert_eq!(
798 state
799 .get(&[Value::String("a".into()), Value::Int64(2)])
800 .unwrap(),
801 Some(&200)
802 );
803 assert_eq!(
804 state
805 .get(&[Value::String("b".into()), Value::Int64(1)])
806 .unwrap(),
807 Some(&300)
808 );
809 }
810
811 #[test]
812 fn test_update_existing() {
813 let (_temp_dir, manager) = create_manager();
814 let mut state: PartitionedState<i64> =
815 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
816
817 state.insert(key(&[1]), 100).unwrap();
819 assert_eq!(state.total_size(), 1);
820
821 let old = state.insert(key(&[1]), 200).unwrap();
823 assert_eq!(old, Some(100));
824 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
828 }
829}