1use std::sync::Arc;
32
33use redis::aio::ConnectionManager;
34use tokio::sync::mpsc;
35
36use crate::error::Result;
37
38const DEFAULT_CHANNEL_SIZE: usize = 16;
40
41#[derive(Debug)]
43pub struct KeyBatch {
44 pub keys: Vec<String>,
46 pub sequence: u64,
48}
49
50#[derive(Debug)]
52pub struct FetchResult<T> {
53 pub data: Vec<T>,
55 pub sequence: u64,
57}
58
59#[derive(Debug, Clone)]
61pub struct ParallelConfig {
62 pub workers: usize,
64 pub channel_size: usize,
66 pub preserve_order: bool,
68}
69
70impl Default for ParallelConfig {
71 fn default() -> Self {
72 Self {
73 workers: 4,
74 channel_size: DEFAULT_CHANNEL_SIZE,
75 preserve_order: false,
76 }
77 }
78}
79
80impl ParallelConfig {
81 pub fn new(workers: usize) -> Self {
83 Self {
84 workers: workers.max(1),
85 ..Default::default()
86 }
87 }
88
89 pub fn with_channel_size(mut self, size: usize) -> Self {
91 self.channel_size = size;
92 self
93 }
94
95 pub fn with_preserve_order(mut self, preserve: bool) -> Self {
97 self.preserve_order = preserve;
98 self
99 }
100}
101
102pub trait ParallelFetch: Send + Sync + 'static {
106 type Output: Send + 'static;
108
109 fn fetch(
113 &self,
114 conn: ConnectionManager,
115 keys: Vec<String>,
116 ) -> impl std::future::Future<Output = Result<Vec<Self::Output>>> + Send;
117}
118
119pub struct ParallelFetcher<F: ParallelFetch> {
121 config: ParallelConfig,
123 conn: ConnectionManager,
125 fetcher: Arc<F>,
127 key_tx: Option<mpsc::Sender<KeyBatch>>,
129 result_rx: Option<mpsc::Receiver<FetchResult<F::Output>>>,
131 next_sequence: u64,
133}
134
135impl<F: ParallelFetch> ParallelFetcher<F> {
136 pub fn new(conn: ConnectionManager, fetcher: F, config: ParallelConfig) -> Self {
138 Self {
139 config,
140 conn,
141 fetcher: Arc::new(fetcher),
142 key_tx: None,
143 result_rx: None,
144 next_sequence: 0,
145 }
146 }
147
148 pub fn start(&mut self) {
152 let (key_tx, key_rx) = mpsc::channel::<KeyBatch>(self.config.channel_size);
153 let (result_tx, result_rx) =
154 mpsc::channel::<FetchResult<F::Output>>(self.config.channel_size);
155
156 let key_rx = Arc::new(tokio::sync::Mutex::new(key_rx));
158
159 for _ in 0..self.config.workers {
161 let conn = self.conn.clone();
162 let fetcher = Arc::clone(&self.fetcher);
163 let key_rx = Arc::clone(&key_rx);
164 let result_tx = result_tx.clone();
165
166 tokio::spawn(async move {
167 loop {
168 let batch = {
170 let mut rx = key_rx.lock().await;
171 rx.recv().await
172 };
173
174 match batch {
175 Some(KeyBatch { keys, sequence }) => {
176 match fetcher.fetch(conn.clone(), keys).await {
178 Ok(data) => {
179 let _ = result_tx.send(FetchResult { data, sequence }).await;
180 }
181 Err(_e) => {
182 }
185 }
186 }
187 None => break, }
189 }
190 });
191 }
192
193 self.key_tx = Some(key_tx);
194 self.result_rx = Some(result_rx);
195 }
196
197 pub async fn submit(&mut self, keys: Vec<String>) -> Result<()> {
199 if let Some(tx) = &self.key_tx {
200 let batch = KeyBatch {
201 keys,
202 sequence: self.next_sequence,
203 };
204 self.next_sequence += 1;
205 tx.send(batch)
206 .await
207 .map_err(|_| crate::error::Error::Channel("Channel closed".to_string()))?;
208 }
209 Ok(())
210 }
211
212 pub fn finish_submitting(&mut self) {
214 self.key_tx = None;
215 }
216
217 pub async fn recv(&mut self) -> Option<FetchResult<F::Output>> {
219 if let Some(rx) = &mut self.result_rx {
220 rx.recv().await
221 } else {
222 None
223 }
224 }
225
226 pub async fn collect_all(&mut self) -> Vec<FetchResult<F::Output>> {
228 let mut results = Vec::new();
229 while let Some(result) = self.recv().await {
230 results.push(result);
231 }
232
233 if self.config.preserve_order {
235 results.sort_by_key(|r| r.sequence);
236 }
237
238 results
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_parallel_config_default() {
248 let config = ParallelConfig::default();
249 assert_eq!(config.workers, 4);
250 assert_eq!(config.channel_size, DEFAULT_CHANNEL_SIZE);
251 assert!(!config.preserve_order);
252 }
253
254 #[test]
255 fn test_parallel_config_builder() {
256 let config = ParallelConfig::new(8)
257 .with_channel_size(32)
258 .with_preserve_order(true);
259 assert_eq!(config.workers, 8);
260 assert_eq!(config.channel_size, 32);
261 assert!(config.preserve_order);
262 }
263
264 #[test]
265 fn test_parallel_config_min_workers() {
266 let config = ParallelConfig::new(0);
267 assert_eq!(config.workers, 1); }
269
270 #[test]
271 fn test_key_batch() {
272 let batch = KeyBatch {
273 keys: vec!["a".to_string(), "b".to_string()],
274 sequence: 42,
275 };
276 assert_eq!(batch.keys.len(), 2);
277 assert_eq!(batch.sequence, 42);
278 }
279}