batuta/agent/driver/
validate.rs1use std::io::Read as _;
12use std::path::Path;
13
14use crate::agent::result::{AgentError, DriverError};
15
16const APR_MAGIC: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
18const GGUF_MAGIC: [u8; 4] = [0x47, 0x47, 0x55, 0x46];
20
21fn read_header(path: &Path, limit: usize) -> Result<Vec<u8>, AgentError> {
23 let file = std::fs::File::open(path).map_err(|e| {
24 AgentError::Driver(DriverError::InferenceFailed(format!("cannot read model file: {e}")))
25 })?;
26 let mut buf = vec![0u8; limit];
27 let n = file.take(limit as u64).read(&mut buf).map_err(|e| {
28 AgentError::Driver(DriverError::InferenceFailed(format!("cannot read model header: {e}")))
29 })?;
30 buf.truncate(n);
31 Ok(buf)
32}
33
34pub(crate) fn validate_model_file(path: &Path) -> Result<(), AgentError> {
40 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
41
42 let header = read_header(path, 65536)?;
44
45 if header.len() < 4 {
46 return Err(AgentError::Driver(DriverError::InferenceFailed(
47 "model file too small (< 4 bytes)".into(),
48 )));
49 }
50
51 match ext {
52 "apr" => validate_apr_header(&header, path),
53 "gguf" => validate_gguf_header(&header, path),
54 _ => Ok(()), }
56}
57
58pub fn is_valid_model_file(path: &Path) -> bool {
64 validate_model_file(path).is_ok()
65}
66
67fn validate_apr_header(header: &[u8], path: &Path) -> Result<(), AgentError> {
69 if header[..4] != APR_MAGIC {
71 return Err(AgentError::Driver(DriverError::InferenceFailed(format!(
72 "invalid APR file (wrong magic bytes): {}",
73 path.display()
74 ))));
75 }
76
77 let header_str = String::from_utf8_lossy(&header[4..]);
89 let has_tokenizer = header_str.contains("tokenizer.merges")
90 || header_str.contains("tokenizer.vocabulary")
91 || header_str.contains("tokenizer.ggml")
92 || header_str.contains("bpe_ranks")
93 || header_str.contains("token_to_id");
94
95 if !has_tokenizer {
96 return Err(AgentError::Driver(DriverError::InferenceFailed(format!(
97 "APR file missing embedded tokenizer: {}\n\
98 APR format requires a self-contained tokenizer (Jidoka: fail-fast).\n\
99 Re-convert with: apr convert {} -o {}",
100 path.display(),
101 path.with_extension("gguf").display(),
102 path.display(),
103 ))));
104 }
105
106 Ok(())
107}
108
109fn validate_gguf_header(header: &[u8], path: &Path) -> Result<(), AgentError> {
111 if header[..4] != GGUF_MAGIC {
112 return Err(AgentError::Driver(DriverError::InferenceFailed(format!(
113 "invalid GGUF file (wrong magic bytes): {}",
114 path.display()
115 ))));
116 }
117 Ok(())
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
127 fn test_apr_without_tokenizer_rejected_at_boundary() {
128 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
129 let mut data = Vec::new();
130 data.extend_from_slice(&APR_MAGIC);
131 data.extend_from_slice(&[0u8; 100]);
132 std::fs::write(tmp.path(), &data).expect("write");
133
134 let result = validate_model_file(tmp.path());
135 assert!(result.is_err(), "APR without tokenizer must be rejected");
136 let err = result.unwrap_err().to_string();
137 assert!(err.contains("missing embedded tokenizer"), "error must mention tokenizer: {err}");
138 assert!(err.contains("apr convert"), "error must include fix command: {err}");
139 }
140
141 #[test]
142 fn test_apr_with_tokenizer_passes_validation() {
143 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
144 let mut data = Vec::new();
145 data.extend_from_slice(&APR_MAGIC);
146 data.extend_from_slice(br#"{"tokenizer.merges":["a b"],"tokenizer.vocabulary":["hi"]}"#);
148 std::fs::write(tmp.path(), &data).expect("write");
149
150 let result = validate_model_file(tmp.path());
151 assert!(result.is_ok(), "APR with tokenizer should pass: {result:?}");
152 }
153
154 #[test]
155 fn test_apr_with_ggml_tokenizer_passes() {
156 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
157 let mut data = Vec::new();
158 data.extend_from_slice(&APR_MAGIC);
159 data.extend_from_slice(b"tokenizer.ggml.tokens present in this header");
160 std::fs::write(tmp.path(), &data).expect("write");
161
162 let result = validate_model_file(tmp.path());
163 assert!(result.is_ok(), "APR with tokenizer.ggml should pass: {result:?}");
164 }
165
166 #[test]
167 fn test_apr_with_vocab_size_only_rejected() {
168 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
169 let mut data = Vec::new();
170 data.extend_from_slice(&APR_MAGIC);
171 data.extend_from_slice(
172 br#"{"architecture":"qwen2","vocab_size":151936,"hidden_size":1536}"#,
173 );
174 std::fs::write(tmp.path(), &data).expect("write");
175
176 let result = validate_model_file(tmp.path());
177 assert!(result.is_err(), "APR with only vocab_size (no tokenizer data) must be rejected");
178 }
179
180 #[test]
181 fn test_gguf_valid_magic_passes() {
182 let tmp = tempfile::NamedTempFile::with_suffix(".gguf").expect("tmpfile");
183 let mut data = Vec::new();
184 data.extend_from_slice(&GGUF_MAGIC);
185 data.extend_from_slice(&[0u8; 100]);
186 std::fs::write(tmp.path(), &data).expect("write");
187
188 let result = validate_model_file(tmp.path());
189 assert!(result.is_ok(), "valid GGUF should pass: {result:?}");
190 }
191
192 #[test]
193 fn test_gguf_invalid_magic_rejected() {
194 let tmp = tempfile::NamedTempFile::with_suffix(".gguf").expect("tmpfile");
195 std::fs::write(tmp.path(), b"NOT_GGUF_DATA_HERE").expect("write");
196
197 let result = validate_model_file(tmp.path());
198 assert!(result.is_err(), "invalid GGUF must be rejected");
199 assert!(result.unwrap_err().to_string().contains("wrong magic bytes"));
200 }
201
202 #[test]
203 fn test_empty_file_rejected() {
204 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
205 std::fs::write(tmp.path(), b"").expect("write");
206
207 let result = validate_model_file(tmp.path());
208 assert!(result.is_err(), "empty file must be rejected");
209 }
210
211 #[test]
212 fn test_is_valid_model_file_public_api() {
213 let tmp = tempfile::NamedTempFile::with_suffix(".apr").expect("tmpfile");
214 let mut data = Vec::new();
215 data.extend_from_slice(&APR_MAGIC);
216 data.extend_from_slice(&[0u8; 100]);
217 std::fs::write(tmp.path(), &data).expect("write");
218
219 assert!(!is_valid_model_file(tmp.path()), "invalid APR should return false");
220 }
221}