1use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use graphos_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30 fn from_values(values: &[Value]) -> Self {
31 let mut buf = Vec::new();
32 serialize_row(values, &mut buf).expect("serialization should not fail");
33 Self(buf)
34 }
35
36 fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37 deserialize_row(&mut self.0.as_slice(), num_columns)
38 }
39}
40
41struct PartitionEntry<V> {
43 num_key_columns: usize,
44 value: V,
45}
46
47pub struct PartitionedState<V> {
52 manager: Arc<SpillManager>,
54 num_partitions: usize,
56 partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58 spill_files: Vec<Option<SpillFile>>,
60 partition_sizes: Vec<usize>,
62 access_times: Vec<u64>,
64 timestamp: u64,
66 value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68 value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73 pub fn new<S, D>(
75 manager: Arc<SpillManager>,
76 num_partitions: usize,
77 value_serializer: S,
78 value_deserializer: D,
79 ) -> Self
80 where
81 S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82 D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83 {
84 let mut partitions = Vec::with_capacity(num_partitions);
85 let mut spill_files = Vec::with_capacity(num_partitions);
86 for _ in 0..num_partitions {
87 partitions.push(Some(HashMap::new()));
88 spill_files.push(None);
89 }
90
91 let partition_sizes = vec![0; num_partitions];
92 let access_times = vec![0; num_partitions];
93
94 Self {
95 manager,
96 num_partitions,
97 partitions,
98 spill_files,
99 partition_sizes,
100 access_times,
101 timestamp: 0,
102 value_serializer: Box::new(value_serializer),
103 value_deserializer: Box::new(value_deserializer),
104 }
105 }
106
107 #[must_use]
109 pub fn partition_for(&self, key: &[Value]) -> usize {
110 let hash = hash_key(key);
111 hash as usize % self.num_partitions
112 }
113
114 fn touch(&mut self, partition_idx: usize) {
116 self.timestamp += 1;
117 self.access_times[partition_idx] = self.timestamp;
118 }
119
120 fn get_partition_mut(
126 &mut self,
127 partition_idx: usize,
128 ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129 self.touch(partition_idx);
130
131 if self.partitions[partition_idx].is_some() {
133 return Ok(self.partitions[partition_idx]
135 .as_mut()
136 .expect("partition is Some: checked on previous line"));
137 }
138
139 if let Some(spill_file) = self.spill_files[partition_idx].take() {
141 let loaded = self.load_partition(&spill_file)?;
142 let bytes = spill_file.bytes_written();
144 let _ = spill_file.delete();
145 self.manager.unregister_spilled_bytes(bytes);
146 self.partitions[partition_idx] = Some(loaded);
147 } else {
148 self.partitions[partition_idx] = Some(HashMap::new());
150 }
151
152 Ok(self.partitions[partition_idx]
154 .as_mut()
155 .expect("partition is Some: set to Some in if/else branches above"))
156 }
157
158 fn load_partition(
160 &self,
161 spill_file: &SpillFile,
162 ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163 let mut reader = spill_file.reader()?;
164 let mut adapter = SpillReaderAdapter(&mut reader);
165
166 let num_entries = read_u64(&mut adapter)? as usize;
167 let mut partition = HashMap::with_capacity(num_entries);
168
169 for _ in 0..num_entries {
170 let key_len = read_u64(&mut adapter)? as usize;
172 let mut key_buf = vec![0u8; key_len];
173 adapter.read_exact(&mut key_buf)?;
174 let serialized_key = SerializedKey(key_buf);
175
176 let num_key_columns = read_u64(&mut adapter)? as usize;
178
179 let value = (self.value_deserializer)(&mut adapter)?;
181
182 partition.insert(
183 serialized_key,
184 PartitionEntry {
185 num_key_columns,
186 value,
187 },
188 );
189 }
190
191 Ok(partition)
192 }
193
194 #[must_use]
196 pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197 self.partitions[partition_idx].is_some()
198 }
199
200 #[must_use]
202 pub fn partition_size(&self, partition_idx: usize) -> usize {
203 self.partition_sizes[partition_idx]
204 }
205
206 #[must_use]
208 pub fn total_size(&self) -> usize {
209 self.partition_sizes.iter().sum()
210 }
211
212 #[must_use]
214 pub fn in_memory_count(&self) -> usize {
215 self.partitions.iter().filter(|p| p.is_some()).count()
216 }
217
218 #[must_use]
220 pub fn spilled_count(&self) -> usize {
221 self.spill_files.iter().filter(|f| f.is_some()).count()
222 }
223
224 pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230 let partition = match self.partitions[partition_idx].take() {
232 Some(p) => p,
233 None => return Ok(0), };
235
236 if partition.is_empty() {
237 return Ok(0);
238 }
239
240 let mut spill_file = self.manager.create_file("partition")?;
242
243 let mut buf = Vec::new();
245 write_u64(&mut buf, partition.len() as u64)?;
246
247 for (key, entry) in &partition {
248 write_u64(&mut buf, key.0.len() as u64)?;
250 buf.extend_from_slice(&key.0);
251
252 write_u64(&mut buf, entry.num_key_columns as u64)?;
254
255 (self.value_serializer)(&entry.value, &mut buf)?;
257 }
258
259 spill_file.write_all(&buf)?;
260 spill_file.finish_write()?;
261
262 let bytes_written = spill_file.bytes_written();
263 self.manager.register_spilled_bytes(bytes_written);
264 self.partition_sizes[partition_idx] = partition.len();
265 self.spill_files[partition_idx] = Some(spill_file);
266
267 Ok(bytes_written as usize)
268 }
269
270 pub fn spill_largest(&mut self) -> std::io::Result<usize> {
278 let largest_idx = self
280 .partitions
281 .iter()
282 .enumerate()
283 .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
284 .max_by_key(|(_, size)| *size)
285 .map(|(idx, _)| idx);
286
287 match largest_idx {
288 Some(idx) => self.spill_partition(idx),
289 None => Ok(0),
290 }
291 }
292
293 pub fn spill_lru(&mut self) -> std::io::Result<usize> {
301 let lru_idx = self
303 .partitions
304 .iter()
305 .enumerate()
306 .filter(|(_, p)| p.is_some())
307 .min_by_key(|(idx, _)| self.access_times[*idx])
308 .map(|(idx, _)| idx);
309
310 match lru_idx {
311 Some(idx) => self.spill_partition(idx),
312 None => Ok(0),
313 }
314 }
315
316 pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
322 let partition_idx = self.partition_for(&key);
323 let num_key_columns = key.len();
324 let serialized_key = SerializedKey::from_values(&key);
325 let partition = self.get_partition_mut(partition_idx)?;
326
327 let old = partition.insert(
328 serialized_key,
329 PartitionEntry {
330 num_key_columns,
331 value,
332 },
333 );
334
335 if old.is_none() {
336 self.partition_sizes[partition_idx] += 1;
337 }
338
339 Ok(old.map(|e| e.value))
340 }
341
342 pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
348 let partition_idx = self.partition_for(key);
349 let serialized_key = SerializedKey::from_values(key);
350 let partition = self.get_partition_mut(partition_idx)?;
351 Ok(partition.get(&serialized_key).map(|e| &e.value))
352 }
353
354 pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
360 where
361 F: FnOnce() -> V,
362 {
363 let partition_idx = self.partition_for(&key);
364 let num_key_columns = key.len();
365 let serialized_key = SerializedKey::from_values(&key);
366
367 let was_new;
368 {
369 let partition = self.get_partition_mut(partition_idx)?;
370 was_new = !partition.contains_key(&serialized_key);
371 if was_new {
372 partition.insert(
373 serialized_key.clone(),
374 PartitionEntry {
375 num_key_columns,
376 value: default(),
377 },
378 );
379 }
380 }
381 if was_new {
382 self.partition_sizes[partition_idx] += 1;
383 }
384
385 let partition = self.get_partition_mut(partition_idx)?;
386 Ok(&mut partition.get_mut(&serialized_key).unwrap().value)
387 }
388
389 pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
397 let mut result = Vec::with_capacity(self.total_size());
398
399 for partition_idx in 0..self.num_partitions {
400 let partition = self.get_partition_mut(partition_idx)?;
401 for (serialized_key, entry) in partition.drain() {
402 let key = serialized_key.to_values(entry.num_key_columns)?;
403 result.push((key, entry.value));
404 }
405 self.partition_sizes[partition_idx] = 0;
406 }
407
408 for spill_file in self.spill_files.iter_mut() {
410 if let Some(file) = spill_file.take() {
411 let bytes = file.bytes_written();
412 let _ = file.delete();
413 self.manager.unregister_spilled_bytes(bytes);
414 }
415 }
416
417 Ok(result)
418 }
419
420 pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
428 let mut result = Vec::with_capacity(self.total_size());
429
430 for partition_idx in 0..self.num_partitions {
431 let partition = self.get_partition_mut(partition_idx)?;
432 for (serialized_key, entry) in partition.iter() {
433 let key = serialized_key.to_values(entry.num_key_columns)?;
434 result.push((key, entry.value.clone()));
435 }
436 }
437
438 Ok(result)
439 }
440
441 pub fn cleanup(&mut self) {
443 for file in self.spill_files.iter_mut().flatten() {
444 let bytes = file.bytes_written();
445 self.manager.unregister_spilled_bytes(bytes);
446 }
447
448 self.spill_files.clear();
449 self.partitions.clear();
450 for _ in 0..self.num_partitions {
451 self.spill_files.push(None);
452 self.partitions.push(Some(HashMap::new()));
453 }
454 self.partition_sizes = vec![0; self.num_partitions];
455 }
456}
457
458impl<V> Drop for PartitionedState<V> {
459 fn drop(&mut self) {
460 for file in self.spill_files.iter().flatten() {
462 let bytes = file.bytes_written();
463 self.manager.unregister_spilled_bytes(bytes);
464 }
465 }
466}
467
468fn hash_key(key: &[Value]) -> u64 {
470 use std::hash::{Hash, Hasher};
471 let mut hasher = std::collections::hash_map::DefaultHasher::new();
472
473 for value in key {
474 match value {
475 Value::Null => 0u8.hash(&mut hasher),
476 Value::Bool(b) => {
477 1u8.hash(&mut hasher);
478 b.hash(&mut hasher);
479 }
480 Value::Int64(n) => {
481 2u8.hash(&mut hasher);
482 n.hash(&mut hasher);
483 }
484 Value::Float64(f) => {
485 3u8.hash(&mut hasher);
486 f.to_bits().hash(&mut hasher);
487 }
488 Value::String(s) => {
489 4u8.hash(&mut hasher);
490 s.hash(&mut hasher);
491 }
492 Value::Bytes(b) => {
493 5u8.hash(&mut hasher);
494 b.hash(&mut hasher);
495 }
496 Value::Timestamp(t) => {
497 6u8.hash(&mut hasher);
498 t.hash(&mut hasher);
499 }
500 Value::List(l) => {
501 7u8.hash(&mut hasher);
502 l.len().hash(&mut hasher);
503 }
504 Value::Map(m) => {
505 8u8.hash(&mut hasher);
506 m.len().hash(&mut hasher);
507 }
508 }
509 }
510
511 hasher.finish()
512}
513
514fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
516 let mut buf = [0u8; 8];
517 reader.read_exact(&mut buf)?;
518 Ok(u64::from_le_bytes(buf))
519}
520
521fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
523 writer.write_all(&value.to_le_bytes())
524}
525
526struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
528
529impl<'a> Read for SpillReaderAdapter<'a> {
530 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
531 self.0.read_exact(buf)?;
532 Ok(buf.len())
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use tempfile::TempDir;
540
541 fn create_manager() -> (TempDir, Arc<SpillManager>) {
543 let temp_dir = TempDir::new().unwrap();
544 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
545 (temp_dir, manager)
546 }
547
548 fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
550 w.write_all(&value.to_le_bytes())
551 }
552
553 fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
555 let mut buf = [0u8; 8];
556 r.read_exact(&mut buf)?;
557 Ok(i64::from_le_bytes(buf))
558 }
559
560 fn key(values: &[i64]) -> Vec<Value> {
561 values.iter().map(|&v| Value::Int64(v)).collect()
562 }
563
564 #[test]
565 fn test_partition_for() {
566 let (_temp_dir, manager) = create_manager();
567 let state: PartitionedState<i64> =
568 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
569
570 let k1 = key(&[1, 2, 3]);
572 let p1 = state.partition_for(&k1);
573 let p2 = state.partition_for(&k1);
574 assert_eq!(p1, p2);
575
576 assert!(p1 < 16);
578 }
579
580 #[test]
581 fn test_insert_and_get() {
582 let (_temp_dir, manager) = create_manager();
583 let mut state: PartitionedState<i64> =
584 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
585
586 state.insert(key(&[1]), 100).unwrap();
588 state.insert(key(&[2]), 200).unwrap();
589 state.insert(key(&[3]), 300).unwrap();
590
591 assert_eq!(state.total_size(), 3);
592
593 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
595 assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
596 assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
597 assert_eq!(state.get(&key(&[4])).unwrap(), None);
598 }
599
600 #[test]
601 fn test_get_or_insert_with() {
602 let (_temp_dir, manager) = create_manager();
603 let mut state: PartitionedState<i64> =
604 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
605
606 let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
608 assert_eq!(*v1, 42);
609
610 let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
612 assert_eq!(*v2, 42);
613
614 *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
616 assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
617 }
618
619 #[test]
620 fn test_spill_and_reload() {
621 let (_temp_dir, manager) = create_manager();
622 let mut state: PartitionedState<i64> =
623 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
624
625 for i in 0..20 {
627 state.insert(key(&[i]), i * 10).unwrap();
628 }
629
630 let initial_total = state.total_size();
631 assert!(initial_total > 0);
632
633 let bytes_spilled = state.spill_largest().unwrap();
635 assert!(bytes_spilled > 0);
636 assert!(state.spilled_count() > 0);
637
638 for i in 0..20 {
640 let expected = i * 10;
641 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
642 }
643 }
644
645 #[test]
646 fn test_spill_lru() {
647 let (_temp_dir, manager) = create_manager();
648 let mut state: PartitionedState<i64> =
649 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
650
651 state.insert(key(&[1]), 10).unwrap();
653 state.insert(key(&[2]), 20).unwrap();
654 state.insert(key(&[3]), 30).unwrap();
655
656 state.get(&key(&[3])).unwrap();
658
659 state.spill_lru().unwrap();
661
662 let partition_idx = state.partition_for(&key(&[3]));
664 assert!(state.is_in_memory(partition_idx));
665 }
666
667 #[test]
668 fn test_drain_all() {
669 let (_temp_dir, manager) = create_manager();
670 let mut state: PartitionedState<i64> =
671 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
672
673 for i in 0..10 {
675 state.insert(key(&[i]), i * 10).unwrap();
676 }
677
678 state.spill_largest().unwrap();
680 state.spill_largest().unwrap();
681
682 let entries = state.drain_all().unwrap();
684 assert_eq!(entries.len(), 10);
685
686 let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
688 values.sort();
689 assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
690
691 assert_eq!(state.total_size(), 0);
693 }
694
695 #[test]
696 fn test_iter_all() {
697 let (_temp_dir, manager) = create_manager();
698 let mut state: PartitionedState<i64> =
699 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
700
701 for i in 0..5 {
703 state.insert(key(&[i]), i * 10).unwrap();
704 }
705
706 let entries = state.iter_all().unwrap();
708 assert_eq!(entries.len(), 5);
709
710 assert_eq!(state.total_size(), 5);
712
713 let entries2 = state.iter_all().unwrap();
715 assert_eq!(entries2.len(), 5);
716 }
717
718 #[test]
719 fn test_many_groups() {
720 let (_temp_dir, manager) = create_manager();
721 let mut state: PartitionedState<i64> =
722 PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
723
724 for i in 0..1000 {
726 state.insert(key(&[i]), i).unwrap();
727 }
728
729 assert_eq!(state.total_size(), 1000);
730
731 for _ in 0..8 {
733 state.spill_largest().unwrap();
734 }
735
736 assert!(state.spilled_count() >= 8);
737
738 for i in 0..1000 {
740 assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
741 }
742 }
743
744 #[test]
745 fn test_cleanup() {
746 let (_temp_dir, manager) = create_manager();
747 let mut state: PartitionedState<i64> =
748 PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
749
750 for i in 0..20 {
752 state.insert(key(&[i]), i).unwrap();
753 }
754 state.spill_largest().unwrap();
755 state.spill_largest().unwrap();
756
757 let spilled_before = manager.spilled_bytes();
758 assert!(spilled_before > 0);
759
760 state.cleanup();
762
763 assert_eq!(state.total_size(), 0);
764 assert_eq!(state.spilled_count(), 0);
765 }
766
767 #[test]
768 fn test_multi_column_key() {
769 let (_temp_dir, manager) = create_manager();
770 let mut state: PartitionedState<i64> =
771 PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
772
773 state
775 .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
776 .unwrap();
777 state
778 .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
779 .unwrap();
780 state
781 .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
782 .unwrap();
783
784 assert_eq!(state.total_size(), 3);
785
786 assert_eq!(
788 state
789 .get(&[Value::String("a".into()), Value::Int64(1)])
790 .unwrap(),
791 Some(&100)
792 );
793 assert_eq!(
794 state
795 .get(&[Value::String("a".into()), Value::Int64(2)])
796 .unwrap(),
797 Some(&200)
798 );
799 assert_eq!(
800 state
801 .get(&[Value::String("b".into()), Value::Int64(1)])
802 .unwrap(),
803 Some(&300)
804 );
805 }
806
807 #[test]
808 fn test_update_existing() {
809 let (_temp_dir, manager) = create_manager();
810 let mut state: PartitionedState<i64> =
811 PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
812
813 state.insert(key(&[1]), 100).unwrap();
815 assert_eq!(state.total_size(), 1);
816
817 let old = state.insert(key(&[1]), 200).unwrap();
819 assert_eq!(old, Some(100));
820 assert_eq!(state.total_size(), 1); assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
824 }
825}