Skip to main content

grafeo_core/execution/spill/
partition.rs

1//! Hash partitioning for spillable aggregation.
2//!
3//! This module implements hash partitioning that allows aggregate state
4//! to be partitioned and spilled to disk when memory pressure is high.
5//!
6//! # Design
7//!
8//! - Groups are assigned to partitions based on their key's hash
9//! - In-memory partitions can be spilled to disk under memory pressure
10//! - Cold (least recently accessed) partitions are spilled first
11//! - When iterating results, spilled partitions are reloaded
12
13use super::file::SpillFile;
14use super::manager::SpillManager;
15use super::serializer::{deserialize_row, serialize_row};
16use grafeo_common::types::Value;
17use std::collections::HashMap;
18use std::io::{Read, Write};
19use std::sync::Arc;
20
21/// Default number of partitions for hash partitioning.
22pub const DEFAULT_NUM_PARTITIONS: usize = 256;
23
24/// A serialized key for use as a HashMap key.
25/// We serialize Value vectors to bytes since Value doesn't implement Hash/Eq.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct SerializedKey(Vec<u8>);
28
29impl SerializedKey {
30    fn from_values(values: &[Value]) -> Self {
31        let mut buf = Vec::new();
32        serialize_row(values, &mut buf).expect("serialization should not fail");
33        Self(buf)
34    }
35
36    fn to_values(&self, num_columns: usize) -> std::io::Result<Vec<Value>> {
37        deserialize_row(&mut self.0.as_slice(), num_columns)
38    }
39}
40
41/// Entry in a partition: the original key columns count and value.
42struct PartitionEntry<V> {
43    num_key_columns: usize,
44    value: V,
45}
46
47/// Partitioned accumulator state for spillable aggregation.
48///
49/// Manages aggregate state across multiple partitions, with the ability
50/// to spill cold partitions to disk under memory pressure.
51pub struct PartitionedState<V> {
52    /// Spill manager for file creation.
53    manager: Arc<SpillManager>,
54    /// Number of partitions.
55    num_partitions: usize,
56    /// In-memory partitions (None = spilled to disk).
57    partitions: Vec<Option<HashMap<SerializedKey, PartitionEntry<V>>>>,
58    /// Spill files for spilled partitions.
59    spill_files: Vec<Option<SpillFile>>,
60    /// Number of groups per partition (for spilled partitions too).
61    partition_sizes: Vec<usize>,
62    /// Access timestamps for LRU eviction.
63    access_times: Vec<u64>,
64    /// Global timestamp counter.
65    timestamp: u64,
66    /// Serializer for V values.
67    value_serializer: Box<dyn Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync>,
68    /// Deserializer for V values.
69    value_deserializer: Box<dyn Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync>,
70}
71
72impl<V: Clone + Send + Sync + 'static> PartitionedState<V> {
73    /// Creates a new partitioned state with custom serialization.
74    pub fn new<S, D>(
75        manager: Arc<SpillManager>,
76        num_partitions: usize,
77        value_serializer: S,
78        value_deserializer: D,
79    ) -> Self
80    where
81        S: Fn(&V, &mut dyn Write) -> std::io::Result<()> + Send + Sync + 'static,
82        D: Fn(&mut dyn Read) -> std::io::Result<V> + Send + Sync + 'static,
83    {
84        let mut partitions = Vec::with_capacity(num_partitions);
85        let mut spill_files = Vec::with_capacity(num_partitions);
86        for _ in 0..num_partitions {
87            partitions.push(Some(HashMap::new()));
88            spill_files.push(None);
89        }
90
91        let partition_sizes = vec![0; num_partitions];
92        let access_times = vec![0; num_partitions];
93
94        Self {
95            manager,
96            num_partitions,
97            partitions,
98            spill_files,
99            partition_sizes,
100            access_times,
101            timestamp: 0,
102            value_serializer: Box::new(value_serializer),
103            value_deserializer: Box::new(value_deserializer),
104        }
105    }
106
107    /// Returns the partition index for a key.
108    #[must_use]
109    pub fn partition_for(&self, key: &[Value]) -> usize {
110        let hash = hash_key(key);
111        hash as usize % self.num_partitions
112    }
113
114    /// Updates access time for a partition.
115    fn touch(&mut self, partition_idx: usize) {
116        self.timestamp += 1;
117        self.access_times[partition_idx] = self.timestamp;
118    }
119
120    /// Gets the in-memory partition, loading from disk if spilled.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if reading from disk fails.
125    fn get_partition_mut(
126        &mut self,
127        partition_idx: usize,
128    ) -> std::io::Result<&mut HashMap<SerializedKey, PartitionEntry<V>>> {
129        self.touch(partition_idx);
130
131        // If partition is in memory, return it
132        if self.partitions[partition_idx].is_some() {
133            // Invariant: just checked is_some() above
134            return Ok(self.partitions[partition_idx]
135                .as_mut()
136                .expect("partition is Some: checked on previous line"));
137        }
138
139        // Load from disk
140        if let Some(spill_file) = self.spill_files[partition_idx].take() {
141            let loaded = self.load_partition(&spill_file)?;
142            // Delete the spill file after loading
143            let bytes = spill_file.bytes_written();
144            let _ = spill_file.delete();
145            self.manager.unregister_spilled_bytes(bytes);
146            self.partitions[partition_idx] = Some(loaded);
147        } else {
148            // Neither in memory nor on disk - create empty partition
149            self.partitions[partition_idx] = Some(HashMap::new());
150        }
151
152        // Invariant: partition was either loaded from disk or created empty above
153        Ok(self.partitions[partition_idx]
154            .as_mut()
155            .expect("partition is Some: set to Some in if/else branches above"))
156    }
157
158    /// Loads a partition from a spill file.
159    fn load_partition(
160        &self,
161        spill_file: &SpillFile,
162    ) -> std::io::Result<HashMap<SerializedKey, PartitionEntry<V>>> {
163        let mut reader = spill_file.reader()?;
164        let mut adapter = SpillReaderAdapter(&mut reader);
165
166        let num_entries = read_u64(&mut adapter)? as usize;
167        let mut partition = HashMap::with_capacity(num_entries);
168
169        for _ in 0..num_entries {
170            // Read key
171            let key_len = read_u64(&mut adapter)? as usize;
172            let mut key_buf = vec![0u8; key_len];
173            adapter.read_exact(&mut key_buf)?;
174            let serialized_key = SerializedKey(key_buf);
175
176            // Read number of key columns
177            let num_key_columns = read_u64(&mut adapter)? as usize;
178
179            // Read value
180            let value = (self.value_deserializer)(&mut adapter)?;
181
182            partition.insert(
183                serialized_key,
184                PartitionEntry {
185                    num_key_columns,
186                    value,
187                },
188            );
189        }
190
191        Ok(partition)
192    }
193
194    /// Returns whether a partition is in memory.
195    #[must_use]
196    pub fn is_in_memory(&self, partition_idx: usize) -> bool {
197        self.partitions[partition_idx].is_some()
198    }
199
200    /// Returns the number of groups in a partition.
201    #[must_use]
202    pub fn partition_size(&self, partition_idx: usize) -> usize {
203        self.partition_sizes[partition_idx]
204    }
205
206    /// Returns the total number of groups across all partitions.
207    #[must_use]
208    pub fn total_size(&self) -> usize {
209        self.partition_sizes.iter().sum()
210    }
211
212    /// Returns the number of in-memory partitions.
213    #[must_use]
214    pub fn in_memory_count(&self) -> usize {
215        self.partitions.iter().filter(|p| p.is_some()).count()
216    }
217
218    /// Returns the number of spilled partitions.
219    #[must_use]
220    pub fn spilled_count(&self) -> usize {
221        self.spill_files.iter().filter(|f| f.is_some()).count()
222    }
223
224    /// Spills a specific partition to disk.
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if writing to disk fails.
229    pub fn spill_partition(&mut self, partition_idx: usize) -> std::io::Result<usize> {
230        // Get partition data
231        let Some(partition) = self.partitions[partition_idx].take() else {
232            return Ok(0); // Already spilled
233        };
234
235        if partition.is_empty() {
236            return Ok(0);
237        }
238
239        // Create spill file
240        let mut spill_file = self.manager.create_file("partition")?;
241
242        // Write partition data
243        let mut buf = Vec::new();
244        write_u64(&mut buf, partition.len() as u64)?;
245
246        for (key, entry) in &partition {
247            // Write key bytes
248            write_u64(&mut buf, key.0.len() as u64)?;
249            buf.extend_from_slice(&key.0);
250
251            // Write number of key columns
252            write_u64(&mut buf, entry.num_key_columns as u64)?;
253
254            // Write value
255            (self.value_serializer)(&entry.value, &mut buf)?;
256        }
257
258        spill_file.write_all(&buf)?;
259        spill_file.finish_write()?;
260
261        let bytes_written = spill_file.bytes_written();
262        self.manager.register_spilled_bytes(bytes_written);
263        self.partition_sizes[partition_idx] = partition.len();
264        self.spill_files[partition_idx] = Some(spill_file);
265
266        Ok(bytes_written as usize)
267    }
268
269    /// Spills the largest in-memory partition.
270    ///
271    /// Returns the number of bytes spilled, or 0 if no partition to spill.
272    ///
273    /// # Errors
274    ///
275    /// Returns an error if writing to disk fails.
276    pub fn spill_largest(&mut self) -> std::io::Result<usize> {
277        // Find largest in-memory partition
278        let largest_idx = self
279            .partitions
280            .iter()
281            .enumerate()
282            .filter_map(|(idx, p)| p.as_ref().map(|m| (idx, m.len())))
283            .max_by_key(|(_, size)| *size)
284            .map(|(idx, _)| idx);
285
286        match largest_idx {
287            Some(idx) => self.spill_partition(idx),
288            None => Ok(0),
289        }
290    }
291
292    /// Spills the least recently used in-memory partition.
293    ///
294    /// Returns the number of bytes spilled, or 0 if no partition to spill.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if writing to disk fails.
299    pub fn spill_lru(&mut self) -> std::io::Result<usize> {
300        // Find LRU in-memory partition
301        let lru_idx = self
302            .partitions
303            .iter()
304            .enumerate()
305            .filter(|(_, p)| p.is_some())
306            .min_by_key(|(idx, _)| self.access_times[*idx])
307            .map(|(idx, _)| idx);
308
309        match lru_idx {
310            Some(idx) => self.spill_partition(idx),
311            None => Ok(0),
312        }
313    }
314
315    /// Inserts or updates a value for a key.
316    ///
317    /// # Errors
318    ///
319    /// Returns an error if loading from disk fails.
320    pub fn insert(&mut self, key: Vec<Value>, value: V) -> std::io::Result<Option<V>> {
321        let partition_idx = self.partition_for(&key);
322        let num_key_columns = key.len();
323        let serialized_key = SerializedKey::from_values(&key);
324        let partition = self.get_partition_mut(partition_idx)?;
325
326        let old = partition.insert(
327            serialized_key,
328            PartitionEntry {
329                num_key_columns,
330                value,
331            },
332        );
333
334        if old.is_none() {
335            self.partition_sizes[partition_idx] += 1;
336        }
337
338        Ok(old.map(|e| e.value))
339    }
340
341    /// Gets a value for a key.
342    ///
343    /// # Errors
344    ///
345    /// Returns an error if loading from disk fails.
346    pub fn get(&mut self, key: &[Value]) -> std::io::Result<Option<&V>> {
347        let partition_idx = self.partition_for(key);
348        let serialized_key = SerializedKey::from_values(key);
349        let partition = self.get_partition_mut(partition_idx)?;
350        Ok(partition.get(&serialized_key).map(|e| &e.value))
351    }
352
353    /// Gets a mutable value for a key, or inserts a default.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if loading from disk fails.
358    pub fn get_or_insert_with<F>(&mut self, key: Vec<Value>, default: F) -> std::io::Result<&mut V>
359    where
360        F: FnOnce() -> V,
361    {
362        let partition_idx = self.partition_for(&key);
363        let num_key_columns = key.len();
364        let serialized_key = SerializedKey::from_values(&key);
365
366        let was_new;
367        {
368            let partition = self.get_partition_mut(partition_idx)?;
369            was_new = !partition.contains_key(&serialized_key);
370            if was_new {
371                partition.insert(
372                    serialized_key.clone(),
373                    PartitionEntry {
374                        num_key_columns,
375                        value: default(),
376                    },
377                );
378            }
379        }
380        if was_new {
381            self.partition_sizes[partition_idx] += 1;
382        }
383
384        let partition = self.get_partition_mut(partition_idx)?;
385        // Invariant: key was either already present or inserted in the block above
386        Ok(&mut partition
387            .get_mut(&serialized_key)
388            .expect("key exists: just inserted or already present in partition")
389            .value)
390    }
391
392    /// Drains all entries from all partitions.
393    ///
394    /// Loads spilled partitions as needed.
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if loading from disk fails.
399    pub fn drain_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
400        let mut result = Vec::with_capacity(self.total_size());
401
402        for partition_idx in 0..self.num_partitions {
403            let partition = self.get_partition_mut(partition_idx)?;
404            for (serialized_key, entry) in partition.drain() {
405                let key = serialized_key.to_values(entry.num_key_columns)?;
406                result.push((key, entry.value));
407            }
408            self.partition_sizes[partition_idx] = 0;
409        }
410
411        // Clean up any remaining spill files
412        for spill_file in &mut self.spill_files {
413            if let Some(file) = spill_file.take() {
414                let bytes = file.bytes_written();
415                let _ = file.delete();
416                self.manager.unregister_spilled_bytes(bytes);
417            }
418        }
419
420        Ok(result)
421    }
422
423    /// Iterates over all entries without draining.
424    ///
425    /// Loads spilled partitions as needed.
426    ///
427    /// # Errors
428    ///
429    /// Returns an error if loading from disk fails.
430    pub fn iter_all(&mut self) -> std::io::Result<Vec<(Vec<Value>, V)>> {
431        let mut result = Vec::with_capacity(self.total_size());
432
433        for partition_idx in 0..self.num_partitions {
434            let partition = self.get_partition_mut(partition_idx)?;
435            for (serialized_key, entry) in partition.iter() {
436                let key = serialized_key.to_values(entry.num_key_columns)?;
437                result.push((key, entry.value.clone()));
438            }
439        }
440
441        Ok(result)
442    }
443
444    /// Cleans up all spill files.
445    pub fn cleanup(&mut self) {
446        for file in self.spill_files.iter_mut().flatten() {
447            let bytes = file.bytes_written();
448            self.manager.unregister_spilled_bytes(bytes);
449        }
450
451        self.spill_files.clear();
452        self.partitions.clear();
453        for _ in 0..self.num_partitions {
454            self.spill_files.push(None);
455            self.partitions.push(Some(HashMap::new()));
456        }
457        self.partition_sizes = vec![0; self.num_partitions];
458    }
459}
460
461impl<V> Drop for PartitionedState<V> {
462    fn drop(&mut self) {
463        // Unregister spilled bytes
464        for file in self.spill_files.iter().flatten() {
465            let bytes = file.bytes_written();
466            self.manager.unregister_spilled_bytes(bytes);
467        }
468    }
469}
470
471/// Hashes a key (vector of values) to a u64.
472fn hash_key(key: &[Value]) -> u64 {
473    use std::hash::{Hash, Hasher};
474    let mut hasher = std::collections::hash_map::DefaultHasher::new();
475
476    for value in key {
477        match value {
478            Value::Null => 0u8.hash(&mut hasher),
479            Value::Bool(b) => {
480                1u8.hash(&mut hasher);
481                b.hash(&mut hasher);
482            }
483            Value::Int64(n) => {
484                2u8.hash(&mut hasher);
485                n.hash(&mut hasher);
486            }
487            Value::Float64(f) => {
488                3u8.hash(&mut hasher);
489                f.to_bits().hash(&mut hasher);
490            }
491            Value::String(s) => {
492                4u8.hash(&mut hasher);
493                s.hash(&mut hasher);
494            }
495            Value::Bytes(b) => {
496                5u8.hash(&mut hasher);
497                b.hash(&mut hasher);
498            }
499            Value::Timestamp(t) => {
500                6u8.hash(&mut hasher);
501                t.hash(&mut hasher);
502            }
503            Value::Date(d) => {
504                10u8.hash(&mut hasher);
505                d.hash(&mut hasher);
506            }
507            Value::Time(t) => {
508                11u8.hash(&mut hasher);
509                t.hash(&mut hasher);
510            }
511            Value::Duration(d) => {
512                12u8.hash(&mut hasher);
513                d.hash(&mut hasher);
514            }
515            Value::ZonedDatetime(zdt) => {
516                14u8.hash(&mut hasher);
517                zdt.hash(&mut hasher);
518            }
519            Value::List(l) => {
520                7u8.hash(&mut hasher);
521                l.len().hash(&mut hasher);
522            }
523            Value::Map(m) => {
524                8u8.hash(&mut hasher);
525                m.len().hash(&mut hasher);
526            }
527            Value::Vector(v) => {
528                9u8.hash(&mut hasher);
529                v.len().hash(&mut hasher);
530                // Hash first few elements for distribution
531                for &f in v.iter().take(4) {
532                    f.to_bits().hash(&mut hasher);
533                }
534            }
535            Value::Path { nodes, edges } => {
536                13u8.hash(&mut hasher);
537                nodes.len().hash(&mut hasher);
538                edges.len().hash(&mut hasher);
539            }
540        }
541    }
542
543    hasher.finish()
544}
545
546/// Helper to read u64 in little endian.
547fn read_u64<R: Read>(reader: &mut R) -> std::io::Result<u64> {
548    let mut buf = [0u8; 8];
549    reader.read_exact(&mut buf)?;
550    Ok(u64::from_le_bytes(buf))
551}
552
553/// Helper to write u64 in little endian.
554fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
555    writer.write_all(&value.to_le_bytes())
556}
557
558/// Adapter to read from SpillFileReader through std::io::Read.
559struct SpillReaderAdapter<'a>(&'a mut super::file::SpillFileReader);
560
561impl Read for SpillReaderAdapter<'_> {
562    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
563        self.0.read_exact(buf)?;
564        Ok(buf.len())
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use tempfile::TempDir;
572
573    /// Creates a test manager. Returns (TempDir, manager). TempDir must be kept alive.
574    fn create_manager() -> (TempDir, Arc<SpillManager>) {
575        let temp_dir = TempDir::new().unwrap();
576        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
577        (temp_dir, manager)
578    }
579
580    /// Simple i64 serializer for tests.
581    #[allow(clippy::trivially_copy_pass_by_ref)] // Required by PartitionedState::new signature
582    fn serialize_i64(value: &i64, w: &mut dyn Write) -> std::io::Result<()> {
583        w.write_all(&value.to_le_bytes())
584    }
585
586    /// Simple i64 deserializer for tests.
587    fn deserialize_i64(r: &mut dyn Read) -> std::io::Result<i64> {
588        let mut buf = [0u8; 8];
589        r.read_exact(&mut buf)?;
590        Ok(i64::from_le_bytes(buf))
591    }
592
593    fn key(values: &[i64]) -> Vec<Value> {
594        values.iter().map(|&v| Value::Int64(v)).collect()
595    }
596
597    #[test]
598    fn test_partition_for() {
599        let (_temp_dir, manager) = create_manager();
600        let state: PartitionedState<i64> =
601            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
602
603        // Same key should always go to same partition
604        let k1 = key(&[1, 2, 3]);
605        let p1 = state.partition_for(&k1);
606        let p2 = state.partition_for(&k1);
607        assert_eq!(p1, p2);
608
609        // Partition should be in range
610        assert!(p1 < 16);
611    }
612
613    #[test]
614    fn test_insert_and_get() {
615        let (_temp_dir, manager) = create_manager();
616        let mut state: PartitionedState<i64> =
617            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
618
619        // Insert some values
620        state.insert(key(&[1]), 100).unwrap();
621        state.insert(key(&[2]), 200).unwrap();
622        state.insert(key(&[3]), 300).unwrap();
623
624        assert_eq!(state.total_size(), 3);
625
626        // Get values
627        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&100));
628        assert_eq!(state.get(&key(&[2])).unwrap(), Some(&200));
629        assert_eq!(state.get(&key(&[3])).unwrap(), Some(&300));
630        assert_eq!(state.get(&key(&[4])).unwrap(), None);
631    }
632
633    #[test]
634    fn test_get_or_insert_with() {
635        let (_temp_dir, manager) = create_manager();
636        let mut state: PartitionedState<i64> =
637            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
638
639        // First access creates the entry
640        let v1 = state.get_or_insert_with(key(&[1]), || 42).unwrap();
641        assert_eq!(*v1, 42);
642
643        // Second access returns existing value
644        let v2 = state.get_or_insert_with(key(&[1]), || 100).unwrap();
645        assert_eq!(*v2, 42);
646
647        // Mutate via returned reference
648        *state.get_or_insert_with(key(&[1]), || 0).unwrap() = 999;
649        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&999));
650    }
651
652    #[test]
653    fn test_spill_and_reload() {
654        let (_temp_dir, manager) = create_manager();
655        let mut state: PartitionedState<i64> =
656            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
657
658        // Insert values that go to different partitions
659        for i in 0..20 {
660            state.insert(key(&[i]), i * 10).unwrap();
661        }
662
663        let initial_total = state.total_size();
664        assert!(initial_total > 0);
665
666        // Spill the largest partition
667        let bytes_spilled = state.spill_largest().unwrap();
668        assert!(bytes_spilled > 0);
669        assert!(state.spilled_count() > 0);
670
671        // Values should still be accessible (reloads from disk)
672        for i in 0..20 {
673            let expected = i * 10;
674            assert_eq!(state.get(&key(&[i])).unwrap(), Some(&expected));
675        }
676    }
677
678    #[test]
679    fn test_spill_lru() {
680        let (_temp_dir, manager) = create_manager();
681        let mut state: PartitionedState<i64> =
682            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
683
684        // Insert values
685        state.insert(key(&[1]), 10).unwrap();
686        state.insert(key(&[2]), 20).unwrap();
687        state.insert(key(&[3]), 30).unwrap();
688
689        // Access key 3 to make it recently used
690        state.get(&key(&[3])).unwrap();
691
692        // Spill LRU - should not spill partition containing key 3
693        state.spill_lru().unwrap();
694
695        // Key 3 should still be in memory
696        let partition_idx = state.partition_for(&key(&[3]));
697        assert!(state.is_in_memory(partition_idx));
698    }
699
700    #[test]
701    fn test_drain_all() {
702        let (_temp_dir, manager) = create_manager();
703        let mut state: PartitionedState<i64> =
704            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
705
706        // Insert values
707        for i in 0..10 {
708            state.insert(key(&[i]), i * 10).unwrap();
709        }
710
711        // Spill some partitions
712        state.spill_largest().unwrap();
713        state.spill_largest().unwrap();
714
715        // Drain all
716        let entries = state.drain_all().unwrap();
717        assert_eq!(entries.len(), 10);
718
719        // Verify all entries are present
720        let mut values: Vec<i64> = entries.iter().map(|(_, v)| *v).collect();
721        values.sort_unstable();
722        assert_eq!(values, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
723
724        // State should be empty
725        assert_eq!(state.total_size(), 0);
726    }
727
728    #[test]
729    fn test_iter_all() {
730        let (_temp_dir, manager) = create_manager();
731        let mut state: PartitionedState<i64> =
732            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
733
734        // Insert values
735        for i in 0..5 {
736            state.insert(key(&[i]), i * 10).unwrap();
737        }
738
739        // Iterate without draining
740        let entries = state.iter_all().unwrap();
741        assert_eq!(entries.len(), 5);
742
743        // State should still have values
744        assert_eq!(state.total_size(), 5);
745
746        // Should be able to iterate again
747        let entries2 = state.iter_all().unwrap();
748        assert_eq!(entries2.len(), 5);
749    }
750
751    #[test]
752    fn test_many_groups() {
753        let (_temp_dir, manager) = create_manager();
754        let mut state: PartitionedState<i64> =
755            PartitionedState::new(manager, 16, serialize_i64, deserialize_i64);
756
757        // Insert many groups
758        for i in 0..1000 {
759            state.insert(key(&[i]), i).unwrap();
760        }
761
762        assert_eq!(state.total_size(), 1000);
763
764        // Spill multiple partitions
765        for _ in 0..8 {
766            state.spill_largest().unwrap();
767        }
768
769        assert!(state.spilled_count() >= 8);
770
771        // All values should still be retrievable
772        for i in 0..1000 {
773            assert_eq!(state.get(&key(&[i])).unwrap(), Some(&i));
774        }
775    }
776
777    #[test]
778    fn test_cleanup() {
779        let (_temp_dir, manager) = create_manager();
780        let mut state: PartitionedState<i64> =
781            PartitionedState::new(Arc::clone(&manager), 4, serialize_i64, deserialize_i64);
782
783        // Insert and spill
784        for i in 0..20 {
785            state.insert(key(&[i]), i).unwrap();
786        }
787        state.spill_largest().unwrap();
788        state.spill_largest().unwrap();
789
790        let spilled_before = manager.spilled_bytes();
791        assert!(spilled_before > 0);
792
793        // Cleanup
794        state.cleanup();
795
796        assert_eq!(state.total_size(), 0);
797        assert_eq!(state.spilled_count(), 0);
798    }
799
800    #[test]
801    fn test_multi_column_key() {
802        let (_temp_dir, manager) = create_manager();
803        let mut state: PartitionedState<i64> =
804            PartitionedState::new(manager, 8, serialize_i64, deserialize_i64);
805
806        // Insert with multi-column keys
807        state
808            .insert(vec![Value::String("a".into()), Value::Int64(1)], 100)
809            .unwrap();
810        state
811            .insert(vec![Value::String("a".into()), Value::Int64(2)], 200)
812            .unwrap();
813        state
814            .insert(vec![Value::String("b".into()), Value::Int64(1)], 300)
815            .unwrap();
816
817        assert_eq!(state.total_size(), 3);
818
819        // Retrieve by multi-column key
820        assert_eq!(
821            state
822                .get(&[Value::String("a".into()), Value::Int64(1)])
823                .unwrap(),
824            Some(&100)
825        );
826        assert_eq!(
827            state
828                .get(&[Value::String("a".into()), Value::Int64(2)])
829                .unwrap(),
830            Some(&200)
831        );
832        assert_eq!(
833            state
834                .get(&[Value::String("b".into()), Value::Int64(1)])
835                .unwrap(),
836            Some(&300)
837        );
838    }
839
840    #[test]
841    fn test_update_existing() {
842        let (_temp_dir, manager) = create_manager();
843        let mut state: PartitionedState<i64> =
844            PartitionedState::new(manager, 4, serialize_i64, deserialize_i64);
845
846        // Insert
847        state.insert(key(&[1]), 100).unwrap();
848        assert_eq!(state.total_size(), 1);
849
850        // Update
851        let old = state.insert(key(&[1]), 200).unwrap();
852        assert_eq!(old, Some(100));
853        assert_eq!(state.total_size(), 1); // Size shouldn't increase
854
855        // Verify update
856        assert_eq!(state.get(&key(&[1])).unwrap(), Some(&200));
857    }
858}