1use crate::error::{RuntimeError, RuntimeResult};
4
5#[derive(Debug, Clone)]
7pub struct GenerationResult {
8 pub text: String,
10 pub token_ids: Vec<u32>,
12 pub prompt_tokens: usize,
14 pub generated_tokens: usize,
16 pub tokens_per_second: f64,
18 pub finish_reason: String,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct TokenStats {
25 pub total_tokens: usize,
27 pub prompt_tokens: usize,
29 pub completion_tokens: usize,
31 pub time_to_first_token_ms: f64,
33 pub tokens_per_second: f64,
35}
36
37#[derive(Debug, Clone)]
39pub struct ModelFileInfo {
40 pub path: String,
42 pub size_bytes: u64,
44 pub format: String,
46 pub is_valid_gguf: bool,
48}
49
50pub fn validate_model_file(path: &str) -> RuntimeResult<ModelFileInfo> {
55 let metadata = std::fs::metadata(path).map_err(|e| {
56 if e.kind() == std::io::ErrorKind::NotFound {
57 RuntimeError::FileNotFound {
58 path: path.to_string(),
59 }
60 } else {
61 RuntimeError::Io(e)
62 }
63 })?;
64
65 if !metadata.is_file() {
66 return Err(RuntimeError::Config(format!(
67 "path '{}' is not a regular file",
68 path
69 )));
70 }
71
72 let size_bytes = metadata.len();
73
74 let mut is_valid_gguf = false;
76 let mut format = "unknown".to_string();
77
78 if size_bytes >= 4 {
79 let file = std::fs::File::open(path).map_err(RuntimeError::Io)?;
80 let mut reader = std::io::BufReader::new(file);
81 let mut magic_bytes = [0u8; 4];
82 use std::io::Read;
83 if reader.read_exact(&mut magic_bytes).is_ok() {
84 let magic = u32::from_le_bytes(magic_bytes);
85 if magic == 0x46554747 {
86 is_valid_gguf = true;
87 format = "GGUF".to_string();
88 } else {
89 format = format!("unknown (magic: 0x{:08X})", magic);
90 }
91 }
92 }
93
94 Ok(ModelFileInfo {
95 path: path.to_string(),
96 size_bytes,
97 format,
98 is_valid_gguf,
99 })
100}
101
102#[derive(Debug, Clone)]
120pub struct MemoryEstimate {
121 pub model_weights_bytes: u64,
123 pub kv_cache_bytes: u64,
125 pub runtime_overhead_bytes: u64,
127 pub total_bytes: u64,
129 pub fits_in_memory: bool,
131}
132
133pub fn estimate_memory_requirements(
146 model_size_bytes: u64,
147 max_seq_len: usize,
148 num_kv_heads: usize,
149 head_dim: usize,
150 num_layers: usize,
151) -> MemoryEstimate {
152 let model_weights_bytes = model_size_bytes;
153
154 let kv_cache_bytes =
156 2u64 * num_layers as u64 * num_kv_heads as u64 * head_dim as u64 * max_seq_len as u64 * 4;
157
158 let runtime_overhead_bytes = model_weights_bytes / 10 + 256 * 1024 * 1024; let total_bytes = model_weights_bytes + kv_cache_bytes + runtime_overhead_bytes;
162
163 let fits_in_memory = total_bytes < 64 * 1024 * 1024 * 1024;
166
167 MemoryEstimate {
168 model_weights_bytes,
169 kv_cache_bytes,
170 runtime_overhead_bytes,
171 total_bytes,
172 fits_in_memory,
173 }
174}
175
176pub fn format_token_count(count: usize) -> String {
188 if count < 1_000 {
189 format!("{} tokens", count)
190 } else if count < 1_000_000 {
191 format!("{:.1}K tokens", count as f64 / 1_000.0)
192 } else if count < 1_000_000_000 {
193 format!("{:.1}M tokens", count as f64 / 1_000_000.0)
194 } else {
195 format!("{:.1}B tokens", count as f64 / 1_000_000_000.0)
196 }
197}
198
199pub fn format_bytes(bytes: u64) -> String {
212 const KB: u64 = 1024;
213 const MB: u64 = 1024 * KB;
214 const GB: u64 = 1024 * MB;
215 const TB: u64 = 1024 * GB;
216
217 if bytes < KB {
218 format!("{} B", bytes)
219 } else if bytes < MB {
220 format!("{:.2} KB", bytes as f64 / KB as f64)
221 } else if bytes < GB {
222 format!("{:.2} MB", bytes as f64 / MB as f64)
223 } else if bytes < TB {
224 format!("{:.2} GB", bytes as f64 / GB as f64)
225 } else {
226 format!("{:.2} TB", bytes as f64 / TB as f64)
227 }
228}
229
230pub fn format_duration(duration: std::time::Duration) -> String {
234 let total_ms = duration.as_millis();
235
236 if total_ms < 1_000 {
237 format!("{}ms", total_ms)
238 } else if total_ms < 60_000 {
239 format!("{:.2}s", duration.as_secs_f64())
240 } else if total_ms < 3_600_000 {
241 let minutes = duration.as_secs() / 60;
242 let seconds = duration.as_secs() % 60;
243 format!("{}m {}s", minutes, seconds)
244 } else {
245 let hours = duration.as_secs() / 3600;
246 let minutes = (duration.as_secs() % 3600) / 60;
247 format!("{}h {}m", hours, minutes)
248 }
249}
250
251pub fn format_tokens_per_second(tps: f64) -> String {
255 if tps < 0.0 {
256 "0.0 t/s".to_string()
257 } else if tps < 10.0 {
258 format!("{:.2} t/s", tps)
259 } else if tps < 1000.0 {
260 format!("{:.1} t/s", tps)
261 } else {
262 format!("{:.0} t/s", tps)
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
273 fn format_token_count_small() {
274 assert_eq!(format_token_count(0), "0 tokens");
275 assert_eq!(format_token_count(42), "42 tokens");
276 assert_eq!(format_token_count(999), "999 tokens");
277 }
278
279 #[test]
280 fn format_token_count_thousands() {
281 assert_eq!(format_token_count(1_000), "1.0K tokens");
282 assert_eq!(format_token_count(1_234), "1.2K tokens");
283 assert_eq!(format_token_count(999_999), "1000.0K tokens");
284 }
285
286 #[test]
287 fn format_token_count_millions() {
288 assert_eq!(format_token_count(1_000_000), "1.0M tokens");
289 assert_eq!(format_token_count(3_500_000), "3.5M tokens");
290 }
291
292 #[test]
293 fn format_token_count_billions() {
294 assert_eq!(format_token_count(1_000_000_000), "1.0B tokens");
295 }
296
297 #[test]
300 fn format_bytes_small() {
301 assert_eq!(format_bytes(0), "0 B");
302 assert_eq!(format_bytes(512), "512 B");
303 assert_eq!(format_bytes(1023), "1023 B");
304 }
305
306 #[test]
307 fn format_bytes_kb() {
308 assert_eq!(format_bytes(1024), "1.00 KB");
309 assert_eq!(format_bytes(1536), "1.50 KB");
310 }
311
312 #[test]
313 fn format_bytes_mb() {
314 assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
315 assert_eq!(format_bytes(512 * 1024 * 1024), "512.00 MB");
316 }
317
318 #[test]
319 fn format_bytes_gb() {
320 assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
321 assert_eq!(
322 format_bytes(2 * 1024 * 1024 * 1024 + 300 * 1024 * 1024),
323 "2.29 GB"
324 );
325 }
326
327 #[test]
328 fn format_bytes_tb() {
329 assert_eq!(format_bytes(1024u64 * 1024 * 1024 * 1024), "1.00 TB");
330 }
331
332 #[test]
335 fn format_duration_ms() {
336 assert_eq!(format_duration(std::time::Duration::from_millis(0)), "0ms");
337 assert_eq!(
338 format_duration(std::time::Duration::from_millis(123)),
339 "123ms"
340 );
341 assert_eq!(
342 format_duration(std::time::Duration::from_millis(999)),
343 "999ms"
344 );
345 }
346
347 #[test]
348 fn format_duration_seconds() {
349 assert_eq!(
350 format_duration(std::time::Duration::from_millis(1_000)),
351 "1.00s"
352 );
353 assert_eq!(
354 format_duration(std::time::Duration::from_millis(1_230)),
355 "1.23s"
356 );
357 }
358
359 #[test]
360 fn format_duration_minutes() {
361 assert_eq!(
362 format_duration(std::time::Duration::from_secs(90)),
363 "1m 30s"
364 );
365 assert_eq!(
366 format_duration(std::time::Duration::from_secs(330)),
367 "5m 30s"
368 );
369 }
370
371 #[test]
372 fn format_duration_hours() {
373 assert_eq!(
374 format_duration(std::time::Duration::from_secs(4500)),
375 "1h 15m"
376 );
377 }
378
379 #[test]
382 fn format_tps() {
383 assert_eq!(format_tokens_per_second(-1.0), "0.0 t/s");
384 assert_eq!(format_tokens_per_second(0.0), "0.00 t/s");
385 assert_eq!(format_tokens_per_second(0.5), "0.50 t/s");
386 assert_eq!(format_tokens_per_second(23.4), "23.4 t/s");
387 assert_eq!(format_tokens_per_second(150.0), "150.0 t/s");
388 assert_eq!(format_tokens_per_second(1500.0), "1500 t/s");
389 }
390
391 #[test]
394 fn estimate_memory_basic() {
395 let est = estimate_memory_requirements(
396 1_000_000_000, 4096, 8, 128, 36, );
402
403 assert_eq!(est.model_weights_bytes, 1_000_000_000);
404 assert_eq!(est.kv_cache_bytes, 2 * 36 * 8 * 128 * 4096 * 4);
406 assert!(est.total_bytes > est.model_weights_bytes + est.kv_cache_bytes);
407 assert!(est.fits_in_memory);
408 }
409
410 #[test]
411 fn estimate_memory_large_model() {
412 let est = estimate_memory_requirements(
413 100_000_000_000, 32768,
415 64,
416 128,
417 80,
418 );
419 assert!(!est.fits_in_memory);
421 }
422
423 #[test]
426 fn validate_model_file_nonexistent() {
427 let path = std::env::temp_dir().join("nonexistent_oxibonsai_model_12345.gguf");
428 let result = validate_model_file(path.to_str().expect("path is valid UTF-8"));
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn validate_model_file_not_gguf() {
434 let dir = std::env::temp_dir();
435 let path = dir.join("oxibonsai_test_not_gguf.bin");
436 std::fs::write(&path, b"this is not a gguf file").expect("write temp file");
437
438 let result = validate_model_file(&path.to_string_lossy());
439 assert!(result.is_ok());
440 let info = result.expect("should return info");
441 assert!(!info.is_valid_gguf);
442 assert!(info.format.contains("unknown"));
443
444 let _ = std::fs::remove_file(&path);
445 }
446
447 #[test]
448 fn validate_model_file_valid_gguf_magic() {
449 let dir = std::env::temp_dir();
450 let path = dir.join("oxibonsai_test_gguf_magic.bin");
451 let mut data = vec![0x47u8, 0x47, 0x55, 0x46];
453 data.extend_from_slice(&[0u8; 100]); std::fs::write(&path, &data).expect("write temp file");
455
456 let result = validate_model_file(&path.to_string_lossy());
457 assert!(result.is_ok());
458 let info = result.expect("should return info");
459 assert!(info.is_valid_gguf);
460 assert_eq!(info.format, "GGUF");
461 assert!(info.size_bytes > 0);
462
463 let _ = std::fs::remove_file(&path);
464 }
465
466 #[test]
467 fn validate_model_file_empty() {
468 let dir = std::env::temp_dir();
469 let path = dir.join("oxibonsai_test_empty.bin");
470 std::fs::write(&path, b"").expect("write temp file");
471
472 let result = validate_model_file(&path.to_string_lossy());
473 assert!(result.is_ok());
474 let info = result.expect("should return info");
475 assert!(!info.is_valid_gguf);
476
477 let _ = std::fs::remove_file(&path);
478 }
479
480 #[test]
483 fn generation_result_clone() {
484 let result = GenerationResult {
485 text: "hello".to_string(),
486 token_ids: vec![1, 2, 3],
487 prompt_tokens: 5,
488 generated_tokens: 3,
489 tokens_per_second: 10.0,
490 finish_reason: "stop".to_string(),
491 };
492 let cloned = result.clone();
493 assert_eq!(cloned.text, "hello");
494 assert_eq!(cloned.generated_tokens, 3);
495 }
496
497 #[test]
498 fn token_stats_default() {
499 let stats = TokenStats::default();
500 assert_eq!(stats.total_tokens, 0);
501 assert_eq!(stats.prompt_tokens, 0);
502 assert_eq!(stats.completion_tokens, 0);
503 assert!((stats.time_to_first_token_ms - 0.0).abs() < f64::EPSILON);
504 assert!((stats.tokens_per_second - 0.0).abs() < f64::EPSILON);
505 }
506}