#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub hidden_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub vocab_size: usize,
pub intermediate_dim: usize,
pub max_seq_len: usize,
}
const F32_SIZE: usize = 4;
const VRAM_SAFETY_MARGIN: f64 = 0.90;
impl StreamingConfig {
#[must_use]
pub fn estimate_full_cache_vram(&self) -> usize {
let head_dim = self.hidden_dim / self.num_heads;
let kv_dim = self.num_kv_heads * head_dim;
let lm_head_bytes = self.hidden_dim * self.vocab_size * F32_SIZE;
let output_norm_bytes = self.hidden_dim * F32_SIZE;
let per_layer_bytes = self.estimate_layer_vram();
let total_layer_bytes = self.num_layers * per_layer_bytes;
let kv_cache_bytes = 2 * self.num_layers * self.max_seq_len * kv_dim * F32_SIZE;
lm_head_bytes + output_norm_bytes + total_layer_bytes + kv_cache_bytes
}
#[must_use]
pub fn estimate_streaming_vram(&self) -> usize {
let head_dim = self.hidden_dim / self.num_heads;
let kv_dim = self.num_kv_heads * head_dim;
let lm_head_bytes = self.hidden_dim * self.vocab_size * F32_SIZE;
let output_norm_bytes = self.hidden_dim * F32_SIZE;
let layer_buffer_bytes = self.estimate_layer_vram();
let kv_cache_bytes = 2 * self.num_layers * self.max_seq_len * kv_dim * F32_SIZE;
lm_head_bytes + output_norm_bytes + layer_buffer_bytes + kv_cache_bytes
}
#[must_use]
pub fn estimate_layer_vram(&self) -> usize {
let head_dim = self.hidden_dim / self.num_heads;
let kv_dim = self.num_kv_heads * head_dim;
let qkv_out_dim = self.hidden_dim + 2 * kv_dim;
let qkv = self.hidden_dim * qkv_out_dim * F32_SIZE;
let o_proj = self.hidden_dim * self.hidden_dim * F32_SIZE;
let ffn_gate = self.hidden_dim * self.intermediate_dim * F32_SIZE;
let ffn_up = self.hidden_dim * self.intermediate_dim * F32_SIZE;
let ffn_down = self.intermediate_dim * self.hidden_dim * F32_SIZE;
let norms = 2 * self.hidden_dim * F32_SIZE;
qkv + o_proj + ffn_gate + ffn_up + ffn_down + norms
}
#[must_use]
pub fn estimate_prefill_cache_vram(&self, fp8: bool) -> usize {
let bytes_per_elem: usize = if fp8 { 1 } else { 2 }; let head_dim = self.hidden_dim / self.num_heads.max(1);
let kv_dim = self.num_kv_heads * head_dim;
let qkv_out = self.hidden_dim + 2 * kv_dim;
let per_layer = self.hidden_dim * qkv_out
+ self.hidden_dim * self.hidden_dim
+ self.hidden_dim * self.intermediate_dim * 3;
let lm_head = self.hidden_dim * self.vocab_size;
(per_layer * self.num_layers + lm_head) * bytes_per_elem
}
}
#[must_use]
pub fn should_use_streaming(free_vram: usize, config: &StreamingConfig) -> bool {
let full_cache_required = config.estimate_full_cache_vram();
let usable_vram = (free_vram as f64 * VRAM_SAFETY_MARGIN) as usize;
if full_cache_required > usable_vram {
let streaming_required = config.estimate_streaming_vram();
streaming_required <= usable_vram
} else {
false }
}
pub fn check_vram_sufficient(
free_vram: usize,
total_vram: usize,
config: &StreamingConfig,
) -> Result<StreamingMode, String> {
let full_cache_required = config.estimate_full_cache_vram();
let streaming_required = config.estimate_streaming_vram();
let usable_vram = (free_vram as f64 * VRAM_SAFETY_MARGIN) as usize;
let full_mb = full_cache_required / (1024 * 1024);
let streaming_mb = streaming_required / (1024 * 1024);
let free_mb = free_vram / (1024 * 1024);
let total_mb = total_vram / (1024 * 1024);
if full_cache_required <= usable_vram {
Ok(StreamingMode::FullCache)
} else if streaming_required <= usable_vram {
Ok(StreamingMode::LayerStreaming)
} else {
Err(format!(
"Insufficient VRAM for GPU inference (GH-201). \
Full cache: {full_mb} MB, Streaming: {streaming_mb} MB, \
Available: {free_mb} MB (of {total_mb} MB total). \
Solutions: (1) Use GGUF format: `apr run model.gguf`, \
(2) Use CPU inference: `--device cpu`, \
(3) Free GPU memory by closing other applications."
))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamingMode {
FullCache,
LayerStreaming,
}
impl StreamingMode {
#[must_use]
pub fn description(&self) -> &'static str {
match self {
Self::FullCache => "Full Cache (all layers pre-cached on GPU)",
Self::LayerStreaming => "Layer Streaming (weights loaded per-layer)",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn qwen2_1_5b_config() -> StreamingConfig {
StreamingConfig {
hidden_dim: 1536,
num_layers: 28,
num_heads: 12,
num_kv_heads: 2,
vocab_size: 151936,
intermediate_dim: 8960,
max_seq_len: 2048,
}
}
fn small_config() -> StreamingConfig {
StreamingConfig {
hidden_dim: 256,
num_layers: 4,
num_heads: 4,
num_kv_heads: 4,
vocab_size: 1000,
intermediate_dim: 512,
max_seq_len: 512,
}
}
#[test]
fn test_full_cache_vram_qwen2_1_5b() {
let config = qwen2_1_5b_config();
let vram = config.estimate_full_cache_vram();
let vram_mb = vram / (1024 * 1024);
assert!(
vram_mb > 5500 && vram_mb < 7000,
"Expected 5.5-7 GB, got {} MB",
vram_mb
);
}
#[test]
fn test_streaming_vram_much_smaller() {
let config = qwen2_1_5b_config();
let full = config.estimate_full_cache_vram();
let streaming = config.estimate_streaming_vram();
assert!(
streaming < full / 2,
"Streaming ({} MB) should be < half of full cache ({} MB)",
streaming / (1024 * 1024),
full / (1024 * 1024)
);
}
#[test]
fn test_streaming_vram_includes_lm_head_and_kv() {
let config = qwen2_1_5b_config();
let streaming = config.estimate_streaming_vram();
let lm_head = config.hidden_dim * config.vocab_size * F32_SIZE;
assert!(
streaming > lm_head,
"Streaming VRAM should include more than just LM head"
);
}
#[test]
fn test_layer_vram_estimate() {
let config = qwen2_1_5b_config();
let layer = config.estimate_layer_vram();
let layer_mb = layer / (1024 * 1024);
assert!(
layer_mb > 150 && layer_mb < 250,
"Expected 150-250 MB per layer, got {} MB",
layer_mb
);
}
#[test]
fn test_should_use_streaming_small_vram() {
let config = qwen2_1_5b_config();
let free_vram = 2 * 1024 * 1024 * 1024;
assert!(
should_use_streaming(free_vram, &config),
"2GB VRAM should trigger streaming mode"
);
}
#[test]
fn test_should_use_streaming_large_vram() {
let config = qwen2_1_5b_config();
let free_vram = 12 * 1024 * 1024 * 1024;
assert!(
!should_use_streaming(free_vram, &config),
"12GB VRAM should use full cache mode"
);
}
#[test]
fn test_check_vram_sufficient_full_cache() {
let config = small_config();
let free_vram = 1024 * 1024 * 1024; let total_vram = 2 * 1024 * 1024 * 1024;
let result = check_vram_sufficient(free_vram, total_vram, &config);
assert!(result.is_ok());
assert_eq!(result.expect("result"), StreamingMode::FullCache);
}
#[test]
fn test_check_vram_sufficient_streaming() {
let config = qwen2_1_5b_config();
let free_vram = 2 * 1024 * 1024 * 1024; let total_vram = 4 * 1024 * 1024 * 1024;
let result = check_vram_sufficient(free_vram, total_vram, &config);
assert!(result.is_ok());
assert_eq!(result.expect("result"), StreamingMode::LayerStreaming);
}
#[test]
fn test_check_vram_insufficient() {
let config = qwen2_1_5b_config();
let free_vram = 512 * 1024 * 1024; let total_vram = 1024 * 1024 * 1024;
let result = check_vram_sufficient(free_vram, total_vram, &config);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Insufficient VRAM"));
}
#[test]
fn test_streaming_mode_description() {
assert!(StreamingMode::FullCache
.description()
.contains("pre-cached"));
assert!(StreamingMode::LayerStreaming
.description()
.contains("per-layer"));
}
}