1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
#![cfg(feature = "dataset")]

//! The dataset API that accesses multiple TFRecord files.
//!
//! The module is available when the `dataset` feature is enabled.
//! The [Dataset] type can be constructed using [DatasetInit] initializer.

use crate::{error::Error, markers::GenericRecord};
use async_std::{
    fs::File,
    io::BufReader,
    path::{Path, PathBuf, MAIN_SEPARATOR},
};
use futures::{
    io::{AsyncReadExt, AsyncSeekExt},
    stream::{StreamExt, TryStream, TryStreamExt},
};
use std::{io::SeekFrom, mem, num::NonZeroUsize, sync::Arc};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct RecordIndex {
    path: Arc<PathBuf>,
    offset: u64,
    len: usize,
}

/// The dataset initializer.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DatasetInit {
    /// Verify the checksum or not.
    pub check_integrity: bool,
    /// Maximum number of open files.
    ///
    /// Limit the number of open files if it is `Some(_)`
    /// It has no limit if it is `None`.
    pub max_open_files: Option<NonZeroUsize>,
    /// Maximum number of concurrent workers.
    ///
    /// If it is `None`, it defaults to [num_cpus::get].
    pub max_workers: Option<NonZeroUsize>,
}

impl Default for DatasetInit {
    fn default() -> Self {
        Self {
            check_integrity: true,
            max_open_files: None,
            max_workers: None,
        }
    }
}

impl DatasetInit {
    /// Open TFRecord files by a path prefix.
    ///
    /// If the path ends with "/", it searchs for all files under the directory.
    /// Otherwise, it lists the files with the path prefix.
    /// The enumerated paths will be sorted in alphabetical order.
    pub async fn from_prefix(self, prefix: &str) -> Result<Dataset, Error> {
        // get parent dir and file name prefix
        let prefix_path: &Path = prefix.as_ref();

        // assume the prefix is a directly if it ends with the separator
        let (dir, file_name_prefix_opt) = if prefix.ends_with(MAIN_SEPARATOR) {
            (prefix_path, None)
        } else {
            let dir = prefix_path.parent().expect("please report bug");
            let file_name_prefix = prefix_path
                .file_name()
                .expect("please report bug")
                .to_str()
                .expect("please report bug");
            (dir, Some(file_name_prefix))
        };

        // filter paths
        let mut paths = dir
            .read_dir()
            .await?
            .map(|result| result.map_err(|err| Error::from(err)))
            .try_filter_map(|entry| async move {
                if !entry.metadata().await?.is_file() {
                    return Ok(None);
                }

                let path = entry.path();
                let file_name =
                    entry
                        .file_name()
                        .into_string()
                        .map_err(|_| Error::UnicodeError {
                            desc: format!(r#"the file path "{}" is not Unicode"#, path.display()),
                        })?;

                match file_name_prefix_opt {
                    Some(file_name_prefix) => {
                        if file_name.starts_with(&file_name_prefix) {
                            Result::<_, Error>::Ok(Some(path))
                        } else {
                            Ok(None)
                        }
                    }
                    None => Ok(Some(path)),
                }
            })
            .try_collect::<Vec<_>>()
            .await?;

        // sort paths
        paths.sort();

        // construct dataset
        self.from_paths(&paths).await
    }

    /// Open TFRecord files by a set of path.
    ///
    /// It assumes every path is a TFRecord file, otherwise it returns error.
    /// The order of the paths affects the order of record indexes..
    pub async fn from_paths<P>(self, paths: &[P]) -> Result<Dataset, Error>
    where
        P: AsRef<Path>,
    {
        let Self {
            check_integrity,
            max_open_files,
            max_workers,
        } = self;

        let max_open_files = max_open_files.map(|num| num.get());
        let max_workers = max_workers
            .map(|num| num.get())
            .unwrap_or_else(|| num_cpus::get());
        let open_file_semaphore = max_open_files.map(|num| Arc::new(Semaphore::new(num)));

        // build record index
        let record_indexes = {
            // spawn indexing worker per path
            let future_iter = paths
                .iter()
                .map(|path| Arc::new(path.as_ref().to_owned()))
                .map(|path| {
                    let open_file_semaphore = open_file_semaphore.clone();

                    async move {
                        // acquire open file permission
                        let permit = match open_file_semaphore {
                            Some(semaphore) => Some(Arc::new(semaphore.acquire_owned().await)),
                            None => None,
                        };

                        let index_stream = {
                            // open index stream
                            let reader = BufReader::new(File::open(&*path).await?);
                            let stream = record_index_stream(reader, check_integrity);

                            // add path to index
                            let stream = stream.map_ok(move |(offset, len)| RecordIndex {
                                path: Arc::clone(&path),
                                offset,
                                len,
                            });

                            // add semaphore permission
                            let stream = stream.map_ok(move |index| {
                                let permit_clone = permit.clone();
                                (permit_clone, index)
                            });

                            stream
                        };

                        Result::<_, Error>::Ok(index_stream)
                    }
                })
                .map(async_std::task::spawn);

            // limit workers by max_workers
            let future_stream = futures::stream::iter(future_iter).buffered(max_workers);

            // drop semaphore permission
            let indexes = future_stream
                .try_flatten()
                .map_ok(|(permit, index)| {
                    mem::drop(permit);
                    index
                })
                .try_collect::<Vec<RecordIndex>>()
                .await?;

            indexes
        };

        let dataset = Dataset {
            state: Arc::new(DatasetState {
                record_indexes,
                max_workers,
                open_file_semaphore,
            }),
            open_file: None,
        };

        Ok(dataset)
    }
}

#[derive(Debug)]
struct DatasetState {
    pub record_indexes: Vec<RecordIndex>,
    pub max_workers: usize,
    pub open_file_semaphore: Option<Arc<Semaphore>>,
}

/// The dataset type.
#[derive(Debug)]
pub struct Dataset {
    state: Arc<DatasetState>,
    open_file: Option<(PathBuf, BufReader<File>, Option<OwnedSemaphorePermit>)>,
}

impl Clone for Dataset {
    fn clone(&self) -> Self {
        Self {
            state: self.state.clone(),
            open_file: None,
        }
    }
}

impl Dataset {
    /// Get the number of indexed records.
    pub fn num_records(&self) -> usize {
        self.state.record_indexes.len()
    }

