1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use std::ffi::OsStr;
use std::path::Path;
/// File type of a machine learning model.
#[derive(Debug, PartialEq)]
pub enum FileType {
/// RTen file format.
///
/// See `docs/rten-file-format.md` in this repository.
Rten,
/// ONNX file format.
Onnx,
}
impl FileType {
/// Return the file type that corresponds to the extension of `path`.
pub fn from_path(path: &Path) -> Option<Self> {
let ext = path.extension().unwrap_or_default();
if ext.eq_ignore_ascii_case(OsStr::new("rten")) {
Some(FileType::Rten)
} else if ext.eq_ignore_ascii_case(OsStr::new("onnx")) {
Some(FileType::Onnx)
} else {
None
}
}
/// Infer file type from the content of a file.
pub fn from_buffer(data: &[u8]) -> Option<Self> {
let magic: Option<[u8; 4]> = data.get(..4).unwrap_or_default().try_into().ok();
// The checks here are ordered from most to least reliable.
// rten files using the v2 format and later start with a 4-byte file
// type identifier.
if magic == Some(*b"RTEN") {
return Some(FileType::Rten);
}
#[cfg(feature = "onnx_format")]
{
use rten_onnx::onnx::is_onnx_model;
use rten_onnx::protobuf::ValueReader;
// ONNX models are serialized Protocol Buffers messages with no file
// type identifier, so we attempt some lightweight protobuf parsing.
if is_onnx_model(ValueReader::from_buf(data)) {
return Some(FileType::Onnx);
}
}
// rten files using the v1 format don't have a file type identifier.
// They are FlatBuffers messages which start with a u32 offset pointing
// to the root table, as described at
// https://flatbuffers.dev/internals/#encoding-example.
if let Some(root_offset) = magic.map(u32::from_le_bytes)
&& data.len() >= root_offset as usize
{
return Some(FileType::Rten);
}
None
}
}
#[cfg(test)]
mod tests {
use rten_testing::TestCases;
use std::path::Path;
use super::FileType;
#[test]
fn test_file_type_from_path() {
#[derive(Debug)]
struct Case<'a> {
path: &'a Path,
file_type: Option<FileType>,
}
let cases = [
Case {
path: Path::new("foo.rten"),
file_type: Some(FileType::Rten),
},
Case {
path: Path::new("foo.onnx"),
file_type: Some(FileType::Onnx),
},
Case {
path: Path::new("foo.md"),
file_type: None,
},
Case {
path: Path::new("foo.ONNX"),
file_type: Some(FileType::Onnx),
},
Case {
path: Path::new("foo.RTeN"),
file_type: Some(FileType::Rten),
},
];
cases.test_each(|case| {
assert_eq!(FileType::from_path(case.path), case.file_type);
});
}
#[test]
fn test_file_type_from_buffer() {
#[derive(Debug)]
struct Case {
buf: Vec<u8>,
expected: Option<FileType>,
}
let cases = [
Case {
buf: b"RTEN".into(),
expected: Some(FileType::Rten),
},
Case {
buf: b"".into(),
expected: None,
},
Case {
buf: {
(128u32)
.to_le_bytes()
.into_iter()
.chain(std::iter::repeat(0).take(128))
.collect()
},
expected: Some(FileType::Rten),
},
Case {
buf: b"unknown format".into(),
expected: None,
},
];
cases.test_each(|case| {
assert_eq!(FileType::from_buffer(&case.buf), case.expected);
});
}
}