lance_io/utils/
tracking_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Make assertions about IO operations to an [ObjectStore].
5//!
6//! When testing code that performs IO, you will often want to make assertions
7//! about the number of reads and writes performed, the amount of data read or
8//! written, and the number of disjoint periods where at least one IO is in-flight.
9//!
10//! This modules provides [`IOTracker`] which can be used to wrap any object store.
11use std::fmt::{Display, Formatter};
12use std::ops::Range;
13use std::sync::{atomic::AtomicU16, Arc, Mutex};
14
15use bytes::Bytes;
16use futures::stream::BoxStream;
17use object_store::path::Path;
18use object_store::{
19    GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
20    PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
21};
22
23use crate::object_store::WrappingObjectStore;
24
25#[derive(Debug, Default, Clone)]
26pub struct IOTracker(Arc<Mutex<IoStats>>);
27
28impl IOTracker {
29    pub fn incremental_stats(&self) -> IoStats {
30        std::mem::take(&mut *self.0.lock().unwrap())
31    }
32}
33
34impl WrappingObjectStore for IOTracker {
35    fn wrap(
36        &self,
37        target: Arc<dyn ObjectStore>,
38        _storage_options: Option<&std::collections::HashMap<String, String>>,
39    ) -> Arc<dyn ObjectStore> {
40        Arc::new(IoTrackingStore::new(target, self.0.clone()))
41    }
42}
43
44#[derive(Debug, Default)]
45pub struct IoStats {
46    pub read_iops: u64,
47    pub read_bytes: u64,
48    pub write_iops: u64,
49    pub write_bytes: u64,
50    /// Number of disjoint periods where at least one IO is in-flight.
51    pub num_hops: u64,
52    pub requests: Vec<IoRequestRecord>,
53}
54
55/// Assertions on IO statistics.
56/// assert_io_eq!(io_stats, read_iops, 1);
57/// assert_io_eq!(io_stats, write_iops, 0, "should be no writes");
58/// assert_io_eq!(io_stats, num_hops, 1, "should be just {}", "one hop");
59#[macro_export]
60macro_rules! assert_io_eq {
61    ($io_stats:expr, $field:ident, $expected:expr) => {
62        assert_eq!(
63            $io_stats.$field, $expected,
64            "Expected {} to be {}, got {}. Requests: {:#?}",
65            stringify!($field),
66            $expected,
67            $io_stats.$field,
68            $io_stats.requests
69        );
70    };
71    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
72        assert_eq!(
73            $io_stats.$field, $expected,
74            "Expected {} to be {}, got {}. Requests: {:#?} {}",
75            stringify!($field),
76            $expected,
77            $io_stats.$field,
78            $io_stats.requests,
79            format_args!($($arg)+)
80        );
81    };
82}
83
84#[macro_export]
85macro_rules! assert_io_gt {
86    ($io_stats:expr, $field:ident, $expected:expr) => {
87        assert!(
88            $io_stats.$field > $expected,
89            "Expected {} to be > {}, got {}. Requests: {:#?}",
90            stringify!($field),
91            $expected,
92            $io_stats.$field,
93            $io_stats.requests
94        );
95    };
96    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
97        assert!(
98            $io_stats.$field > $expected,
99            "Expected {} to be > {}, got {}. Requests: {:#?} {}",
100            stringify!($field),
101            $expected,
102            $io_stats.$field,
103            $io_stats.requests,
104            format_args!($($arg)+)
105        );
106    };
107}
108
109#[macro_export]
110macro_rules! assert_io_lt {
111    ($io_stats:expr, $field:ident, $expected:expr) => {
112        assert!(
113            $io_stats.$field < $expected,
114            "Expected {} to be < {}, got {}. Requests: {:#?}",
115            stringify!($field),
116            $expected,
117            $io_stats.$field,
118            $io_stats.requests
119        );
120    };
121    ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => {
122        assert!(
123            $io_stats.$field < $expected,
124            "Expected {} to be < {}, got {}. Requests: {:#?} {}",
125            stringify!($field),
126            $expected,
127            $io_stats.$field,
128            $io_stats.requests,
129            format_args!($($arg)+)
130        );
131    };
132}
133
134// These fields are "dead code" because we just use them right now to display
135// in test failure messages through Debug. (The lint ignores Debug impls.)
136#[allow(dead_code)]
137#[derive(Clone)]
138pub struct IoRequestRecord {
139    pub method: &'static str,
140    pub path: Path,
141    pub range: Option<Range<u64>>,
142}
143
144impl std::fmt::Debug for IoRequestRecord {
145    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
146        // For example: "put /path/to/file range: 0-100"
147        write!(
148            f,
149            "IORequest(method={}, path=\"{}\"",
150            self.method, self.path
151        )?;
152        if let Some(range) = &self.range {
153            write!(f, ", range={:?}", range)?;
154        }
155        write!(f, ")")?;
156        Ok(())
157    }
158}
159
160impl Display for IoStats {
161    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
162        write!(f, "{:#?}", self)
163    }
164}
165
166#[derive(Debug)]
167pub struct IoTrackingStore {
168    target: Arc<dyn ObjectStore>,
169    stats: Arc<Mutex<IoStats>>,
170    active_requests: Arc<AtomicU16>,
171}
172
173impl Display for IoTrackingStore {
174    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175        write!(f, "{:#?}", self)
176    }
177}
178
179impl IoTrackingStore {
180    fn new(target: Arc<dyn ObjectStore>, stats: Arc<Mutex<IoStats>>) -> Self {
181        Self {
182            target,
183            stats,
184            active_requests: Arc::new(AtomicU16::new(0)),
185        }
186    }
187
188    fn record_read(
189        &self,
190        method: &'static str,
191        path: Path,
192        num_bytes: u64,
193        range: Option<Range<u64>>,
194    ) {
195        let mut stats = self.stats.lock().unwrap();
196        stats.read_iops += 1;
197        stats.read_bytes += num_bytes;
198        stats.requests.push(IoRequestRecord {
199            method,
200            path,
201            range,
202        });
203    }
204
205    fn record_write(&self, method: &'static str, path: Path, num_bytes: u64) {
206        let mut stats = self.stats.lock().unwrap();
207        stats.write_iops += 1;
208        stats.write_bytes += num_bytes;
209        stats.requests.push(IoRequestRecord {
210            method,
211            path,
212            range: None,
213        });
214    }
215
216    fn hop_guard(&self) -> HopGuard {
217        HopGuard::new(self.active_requests.clone(), self.stats.clone())
218    }
219}
220
221#[async_trait::async_trait]
222#[deny(clippy::missing_trait_methods)]
223impl ObjectStore for IoTrackingStore {
224    async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
225        let _guard = self.hop_guard();
226        self.record_write("put", location.to_owned(), bytes.content_length() as u64);
227        self.target.put(location, bytes).await
228    }
229
230    async fn put_opts(
231        &self,
232        location: &Path,
233        bytes: PutPayload,
234        opts: PutOptions,
235    ) -> OSResult<PutResult> {
236        let _guard = self.hop_guard();
237        self.record_write(
238            "put_opts",
239            location.to_owned(),
240            bytes.content_length() as u64,
241        );
242        self.target.put_opts(location, bytes, opts).await
243    }
244
245    async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
246        let _guard = self.hop_guard();
247        let target = self.target.put_multipart(location).await?;
248        Ok(Box::new(IoTrackingMultipartUpload {
249            target,
250            stats: self.stats.clone(),
251            path: location.to_owned(),
252            _guard,
253        }))
254    }
255
256    async fn put_multipart_opts(
257        &self,
258        location: &Path,
259        opts: PutMultipartOptions,
260    ) -> OSResult<Box<dyn MultipartUpload>> {
261        let _guard = self.hop_guard();
262        let target = self.target.put_multipart_opts(location, opts).await?;
263        Ok(Box::new(IoTrackingMultipartUpload {
264            target,
265            stats: self.stats.clone(),
266            path: location.to_owned(),
267            _guard,
268        }))
269    }
270
271    async fn get(&self, location: &Path) -> OSResult<GetResult> {
272        let _guard = self.hop_guard();
273        let result = self.target.get(location).await;
274        if let Ok(result) = &result {
275            let num_bytes = result.range.end - result.range.start;
276            self.record_read("get", location.to_owned(), num_bytes, None);
277        }
278        result
279    }
280
281    async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
282        let _guard = self.hop_guard();
283        let range = match &options.range {
284            Some(GetRange::Bounded(range)) => Some(range.clone()),
285            _ => None, // TODO: fill in other options.
286        };
287        let result = self.target.get_opts(location, options).await;
288        if let Ok(result) = &result {
289            let num_bytes = result.range.end - result.range.start;
290
291            self.record_read("get_opts", location.to_owned(), num_bytes, range);
292        }
293        result
294    }
295
296    async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
297        let _guard = self.hop_guard();
298        let result = self.target.get_range(location, range.clone()).await;
299        if let Ok(result) = &result {
300            self.record_read(
301                "get_range",
302                location.to_owned(),
303                result.len() as u64,
304                Some(range),
305            );
306        }
307        result
308    }
309
310    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
311        let _guard = self.hop_guard();
312        let result = self.target.get_ranges(location, ranges).await;
313        if let Ok(result) = &result {
314            self.record_read(
315                "get_ranges",
316                location.to_owned(),
317                result.iter().map(|b| b.len() as u64).sum(),
318                None,
319            );
320        }
321        result
322    }
323
324    async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
325        let _guard = self.hop_guard();
326        self.record_read("head", location.to_owned(), 0, None);
327        self.target.head(location).await
328    }
329
330    async fn delete(&self, location: &Path) -> OSResult<()> {
331        let _guard = self.hop_guard();
332        self.record_write("delete", location.to_owned(), 0);
333        self.target.delete(location).await
334    }
335
336    fn delete_stream<'a>(
337        &'a self,
338        locations: BoxStream<'a, OSResult<Path>>,
339    ) -> BoxStream<'a, OSResult<Path>> {
340        self.target.delete_stream(locations)
341    }
342
343    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
344        let _guard = self.hop_guard();
345        self.record_read("list", prefix.cloned().unwrap_or_default(), 0, None);
346        self.target.list(prefix)
347    }
348
349    fn list_with_offset(
350        &self,
351        prefix: Option<&Path>,
352        offset: &Path,
353    ) -> BoxStream<'static, OSResult<ObjectMeta>> {
354        self.record_read(
355            "list_with_offset",
356            prefix.cloned().unwrap_or_default(),
357            0,
358            None,
359        );
360        self.target.list_with_offset(prefix, offset)
361    }
362
363    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
364        let _guard = self.hop_guard();
365        self.record_read(
366            "list_with_delimiter",
367            prefix.cloned().unwrap_or_default(),
368            0,
369            None,
370        );
371        self.target.list_with_delimiter(prefix).await
372    }
373
374    async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
375        let _guard = self.hop_guard();
376        self.record_write("copy", from.to_owned(), 0);
377        self.target.copy(from, to).await
378    }
379
380    async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
381        let _guard = self.hop_guard();
382        self.record_write("rename", from.to_owned(), 0);
383        self.target.rename(from, to).await
384    }
385
386    async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
387        let _guard = self.hop_guard();
388        self.record_write("rename_if_not_exists", from.to_owned(), 0);
389        self.target.rename_if_not_exists(from, to).await
390    }
391
392    async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
393        let _guard = self.hop_guard();
394        self.record_write("copy_if_not_exists", from.to_owned(), 0);
395        self.target.copy_if_not_exists(from, to).await
396    }
397}
398
399#[derive(Debug)]
400struct IoTrackingMultipartUpload {
401    target: Box<dyn MultipartUpload>,
402    path: Path,
403    stats: Arc<Mutex<IoStats>>,
404    _guard: HopGuard,
405}
406
407#[async_trait::async_trait]
408impl MultipartUpload for IoTrackingMultipartUpload {
409    async fn abort(&mut self) -> OSResult<()> {
410        self.target.abort().await
411    }
412
413    async fn complete(&mut self) -> OSResult<PutResult> {
414        self.target.complete().await
415    }
416
417    fn put_part(&mut self, payload: PutPayload) -> UploadPart {
418        {
419            let mut stats = self.stats.lock().unwrap();
420            stats.write_iops += 1;
421            stats.write_bytes += payload.content_length() as u64;
422            stats.requests.push(IoRequestRecord {
423                method: "put_part",
424                path: self.path.to_owned(),
425                range: None,
426            });
427        }
428        self.target.put_part(payload)
429    }
430}
431
432#[derive(Debug)]
433struct HopGuard {
434    active_requests: Arc<AtomicU16>,
435    stats: Arc<Mutex<IoStats>>,
436}
437
438impl HopGuard {
439    fn new(active_requests: Arc<AtomicU16>, stats: Arc<Mutex<IoStats>>) -> Self {
440        active_requests.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
441        Self {
442            active_requests,
443            stats,
444        }
445    }
446}
447
448impl Drop for HopGuard {
449    fn drop(&mut self) {
450        if self
451            .active_requests
452            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
453            == 1
454        {
455            let mut stats = self.stats.lock().unwrap();
456            stats.num_hops += 1;
457        }
458    }
459}