lance_io/utils/
tracking_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Make assertions about IO operations to an [ObjectStore].
5//!
6//! When testing code that performs IO, you will often want to make assertions
7//! about the number of reads and writes performed, the amount of data read or
8//! written, and the number of disjoint periods where at least one IO is in-flight.
9//!
10//! This modules provides [`IOTracker`] which can be used to wrap any object store.
11use std::fmt::{Display, Formatter};
12use std::ops::Range;
13#[cfg(feature = "test-util")]
14use std::sync::atomic::AtomicU16;
15use std::sync::{Arc, Mutex};
16
17use bytes::Bytes;
18use futures::stream::BoxStream;
19use object_store::path::Path;
20use object_store::{
21    GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
22    PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
23};
24
25use crate::object_store::WrappingObjectStore;
26
27#[derive(Debug, Default, Clone)]
28pub struct IOTracker(Arc<Mutex<IoStats>>);
29
30impl IOTracker {
31    /// Get IO statistics and reset the counters (incremental pattern).
32    ///
33    /// This returns the accumulated statistics since the last call and resets
34    /// the internal counters to zero.
35    pub fn incremental_stats(&self) -> IoStats {
36        std::mem::take(&mut *self.0.lock().unwrap())
37    }
38
39    /// Get a snapshot of current IO statistics without resetting counters.
40    ///
41    /// This returns a clone of the current statistics without modifying the
42    /// internal state. Use this when you need to check stats without resetting.
43    pub fn stats(&self) -> IoStats {
44        self.0.lock().unwrap().clone()
45    }
46
47    /// Record a read operation for tracking.
48    ///
49    /// This is used by readers that bypass the ObjectStore layer (like LocalObjectReader)
50    /// to ensure their IO operations are still tracked.
51    pub fn record_read(
52        &self,
53        #[allow(unused_variables)] method: &'static str,
54        #[allow(unused_variables)] path: Path,
55        num_bytes: u64,
56        #[allow(unused_variables)] range: Option<Range<u64>>,
57    ) {
58        let mut stats = self.0.lock().unwrap();
59        stats.read_iops += 1;
60        stats.read_bytes += num_bytes;
61        #[cfg(feature = "test-util")]
62        stats.requests.push(IoRequestRecord {
63            method,
64            path,
65            range,
66        });
67    }
68}
69
70impl WrappingObjectStore for IOTracker {
71    fn wrap(&self, _store_prefix: &str, target: Arc<dyn ObjectStore>) -> Arc<dyn ObjectStore> {
72        Arc::new(IoTrackingStore::new(target, self.0.clone()))
73    }
74}
75
76#[derive(Debug, Default, Clone)]
77pub struct IoStats {
78    pub read_iops: u64,
79    pub read_bytes: u64,
80    pub write_iops: u64,
81    pub written_bytes: u64,
82    // This is only really meaningful in tests where there isn't any concurrent IO.
83    #[cfg(feature = "test-util")]
84    /// Number of disjoint periods where at least one IO is in-flight.
85    pub num_stages: u64,
86    #[cfg(feature = "test-util")]
87    pub requests: Vec<IoRequestRecord>,
88}
89
90/// Assertions on IO statistics.
91/// assert_io_eq!(io_stats, read_iops, 1);
92/// assert_io_eq!(io_stats, write_iops, 0, "should be no writes");
93/// assert_io_eq!(io_stats, num_hops, 1, "should be just {}", "one hop");
94#[cfg(feature = "test-util")]
95#[macro_export]
96macro_rules! assert_io_eq {
97    ($io_stats:expr, $field:ident, $expected:expr) => {
98        assert_eq!(
99            $io_stats.$field, $expected,
100            "Expected {} to be {}, got {}. Requests: {:#?}",
101            stringify!($field),
102            $expected,
103            $io_stats.$field,
104            $io_stats.requests
105        );
106    };
107    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
108        assert_eq!(
109            $io_stats.$field, $expected,
110            "Expected {} to be {}, got {}. Requests: {:#?} {}",
111            stringify!($field),
112            $expected,
113            $io_stats.$field,
114            $io_stats.requests,
115            format_args!($($arg)+)
116        );
117    };
118}
119
120#[cfg(feature = "test-util")]
121#[macro_export]
122macro_rules! assert_io_gt {
123    ($io_stats:expr, $field:ident, $expected:expr) => {
124        assert!(
125            $io_stats.$field > $expected,
126            "Expected {} to be > {}, got {}. Requests: {:#?}",
127            stringify!($field),
128            $expected,
129            $io_stats.$field,
130            $io_stats.requests
131        );
132    };
133    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
134        assert!(
135            $io_stats.$field > $expected,
136            "Expected {} to be > {}, got {}. Requests: {:#?} {}",
137            stringify!($field),
138            $expected,
139            $io_stats.$field,
140            $io_stats.requests,
141            format_args!($($arg)+)
142        );
143    };
144}
145
146#[cfg(feature = "test-util")]
147#[macro_export]
148macro_rules! assert_io_lt {
149    ($io_stats:expr, $field:ident, $expected:expr) => {
150        assert!(
151            $io_stats.$field < $expected,
152            "Expected {} to be < {}, got {}. Requests: {:#?}",
153            stringify!($field),
154            $expected,
155            $io_stats.$field,
156            $io_stats.requests
157        );
158    };
159    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
160        assert!(
161            $io_stats.$field < $expected,
162            "Expected {} to be < {}, got {}. Requests: {:#?} {}",
163            stringify!($field),
164            $expected,
165            $io_stats.$field,
166            $io_stats.requests,
167            format_args!($($arg)+)
168        );
169    };
170}
171
172// These fields are "dead code" because we just use them right now to display
173// in test failure messages through Debug. (The lint ignores Debug impls.)
174#[allow(dead_code)]
175#[derive(Clone)]
176pub struct IoRequestRecord {
177    pub method: &'static str,
178    pub path: Path,
179    pub range: Option<Range<u64>>,
180}
181
182impl std::fmt::Debug for IoRequestRecord {
183    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184        // For example: "put /path/to/file range: 0-100"
185        write!(
186            f,
187            "IORequest(method={}, path=\"{}\"",
188            self.method, self.path
189        )?;
190        if let Some(range) = &self.range {
191            write!(f, ", range={:?}", range)?;
192        }
193        write!(f, ")")?;
194        Ok(())
195    }
196}
197
198impl Display for IoStats {
199    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
200        write!(f, "{:#?}", self)
201    }
202}
203
204#[derive(Debug)]
205pub struct IoTrackingStore {
206    target: Arc<dyn ObjectStore>,
207    stats: Arc<Mutex<IoStats>>,
208    #[cfg(feature = "test-util")]
209    active_requests: Arc<AtomicU16>,
210}
211
212impl Display for IoTrackingStore {
213    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214        write!(f, "{:#?}", self)
215    }
216}
217
218impl IoTrackingStore {
219    pub fn new(target: Arc<dyn ObjectStore>, stats: Arc<Mutex<IoStats>>) -> Self {
220        Self {
221            target,
222            stats,
223            #[cfg(feature = "test-util")]
224            active_requests: Arc::new(AtomicU16::new(0)),
225        }
226    }
227
228    fn record_read(
229        &self,
230        method: &'static str,
231        path: Path,
232        num_bytes: u64,
233        range: Option<Range<u64>>,
234    ) {
235        let mut stats = self.stats.lock().unwrap();
236        stats.read_iops += 1;
237        stats.read_bytes += num_bytes;
238        #[cfg(feature = "test-util")]
239        stats.requests.push(IoRequestRecord {
240            method,
241            path,
242            range,
243        });
244        #[cfg(not(feature = "test-util"))]
245        let _ = (method, path, range); // Suppress unused variable warnings
246    }
247
248    fn record_write(&self, method: &'static str, path: Path, num_bytes: u64) {
249        let mut stats = self.stats.lock().unwrap();
250        stats.write_iops += 1;
251        stats.written_bytes += num_bytes;
252        #[cfg(feature = "test-util")]
253        stats.requests.push(IoRequestRecord {
254            method,
255            path,
256            range: None,
257        });
258        #[cfg(not(feature = "test-util"))]
259        let _ = (method, path); // Suppress unused variable warnings
260    }
261
262    #[cfg(feature = "test-util")]
263    fn stage_guard(&self) -> StageGuard {
264        StageGuard::new(self.active_requests.clone(), self.stats.clone())
265    }
266
267    #[cfg(not(feature = "test-util"))]
268    fn stage_guard(&self) -> StageGuard {
269        StageGuard
270    }
271}
272
273#[async_trait::async_trait]
274#[deny(clippy::missing_trait_methods)]
275impl ObjectStore for IoTrackingStore {
276    async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
277        let _guard = self.stage_guard();
278        self.record_write("put", location.to_owned(), bytes.content_length() as u64);
279        self.target.put(location, bytes).await
280    }
281
282    async fn put_opts(
283        &self,
284        location: &Path,
285        bytes: PutPayload,
286        opts: PutOptions,
287    ) -> OSResult<PutResult> {
288        let _guard = self.stage_guard();
289        self.record_write(
290            "put_opts",
291            location.to_owned(),
292            bytes.content_length() as u64,
293        );
294        self.target.put_opts(location, bytes, opts).await
295    }
296
297    async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
298        let _guard = self.stage_guard();
299        let target = self.target.put_multipart(location).await?;
300        Ok(Box::new(IoTrackingMultipartUpload {
301            target,
302            stats: self.stats.clone(),
303            #[cfg(feature = "test-util")]
304            path: location.to_owned(),
305            #[cfg(feature = "test-util")]
306            _guard,
307        }))
308    }
309
310    async fn put_multipart_opts(
311        &self,
312        location: &Path,
313        opts: PutMultipartOptions,
314    ) -> OSResult<Box<dyn MultipartUpload>> {
315        let _guard = self.stage_guard();
316        let target = self.target.put_multipart_opts(location, opts).await?;
317        Ok(Box::new(IoTrackingMultipartUpload {
318            target,
319            stats: self.stats.clone(),
320            #[cfg(feature = "test-util")]
321            path: location.to_owned(),
322            #[cfg(feature = "test-util")]
323            _guard,
324        }))
325    }
326
327    async fn get(&self, location: &Path) -> OSResult<GetResult> {
328        let _guard = self.stage_guard();
329        let result = self.target.get(location).await;
330        if let Ok(result) = &result {
331            let num_bytes = result.range.end - result.range.start;
332            self.record_read("get", location.to_owned(), num_bytes, None);
333        }
334        result
335    }
336
337    async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
338        let _guard = self.stage_guard();
339        let range = match &options.range {
340            Some(GetRange::Bounded(range)) => Some(range.clone()),
341            _ => None, // TODO: fill in other options.
342        };
343        let result = self.target.get_opts(location, options).await;
344        if let Ok(result) = &result {
345            let num_bytes = result.range.end - result.range.start;
346
347            self.record_read("get_opts", location.to_owned(), num_bytes, range);
348        }
349        result
350    }
351
352    async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
353        let _guard = self.stage_guard();
354        let result = self.target.get_range(location, range.clone()).await;
355        if let Ok(result) = &result {
356            self.record_read(
357                "get_range",
358                location.to_owned(),
359                result.len() as u64,
360                Some(range),
361            );
362        }
363        result
364    }
365
366    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
367        let _guard = self.stage_guard();
368        let result = self.target.get_ranges(location, ranges).await;
369        if let Ok(result) = &result {
370            self.record_read(
371                "get_ranges",
372                location.to_owned(),
373                result.iter().map(|b| b.len() as u64).sum(),
374                None,
375            );
376        }
377        result
378    }
379
380    async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
381        let _guard = self.stage_guard();
382        self.record_read("head", location.to_owned(), 0, None);
383        self.target.head(location).await
384    }
385
386    async fn delete(&self, location: &Path) -> OSResult<()> {
387        let _guard = self.stage_guard();
388        self.record_write("delete", location.to_owned(), 0);
389        self.target.delete(location).await
390    }
391
392    fn delete_stream<'a>(
393        &'a self,
394        locations: BoxStream<'a, OSResult<Path>>,
395    ) -> BoxStream<'a, OSResult<Path>> {
396        self.target.delete_stream(locations)
397    }
398
399    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
400        let _guard = self.stage_guard();
401        self.record_read("list", prefix.cloned().unwrap_or_default(), 0, None);
402        self.target.list(prefix)
403    }
404
405    fn list_with_offset(
406        &self,
407        prefix: Option<&Path>,
408        offset: &Path,
409    ) -> BoxStream<'static, OSResult<ObjectMeta>> {
410        self.record_read(
411            "list_with_offset",
412            prefix.cloned().unwrap_or_default(),
413            0,
414            None,
415        );
416        self.target.list_with_offset(prefix, offset)
417    }
418
419    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
420        let _guard = self.stage_guard();
421        self.record_read(
422            "list_with_delimiter",
423            prefix.cloned().unwrap_or_default(),
424            0,
425            None,
426        );
427        self.target.list_with_delimiter(prefix).await
428    }
429
430    async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
431        let _guard = self.stage_guard();
432        self.record_write("copy", from.to_owned(), 0);
433        self.target.copy(from, to).await
434    }
435
436    async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
437        let _guard = self.stage_guard();
438        self.record_write("rename", from.to_owned(), 0);
439        self.target.rename(from, to).await
440    }
441
442    async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
443        let _guard = self.stage_guard();
444        self.record_write("rename_if_not_exists", from.to_owned(), 0);
445        self.target.rename_if_not_exists(from, to).await
446    }
447
448    async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
449        let _guard = self.stage_guard();
450        self.record_write("copy_if_not_exists", from.to_owned(), 0);
451        self.target.copy_if_not_exists(from, to).await
452    }
453}
454
455#[derive(Debug)]
456struct IoTrackingMultipartUpload {
457    target: Box<dyn MultipartUpload>,
458    #[cfg(feature = "test-util")]
459    path: Path,
460    stats: Arc<Mutex<IoStats>>,
461    #[cfg(feature = "test-util")]
462    _guard: StageGuard,
463}
464
465#[async_trait::async_trait]
466impl MultipartUpload for IoTrackingMultipartUpload {
467    async fn abort(&mut self) -> OSResult<()> {
468        self.target.abort().await
469    }
470
471    async fn complete(&mut self) -> OSResult<PutResult> {
472        self.target.complete().await
473    }
474
475    fn put_part(&mut self, payload: PutPayload) -> UploadPart {
476        {
477            let mut stats = self.stats.lock().unwrap();
478            stats.write_iops += 1;
479            stats.written_bytes += payload.content_length() as u64;
480            #[cfg(feature = "test-util")]
481            stats.requests.push(IoRequestRecord {
482                method: "put_part",
483                path: self.path.to_owned(),
484                range: None,
485            });
486        }
487        self.target.put_part(payload)
488    }
489}
490
491#[cfg(feature = "test-util")]
492#[derive(Debug)]
493struct StageGuard {
494    active_requests: Arc<AtomicU16>,
495    stats: Arc<Mutex<IoStats>>,
496}
497
498#[cfg(not(feature = "test-util"))]
499struct StageGuard;
500
501#[cfg(feature = "test-util")]
502impl StageGuard {
503    fn new(active_requests: Arc<AtomicU16>, stats: Arc<Mutex<IoStats>>) -> Self {
504        active_requests.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
505        Self {
506            active_requests,
507            stats,
508        }
509    }
510}
511
512#[cfg(feature = "test-util")]
513impl Drop for StageGuard {
514    fn drop(&mut self) {
515        if self
516            .active_requests
517            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
518            == 1
519        {
520            let mut stats = self.stats.lock().unwrap();
521            stats.num_stages += 1;
522        }
523    }
524}