rten 0.24.0

Machine learning runtime
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
//! Functions for loading tensor data stored externally to the main model file.
//!
//! This is used for ONNX models when `TensorProto`s reference external data
//! files.

use std::cell::RefCell;
use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::ops::Range;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;

#[cfg(feature = "mmap")]
use memmap2::Mmap;

use super::load_error::{LoadError, LoadErrorImpl};
use crate::constant_storage::ConstantStorage;

/// Specifies the location of tensor data which is stored externally from the
/// main model file.
#[derive(Clone, Debug)]
pub struct DataLocation {
    /// Name of the external data file.
    pub path: String,

    /// Offset of the start of the tensor data in bytes.
    pub offset: u64,

    /// Length of the tensor data in bytes.
    pub length: u64,
}

/// A slice of a shared buffer, where the slice contains the data for one
/// tensor.
#[derive(Debug)]
pub struct DataSlice {
    /// The shared buffer. This may contain data for one or multiple tensors
    /// and may be a `Vec<u8>`, a memory-mapped file or static slice.
    pub storage: Arc<ConstantStorage>,

    /// The range of bytes within `storage` that contain the tensor data.
    pub bytes: Range<usize>,
}

impl DataSlice {
    pub fn data(&self) -> &[u8] {
        &self.storage.data()[self.bytes.clone()]
    }
}

/// Errors reading tensor data from an external file.
#[derive(Debug)]
pub enum ExternalDataError {
    /// An IO error occurred when accessing the external file.
    IoError(std::io::Error),

    /// The length of the external data is too large.
    InvalidLength,

    /// An invalid path was specified.
    InvalidPath(PathBuf),

    /// External data is not supported in the current environment.
    NotSupported,

    /// The length of the external data file is too short for the offset and
    /// length of the external data.
    TooShort {
        /// Minimum length the file would need to be in bytes.
        required_len: usize,
        /// Actual length of the file in bytes.
        actual_len: usize,
    },

    /// The external data file path is disallowed.
    DisallowedPath(PathBuf),
}

impl std::fmt::Display for ExternalDataError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::IoError(err) => write!(f, "io error: {}", err),
            Self::InvalidLength => write!(f, "invalid data length"),
            Self::InvalidPath(path) => write!(f, "invalid path \"{}\"", path.display()),
            Self::NotSupported => write!(f, "external data not supported"),
            Self::TooShort {
                required_len,
                actual_len,
            } => write!(
                f,
                "file too short. required {} actual {}",
                required_len, actual_len
            ),
            Self::DisallowedPath(path) => {
                write!(f, "disallowed path \"{}\"", path.display(),)
            }
        }
    }
}

impl std::error::Error for ExternalDataError {}

impl From<std::io::Error> for ExternalDataError {
    fn from(val: std::io::Error) -> Self {
        Self::IoError(val)
    }
}

impl From<ExternalDataError> for LoadError {
    fn from(err: ExternalDataError) -> LoadError {
        LoadErrorImpl::ExternalDataError(Box::new(err)).into()
    }
}

/// Trait for loading data from an external file.
pub trait DataLoader {
    /// Load data from the file and offset specified by `location`.
    fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError>;
}

/// Check if `path` is an allowed path for external data files for a given
/// model path.
///
/// Any data loaded from an external data file can potentially be returned in
/// inference results (directly or indirectly). Hence some measures are taken to
/// prevent loading of data from files not intended for this. The ONNX
/// documentation states that the only restriction is that parent directory
/// components ("..") are disallowed
/// (https://onnx.ai/onnx/repo-docs/ExternalData.html#external-data-field). This
/// implementation imposes additional restrictions:
///
///  - The file must have one of the known extensions used for tensor data
///    ("data", "onnx_data", "onnx_data_N" etc.)
///  - The file must be a relative path with only a filename component (ie. it
///    must be located in the same directory as the model)
fn is_allowed_external_data_path(path: &Path) -> bool {
    // Data file path must be relative and consist only of a filename.
    let mut components = path.components();
    let Some(Component::Normal(_)) = components.next() else {
        return false;
    };
    if components.next().is_some() {
        return false;
    }

    // Check for allowed extension. The most common extensions used are "data"
    // or "onnx_data", but large files are sometimes split into pieces eg.
    // "onnx_data_N".
    match path.extension().and_then(|ext| ext.to_str()) {
        Some(ext) if ext.starts_with("data") => true,
        Some(ext) if ext.starts_with("onnx_data") => true,
        _ => false,
    }
}

/// External data loader that uses standard file IO.
pub struct FileLoader {
    /// Path to directory containing external data.
    dir_path: PathBuf,

