polars_redis/
parallel.rs

1//! Parallel batch fetching infrastructure.
2//!
3//! This module provides utilities for parallelizing Redis fetch operations
4//! while maintaining a single SCAN cursor for key discovery.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────┐     ┌──────────────┐     ┌─────────────────┐
10//! │  SCAN Keys  │────▶│  Key Buffer  │────▶│  Fetch Workers  │
11//! │  (single)   │     │  (channel)   │     │  (N parallel)   │
12//! └─────────────┘     └──────────────┘     └─────────────────┘
13//!                                                   │
14//!                                                   ▼
15//!                                          ┌─────────────────┐
16//!                                          │  Result Buffer  │
17//!                                          │  (channel)      │
18//!                                          └─────────────────┘
19//! ```
20//!
21//! # Example
22//!
23//! ```ignore
24//! use polars_redis::parallel::{ParallelFetcher, FetchTask};
25//!
26//! let fetcher = ParallelFetcher::new(conn, 4); // 4 workers
27//! fetcher.submit(keys);
28//! let results = fetcher.collect().await?;
29//! ```
30
31use std::sync::Arc;
32
33use redis::aio::ConnectionManager;
34use tokio::sync::mpsc;
35
36use crate::error::Result;
37
38/// Default channel buffer size for key batches.
39const DEFAULT_CHANNEL_SIZE: usize = 16;
40
41/// A batch of keys to be fetched.
42#[derive(Debug)]
43pub struct KeyBatch {
44    /// The keys to fetch.
45    pub keys: Vec<String>,
46    /// Batch sequence number for ordering.
47    pub sequence: u64,
48}
49
50/// Result from a fetch operation.
51#[derive(Debug)]
52pub struct FetchResult<T> {
53    /// The fetched data.
54    pub data: Vec<T>,
55    /// Batch sequence number for ordering.
56    pub sequence: u64,
57}
58
59/// Configuration for parallel fetching.
60#[derive(Debug, Clone)]
61pub struct ParallelConfig {
62    /// Number of worker tasks.
63    pub workers: usize,
64    /// Channel buffer size.
65    pub channel_size: usize,
66    /// Whether to preserve ordering.
67    pub preserve_order: bool,
68}
69
70impl Default for ParallelConfig {
71    fn default() -> Self {
72        Self {
73            workers: 4,
74            channel_size: DEFAULT_CHANNEL_SIZE,
75            preserve_order: false,
76        }
77    }
78}
79
80impl ParallelConfig {
81    /// Create a new config with the given worker count.
82    pub fn new(workers: usize) -> Self {
83        Self {
84            workers: workers.max(1),
85            ..Default::default()
86        }
87    }
88
89    /// Set the channel buffer size.
90    pub fn with_channel_size(mut self, size: usize) -> Self {
91        self.channel_size = size;
92        self
93    }
94
95    /// Enable ordering preservation (slower but deterministic).
96    pub fn with_preserve_order(mut self, preserve: bool) -> Self {
97        self.preserve_order = preserve;
98        self
99    }
100}
101
102/// Trait for types that can be fetched in parallel.
103///
104/// Implementors define how to fetch a batch of keys from Redis.
105pub trait ParallelFetch: Send + Sync + 'static {
106    /// The output type for each fetched item.
107    type Output: Send + 'static;
108
109    /// Fetch data for the given keys.
110    ///
111    /// This is called by worker tasks and should be efficient for batch operations.
112    fn fetch(
113        &self,
114        conn: ConnectionManager,
115        keys: Vec<String>,
116    ) -> impl std::future::Future<Output = Result<Vec<Self::Output>>> + Send;
117}
118
119/// Parallel fetcher that distributes key batches across worker tasks.
120pub struct ParallelFetcher<F: ParallelFetch> {
121    /// Configuration.
122    config: ParallelConfig,
123    /// Redis connection (cloned for each worker).
124    conn: ConnectionManager,
125    /// The fetch implementation.
126    fetcher: Arc<F>,
127    /// Sender for submitting key batches.
128    key_tx: Option<mpsc::Sender<KeyBatch>>,
129    /// Receiver for fetch results.
130    result_rx: Option<mpsc::Receiver<FetchResult<F::Output>>>,
131    /// Next sequence number for batches.
132    next_sequence: u64,
133}
134
135impl<F: ParallelFetch> ParallelFetcher<F> {
136    /// Create a new parallel fetcher.
137    pub fn new(conn: ConnectionManager, fetcher: F, config: ParallelConfig) -> Self {
138        Self {
139            config,
140            conn,
141            fetcher: Arc::new(fetcher),
142            key_tx: None,
143            result_rx: None,
144            next_sequence: 0,
145        }
146    }
147
148    /// Start the worker tasks.
149    ///
150    /// This must be called before submitting batches.
151    pub fn start(&mut self) {
152        let (key_tx, key_rx) = mpsc::channel::<KeyBatch>(self.config.channel_size);
153        let (result_tx, result_rx) =
154            mpsc::channel::<FetchResult<F::Output>>(self.config.channel_size);
155
156        // Share the receiver across workers
157        let key_rx = Arc::new(tokio::sync::Mutex::new(key_rx));
158
159        // Spawn worker tasks
160        for _ in 0..self.config.workers {
161            let conn = self.conn.clone();
162            let fetcher = Arc::clone(&self.fetcher);
163            let key_rx = Arc::clone(&key_rx);
164            let result_tx = result_tx.clone();
165
166            tokio::spawn(async move {
167                loop {
168                    // Get next batch from shared receiver
169                    let batch = {
170                        let mut rx = key_rx.lock().await;
171                        rx.recv().await
172                    };
173
174                    match batch {
175                        Some(KeyBatch { keys, sequence }) => {
176                            // Fetch the data
177                            match fetcher.fetch(conn.clone(), keys).await {
178                                Ok(data) => {
179                                    let _ = result_tx.send(FetchResult { data, sequence }).await;
180                                }
181                                Err(_e) => {
182                                    // Error in fetch, continue processing
183                                    // TODO: Consider adding error channel for reporting
184                                }
185                            }
186                        }
187                        None => break, // Channel closed, exit worker
188                    }
189                }
190            });
191        }
192
193        self.key_tx = Some(key_tx);
194        self.result_rx = Some(result_rx);
195    }
196
197    /// Submit a batch of keys for fetching.
198    pub async fn submit(&mut self, keys: Vec<String>) -> Result<()> {
199        if let Some(tx) = &self.key_tx {
200            let batch = KeyBatch {
201                keys,
202                sequence: self.next_sequence,
203            };
204            self.next_sequence += 1;
205            tx.send(batch)
206                .await
207                .map_err(|_| crate::error::Error::Channel("Channel closed".to_string()))?;
208        }
209        Ok(())
210    }
211
212    /// Close the input channel and signal workers to finish.
213    pub fn finish_submitting(&mut self) {
214        self.key_tx = None;
215    }
216
217    /// Receive the next result.
218    pub async fn recv(&mut self) -> Option<FetchResult<F::Output>> {
219        if let Some(rx) = &mut self.result_rx {
220            rx.recv().await
221        } else {
222            None
223        }
224    }
225
226    /// Collect all remaining results.
227    pub async fn collect_all(&mut self) -> Vec<FetchResult<F::Output>> {
228        let mut results = Vec::new();
229        while let Some(result) = self.recv().await {
230            results.push(result);
231        }
232
233        // Sort by sequence if ordering is required
234        if self.config.preserve_order {
235            results.sort_by_key(|r| r.sequence);
236        }
237
238        results
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_parallel_config_default() {
248        let config = ParallelConfig::default();
249        assert_eq!(config.workers, 4);
250        assert_eq!(config.channel_size, DEFAULT_CHANNEL_SIZE);
251        assert!(!config.preserve_order);
252    }
253
254    #[test]
255    fn test_parallel_config_builder() {
256        let config = ParallelConfig::new(8)
257            .with_channel_size(32)
258            .with_preserve_order(true);
259        assert_eq!(config.workers, 8);
260        assert_eq!(config.channel_size, 32);
261        assert!(config.preserve_order);
262    }
263
264    #[test]
265    fn test_parallel_config_min_workers() {
266        let config = ParallelConfig::new(0);
267        assert_eq!(config.workers, 1); // Minimum 1 worker
268    }
269
270    #[test]
271    fn test_key_batch() {
272        let batch = KeyBatch {
273            keys: vec!["a".to_string(), "b".to_string()],
274            sequence: 42,
275        };
276        assert_eq!(batch.keys.len(), 2);
277        assert_eq!(batch.sequence, 42);
278    }
279}