Skip to main content

oxibonsai_runtime/
convenience.rs

1//! High-level convenience functions for common OxiBonsai operations.
2
3use crate::error::{RuntimeError, RuntimeResult};
4
5/// Generation result with metadata.
6#[derive(Debug, Clone)]
7pub struct GenerationResult {
8    /// The generated text.
9    pub text: String,
10    /// Token IDs of the generated tokens.
11    pub token_ids: Vec<u32>,
12    /// Number of tokens in the prompt.
13    pub prompt_tokens: usize,
14    /// Number of generated tokens.
15    pub generated_tokens: usize,
16    /// Generation speed in tokens per second.
17    pub tokens_per_second: f64,
18    /// Reason generation stopped (e.g. "stop", "length", "error").
19    pub finish_reason: String,
20}
21
22/// Simple token generation statistics.
23#[derive(Debug, Clone, Default)]
24pub struct TokenStats {
25    /// Total tokens (prompt + completion).
26    pub total_tokens: usize,
27    /// Number of tokens in the prompt.
28    pub prompt_tokens: usize,
29    /// Number of generated tokens.
30    pub completion_tokens: usize,
31    /// Time to first token in milliseconds.
32    pub time_to_first_token_ms: f64,
33    /// Average generation speed in tokens per second.
34    pub tokens_per_second: f64,
35}
36
37/// Information about a model file.
38#[derive(Debug, Clone)]
39pub struct ModelFileInfo {
40    /// File path.
41    pub path: String,
42    /// File size in bytes.
43    pub size_bytes: u64,
44    /// Detected format description.
45    pub format: String,
46    /// Whether the file appears to be a valid GGUF file.
47    pub is_valid_gguf: bool,
48}
49
50/// Validate that a model file exists and has the correct format.
51///
52/// Checks for file existence, reads the magic number, and verifies
53/// it matches the GGUF format (magic = 0x46554747).
54pub fn validate_model_file(path: &str) -> RuntimeResult<ModelFileInfo> {
55    let metadata = std::fs::metadata(path).map_err(|e| {
56        if e.kind() == std::io::ErrorKind::NotFound {
57            RuntimeError::FileNotFound {
58                path: path.to_string(),
59            }
60        } else {
61            RuntimeError::Io(e)
62        }
63    })?;
64
65    if !metadata.is_file() {
66        return Err(RuntimeError::Config(format!(
67            "path '{}' is not a regular file",
68            path
69        )));
70    }
71
72    let size_bytes = metadata.len();
73
74    // Check GGUF magic number (first 4 bytes)
75    let mut is_valid_gguf = false;
76    let mut format = "unknown".to_string();
77
78    if size_bytes >= 4 {
79        let file = std::fs::File::open(path).map_err(RuntimeError::Io)?;
80        let mut reader = std::io::BufReader::new(file);
81        let mut magic_bytes = [0u8; 4];
82        use std::io::Read;
83        if reader.read_exact(&mut magic_bytes).is_ok() {
84            let magic = u32::from_le_bytes(magic_bytes);
85            if magic == 0x46554747 {
86                is_valid_gguf = true;
87                format = "GGUF".to_string();
88            } else {
89                format = format!("unknown (magic: 0x{:08X})", magic);
90            }
91        }
92    }
93
94    Ok(ModelFileInfo {
95        path: path.to_string(),
96        size_bytes,
97        format,
98        is_valid_gguf,
99    })
100}
101
102/// Memory usage estimate for model inference.
103///
104/// # Example
105///
106/// ```
107/// use oxibonsai_runtime::convenience::estimate_memory_requirements;
108///
109/// let est = estimate_memory_requirements(
110///     1_000_000_000, // 1 GB model
111///     4096,          // max sequence length
112///     8,             // KV heads
113///     128,           // head dim
114///     36,            // layers
115/// );
116/// assert!(est.total_bytes > est.model_weights_bytes);
117/// assert!(est.fits_in_memory);
118/// ```
119#[derive(Debug, Clone)]
120pub struct MemoryEstimate {
121    /// Memory required for model weights.
122    pub model_weights_bytes: u64,
123    /// Memory required for KV cache.
124    pub kv_cache_bytes: u64,
125    /// Estimated runtime overhead (buffers, activations, etc.).
126    pub runtime_overhead_bytes: u64,
127    /// Total estimated memory requirement.
128    pub total_bytes: u64,
129    /// Whether the model fits in available memory (heuristic check).
130    pub fits_in_memory: bool,
131}
132
133/// Estimate memory requirements for inference.
134///
135/// This provides a rough estimate based on model dimensions.
136/// For 1-bit models, weight memory is significantly reduced compared
137/// to FP16/FP32 models.
138///
139/// # Parameters
140/// - `model_size_bytes`: Size of the model file on disk.
141/// - `max_seq_len`: Maximum sequence length for KV cache.
142/// - `num_kv_heads`: Number of KV attention heads.
143/// - `head_dim`: Dimension of each attention head.
144/// - `num_layers`: Number of transformer layers.
145pub fn estimate_memory_requirements(
146    model_size_bytes: u64,
147    max_seq_len: usize,
148    num_kv_heads: usize,
149    head_dim: usize,
150    num_layers: usize,
151) -> MemoryEstimate {
152    let model_weights_bytes = model_size_bytes;
153
154    // KV cache: 2 (K+V) * num_layers * num_kv_heads * head_dim * max_seq_len * 4 bytes (f32)
155    let kv_cache_bytes =
156        2u64 * num_layers as u64 * num_kv_heads as u64 * head_dim as u64 * max_seq_len as u64 * 4;
157
158    // Runtime overhead: ~10% of model weights + some fixed overhead for activations
159    let runtime_overhead_bytes = model_weights_bytes / 10 + 256 * 1024 * 1024; // +256MB base
160
161    let total_bytes = model_weights_bytes + kv_cache_bytes + runtime_overhead_bytes;
162
163    // Heuristic: check against a reasonable memory budget (e.g. 90% of typical systems)
164    // For now, we just check if total is under 64GB which covers most systems
165    let fits_in_memory = total_bytes < 64 * 1024 * 1024 * 1024;
166
167    MemoryEstimate {
168        model_weights_bytes,
169        kv_cache_bytes,
170        runtime_overhead_bytes,
171        total_bytes,
172        fits_in_memory,
173    }
174}
175
176/// Format a token count for human-readable display.
177///
178/// # Example
179///
180/// ```
181/// use oxibonsai_runtime::convenience::format_token_count;
182///
183/// assert_eq!(format_token_count(42), "42 tokens");
184/// assert_eq!(format_token_count(1_500), "1.5K tokens");
185/// assert_eq!(format_token_count(3_500_000), "3.5M tokens");
186/// ```
187pub fn format_token_count(count: usize) -> String {
188    if count < 1_000 {
189        format!("{} tokens", count)
190    } else if count < 1_000_000 {
191        format!("{:.1}K tokens", count as f64 / 1_000.0)
192    } else if count < 1_000_000_000 {
193        format!("{:.1}M tokens", count as f64 / 1_000_000.0)
194    } else {
195        format!("{:.1}B tokens", count as f64 / 1_000_000_000.0)
196    }
197}
198
199/// Format a byte count for human-readable display.
200///
201/// # Example
202///
203/// ```
204/// use oxibonsai_runtime::convenience::format_bytes;
205///
206/// assert_eq!(format_bytes(512), "512 B");
207/// assert_eq!(format_bytes(1024), "1.00 KB");
208/// assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
209/// assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
210/// ```
211pub fn format_bytes(bytes: u64) -> String {
212    const KB: u64 = 1024;
213    const MB: u64 = 1024 * KB;
214    const GB: u64 = 1024 * MB;
215    const TB: u64 = 1024 * GB;
216
217    if bytes < KB {
218        format!("{} B", bytes)
219    } else if bytes < MB {
220        format!("{:.2} KB", bytes as f64 / KB as f64)
221    } else if bytes < GB {
222        format!("{:.2} MB", bytes as f64 / MB as f64)
223    } else if bytes < TB {
224        format!("{:.2} GB", bytes as f64 / GB as f64)
225    } else {
226        format!("{:.2} TB", bytes as f64 / TB as f64)
227    }
228}
229
230/// Format a duration for human-readable display.
231///
232/// Examples: "123ms", "1.23s", "5m 30s", "1h 15m"
233pub fn format_duration(duration: std::time::Duration) -> String {
234    let total_ms = duration.as_millis();
235
236    if total_ms < 1_000 {
237        format!("{}ms", total_ms)
238    } else if total_ms < 60_000 {
239        format!("{:.2}s", duration.as_secs_f64())
240    } else if total_ms < 3_600_000 {
241        let minutes = duration.as_secs() / 60;
242        let seconds = duration.as_secs() % 60;
243        format!("{}m {}s", minutes, seconds)
244    } else {
245        let hours = duration.as_secs() / 3600;
246        let minutes = (duration.as_secs() % 3600) / 60;
247        format!("{}h {}m", hours, minutes)
248    }
249}
250
251/// Format tokens per second for display.
252///
253/// Examples: "23.4 t/s", "0.5 t/s", "150.0 t/s"
254pub fn format_tokens_per_second(tps: f64) -> String {
255    if tps < 0.0 {
256        "0.0 t/s".to_string()
257    } else if tps < 10.0 {
258        format!("{:.2} t/s", tps)
259    } else if tps < 1000.0 {
260        format!("{:.1} t/s", tps)
261    } else {
262        format!("{:.0} t/s", tps)
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    // ── format_token_count ──
271
272    #[test]
273    fn format_token_count_small() {
274        assert_eq!(format_token_count(0), "0 tokens");
275        assert_eq!(format_token_count(42), "42 tokens");
276        assert_eq!(format_token_count(999), "999 tokens");
277    }
278
279    #[test]
280    fn format_token_count_thousands() {
281        assert_eq!(format_token_count(1_000), "1.0K tokens");
282        assert_eq!(format_token_count(1_234), "1.2K tokens");
283        assert_eq!(format_token_count(999_999), "1000.0K tokens");
284    }
285
286    #[test]
287    fn format_token_count_millions() {
288        assert_eq!(format_token_count(1_000_000), "1.0M tokens");
289        assert_eq!(format_token_count(3_500_000), "3.5M tokens");
290    }
291
292    #[test]
293    fn format_token_count_billions() {
294        assert_eq!(format_token_count(1_000_000_000), "1.0B tokens");
295    }
296
297    // ── format_bytes ──
298
299    #[test]
300    fn format_bytes_small() {
301        assert_eq!(format_bytes(0), "0 B");
302        assert_eq!(format_bytes(512), "512 B");
303        assert_eq!(format_bytes(1023), "1023 B");
304    }
305
306    #[test]
307    fn format_bytes_kb() {
308        assert_eq!(format_bytes(1024), "1.00 KB");
309        assert_eq!(format_bytes(1536), "1.50 KB");
310    }
311
312    #[test]
313    fn format_bytes_mb() {
314        assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
315        assert_eq!(format_bytes(512 * 1024 * 1024), "512.00 MB");
316    }
317
318    #[test]
319    fn format_bytes_gb() {
320        assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
321        assert_eq!(
322            format_bytes(2 * 1024 * 1024 * 1024 + 300 * 1024 * 1024),
323            "2.29 GB"
324        );
325    }
326
327    #[test]
328    fn format_bytes_tb() {
329        assert_eq!(format_bytes(1024u64 * 1024 * 1024 * 1024), "1.00 TB");
330    }
331
332    // ── format_duration ──
333
334    #[test]
335    fn format_duration_ms() {
336        assert_eq!(format_duration(std::time::Duration::from_millis(0)), "0ms");
337        assert_eq!(
338            format_duration(std::time::Duration::from_millis(123)),
339            "123ms"
340        );
341        assert_eq!(
342            format_duration(std::time::Duration::from_millis(999)),
343            "999ms"
344        );
345    }
346
347    #[test]
348    fn format_duration_seconds() {
349        assert_eq!(
350            format_duration(std::time::Duration::from_millis(1_000)),
351            "1.00s"
352        );
353        assert_eq!(
354            format_duration(std::time::Duration::from_millis(1_230)),
355            "1.23s"
356        );
357    }
358
359    #[test]
360    fn format_duration_minutes() {
361        assert_eq!(
362            format_duration(std::time::Duration::from_secs(90)),
363            "1m 30s"
364        );
365        assert_eq!(
366            format_duration(std::time::Duration::from_secs(330)),
367            "5m 30s"
368        );
369    }
370
371    #[test]
372    fn format_duration_hours() {
373        assert_eq!(
374            format_duration(std::time::Duration::from_secs(4500)),
375            "1h 15m"
376        );
377    }
378
379    // ── format_tokens_per_second ──
380
381    #[test]
382    fn format_tps() {
383        assert_eq!(format_tokens_per_second(-1.0), "0.0 t/s");
384        assert_eq!(format_tokens_per_second(0.0), "0.00 t/s");
385        assert_eq!(format_tokens_per_second(0.5), "0.50 t/s");
386        assert_eq!(format_tokens_per_second(23.4), "23.4 t/s");
387        assert_eq!(format_tokens_per_second(150.0), "150.0 t/s");
388        assert_eq!(format_tokens_per_second(1500.0), "1500 t/s");
389    }
390
391    // ── memory estimation ──
392
393    #[test]
394    fn estimate_memory_basic() {
395        let est = estimate_memory_requirements(
396            1_000_000_000, // ~1GB model
397            4096,          // max_seq_len
398            8,             // num_kv_heads
399            128,           // head_dim
400            36,            // num_layers
401        );
402
403        assert_eq!(est.model_weights_bytes, 1_000_000_000);
404        // KV cache: 2 * 36 * 8 * 128 * 4096 * 4 = 1,207,959,552
405        assert_eq!(est.kv_cache_bytes, 2 * 36 * 8 * 128 * 4096 * 4);
406        assert!(est.total_bytes > est.model_weights_bytes + est.kv_cache_bytes);
407        assert!(est.fits_in_memory);
408    }
409
410    #[test]
411    fn estimate_memory_large_model() {
412        let est = estimate_memory_requirements(
413            100_000_000_000, // 100GB
414            32768,
415            64,
416            128,
417            80,
418        );
419        // This should not fit in 64GB
420        assert!(!est.fits_in_memory);
421    }
422
423    // ── validate_model_file ──
424
425    #[test]
426    fn validate_model_file_nonexistent() {
427        let path = std::env::temp_dir().join("nonexistent_oxibonsai_model_12345.gguf");
428        let result = validate_model_file(path.to_str().expect("path is valid UTF-8"));
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn validate_model_file_not_gguf() {
434        let dir = std::env::temp_dir();
435        let path = dir.join("oxibonsai_test_not_gguf.bin");
436        std::fs::write(&path, b"this is not a gguf file").expect("write temp file");
437
438        let result = validate_model_file(&path.to_string_lossy());
439        assert!(result.is_ok());
440        let info = result.expect("should return info");
441        assert!(!info.is_valid_gguf);
442        assert!(info.format.contains("unknown"));
443
444        let _ = std::fs::remove_file(&path);
445    }
446
447    #[test]
448    fn validate_model_file_valid_gguf_magic() {
449        let dir = std::env::temp_dir();
450        let path = dir.join("oxibonsai_test_gguf_magic.bin");
451        // GGUF magic = 0x46554747 = little-endian bytes [0x47, 0x47, 0x55, 0x46]
452        let mut data = vec![0x47u8, 0x47, 0x55, 0x46];
453        data.extend_from_slice(&[0u8; 100]); // pad with zeros
454        std::fs::write(&path, &data).expect("write temp file");
455
456        let result = validate_model_file(&path.to_string_lossy());
457        assert!(result.is_ok());
458        let info = result.expect("should return info");
459        assert!(info.is_valid_gguf);
460        assert_eq!(info.format, "GGUF");
461        assert!(info.size_bytes > 0);
462
463        let _ = std::fs::remove_file(&path);
464    }
465
466    #[test]
467    fn validate_model_file_empty() {
468        let dir = std::env::temp_dir();
469        let path = dir.join("oxibonsai_test_empty.bin");
470        std::fs::write(&path, b"").expect("write temp file");
471
472        let result = validate_model_file(&path.to_string_lossy());
473        assert!(result.is_ok());
474        let info = result.expect("should return info");
475        assert!(!info.is_valid_gguf);
476
477        let _ = std::fs::remove_file(&path);
478    }
479
480    // ── GenerationResult / TokenStats ──
481
482    #[test]
483    fn generation_result_clone() {
484        let result = GenerationResult {
485            text: "hello".to_string(),
486            token_ids: vec![1, 2, 3],
487            prompt_tokens: 5,
488            generated_tokens: 3,
489            tokens_per_second: 10.0,
490            finish_reason: "stop".to_string(),
491        };
492        let cloned = result.clone();
493        assert_eq!(cloned.text, "hello");
494        assert_eq!(cloned.generated_tokens, 3);
495    }
496
497    #[test]
498    fn token_stats_default() {
499        let stats = TokenStats::default();
500        assert_eq!(stats.total_tokens, 0);
501        assert_eq!(stats.prompt_tokens, 0);
502        assert_eq!(stats.completion_tokens, 0);
503        assert!((stats.time_to_first_token_ms - 0.0).abs() < f64::EPSILON);
504        assert!((stats.tokens_per_second - 0.0).abs() < f64::EPSILON);
505    }
506}