Skip to main content

exiftool_rs_wrapper/
stream.rs

1//! 流式处理和性能优化模块
2//!
3//! 支持大文件流式处理、进度回调、内存池优化
4
5use crate::ExifTool;
6use crate::error::Result;
7use std::fmt;
8use std::io::{self, Read};
9use std::path::Path;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12
13/// 进度回调函数类型
14pub type ProgressCallback = Arc<dyn Fn(usize, usize) -> bool + Send + Sync>;
15
16/// 流式处理选项
17#[derive(Clone)]
18pub struct StreamOptions {
19    /// 缓冲区大小(字节)
20    pub buffer_size: usize,
21    /// 进度回调
22    pub progress_callback: Option<ProgressCallback>,
23    /// 超时(秒)
24    pub timeout: Option<u64>,
25}
26
27impl fmt::Debug for StreamOptions {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("StreamOptions")
30            .field("buffer_size", &self.buffer_size)
31            .field("has_callback", &self.progress_callback.is_some())
32            .field("timeout", &self.timeout)
33            .finish()
34    }
35}
36
37impl Default for StreamOptions {
38    fn default() -> Self {
39        Self {
40            buffer_size: 64 * 1024, // 64KB 默认缓冲区
41            progress_callback: None,
42            timeout: None,
43        }
44    }
45}
46
47impl StreamOptions {
48    /// 创建新的流式处理选项
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// 设置缓冲区大小
54    pub fn buffer_size(mut self, size: usize) -> Self {
55        self.buffer_size = size;
56        self
57    }
58
59    /// 设置进度回调
60    pub fn on_progress<F>(mut self, callback: F) -> Self
61    where
62        F: Fn(usize, usize) -> bool + Send + Sync + 'static,
63    {
64        self.progress_callback = Some(Arc::new(callback));
65        self
66    }
67
68    /// 设置超时
69    pub fn timeout(mut self, seconds: u64) -> Self {
70        self.timeout = Some(seconds);
71        self
72    }
73}
74
75/// 进度追踪器
76pub struct ProgressTracker {
77    /// 总字节数
78    total: AtomicU64,
79    /// 已处理字节数
80    processed: AtomicU64,
81    /// 回调函数
82    callback: Option<ProgressCallback>,
83    /// 是否取消
84    cancelled: std::sync::atomic::AtomicBool,
85}
86
87impl fmt::Debug for ProgressTracker {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("ProgressTracker")
90            .field("total", &self.total.load(Ordering::SeqCst))
91            .field("processed", &self.processed.load(Ordering::SeqCst))
92            .field("has_callback", &self.callback.is_some())
93            .field("cancelled", &self.cancelled.load(Ordering::SeqCst))
94            .finish()
95    }
96}
97
98impl ProgressTracker {
99    /// 创建新的进度追踪器
100    pub fn new(total: usize, callback: Option<ProgressCallback>) -> Self {
101        Self {
102            total: AtomicU64::new(total as u64),
103            processed: AtomicU64::new(0),
104            callback,
105            cancelled: std::sync::atomic::AtomicBool::new(false),
106        }
107    }
108
109    /// 更新进度
110    pub fn update(&self, bytes: usize) {
111        let processed = self.processed.fetch_add(bytes as u64, Ordering::SeqCst) + bytes as u64;
112        let total = self.total.load(Ordering::SeqCst);
113
114        if let Some(ref callback) = self.callback
115            && !callback(processed as usize, total as usize)
116        {
117            self.cancelled.store(true, Ordering::SeqCst);
118        }
119    }
120
121    /// 检查是否已取消
122    pub fn is_cancelled(&self) -> bool {
123        self.cancelled.load(Ordering::SeqCst)
124    }
125
126    /// 获取进度百分比
127    pub fn percentage(&self) -> f64 {
128        let processed = self.processed.load(Ordering::SeqCst);
129        let total = self.total.load(Ordering::SeqCst);
130
131        if total == 0 {
132            0.0
133        } else {
134            (processed as f64 / total as f64) * 100.0
135        }
136    }
137
138    /// 获取已处理字节数
139    pub fn processed(&self) -> u64 {
140        self.processed.load(Ordering::SeqCst)
141    }
142
143    /// 获取总字节数
144    pub fn total(&self) -> u64 {
145        self.total.load(Ordering::SeqCst)
146    }
147}
148
149/// 缓冲读取器(支持进度追踪)
150pub struct ProgressReader<R: Read> {
151    inner: R,
152    tracker: Arc<ProgressTracker>,
153    #[allow(dead_code)]
154    buffer_size: usize,
155}
156
157impl<R: Read> ProgressReader<R> {
158    /// 创建新的进度读取器
159    pub fn new(inner: R, tracker: Arc<ProgressTracker>, buffer_size: usize) -> Self {
160        Self {
161            inner,
162            tracker,
163            buffer_size,
164        }
165    }
166
167    /// 检查是否已取消
168    pub fn is_cancelled(&self) -> bool {
169        self.tracker.is_cancelled()
170    }
171}
172
173impl<R: Read> Read for ProgressReader<R> {
174    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175        if self.is_cancelled() {
176            return Err(io::Error::new(
177                io::ErrorKind::Interrupted,
178                "Operation cancelled",
179            ));
180        }
181
182        let n = self.inner.read(buf)?;
183        self.tracker.update(n);
184        Ok(n)
185    }
186}
187
188/// 流式处理 trait
189pub trait StreamingOperations {
190    /// 流式处理大文件
191    fn process_streaming<P, F, R>(
192        &self,
193        path: P,
194        options: &StreamOptions,
195        processor: F,
196    ) -> Result<R>
197    where
198        P: AsRef<Path>,
199        F: FnMut(&mut dyn Read) -> Result<R>;
200
201    /// 批量处理带进度回调
202    fn process_batch_with_progress<P, F>(
203        &self,
204        paths: &[P],
205        options: &StreamOptions,
206        processor: F,
207    ) -> Vec<Result<()>>
208    where
209        P: AsRef<Path>,
210        F: FnMut(&ExifTool, &Path, &ProgressTracker) -> Result<()>;
211}
212
213impl StreamingOperations for ExifTool {
214    fn process_streaming<P, F, R>(
215        &self,
216        path: P,
217        options: &StreamOptions,
218        mut processor: F,
219    ) -> Result<R>
220    where
221        P: AsRef<Path>,
222        F: FnMut(&mut dyn Read) -> Result<R>,
223    {
224        // 使用标准文件读取实现流式处理
225        let file = std::fs::File::open(path.as_ref()).map_err(crate::error::Error::Io)?;
226
227        let tracker = Arc::new(ProgressTracker::new(1, options.progress_callback.clone()));
228
229        let mut reader = ProgressReader::new(file, tracker, options.buffer_size);
230
231        processor(&mut reader)
232    }
233
234    fn process_batch_with_progress<P, F>(
235        &self,
236        paths: &[P],
237        options: &StreamOptions,
238        processor: F,
239    ) -> Vec<Result<()>>
240    where
241        P: AsRef<Path>,
242        F: FnMut(&ExifTool, &Path, &ProgressTracker) -> Result<()>,
243    {
244        let total = paths.len();
245        let tracker = Arc::new(ProgressTracker::new(
246            total,
247            options.progress_callback.clone(),
248        ));
249
250        let mut results = Vec::with_capacity(total);
251        let mut processor = processor;
252
253        for path in paths {
254            let result = processor(self, path.as_ref(), &tracker);
255            tracker.update(1);
256            results.push(result);
257        }
258
259        results
260    }
261}
262
263/// 缓存管理器
264pub struct Cache<K, V> {
265    /// 内部缓存
266    inner: std::sync::Mutex<lru::LruCache<K, V>>,
267    /// 命中率统计
268    hits: AtomicU64,
269    /// 未命中统计
270    misses: AtomicU64,
271}
272
273impl<K, V> fmt::Debug for Cache<K, V> {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        f.debug_struct("Cache")
276            .field("hits", &self.hits.load(Ordering::SeqCst))
277            .field("misses", &self.misses.load(Ordering::SeqCst))
278            .finish()
279    }
280}
281
282impl<K: std::hash::Hash + Eq + Clone, V: Clone> Cache<K, V> {
283    /// 创建新的缓存
284    pub fn new(capacity: usize) -> Self {
285        use std::num::NonZeroUsize;
286        let capacity = NonZeroUsize::new(capacity.max(1)).unwrap();
287        Self {
288            inner: std::sync::Mutex::new(lru::LruCache::new(capacity)),
289            hits: AtomicU64::new(0),
290            misses: AtomicU64::new(0),
291        }
292    }
293
294    /// 获取值
295    pub fn get(&self, key: &K) -> Option<V> {
296        let mut cache = self.inner.lock().ok()?;
297
298        if let Some(value) = cache.get(key) {
299            self.hits.fetch_add(1, Ordering::SeqCst);
300            Some(value.clone())
301        } else {
302            self.misses.fetch_add(1, Ordering::SeqCst);
303            None
304        }
305    }
306
307    /// 插入值
308    pub fn put(&self, key: K, value: V) {
309        if let Ok(mut cache) = self.inner.lock() {
310            cache.put(key, value);
311        }
312    }
313
314    /// 获取命中率
315    pub fn hit_rate(&self) -> f64 {
316        let hits = self.hits.load(Ordering::SeqCst);
317        let misses = self.misses.load(Ordering::SeqCst);
318        let total = hits + misses;
319
320        if total == 0 {
321            0.0
322        } else {
323            (hits as f64 / total as f64) * 100.0
324        }
325    }
326
327    /// 清空缓存
328    pub fn clear(&self) {
329        if let Ok(mut cache) = self.inner.lock() {
330            cache.clear();
331        }
332        self.hits.store(0, Ordering::SeqCst);
333        self.misses.store(0, Ordering::SeqCst);
334    }
335}
336
337/// 性能统计
338#[derive(Debug, Default)]
339pub struct PerformanceStats {
340    /// 总操作数
341    pub total_operations: AtomicU64,
342    /// 成功操作数
343    pub successful_operations: AtomicU64,
344    /// 失败操作数
345    pub failed_operations: AtomicU64,
346    /// 总耗时(微秒)
347    pub total_time_us: AtomicU64,
348}
349
350impl PerformanceStats {
351    /// 记录操作
352    pub fn record(&self, success: bool, elapsed_us: u64) {
353        self.total_operations.fetch_add(1, Ordering::SeqCst);
354        self.total_time_us.fetch_add(elapsed_us, Ordering::SeqCst);
355
356        if success {
357            self.successful_operations.fetch_add(1, Ordering::SeqCst);
358        } else {
359            self.failed_operations.fetch_add(1, Ordering::SeqCst);
360        }
361    }
362
363    /// 获取平均耗时(微秒)
364    pub fn avg_time_us(&self) -> u64 {
365        let total = self.total_operations.load(Ordering::SeqCst);
366        let time = self.total_time_us.load(Ordering::SeqCst);
367
368        if total == 0 { 0 } else { time / total }
369    }
370
371    /// 获取成功率
372    pub fn success_rate(&self) -> f64 {
373        let total = self.total_operations.load(Ordering::SeqCst);
374        let success = self.successful_operations.load(Ordering::SeqCst);
375
376        if total == 0 {
377            0.0
378        } else {
379            (success as f64 / total as f64) * 100.0
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use std::io::Cursor;
388
389    #[test]
390    fn test_progress_tracker() {
391        let tracker = ProgressTracker::new(100, None);
392
393        tracker.update(25);
394        assert_eq!(tracker.processed(), 25);
395        assert_eq!(tracker.percentage(), 25.0);
396
397        tracker.update(50);
398        assert_eq!(tracker.processed(), 75);
399        assert_eq!(tracker.percentage(), 75.0);
400    }
401
402    #[test]
403    fn test_progress_tracker_callback() {
404        let called = Arc::new(AtomicU64::new(0));
405        let called_clone = Arc::clone(&called);
406
407        let callback: ProgressCallback = Arc::new(move |processed, total| {
408            called_clone.store(processed as u64, Ordering::SeqCst);
409            assert_eq!(total, 100);
410            true
411        });
412
413        let tracker = ProgressTracker::new(100, Some(callback));
414        tracker.update(50);
415
416        assert_eq!(called.load(Ordering::SeqCst), 50);
417    }
418
419    #[test]
420    fn test_progress_reader() {
421        let data = b"Hello, World!";
422        let tracker = Arc::new(ProgressTracker::new(data.len(), None));
423
424        let mut reader = ProgressReader::new(Cursor::new(data), tracker.clone(), 1024);
425
426        let mut buf = Vec::new();
427        reader.read_to_end(&mut buf).unwrap();
428
429        assert_eq!(buf, data);
430        assert_eq!(tracker.processed(), data.len() as u64);
431    }
432
433    #[test]
434    fn test_performance_stats() {
435        let stats = PerformanceStats::default();
436
437        stats.record(true, 1000);
438        stats.record(true, 2000);
439        stats.record(false, 500);
440
441        assert_eq!(stats.total_operations.load(Ordering::SeqCst), 3);
442        assert_eq!(stats.successful_operations.load(Ordering::SeqCst), 2);
443        assert_eq!(stats.failed_operations.load(Ordering::SeqCst), 1);
444        assert_eq!(stats.avg_time_us(), 1166); // (1000+2000+500)/3
445    }
446}