    /// Map of external data file name to open file.
    files: RefCell<HashMap<PathBuf, File>>,
}

impl FileLoader {
    /// Create an external data loader which loads data for the model file
    /// specified by `model_path`.
    ///
    /// Data file paths will be resolved relative to the directory containing
    /// `model_path`.
    pub fn new(model_path: &Path) -> Result<Self, ExternalDataError> {
        let dir_path = dir_path_from_model_path(model_path)?;

        Ok(Self {
            dir_path,
            files: HashMap::new().into(),
        })
    }

    fn read(&self, location: &DataLocation) -> Result<Vec<u8>, ExternalDataError> {
        // On a big-endian system we'd need to perform byte-swapping while loading.
        if cfg!(target_endian = "big") {
            return Err(ExternalDataError::NotSupported);
        }

        // On a 32-bit systems assume we can't load more than 2GB of data.
        if location.length > isize::MAX as u64 {
            return Err(ExternalDataError::InvalidLength);
        }
        let vec_len = location.length as usize;

        let mut files = self.files.borrow_mut();
        let mut file = get_or_open_file(&mut files, &self.dir_path, Path::new(&location.path))?;
        file.seek(SeekFrom::Start(location.offset))
            .map_err(ExternalDataError::IoError)?;

        // Ideally we would fill the buffer in one call via [`Read::read_buf`].
        // Since that API is not stabilized yet, we fill in small chunks, which
        // requires extra copying.
        let mut remaining = vec_len;
        let mut buf = Vec::with_capacity(remaining);

        // Buffer size chosen to match BufReader's default.
        const TMP_SIZE: usize = 8192;
        let mut tmp = [0u8; TMP_SIZE];

        loop {
            let tmp_size = remaining.min(TMP_SIZE);
            let n_read =
                read_fill(&mut file, &mut tmp[..tmp_size]).map_err(ExternalDataError::IoError)?;
            let chunk = &tmp[..n_read];
            remaining -= chunk.len();
            buf.extend_from_slice(chunk);

            if n_read < tmp.len() || remaining == 0 {
                break;
            }
        }

        if buf.len() != vec_len {
            return Err(ExternalDataError::TooShort {
                required_len: vec_len,
                actual_len: buf.len(),
            });
        }

        Ok(buf)
    }
}

/// Read from `src` repeatedly until `buf` is filled or we reach the end of the
/// file.
fn read_fill<R: Read>(mut src: R, buf: &mut [u8]) -> std::io::Result<usize> {
    let mut total_read = 0;
    loop {
        let n = src.read(&mut buf[total_read..])?;
        total_read += n;
        if n == 0 || total_read == buf.len() {
            break;
        }
    }
    Ok(total_read)
}

impl DataLoader for FileLoader {
    fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
        let bytes = self.read(location)?;
        Ok(DataSlice {
            bytes: 0..bytes.len(),
            storage: Arc::new(ConstantStorage::Buffer(bytes)),
        })
    }
}

fn get_or_open_file<'a>(
    files: &'a mut HashMap<PathBuf, File>,
    dir_path: &Path,
    data_path: &Path,
) -> Result<&'a mut File, ExternalDataError> {
    let data_path = Path::new(data_path);
    if !is_allowed_external_data_path(data_path) {
        return Err(ExternalDataError::DisallowedPath(data_path.into()));
    }

    // Check if we already opened the file.
    if files.get(data_path).is_none() {
        let mut file_path = dir_path.to_path_buf();
        file_path.push(data_path);
        let file = File::open(file_path).map_err(ExternalDataError::IoError)?;
        files.insert(data_path.into(), file);
    }

    Ok(files.get_mut(data_path).unwrap())
}

fn dir_path_from_model_path(model_path: &Path) -> Result<PathBuf, ExternalDataError> {
    // Resolve the path now to avoid the possibility of loading data from
    // an unexpected location if `model_path` is relative and the current
    // working directory changes before a data file is loaded.
    let model_path = if !cfg!(target_arch = "wasm32") {
        model_path.canonicalize()?
    } else {
        // On WASM / WASI `Path::canonicalize` is not available.
        model_path.to_path_buf()
    };

    if !model_path.is_file() {
        return Err(ExternalDataError::InvalidPath(model_path));
    }
    let dir_path = model_path
        .parent()
        // Since `model_path` is a file path, it cannot be the root ("/").
        .expect("should have parent dir")
        .to_path_buf();
    Ok(dir_path)
}

