Skip to main content

grafeo_core/execution/spill/
partition.rs

1//! Hash partitioning for spillable aggregation.
2//!
3//! This module implements hash partitioning that allows aggregate state
4//! to be partitioned and spilled to disk when memory pressure is high.
5//!
6//! # Design
7//!
8//! - Groups are assigned to partitions based on their key's hash
9//! - In-memory partitions can be spilled to disk under memory pressure
10//! - Cold (least recently accessed) partitions are spilled first
11//! - When iterating results, spilled partitions are reloaded
12
13use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21/// Default number of partitions for hash partitioning.
22pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24/// A serialized key for use as a HashMap key.
25/// We serialize Value vectors to bytes since Value doesn't implement Hash/Eq.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30    fn from_values(values: &[Value]) -> Self {
31        let mut buf = Vec::new();
32        serialize_row(values, &mut buf).expect("serialization should not fail");
33        Self(buf)
34    }
35
36    fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37        deserialize_row(&mut self.0.as_slice(), num_columns)
38    }
39}
40
41/// Entry in a partition: the original key columns count and value.
42struct PartitionEntry<V> {
43    num_key_columns: usize,
44    value: V,
45}
46
47/// Partitioned accumulator state for spillable aggregation.
48///
49/// Manages aggregate state across multiple partitions, with the ability
50/// to spill cold partitions to disk under memory pressure.
51pub struct PartitionedState<V> {
52    /// Spill manager for file creation.
53    manager: Arc<SpillManager>,
54    /// Number of partitions.
55    num_partitions: usize,
56    /// In-memory partitions (None = spilled to disk).
57    partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58    /// Spill files for spilled partitions.
59    spill_files: Vec<Option<SpillFile>>,
60    /// Number of groups per partition (for spilled partitions too).
61    partition_sizes: Vec<usize>,
62    /// Access timestamps for LRU eviction.
63    access_times: Vec<u64>,
64    /// Global timestamp counter.
65    timestamp: u64,
66    /// Serializer for V values.
67    value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68    /// Deserializer for V values.
69    value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73    /// Creates a new partitioned state with custom serialization.
74    pub fn new<S, D>(
75        manager: Arc<SpillManager>,
76        num_partitions: usize,
77        value_serializer: S,
78        value_deserializer: D,
79    ) -> Self
80    where
81        S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82        D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83    {
84        let mut partitions = Vec::with_capacity(num_partitions);
85        let mut spill_files = Vec::with_capacity(num_partitions);
86        for _ in 0..num_partitions {
87            partitions.push(Some(HashMap::new()));
88            spill_files.push(None);
89        }
90
91        let partition_sizes = vec![0; num_partitions];
92        let access_times = vec![0; num_partitions];
93
94        Self {
95            manager,
96            num_partitions,
97            partitions,
98            spill_files,
99            partition_sizes,
100            access_times,
101            timestamp: 0,
102            value_serializer: Box::new(value_serializer),
103            value_deserializer: Box::new(value_deserializer),
104        }
105    }
106
107    /// Returns the partition index for a key.
108    #[must_use]
109    pub fn partition_for(&self, key: &[Value]) -> usize {
110        let hash = hash_key(key);
111        // reason: on 64-bit targets u64 == usize; on 32-bit the modulo handles overflow
112        #[allow(clippy::cast_possible_truncation)]
113        {
114            hash as usize % self.num_partitions
115        }
116    }
117
118    /// Updates access time for a partition.
119    fn touch(&mut self, partition_idx: usize) {
120        self.timestamp += 1;
121        self.access_times[partition_idx] = self.timestamp;
122    }
123
124    /// Gets the in-memory partition, loading from disk if spilled.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if reading from disk fails.
129    fn get_partition_mut(
130        &mut self,
131        partition_idx: usize,
132    ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
133        self.touch(partition_idx);
134
135        // If partition is in memory, return it
136        if self.partitions[partition_idx].is_some() {
137            // Invariant: just checked is_some() above
138            return Ok(self.partitions[partition_idx]
139                .as_mut()
140                .expect("partition is Some: checked on previous line"));
141        }
142
143        // Load from disk
144        if let Some(spill_file) = self.spill_files[partition_idx].take() {
145            let loaded = self.load_partition(&spill_file)?;
146            // Delete the spill file after loading
147            let bytes = spill_file.bytes_written();
148            let _ = spill_file.delete();
149            self.manager.unregister_spilled_bytes(bytes);
150            self.partitions[partition_idx] = Some(loaded);
151        } else {
152            // Neither in memory nor on disk - create empty partition
153            self.partitions[partition_idx] = Some(HashMap::new());
154        }
155
156        // Invariant: partition was either loaded from disk or created empty above
157        Ok(self.partitions[partition_idx]
158            .as_mut()
159            .expect("partition is Some: set to Some in if/else branches above"))
160    }
161
162    /// Loads a partition from a spill file.
163    fn load_partition(
164        &self,
165        spill_file: &SpillFile,
166    ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
167        let mut reader = spill_file.reader()?;
168        let mut adapter = SpillReaderAdapter(&mut reader);
169
170        // reason: deserialized counts are bounded by available data
171        #[allow(clippy::cast_possible_truncation)]
172        let num_entries = read_u64(&mut adapter)? as usize;
173        let mut partition = HashMap::with_capacity(num_entries);
174
175        for _ in 0..num_entries {
176            // Read key
177            // reason: deserialized lengths are bounded by available data
178            #[allow(clippy::cast_possible_truncation)]
179            let key_len = read_u64(&mut adapter)? as usize;
180            let mut key_buf = vec![0u8; key_len];
181            adapter.read_exact(&mut key_buf)?;
182            let serialized_key = SerializedKey(key_buf);
183
184            // Read number of key columns
185            // reason: deserialized counts are bounded by available data
186            #[allow(clippy::cast_possible_truncation)]
187            let num_key_columns = read_u64(&mut adapter)? as usize;
188
189            // Read value
190            let value = (self.value_deserializer)(&mut adapter)?;
191
192            partition.insert(
193                serialized_key,
194                PartitionEntry {
195                    num_key_columns,
196                    value,
197                },
198            );
199        }
200
201        Ok(partition)
202    }
203
204    /// Returns whether a partition is in memory.
205    #[must_use]
206    pub fn is_in_memory(&self, partition_idx: usize) -> bool {
207        self.partitions[partition_idx].is_some()
208    }
209
210    /// Returns the number of groups in a partition.
211    #[must_use]
212    pub fn partition_size(&self, partition_idx: usize) -> usize {
213        self.partition_sizes[partition_idx]
214    }
215
216    /// Returns the total number of groups across all partitions.
217    #[must_use]
218    pub fn total_size(&self) -> usize {
219        self.partition_sizes.iter().sum()
220    }
221
222    /// Returns the number of in-memory partitions.
223    #[must_use]
224    pub fn in_memory_count(&self) -> usize {
225        self.partitions.iter().filter(|p| p.is_some()).count()
226    }
227
228    /// Returns the number of spilled partitions.
229    #[must_use]
230    pub fn spilled_count(&self) -> usize {
231        self.spill_files.iter().filter(|f| f.is_some()).count()
232    }
233
234    /// Spills a specific partition to disk.
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if writing to disk fails.
239    pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
240        // Get partition data
241        let Some(partition) = self.partitions[partition_idx].take() else {
242            return Ok(0); // Already spilled
243        };
244
245        if partition.is_empty() {
246            return Ok(0);
247        }
248
249        // Create spill file
250        let mut spill_file = self.manager.create_file("partition")?;
251
252        // Write partition data
253        let mut buf = Vec::new();
254        write_u64(&mut buf, partition.len() as u64)?;
255
256        for (key, entry) in &partition {
257            // Write key bytes
258            write_u64(&mut buf, key.0.len() as u64)?;
259            buf.extend_from_slice(&key.0);
260
261            // Write number of key columns
262            write_u64(&mut buf, entry.num_key_columns as u64)?;
263
264            // Write value
265            (self.value_serializer)(&entry.value, &mut buf)?;
266        }
267
268        spill_file.write_all(&buf)?;
269        spill_file.finish_write()?;
270
271        let bytes_written = spill_file.bytes_written();
272        self.manager.register_spilled_bytes(bytes_written);
273        self.partition_sizes[partition_idx] = partition.len();
274        self.spill_files[partition_idx] = Some(spill_file);
275
276        // reason: bytes written is bounded by available disk space, fits usize
277        #[allow(clippy::cast_possible_truncation)]
278        Ok(bytes_written as usize)
279    }
280
281    /// Spills the largest in-memory partition.
282    ///
283    /// Returns the number of bytes spilled, or 0 if no partition to spill.
284    ///
285    /// # Errors
286    ///
287    /// Returns an error if writing to disk fails.
288    pub fn spill_largest(&mut self) -> std::io::Result<usize> {
289        // Find largest in-memory partition
290        let largest_idx = self
291            .partitions
292            .iter()
293            .enumerate()
294            .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
295            .max_by_key(|(_, size)| *size)
296            .map(|(idx, _)| idx);
297
298        match largest_idx {
299            Some(idx) => self.spill_partition(idx),
300            None => Ok(0),
301        }
302    }
303
304    /// Spills the least recently used in-memory partition.
305    ///
306    /// Returns the number of bytes spilled, or 0 if no partition to spill.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if writing to disk fails.
311    pub fn spill_lru(&mut self) -> std::io::Result<usize> {
312        // Find LRU in-memory partition
313        let lru_idx = self
314            .partitions
315            .iter()
316            .enumerate()
317            .filter(|(_, p)| p.is_some())
318            .min_by_key(|(idx, _)| self.access_times[*idx])
319            .map(|(idx, _)| idx);
320
321        match lru_idx {
322            Some(idx) => self.spill_partition(idx),
323            None => Ok(0),
324        }
325    }
326
327    /// Inserts or updates a value for a key.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if loading from disk fails.
332    pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
333        let partition_idx = self.partition_for(&key);
334        let num_key_columns = key.len();
335        let serialized_key = SerializedKey::from_values(&key);
336        let partition = self.get_partition_mut(partition_idx)?;
337
338        let old = partition.insert(
339            serialized_key,
340            PartitionEntry {
341                num_key_columns,
342                value,
343            },
344        );
345
346        if old.is_none() {
347            self.partition_sizes[partition_idx] += 1;
348        }
349
350        Ok(old.map(|e| e.value))
351    }
352
353    /// Gets a value for a key.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if loading from disk fails.
358    pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
359        let partition_idx = self.partition_for(key);
360        let serialized_key = SerializedKey::from_values(key);
361        let partition = self.get_partition_mut(partition_idx)?;
362        Ok(partition.get(&serialized_key).map(|e| &e.value))
363    }
364
365    /// Gets a mutable value for a key, or inserts a default.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if loading from disk fails.
370    ///
371    /// # Panics
372    ///
373    /// Panics if the key was just inserted but cannot be found (invariant violation).
374    pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
375    where
376        F: FnOnce() -> V,
377    {
378        let partition_idx = self.partition_for(&key);
379        let num_key_columns = key.len();
380        let serialized_key = SerializedKey::from_values(&key);
381
382        let was_new;
383        {
384            let partition = self.get_partition_mut(partition_idx)?;
385            was_new = !partition.contains_key(&serialized_key);
386            if was_new {
387                partition.insert(
388                    serialized_key.clone(),
389                    PartitionEntry {
390                        num_key_columns,
391                        value: default(),
392                    },
393                );
394            }
395        }
396        if was_new {
397            self.partition_sizes[partition_idx] += 1;
398        }
399
400        let partition = self.get_partition_mut(partition_idx)?;
401        // Invariant: key was either already present or inserted in the block above
402        Ok(&mut partition
403            .get_mut(&serialized_key)
404            .expect("key exists: just inserted or already present in partition")
405            .value)
406    }
407
408    /// Drains all entries from all partitions.
409    ///
410    /// Loads spilled partitions as needed.
411    ///
412    /// # Errors
413    ///
414    /// Returns an error if loading from disk fails.
415    pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
416        let mut result = Vec::with_capacity(self.total_size());
417
418        for partition_idx in 0..self.num_partitions {
419            let partition = self.get_partition_mut(partition_idx)?;
420            for (serialized_key, entry) in partition.drain() {
421                let key = serialized_key.to_values(entry.num_key_columns)?;
422                result.push((key, entry.value));
423            }
424            self.partition_sizes[partition_idx] = 0;
425        }
426
427        // Clean up any remaining spill files
428        for spill_file in &mut self.spill_files {
429            if let Some(file) = spill_file.take() {
430                let bytes = file.bytes_written();
431                let _ = file.delete();
432                self.manager.unregister_spilled_bytes(bytes);
433            }
434        }
435
436        Ok(result)
437    }
438
439    /// Iterates over all entries without draining.
440    ///
441    /// Loads spilled partitions as needed.
442    ///
443    /// # Errors
444    ///
445    /// Returns an error if loading from disk fails.
446    pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
447        let mut result = Vec::with_capacity(self.total_size());
448
449        for partition_idx in 0..self.num_partitions {
450            let partition = self.get_partition_mut(partition_idx)?;
451            for (serialized_key, entry) in partition.iter() {
452                let key = serialized_key.to_values(entry.num_key_columns)?;
453                result.push((key, entry.value.clone()));
454            }
455        }
456
457        Ok(result)
458    }
459
460    /// Cleans up all spill files.
461    pub fn cleanup(&mut self) {
462        for file in self.spill_files.iter_mut().flatten() {
463            let bytes = file.bytes_written();
464            self.manager.unregister_spilled_bytes(bytes);
465        }
466
467        self.spill_files.clear();
468        self.partitions.clear();
469        for _ in 0..self.num_partitions {
470            self.spill_files.push(None);
471            self.partitions.push(Some(HashMap::new()));
472        }
473        self.partition_sizes = vec![0; self.num_partitions];
474    }
475}
476
477impl<V> Drop for PartitionedState<V> {
478    fn drop(&mut self) {
479        // Unregister spilled bytes
480        for file in self.spill_files.iter().flatten() {
481            let bytes = file.bytes_written();
482            self.manager.unregister_spilled_bytes(bytes);
483        }
484    }
485}
486
487/// Hashes a key (vector of values) to a u64.
488fn hash_key(key: &[Value]) -> u64 {
489    use std::hash::{Hash, Hasher};
490    let mut hasher = std::collections::hash_map::DefaultHasher::new();
491
492    for value in key {
493        match value {
494            Value::Null => 0u8.hash(&mut hasher),
495            Value::Bool(b) => {
496                1u8.hash(&mut hasher);
497                b.hash(&mut hasher);
498            }
499            Value::Int64(n) => {
500                2u8.hash(&mut hasher);
501                n.hash(&mut hasher);
502            }
503            Value::Float64(f) => {
504                3u8.hash(&mut hasher);
505                f.to_bits().hash(&mut hasher);
506            }
507            Value::String(s) => {
508                4u8.hash(&mut hasher);
509                s.hash(&mut hasher);
510            }
511            Value::Bytes(b) => {
512                5u8.hash(&mut hasher);
513                b.hash(&mut hasher);
514            }
515            Value::Timestamp(t) => {
516                6u8.hash(&mut hasher);
517                t.hash(&mut hasher);
518            }
519            Value::Date(d) => {
520                10u8.hash(&mut hasher);
521                d.hash(&mut hasher);
522            }
523            Value::Time(t) => {
524                11u8.hash(&mut hasher);
525                t.hash(&mut hasher);
526            }
527            Value::Duration(d) => {
528                12u8.hash(&mut hasher);
529                d.hash(&mut hasher);
530            }
531            Value::ZonedDatetime(zdt) => {
532                14u8.hash(&mut hasher);
533                zdt.hash(&mut hasher);
534            }
535            Value::List(l) => {
536                7u8.hash(&mut hasher);
537                l.len().hash(&mut hasher);
538            }
539            Value::Map(m) => {
540                8u8.hash(&mut hasher);
541                m.len().hash(&mut hasher);
542            }
543            Value::Vector(v) => {
544                9u8.hash(&mut hasher);
545                v.len().hash(&mut hasher);
546                // Hash first few elements for distribution
547                for &f in v.iter().take(4) {
548                    f.to_bits().hash(&mut hasher);
549                }
550            }
551            Value::Path { nodes, edges } => {
552                13u8.hash(&mut hasher);
553                nodes.len().hash(&mut hasher);
554                edges.len().hash(&mut hasher);
555            }
556            Value::GCounter(counts) => {
557                15u8.hash(&mut hasher);
558                counts.len().hash(&mut hasher);
559            }
560            Value::OnCounter { pos, neg } => {
561                16u8.hash(&mut hasher);
562                pos.len().hash(&mut hasher);
563                neg.len().hash(&mut hasher);
564            }
565            _ => {
566                255u8.hash(&mut hasher);
567            }
568        }
569    }
570
571    hasher.finish()
572}
573
574/// Helper to read u64 in little endian.
575fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
576    let mut buf = [0u8; 8];
577    reader.read_exact(&mut buf)?;
578    Ok(u64::from_le_bytes(buf))
579}
580
581/// Helper to write u64 in little endian.
582fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
583    writer.write_all(&value.to_le_bytes())
584}
585
586/// Adapter to read from SpillFileReader through std::io::Read.
587struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
588
589impl Read for SpillReaderAdapter<'_> {
590    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
591        self.0.read_exact(buf)?;
592        Ok(buf.len())
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use tempfile::TempDir;
600
601    /// Creates a test manager. Returns (TempDir, manager). TempDir must be kept alive.
602    fn create_manager() -> (TempDir, Arc<SpillManager>) {
603        let temp_dir = TempDir::new().unwrap();
604        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
605        (temp_dir, manager)
606    }
607
608    /// Simple i64 serializer for tests.
609    #[allow(clippy::trivially_copy_pass_by_ref)] // Required by PartitionedState::new signature
610    fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
611        w.write_all(&value.to_le_bytes())
612    }
613
614    /// Simple i64 deserializer for tests.
615    fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
616        let mut buf = [0u8; 8];
617        r.read_exact(&mut buf)?;
618        Ok(i64::from_le_bytes(buf))
619    }
620
621    fn key(values: &[i64]) -> Vec<Value> {
622        values.iter().map(|&v| Value::Int64(v)).collect()
623    }
624
625    #[test]
626    fn test_partition_for() {
627        let (_temp_dir, manager) = create_manager();
628        let state: PartitionedState<i64> =
629            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
630
631        // Same key should always go to same partition
632        let k1 = key(&[1, 2, 3]);
633        let p1 = state.partition_for(&k1);
634        let p2 = state.partition_for(&k1);
635        assert_eq!(p1, p2);
636
637        // Partition should be in range
638        assert!(p1 < 16);
639    }
640
641    #[test]
642    fn test_insert_and_get() {
643        let (_temp_dir, manager) = create_manager();
644        let mut state: PartitionedState<i64> =
645            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
646
647        // Insert some values
648        state.insert(key(&[1]), 100).unwrap();
649        state.insert(key(&[2]), 200).unwrap();
650        state.insert(key(&[3]), 300).unwrap();
651
652        assert_eq!(state.total_size(), 3);
653
654        // Get values
655        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
656        assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
657        assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
658        assert_eq!(state.get(&key(&[4])).unwrap(), None);
659    }
660
661    #[test]
662    fn test_get_or_insert_with() {
663        let (_temp_dir, manager) = create_manager();
664        let mut state: PartitionedState<i64> =
665            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
666
667        // First access creates the entry
668        let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
669        assert_eq!(*v1, 42);
670
671        // Second access returns existing value
672        let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
673        assert_eq!(*v2, 42);
674
675        // Mutate via returned reference
676        *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
677        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
678    }
679
680    #[test]
681    fn test_spill_and_reload() {
682        let (_temp_dir, manager) = create_manager();
683        let mut state: PartitionedState<i64> =
684            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
685
686        // Insert values that go to different partitions
687        for i in 0..20 {
688            state.insert(key(&[i]), i * 10).unwrap();
689        }
690
691        let initial_total = state.total_size();
692        assert!(initial_total > 0);
693
694        // Spill the largest partition
695        let bytes_spilled = state.spill_largest().unwrap();
696        assert!(bytes_spilled > 0);
697        assert!(state.spilled_count() > 0);
698
699        // Values should still be accessible (reloads from disk)
700        for i in 0..20 {
701            let expected = i * 10;
702            assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
703        }
704    }
705
706    #[test]
707    fn test_spill_lru() {
708        let (_temp_dir, manager) = create_manager();
709        let mut state: PartitionedState<i64> =
710            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
711
712        // Insert values
713        state.insert(key(&[1]), 10).unwrap();
714        state.insert(key(&[2]), 20).unwrap();
715        state.insert(key(&[3]), 30).unwrap();
716
717        // Access key 3 to make it recently used
718        state.get(&key(&[3])).unwrap();
719
720        // Spill LRU - should not spill partition containing key 3
721        state.spill_lru().unwrap();
722
723        // Key 3 should still be in memory
724        let partition_idx = state.partition_for(&key(&[3]));
725        assert!(state.is_in_memory(partition_idx));
726    }
727
728    #[test]
729    fn test_drain_all() {
730        let (_temp_dir, manager) = create_manager();
731        let mut state: PartitionedState<i64> =
732            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
733
734        // Insert values
735        for i in 0..10 {
736            state.insert(key(&[i]), i * 10).unwrap();
737        }
738
739        // Spill some partitions
740        state.spill_largest().unwrap();
741        state.spill_largest().unwrap();
742
743        // Drain all
744        let entries = state.drain_all().unwrap();
745        assert_eq!(entries.len(), 10);
746
747        // Verify all entries are present
748        let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
749        values.sort_unstable();
750        assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
751
752        // State should be empty
753        assert_eq!(state.total_size(), 0);
754    }
755
756    #[test]
757    fn test_iter_all() {
758        let (_temp_dir, manager) = create_manager();
759        let mut state: PartitionedState<i64> =
760            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
761
762        // Insert values
763        for i in 0..5 {
764            state.insert(key(&[i]), i * 10).unwrap();
765        }
766
767        // Iterate without draining
768        let entries = state.iter_all().unwrap();
769        assert_eq!(entries.len(), 5);
770
771        // State should still have values
772        assert_eq!(state.total_size(), 5);
773
774        // Should be able to iterate again
775        let entries2 = state.iter_all().unwrap();
776        assert_eq!(entries2.len(), 5);
777    }
778
779    #[test]
780    fn test_many_groups() {
781        let (_temp_dir, manager) = create_manager();
782        let mut state: PartitionedState<i64> =
783            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
784
785        // Insert many groups
786        for i in 0..1000 {
787            state.insert(key(&[i]), i).unwrap();
788        }
789
790        assert_eq!(state.total_size(), 1000);
791
792        // Spill multiple partitions
793        for _ in 0..8 {
794            state.spill_largest().unwrap();
795        }
796
797        assert!(state.spilled_count() >= 8);
798
799        // All values should still be retrievable
800        for i in 0..1000 {
801            assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
802        }
803    }
804
805    #[test]
806    fn test_cleanup() {
807        let (_temp_dir, manager) = create_manager();
808        let mut state: PartitionedState<i64> =
809            PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
810
811        // Insert and spill
812        for i in 0..20 {
813            state.insert(key(&[i]), i).unwrap();
814        }
815        state.spill_largest().unwrap();
816        state.spill_largest().unwrap();
817
818        let spilled_before = manager.spilled_bytes();
819        assert!(spilled_before > 0);
820
821        // Cleanup
822        state.cleanup();
823
824        assert_eq!(state.total_size(), 0);
825        assert_eq!(state.spilled_count(), 0);
826    }
827
828    #[test]
829    fn test_multi_column_key() {
830        let (_temp_dir, manager) = create_manager();
831        let mut state: PartitionedState<i64> =
832            PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
833
834        // Insert with multi-column keys
835        state
836            .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
837            .unwrap();
838        state
839            .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
840            .unwrap();
841        state
842            .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
843            .unwrap();
844
845        assert_eq!(state.total_size(), 3);
846
847        // Retrieve by multi-column key
848        assert_eq!(
849            state
850                .get(&[Value::String("a".into()), Value::Int64(1)])
851                .unwrap(),
852            Some(&100)
853        );
854        assert_eq!(
855            state
856                .get(&[Value::String("a".into()), Value::Int64(2)])
857                .unwrap(),
858            Some(&200)
859        );
860        assert_eq!(
861            state
862                .get(&[Value::String("b".into()), Value::Int64(1)])
863                .unwrap(),
864            Some(&300)
865        );
866    }
867
868    #[test]
869    fn test_update_existing() {
870        let (_temp_dir, manager) = create_manager();
871        let mut state: PartitionedState<i64> =
872            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
873
874        // Insert
875        state.insert(key(&[1]), 100).unwrap();
876        assert_eq!(state.total_size(), 1);
877
878        // Update
879        let old = state.insert(key(&[1]), 200).unwrap();
880        assert_eq!(old, Some(100));
881        assert_eq!(state.total_size(), 1); // Size shouldn't increase
882
883        // Verify update
884        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
885    }
886}