1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 hash as usize % self.num_partitions
112 }
113
114 fn touch(&mut self, partition_idx: usize) {
116 self.timestamp += 1;
117 self.access_times[partition_idx] = self.timestamp;
118 }
119
120 fn get_partition_mut(
126 &mut self,
127 partition_idx: usize,
128 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129 self.touch(partition_idx);
130
131 if self.partitions[partition_idx].is_some() {
133 return Ok(self.partitions[partition_idx]
135 .as_mut()
136 .expect("partition is Some: checked on previous line"));
137 }
138
139 if let Some(spill_file) = self.spill_files[partition_idx].take() {
141 let loaded = self.load_partition(&spill_file)?;
142 let bytes = spill_file.bytes_written();
144 let _ = spill_file.delete();
145 self.manager.unregister_spilled_bytes(bytes);
146 self.partitions[partition_idx] = Some(loaded);
147 } else {
148 self.partitions[partition_idx] = Some(HashMap::new());
150 }
151
152 Ok(self.partitions[partition_idx]
154 .as_mut()
155 .expect("partition is Some: set to Some in if/else branches above"))
156 }
157
158 fn load_partition(
160 &self,
161 spill_file: &SpillFile,
162 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163 let mut reader = spill_file.reader()?;
164 let mut adapter = SpillReaderAdapter(&mut reader);
165
166 let num_entries = read_u64(&mut adapter)? as usize;
167 let mut partition = HashMap::with_capacity(num_entries);
168
169 for _ in 0..num_entries {
170 let key_len = read_u64(&mut adapter)? as usize;
172 let mut key_buf = vec![0u8; key_len];
173 adapter.read_exact(&mut key_buf)?;
174 let serialized_key = SerializedKey(key_buf);
175
176 let num_key_columns = read_u64(&mut adapter)? as usize;
178
179 let value = (self.value_deserializer)(&mut adapter)?;
181
182 partition.insert(
183 serialized_key,
184 PartitionEntry {
185 num_key_columns,
186 value,
187 },
188 );
189 }
190
191 Ok(partition)
192 }
193
194 #[must_use]
196 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197 self.partitions[partition_idx].is_some()
198 }
199
200 #[must_use]
202 pub fn partition_size(&self, partition_idx: usize) -> usize {
203 self.partition_sizes[partition_idx]
204 }
205
206 #[must_use]
208 pub fn total_size(&self) -> usize {
209 self.partition_sizes.iter().sum()
210 }
211
212 #[must_use]
214 pub fn in_memory_count(&self) -> usize {
215 self.partitions.iter().filter(|p| p.is_some()).count()
216 }
217
218 #[must_use]
220 pub fn spilled_count(&self) -> usize {
221 self.spill_files.iter().filter(|f| f.is_some()).count()
222 }
223
224 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230 let partition = match self.partitions[partition_idx].take() {
232 Some(p) => p,
233 None => return Ok(0), };
235
236 if partition.is_empty() {
237 return Ok(0);
238 }
239
240 let mut spill_file = self.manager.create_file("partition")?;
242
243 let mut buf = Vec::new();
245 write_u64(&mut buf, partition.len() as u64)?;
246
247 for (key, entry) in &partition {
248 write_u64(&mut buf, key.0.len() as u64)?;
250 buf.extend_from_slice(&key.0);
251
252 write_u64(&mut buf, entry.num_key_columns as u64)?;
254
255 (self.value_serializer)(&entry.value, &mut buf)?;
257 }
258
259 spill_file.write_all(&buf)?;
260 spill_file.finish_write()?;
261
262 let bytes_written = spill_file.bytes_written();
263 self.manager.register_spilled_bytes(bytes_written);
264 self.partition_sizes[partition_idx] = partition.len();
265 self.spill_files[partition_idx] = Some(spill_file);
266
267 Ok(bytes_written as usize)
268 }
269
270 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
278 let largest_idx = self
280 .partitions
281 .iter()
282 .enumerate()
283 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
284 .max_by_key(|(_, size)| *size)
285 .map(|(idx, _)| idx);
286
287 match largest_idx {
288 Some(idx) => self.spill_partition(idx),
289 None => Ok(0),
290 }
291 }
292
293 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
301 let lru_idx = self
303 .partitions
304 .iter()
305 .enumerate()
306 .filter(|(_, p)| p.is_some())
307 .min_by_key(|(idx, _)| self.access_times[*idx])
308 .map(|(idx, _)| idx);
309
310 match lru_idx {
311 Some(idx) => self.spill_partition(idx),
312 None => Ok(0),
313 }
314 }
315
316 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
322 let partition_idx = self.partition_for(&key);
323 let num_key_columns = key.len();
324 let serialized_key = SerializedKey::from_values(&key);
325 let partition = self.get_partition_mut(partition_idx)?;
326
327 let old = partition.insert(
328 serialized_key,
329 PartitionEntry {
330 num_key_columns,
331 value,
332 },
333 );
334
335 if old.is_none() {
336 self.partition_sizes[partition_idx] += 1;
337 }
338
339 Ok(old.map(|e| e.value))
340 }
341
342 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
348 let partition_idx = self.partition_for(key);
349 let serialized_key = SerializedKey::from_values(key);
350 let partition = self.get_partition_mut(partition_idx)?;
351 Ok(partition.get(&serialized_key).map(|e| &e.value))
352 }
353
354 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
360 where
361 F: FnOnce() -> V,
362 {
363 let partition_idx = self.partition_for(&key);
364 let num_key_columns = key.len();
365 let serialized_key = SerializedKey::from_values(&key);
366
367 let was_new;
368 {
369 let partition = self.get_partition_mut(partition_idx)?;
370 was_new = !partition.contains_key(&serialized_key);
371 if was_new {
372 partition.insert(
373 serialized_key.clone(),
374 PartitionEntry {
375 num_key_columns,
376 value: default(),
377 },
378 );
379 }
380 }
381 if was_new {
382 self.partition_sizes[partition_idx] += 1;
383 }
384
385 let partition = self.get_partition_mut(partition_idx)?;
386 Ok(&mut partition
388 .get_mut(&serialized_key)
389 .expect("key exists: just inserted or already present in partition")
390 .value)
391 }
392
393 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
401 let mut result = Vec::with_capacity(self.total_size());
402
403 for partition_idx in 0..self.num_partitions {
404 let partition = self.get_partition_mut(partition_idx)?;
405 for (serialized_key, entry) in partition.drain() {
406 let key = serialized_key.to_values(entry.num_key_columns)?;
407 result.push((key, entry.value));
408 }
409 self.partition_sizes[partition_idx] = 0;
410 }
411
412 for spill_file in self.spill_files.iter_mut() {
414 if let Some(file) = spill_file.take() {
415 let bytes = file.bytes_written();
416 let _ = file.delete();
417 self.manager.unregister_spilled_bytes(bytes);
418 }
419 }
420
421 Ok(result)
422 }
423
424 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
432 let mut result = Vec::with_capacity(self.total_size());
433
434 for partition_idx in 0..self.num_partitions {
435 let partition = self.get_partition_mut(partition_idx)?;
436 for (serialized_key, entry) in partition.iter() {
437 let key = serialized_key.to_values(entry.num_key_columns)?;
438 result.push((key, entry.value.clone()));
439 }
440 }
441
442 Ok(result)
443 }
444
445 pub fn cleanup(&mut self) {
447 for file in self.spill_files.iter_mut().flatten() {
448 let bytes = file.bytes_written();
449 self.manager.unregister_spilled_bytes(bytes);
450 }
451
452 self.spill_files.clear();
453 self.partitions.clear();
454 for _ in 0..self.num_partitions {
455 self.spill_files.push(None);
456 self.partitions.push(Some(HashMap::new()));
457 }
458 self.partition_sizes = vec![0; self.num_partitions];
459 }
460}
461
462impl<V> Drop for PartitionedState<V> {
463 fn drop(&mut self) {
464 for file in self.spill_files.iter().flatten() {
466 let bytes = file.bytes_written();
467 self.manager.unregister_spilled_bytes(bytes);
468 }
469 }
470}
471
472fn hash_key(key: &[Value]) -> u64 {
474 use std::hash::{Hash, Hasher};
475 let mut hasher = std::collections::hash_map::DefaultHasher::new();
476
477 for value in key {
478 match value {
479 Value::Null => 0u8.hash(&mut hasher),
480 Value::Bool(b) => {
481 1u8.hash(&mut hasher);
482 b.hash(&mut hasher);
483 }
484 Value::Int64(n) => {
485 2u8.hash(&mut hasher);
486 n.hash(&mut hasher);
487 }
488 Value::Float64(f) => {
489 3u8.hash(&mut hasher);
490 f.to_bits().hash(&mut hasher);
491 }
492 Value::String(s) => {
493 4u8.hash(&mut hasher);
494 s.hash(&mut hasher);
495 }
496 Value::Bytes(b) => {
497 5u8.hash(&mut hasher);
498 b.hash(&mut hasher);
499 }
500 Value::Timestamp(t) => {
501 6u8.hash(&mut hasher);
502 t.hash(&mut hasher);
503 }
504 Value::List(l) => {
505 7u8.hash(&mut hasher);
506 l.len().hash(&mut hasher);
507 }
508 Value::Map(m) => {
509 8u8.hash(&mut hasher);
510 m.len().hash(&mut hasher);
511 }
512 Value::Vector(v) => {
513 9u8.hash(&mut hasher);
514 v.len().hash(&mut hasher);
515 for &f in v.iter().take(4) {
517 f.to_bits().hash(&mut hasher);
518 }
519 }
520 }
521 }
522
523 hasher.finish()
524}
525
526fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
528 let mut buf = [0u8; 8];
529 reader.read_exact(&mut buf)?;
530 Ok(u64::from_le_bytes(buf))
531}
532
533fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
535 writer.write_all(&value.to_le_bytes())
536}
537
538struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
540
541impl<'a> Read for SpillReaderAdapter<'a> {
542 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
543 self.0.read_exact(buf)?;
544 Ok(buf.len())
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use tempfile::TempDir;
552
553 fn create_manager() -> (TempDir, Arc<SpillManager>) {
555 let temp_dir = TempDir::new().unwrap();
556 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
557 (temp_dir, manager)
558 }
559
560 #[allow(clippy::trivially_copy_pass_by_ref)] fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
563 w.write_all(&value.to_le_bytes())
564 }
565
566 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
568 let mut buf = [0u8; 8];
569 r.read_exact(&mut buf)?;
570 Ok(i64::from_le_bytes(buf))
571 }
572
573 fn key(values: &[i64]) -> Vec<Value> {
574 values.iter().map(|&v| Value::Int64(v)).collect()
575 }
576
577 #[test]
578 fn test_partition_for() {
579 let (_temp_dir, manager) = create_manager();
580 let state: PartitionedState<i64> =
581 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
582
583 let k1 = key(&[1, 2, 3]);
585 let p1 = state.partition_for(&k1);
586 let p2 = state.partition_for(&k1);
587 assert_eq!(p1, p2);
588
589 assert!(p1 < 16);
591 }
592
593 #[test]
594 fn test_insert_and_get() {
595 let (_temp_dir, manager) = create_manager();
596 let mut state: PartitionedState<i64> =
597 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
598
599 state.insert(key(&[1]), 100).unwrap();
601 state.insert(key(&[2]), 200).unwrap();
602 state.insert(key(&[3]), 300).unwrap();
603
604 assert_eq!(state.total_size(), 3);
605
606 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
608 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
609 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
610 assert_eq!(state.get(&key(&[4])).unwrap(), None);
611 }
612
613 #[test]
614 fn test_get_or_insert_with() {
615 let (_temp_dir, manager) = create_manager();
616 let mut state: PartitionedState<i64> =
617 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
618
619 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
621 assert_eq!(*v1, 42);
622
623 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
625 assert_eq!(*v2, 42);
626
627 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
629 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
630 }
631
632 #[test]
633 fn test_spill_and_reload() {
634 let (_temp_dir, manager) = create_manager();
635 let mut state: PartitionedState<i64> =
636 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
637
638 for i in 0..20 {
640 state.insert(key(&[i]), i * 10).unwrap();
641 }
642
643 let initial_total = state.total_size();
644 assert!(initial_total > 0);
645
646 let bytes_spilled = state.spill_largest().unwrap();
648 assert!(bytes_spilled > 0);
649 assert!(state.spilled_count() > 0);
650
651 for i in 0..20 {
653 let expected = i * 10;
654 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
655 }
656 }
657
658 #[test]
659 fn test_spill_lru() {
660 let (_temp_dir, manager) = create_manager();
661 let mut state: PartitionedState<i64> =
662 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
663
664 state.insert(key(&[1]), 10).unwrap();
666 state.insert(key(&[2]), 20).unwrap();
667 state.insert(key(&[3]), 30).unwrap();
668
669 state.get(&key(&[3])).unwrap();
671
672 state.spill_lru().unwrap();
674
675 let partition_idx = state.partition_for(&key(&[3]));
677 assert!(state.is_in_memory(partition_idx));
678 }
679
680 #[test]
681 fn test_drain_all() {
682 let (_temp_dir, manager) = create_manager();
683 let mut state: PartitionedState<i64> =
684 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
685
686 for i in 0..10 {
688 state.insert(key(&[i]), i * 10).unwrap();
689 }
690
691 state.spill_largest().unwrap();
693 state.spill_largest().unwrap();
694
695 let entries = state.drain_all().unwrap();
697 assert_eq!(entries.len(), 10);
698
699 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
701 values.sort();
702 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
703
704 assert_eq!(state.total_size(), 0);
706 }
707
708 #[test]
709 fn test_iter_all() {
710 let (_temp_dir, manager) = create_manager();
711 let mut state: PartitionedState<i64> =
712 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
713
714 for i in 0..5 {
716 state.insert(key(&[i]), i * 10).unwrap();
717 }
718
719 let entries = state.iter_all().unwrap();
721 assert_eq!(entries.len(), 5);
722
723 assert_eq!(state.total_size(), 5);
725
726 let entries2 = state.iter_all().unwrap();
728 assert_eq!(entries2.len(), 5);
729 }
730
731 #[test]
732 fn test_many_groups() {
733 let (_temp_dir, manager) = create_manager();
734 let mut state: PartitionedState<i64> =
735 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
736
737 for i in 0..1000 {
739 state.insert(key(&[i]), i).unwrap();
740 }
741
742 assert_eq!(state.total_size(), 1000);
743
744 for _ in 0..8 {
746 state.spill_largest().unwrap();
747 }
748
749 assert!(state.spilled_count() >= 8);
750
751 for i in 0..1000 {
753 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
754 }
755 }
756
757 #[test]
758 fn test_cleanup() {
759 let (_temp_dir, manager) = create_manager();
760 let mut state: PartitionedState<i64> =
761 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
762
763 for i in 0..20 {
765 state.insert(key(&[i]), i).unwrap();
766 }
767 state.spill_largest().unwrap();
768 state.spill_largest().unwrap();
769
770 let spilled_before = manager.spilled_bytes();
771 assert!(spilled_before > 0);
772
773 state.cleanup();
775
776 assert_eq!(state.total_size(), 0);
777 assert_eq!(state.spilled_count(), 0);
778 }
779
780 #[test]
781 fn test_multi_column_key() {
782 let (_temp_dir, manager) = create_manager();
783 let mut state: PartitionedState<i64> =
784 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
785
786 state
788 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
789 .unwrap();
790 state
791 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
792 .unwrap();
793 state
794 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
795 .unwrap();
796
797 assert_eq!(state.total_size(), 3);
798
799 assert_eq!(
801 state
802 .get(&[Value::String("a".into()), Value::Int64(1)])
803 .unwrap(),
804 Some(&100)
805 );
806 assert_eq!(
807 state
808 .get(&[Value::String("a".into()), Value::Int64(2)])
809 .unwrap(),
810 Some(&200)
811 );
812 assert_eq!(
813 state
814 .get(&[Value::String("b".into()), Value::Int64(1)])
815 .unwrap(),
816 Some(&300)
817 );
818 }
819
820 #[test]
821 fn test_update_existing() {
822 let (_temp_dir, manager) = create_manager();
823 let mut state: PartitionedState<i64> =
824 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
825
826 state.insert(key(&[1]), 100).unwrap();
828 assert_eq!(state.total_size(), 1);
829
830 let old = state.insert(key(&[1]), 200).unwrap();
832 assert_eq!(old, Some(100));
833 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
837 }
838}