/// External data loader that uses memory mapping.
///
/// # Alignment requirements
///
/// The ONNX Protocol Buffers schema and external data documentation state that
/// data offsets for individual tensors should be aligned to the page size,
/// which is the granularity of the `offset` argument to `mmap`. RTen only
/// requires offsets to be a multiple of the tensor element type's alignment.
/// This is because RTen creates only one memory map for each external data
/// file, covering the whole file. Each tensor using data from that file will
/// then share the memory mapping.
#[cfg(feature = "mmap")]
pub struct MmapLoader {
    /// Path to directory containing external data.
    dir_path: PathBuf,

    /// Map of filename to open mmap-ed content.
    mmaps: RefCell<HashMap<PathBuf, Arc<ConstantStorage>>>,
}

#[cfg(feature = "mmap")]
impl MmapLoader {
    /// Create a data loader which will use memory mapping to load data from
    /// files in the same directory as `model_path`.
    ///
    /// One memory map will be created per external file. This map will remain
    /// open as long as the `MmapLoader` or any tensors using data from it are
    /// still open.
    ///
    /// # Safety
    ///
    /// This method is marked as unsafe because truncating the file on disk
    /// while the file is mapped could cause undefined behavior. Applications
    /// must decide this is an acceptable risk for their use. See the notes for
    /// [`Model::load_mmap`](crate::model::Model::load_mmap).
    pub unsafe fn new(model_path: &Path) -> Result<Self, ExternalDataError> {
        let dir_path = dir_path_from_model_path(model_path)?;

        Ok(Self {
            dir_path,
            mmaps: HashMap::new().into(),
        })
    }

    fn get_or_open_mmap(
        &self,
        data_path: &Path,
    ) -> Result<Arc<ConstantStorage>, ExternalDataError> {
        let mut mmaps = self.mmaps.borrow_mut();

        let data_path = Path::new(data_path);
        if !is_allowed_external_data_path(data_path) {
            return Err(ExternalDataError::DisallowedPath(data_path.into()));
        }

        // Check if we already opened the file.
        if mmaps.get(data_path).is_none() {
            let mut file_path = self.dir_path.to_path_buf();
            file_path.push(data_path);
            let file = File::open(file_path).map_err(ExternalDataError::IoError)?;

            // Safety: By constructing an instance of `Self`, the caller has
            // accepted the risks that come with mmap.
            let mmap = unsafe { Mmap::map(&file) }?;

            let storage = Arc::new(ConstantStorage::Mmap(mmap));
            mmaps.insert(data_path.into(), storage);
        }

        Ok(mmaps.get(data_path).unwrap().clone())
    }
}

#[cfg(feature = "mmap")]
impl DataLoader for MmapLoader {
    fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
        let storage = self.get_or_open_mmap(Path::new(&location.path))?;

        let end_offset = location.offset.saturating_add(location.length);
        if end_offset > storage.data().len() as u64 {
            return Err(ExternalDataError::TooShort {
                required_len: end_offset as usize,
                actual_len: storage.data().len(),
            });
        }

        Ok(DataSlice {
            storage,
            bytes: location.offset as usize..location.offset as usize + location.length as usize,
        })
    }
}

/// A [`DataLoader`] which loads data from in-memory buffers.
///
/// This supports loading models with external data in contexts where a file
/// system is not available (eg. a browser) or inconvenient to use (eg. in
/// tests).
pub struct MemLoader(HashMap<String, Arc<ConstantStorage>>);

impl MemLoader {
    pub fn new(map: HashMap<String, Arc<ConstantStorage>>) -> Self {
        Self(map)
    }

    #[cfg(test)]
    pub fn from_entries(entries: impl IntoIterator<Item = (String, Vec<u8>)>) -> Self {
        let map = entries
            .into_iter()
            .map(|(path, buf)| {
                let storage = Arc::new(ConstantStorage::Buffer(buf));
                (path, storage)
            })
            .collect();
        Self(map)
    }
}

impl DataLoader for MemLoader {
    fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
        // The restrictions on allowed data paths don't matter for the in-memory
        // data loader, but we apply them for consistency with other loaders.
        if !is_allowed_external_data_path(Path::new(&location.path)) {
            return Err(ExternalDataError::DisallowedPath(
                location.path.clone().into(),
            ));
        }
        let Some(storage) = self.0.get(&location.path) else {
            // Error message chosen to match the file loader.
            return Err(ExternalDataError::IoError(std::io::Error::new(
                std::io::ErrorKind::NotFound,
                "No such file or directory".to_string(),
            )));
        };
        let end_offset = location.offset + location.length;
        if end_offset > storage.data().len() as u64 {
            return Err(ExternalDataError::TooShort {
                required_len: end_offset as usize,
                actual_len: storage.data().len(),
            });
        }

