use std::io::{Cursor, Read};
use crate::error::{Error, Result};
pub const EDGEFIRST_JSON: &str = "edgefirst.json";
pub const LABELS_TXT: &str = "labels.txt";
pub const METADATA_JSON: &str = "metadata.json";
pub struct ModelArchive<'a> {
inner: zip::ZipArchive<Cursor<&'a [u8]>>,
}
impl std::fmt::Debug for ModelArchive<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModelArchive")
.field("len", &self.inner.len())
.finish()
}
}
impl<'a> ModelArchive<'a> {
pub fn new(data: &'a [u8]) -> Result<Self> {
let inner = zip::ZipArchive::new(Cursor::new(data))
.map_err(|e| Error::invalid_argument(format!("no embedded ZIP archive: {e}")))?;
Ok(Self { inner })
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn entry_names(&self) -> impl Iterator<Item = &str> {
self.inner.file_names()
}
pub fn read(&mut self, name: &str) -> Result<Vec<u8>> {
let mut file = self
.inner
.by_name(name)
.map_err(|e| Error::invalid_argument(format!("archive entry {name:?}: {e}")))?;
let mut buf = Vec::with_capacity(usize::try_from(file.size()).unwrap_or(0));
file.read_to_end(&mut buf)
.map_err(|e| Error::invalid_argument(format!("read archive entry {name:?}: {e}")))?;
Ok(buf)
}
pub fn read_to_string(&mut self, name: &str) -> Result<String> {
let mut file = self
.inner
.by_name(name)
.map_err(|e| Error::invalid_argument(format!("archive entry {name:?}: {e}")))?;
let mut s = String::with_capacity(usize::try_from(file.size()).unwrap_or(0));
file.read_to_string(&mut s).map_err(|e| {
Error::invalid_argument(format!("read archive entry {name:?} as utf-8: {e}"))
})?;
Ok(s)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.inner.index_for_name(name).is_some()
}
pub fn edgefirst_json(&mut self) -> Result<String> {
self.read_to_string(EDGEFIRST_JSON)
}
pub fn labels(&mut self) -> Result<Vec<String>> {
let raw = self.read_to_string(LABELS_TXT)?;
Ok(raw
.lines()
.map(str::trim_end)
.filter(|line| !line.is_empty())
.map(str::to_owned)
.collect())
}
}
#[must_use]
pub fn has_archive(data: &[u8]) -> bool {
zip::ZipArchive::new(Cursor::new(data)).is_ok()
}
pub fn edgefirst_json(data: &[u8]) -> Result<String> {
ModelArchive::new(data)?.edgefirst_json()
}
pub fn labels(data: &[u8]) -> Result<Vec<String>> {
ModelArchive::new(data)?.labels()
}
#[cfg(test)]
mod tests {
use super::*;
static MODEL_WITH_ARCHIVE: &[u8] = include_bytes!("../../../testdata/yolov8n-seg-int8.tflite");
static MINIMAL_MODEL: &[u8] = include_bytes!("../../../testdata/minimal.tflite");
static SCHEMA_V2_COMBINED: &[u8] =
include_bytes!("../../../testdata/yolov8n-seg-combined-int8.tflite");
static SCHEMA_V2_LOGICAL: &[u8] =
include_bytes!("../../../testdata/yolov8n-seg-logical-int8.tflite");
static SCHEMA_V2_SMART: &[u8] =
include_bytes!("../../../testdata/yolov8n-seg-smart-int8.tflite");
fn schema_v2_models() -> [(&'static str, &'static [u8]); 3] {
[
("combined", SCHEMA_V2_COMBINED),
("logical", SCHEMA_V2_LOGICAL),
("smart", SCHEMA_V2_SMART),
]
}
#[test]
fn detects_archive_presence() {
assert!(has_archive(MODEL_WITH_ARCHIVE));
assert!(!has_archive(MINIMAL_MODEL));
assert!(!has_archive(&[]));
assert!(!has_archive(&[0u8; 16]));
}
#[test]
fn lists_expected_entries() {
let archive = ModelArchive::new(MODEL_WITH_ARCHIVE).unwrap();
assert!(archive.contains(EDGEFIRST_JSON));
assert!(archive.contains(LABELS_TXT));
assert!(archive.contains(METADATA_JSON));
assert!(!archive.contains("missing.txt"));
assert!(!archive.is_empty());
}
#[test]
fn reads_edgefirst_json() {
let mut archive = ModelArchive::new(MODEL_WITH_ARCHIVE).unwrap();
let json = archive.edgefirst_json().unwrap();
assert!(json.contains("\"decoder_version\""));
assert!(json.starts_with('{') && json.trim_end().ends_with('}'));
}
#[test]
fn reads_labels() {
let mut archive = ModelArchive::new(MODEL_WITH_ARCHIVE).unwrap();
let labels = archive.labels().unwrap();
assert_eq!(labels.len(), 80);
assert_eq!(labels[0], "person");
}
#[test]
fn missing_entry_is_invalid_argument() {
let mut archive = ModelArchive::new(MODEL_WITH_ARCHIVE).unwrap();
let err = archive.read("does-not-exist").unwrap_err();
assert!(err.is_invalid_argument());
}
#[test]
fn one_shot_helpers_roundtrip() {
let json = edgefirst_json(MODEL_WITH_ARCHIVE).unwrap();
assert!(json.contains("decoder_version"));
let labels = labels(MODEL_WITH_ARCHIVE).unwrap();
assert_eq!(labels.len(), 80);
}
#[test]
fn no_archive_is_invalid_argument() {
let err = ModelArchive::new(MINIMAL_MODEL).unwrap_err();
assert!(err.is_invalid_argument());
}
#[test]
fn schema_v2_fixtures_open_cleanly() {
for (name, bytes) in schema_v2_models() {
assert!(has_archive(bytes), "{name}: no embedded archive detected");
let mut archive =
ModelArchive::new(bytes).unwrap_or_else(|e| panic!("{name}: open failed: {e}"));
assert!(
archive.contains(EDGEFIRST_JSON),
"{name}: missing edgefirst.json"
);
assert!(archive.contains(LABELS_TXT), "{name}: missing labels.txt");
let labels = archive
.labels()
.unwrap_or_else(|e| panic!("{name}: labels read failed: {e}"));
assert_eq!(labels.len(), 80, "{name}: expected 80 COCO labels");
assert_eq!(labels[0], "person", "{name}: first label");
}
}
#[test]
fn schema_v2_fixtures_advertise_schema_v2() {
for (name, bytes) in schema_v2_models() {
let mut archive = ModelArchive::new(bytes).expect("open");
let json = archive
.edgefirst_json()
.unwrap_or_else(|e| panic!("{name}: edgefirst.json read: {e}"));
assert!(
json.contains("\"schema_version\": 2") || json.contains("\"schema_version\":2"),
"{name}: edgefirst.json is not schema v2:\n{}",
&json[..json.len().min(200)],
);
assert!(
json.contains("\"decoder_version\": \"yolov8\"")
|| json.contains("\"decoder_version\":\"yolov8\""),
"{name}: decoder_version != yolov8"
);
}
}
type LayoutCase = (
&'static str,
&'static [u8],
&'static [&'static str],
&'static [&'static str],
);
#[test]
fn schema_v2_fixtures_match_expected_layout() {
let cases: &[LayoutCase] = &[
(
"combined",
SCHEMA_V2_COMBINED,
&["\"type\": \"detection\"", "\"type\": \"protos\""],
&["\"type\": \"boxes\"", "\"scale_index\""],
),
(
"logical",
SCHEMA_V2_LOGICAL,
&[
"\"type\": \"boxes\"",
"\"type\": \"scores\"",
"\"type\": \"mask_coefs\"",
"\"type\": \"protos\"",
],
&["\"scale_index\""],
),
(
"smart",
SCHEMA_V2_SMART,
&["\"scale_index\"", "\"type\": \"protos\""],
&[],
),
];
for (name, bytes, must_contain, must_not_contain) in cases {
let mut archive = ModelArchive::new(bytes).expect("open");
let json = archive.edgefirst_json().expect("read");
for needle in *must_contain {
assert!(
json.contains(needle),
"{name}: expected {needle:?} in edgefirst.json"
);
}
for needle in *must_not_contain {
assert!(
!json.contains(needle),
"{name}: did not expect {needle:?} in edgefirst.json"
);
}
}
}
}