use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
Apr,
Gguf,
SafeTensors,
}
impl std::fmt::Display for ModelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Apr => write!(f, "APR"),
Self::Gguf => write!(f, "GGUF"),
Self::SafeTensors => write!(f, "SafeTensors"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FormatError {
TooShort {
len: usize,
},
UnknownFormat,
HeaderTooLarge {
size: u64,
},
ExtensionMismatch {
detected: ModelFormat,
extension: String,
},
}
impl std::fmt::Display for FormatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooShort { len } => {
write!(
f,
"Data too short for format detection: {len} bytes (need 8)"
)
},
Self::UnknownFormat => write!(f, "Unknown model format (no magic bytes matched)"),
Self::HeaderTooLarge { size } => write!(
f,
"SafeTensors header too large: {size} bytes (max 100MB for DOS protection)"
),
Self::ExtensionMismatch {
detected,
extension,
} => {
write!(
f,
"Extension mismatch: detected {detected} but file has extension .{extension}"
)
},
}
}
}
impl std::error::Error for FormatError {}
pub const APR_MAGIC: &[u8; 4] = b"APRN";
pub const APR_V2_MAGIC: &[u8; 4] = b"APR2";
pub const GGUF_MAGIC: &[u8; 4] = b"GGUF";
pub const MAX_SAFETENSORS_HEADER: u64 = 100_000_000;
pub fn detect_format(data: &[u8]) -> Result<ModelFormat, FormatError> {
if data.len() < 8 {
return Err(FormatError::TooShort { len: data.len() });
}
if &data[0..4] == APR_MAGIC || &data[0..4] == APR_V2_MAGIC {
return Ok(ModelFormat::Apr);
}
if &data[0..4] == GGUF_MAGIC {
return Ok(ModelFormat::Gguf);
}
let header_size = u64::from_le_bytes(data[0..8].try_into().expect("slice is exactly 8 bytes"));
if header_size < MAX_SAFETENSORS_HEADER && header_size > 0 {
return Ok(ModelFormat::SafeTensors);
}
if header_size >= MAX_SAFETENSORS_HEADER {
return Err(FormatError::HeaderTooLarge { size: header_size });
}
Err(FormatError::UnknownFormat)
}
pub fn detect_format_from_path(path: &Path) -> Result<ModelFormat, FormatError> {
let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match extension.to_lowercase().as_str() {
"apr" => Ok(ModelFormat::Apr),
"gguf" => Ok(ModelFormat::Gguf),
"safetensors" => Ok(ModelFormat::SafeTensors),
_ => Err(FormatError::UnknownFormat),
}
}
pub fn detect_and_verify_format(path: &Path, data: &[u8]) -> Result<ModelFormat, FormatError> {
let from_data = detect_format(data)?;
let from_path = detect_format_from_path(path);
if let Ok(path_format) = from_path {
if path_format != from_data {
return Err(FormatError::ExtensionMismatch {
detected: from_data,
extension: path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_string(),
});
}
}
Ok(from_data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_apr_format() {
let data = b"APRNxxxxxxxxxxxxxxxx";
assert_eq!(detect_format(data).unwrap(), ModelFormat::Apr);
}
#[test]
fn test_detect_gguf_format() {
let data = b"GGUFxxxxxxxxxxxxxxxx";
assert_eq!(detect_format(data).unwrap(), ModelFormat::Gguf);
}
#[test]
fn test_detect_safetensors_format() {
let mut data = vec![0u8; 16];
let header_size: u64 = 1000;
data[0..8].copy_from_slice(&header_size.to_le_bytes());
assert_eq!(detect_format(&data).unwrap(), ModelFormat::SafeTensors);
}
#[test]
fn test_detect_format_too_short() {
let data = b"APR"; let result = detect_format(data);
assert!(matches!(result, Err(FormatError::TooShort { len: 3 })));
}
#[test]
fn test_detect_format_empty() {
let data: &[u8] = &[];
let result = detect_format(data);
assert!(matches!(result, Err(FormatError::TooShort { len: 0 })));
}
#[test]
fn test_detect_safetensors_header_too_large() {
let mut data = vec![0u8; 16];
let header_size: u64 = 200_000_000; data[0..8].copy_from_slice(&header_size.to_le_bytes());
let result = detect_format(&data);
assert!(matches!(
result,
Err(FormatError::HeaderTooLarge { size: 200_000_000 })
));
}
#[test]
fn test_detect_unknown_format() {
let data = b"\x00\x00\x00\x00\x00\x00\x00\x00xxxx";
let result = detect_format(data);
assert!(matches!(result, Err(FormatError::UnknownFormat)));
}
#[test]
fn test_detect_format_from_path_apr() {
let path = Path::new("model.apr");
assert_eq!(detect_format_from_path(path).unwrap(), ModelFormat::Apr);
}
#[test]
fn test_detect_format_from_path_gguf() {
let path = Path::new("llama-7b-q4.gguf");
assert_eq!(detect_format_from_path(path).unwrap(), ModelFormat::Gguf);
}
#[test]
fn test_detect_format_from_path_safetensors() {
let path = Path::new("model.safetensors");
assert_eq!(
detect_format_from_path(path).unwrap(),
ModelFormat::SafeTensors
);
}
#[test]
fn test_detect_format_from_path_unknown() {
let path = Path::new("model.bin");
let result = detect_format_from_path(path);
assert!(matches!(result, Err(FormatError::UnknownFormat)));
}
#[test]
fn test_detect_format_from_path_uppercase() {
let path = Path::new("MODEL.APR");
assert_eq!(detect_format_from_path(path).unwrap(), ModelFormat::Apr);
}
#[test]
fn test_detect_and_verify_format_match() {
let path = Path::new("model.apr");
let data = b"APRNxxxxxxxxxxxxxxxx";
assert_eq!(
detect_and_verify_format(path, data).unwrap(),
ModelFormat::Apr
);
}
#[test]
fn test_detect_and_verify_format_mismatch() {
let path = Path::new("model.apr"); let data = b"GGUFxxxxxxxxxxxxxxxx"; let result = detect_and_verify_format(path, data);
assert!(matches!(
result,
Err(FormatError::ExtensionMismatch {
detected: ModelFormat::Gguf,
..
})
));
}
#[test]
fn test_detect_and_verify_unknown_extension_ok() {
let path = Path::new("model.bin");
let data = b"APRNxxxxxxxxxxxxxxxx";
assert_eq!(
detect_and_verify_format(path, data).unwrap(),
ModelFormat::Apr
);
}
#[test]
fn test_model_format_display() {
assert_eq!(format!("{}", ModelFormat::Apr), "APR");
assert_eq!(format!("{}", ModelFormat::Gguf), "GGUF");
assert_eq!(format!("{}", ModelFormat::SafeTensors), "SafeTensors");
}
#[test]
fn test_format_error_display() {
let err = FormatError::TooShort { len: 5 };
assert!(err.to_string().contains("5 bytes"));
let err = FormatError::UnknownFormat;
assert!(err.to_string().contains("Unknown"));
let err = FormatError::HeaderTooLarge { size: 999 };
assert!(err.to_string().contains("999 bytes"));
let err = FormatError::ExtensionMismatch {
detected: ModelFormat::Gguf,
extension: "apr".to_string(),
};
assert!(err.to_string().contains("GGUF"));
assert!(err.to_string().contains(".apr"));
}
#[test]
fn test_magic_constants() {
assert_eq!(APR_MAGIC, b"APRN");
assert_eq!(GGUF_MAGIC, b"GGUF");
assert_eq!(MAX_SAFETENSORS_HEADER, 100_000_000);
}
#[test]
fn test_exactly_8_bytes_safetensors() {
let header_size: u64 = 500;
let data = header_size.to_le_bytes();
assert_eq!(detect_format(&data).unwrap(), ModelFormat::SafeTensors);
}
#[test]
fn test_apr_with_trailing_data() {
let mut data = b"APRN".to_vec();
data.extend_from_slice(&[0u8; 1000]);
assert_eq!(detect_format(&data).unwrap(), ModelFormat::Apr);
}
#[test]
fn test_gguf_with_trailing_data() {
let mut data = b"GGUF".to_vec();
data.extend_from_slice(&[0u8; 1000]);
assert_eq!(detect_format(&data).unwrap(), ModelFormat::Gguf);
}
#[test]
fn test_safetensors_boundary_header_size() {
let mut data = vec![0u8; 16];
let header_size: u64 = MAX_SAFETENSORS_HEADER - 1;
data[0..8].copy_from_slice(&header_size.to_le_bytes());
assert_eq!(detect_format(&data).unwrap(), ModelFormat::SafeTensors);
}
#[test]
fn test_safetensors_exactly_at_limit() {
let mut data = vec![0u8; 16];
let header_size: u64 = MAX_SAFETENSORS_HEADER;
data[0..8].copy_from_slice(&header_size.to_le_bytes());
let result = detect_format(&data);
assert!(matches!(result, Err(FormatError::HeaderTooLarge { .. })));
}
}