        let bytes = (location.offset as usize)..end_offset as usize;
        Ok(DataSlice {
            storage: storage.clone(),
            bytes,
        })
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;
    use std::panic::RefUnwindSafe;
    use std::path::{Path, PathBuf};
    use std::sync::Arc;

    use super::{DataLoader, DataLocation, ExternalDataError, FileLoader, MemLoader};
    use crate::constant_storage::ConstantStorage;
    use rten_testing::TestCases;

    fn temp_dir() -> PathBuf {
        if cfg!(target_arch = "wasm32") {
            // temp_dir is not available on WASM / WASI, so just use the current
            // directory.
            PathBuf::new()
        } else {
            std::env::temp_dir()
        }
    }

    struct TempFile {
        path: PathBuf,
    }

    impl TempFile {
        fn new(name: impl AsRef<Path>, content: &[u8]) -> std::io::Result<Self> {
            let mut path = temp_dir();
            path.push(name);
            std::fs::write(&path, content)?;
            Ok(Self { path })
        }

        fn path(&self) -> &Path {
            &self.path
        }
    }

    impl Drop for TempFile {
        fn drop(&mut self) {
            std::fs::remove_file(&self.path).expect("should remove file");
        }
    }

    // Run common tests for `DataLoader` impls. The `base_name` must be unique
    // for each test.
    fn test_loader<L: DataLoader>(
        base_name: &str,
        make_loader: impl Fn(&Path) -> Result<L, ExternalDataError> + RefUnwindSafe,
    ) {
        let bytes: Vec<u8> = (0..32).collect();
        let model_file = TempFile::new(format!("{base_name}.onnx"), &[]).unwrap();
        let data_file = TempFile::new(format!("{base_name}.onnx.data"), &bytes).unwrap();

        let data_filename = data_file
            .path()
            .file_name()
            .unwrap()
            .to_string_lossy()
            .to_string();

        #[derive(Debug)]
        struct Case {
            location: DataLocation,
            expected: Result<Vec<u8>, String>,
        }

        let cases = [
            // Part of file
            Case {
                location: DataLocation {
                    path: data_filename.clone(),
                    offset: 8,
                    length: 8,
                },
                expected: Ok(bytes[8..16].into()),
            },
            // Full file
            Case {
                location: DataLocation {
                    path: data_filename.clone(),
                    offset: 0,
                    length: bytes.len() as u64,
                },
                expected: Ok(bytes.clone()),
            },
            // Empty path
            Case {
                location: DataLocation {
                    path: String::new(),
                    offset: 0,
                    length: 0,
                },
                expected: Err("disallowed path".into()),
            },
            // Path containing parent directory
            Case {
                location: DataLocation {
                    path: "../foo.data".into(),
                    offset: 0,
                    length: 0,
                },
                expected: Err("disallowed path".into()),
            },
            // Path with disallowed extension
            Case {
                location: DataLocation {
                    path: "not_a_data_file.md".into(),
                    offset: 0,
                    length: 0,
                },
                expected: Err("disallowed path".into()),
            },
            // File does not exist
            Case {
                location: DataLocation {
                    path: "file_does_not_exist.data".into(),
                    offset: 0,
                    length: 0,
                },
                expected: Err("No such file or directory".into()),
            },
            // Range extends beyond end of file
            Case {
                location: DataLocation {
                    path: data_filename,
                    offset: 0,
                    length: 36,
                },
                expected: Err("file too short".into()),
            },
        ];

        cases.test_each(|case| {
            let loader = make_loader(model_file.path()).unwrap();
            let data = loader.load(&case.location).map_err(|e| e.to_string());
            match (&data, &case.expected) {
                (Ok(actual), Ok(expected)) => assert_eq!(actual.data(), expected),
                (Err(actual), Err(expected)) => assert!(
                    actual.contains(expected),
                    "{} does not contain {}",
                    actual,
                    expected
                ),
                (actual, expected) => assert_eq!(actual.is_ok(), expected.is_ok()),
            }
        });
    }

    #[test]
    fn test_file_loader() {
        test_loader("test_file_loader", FileLoader::new)
    }

    #[cfg(feature = "mmap")]
    #[test]
    fn test_mmap_loader() {
        use super::MmapLoader;
        test_loader("test_mmap_loader", |model_path| unsafe {
            MmapLoader::new(model_path)
        })
    }

    #[test]
    fn test_mem_loader() {
        test_loader("test_mem_loader", |_model_path| {
            let mut map = HashMap::new();
            let buf = (0..32).collect();
            let storage = Arc::new(ConstantStorage::Buffer(buf));
            map.insert("test_mem_loader.onnx.data".to_string(), storage);
            Ok(MemLoader::new(map))
        })
    }
}