    /// Get an example by an index number.
    ///
    /// It returns `None` if the index number is greater than or equal to [num_records](Dataset::num_records).
    pub async fn get<T>(&mut self, index: usize) -> Result<Option<T>, Error>
    where
        T: GenericRecord,
    {
        // try to get record index
        let record_index = match self.state.record_indexes.get(index) {
            Some(record_index) => record_index.to_owned(),
            None => return Ok(None),
        };
        let RecordIndex { offset, len, path } = record_index;

        let reader = self.open_file(&*path).await?;
        let bytes = try_read_record_at(reader, offset, len).await?;
        let record = T::from_bytes(bytes)?;
        Ok(Some(record))
    }

    /// Gets the record stream.
    pub fn stream<T>(&self) -> impl TryStream<Ok = T, Error = Error> + Send
    where
        T: GenericRecord,
    {
        let dataset = self.clone();
        futures::stream::try_unfold((dataset, 0), |state| async move {
            let (mut dataset, index) = state;
            Ok(dataset.get::<T>(index).await?.map(|record| {
                let new_state = (dataset, index + 1);
                (record, new_state)
            }))
        })
    }

    async fn open_file<P>(&mut self, path: P) -> Result<&mut BufReader<File>, Error>
    where
        P: AsRef<Path>,
    {
        let path = path.as_ref();

        // re-open file if path is distinct
        match self.open_file.take() {
            Some((opened_path, reader, permit)) if opened_path == path => {
                self.open_file = Some((opened_path, reader, permit));
            }
            args => {
                mem::drop(args); // drop previous permit and reader
                let semaphore_opt = self.state.open_file_semaphore.clone();
                let permit = match semaphore_opt {
                    Some(semaphore) => Some(semaphore.acquire_owned().await.unwrap()),
                    None => None,
                };
                let reader = BufReader::new(File::open(&path).await?);
                self.open_file = Some((path.to_owned(), reader, permit));
            }
        }

        Ok(&mut self.open_file.as_mut().unwrap().1)
    }
}

static_assertions::assert_impl_all!(Dataset: Send, Sync);

fn record_index_stream<R>(
    reader: R,
    check_integrity: bool,
) -> impl TryStream<Ok = (u64, usize), Error = Error>
where
    R: AsyncReadExt + AsyncSeekExt + Unpin,
{
    futures::stream::try_unfold((reader, check_integrity), |args| async move {
        let (mut reader, check_integrity) = args;

        let len = match crate::io::async_::try_read_len(&mut reader, check_integrity).await? {
            Some(len) => len,
            None => return Ok(None),
        };

        let offset = reader.seek(SeekFrom::Current(0)).await?;
        crate::io::async_::try_read_record_data(&mut reader, len, check_integrity).await?;

        let index = (offset, len);
        let args = (reader, check_integrity);
        Result::<_, Error>::Ok(Some((index, args)))
    })
}

async fn try_read_record_at<R>(reader: &mut R, offset: u64, len: usize) -> Result<Vec<u8>, Error>
where
    R: AsyncReadExt + AsyncSeekExt + Unpin,
{
    reader.seek(SeekFrom::Start(offset)).await?;
    let bytes = crate::io::async_::try_read_record_data(reader, len, false).await?;

    Ok(bytes)
}