use crate::format::v2::{AprV2ReaderRef, AprV2StreamingWriter};
pub(crate) const STREAMING_THRESHOLD_BYTES: u64 = 4 * 1024 * 1024 * 1024;
pub(crate) fn streaming_quantize_apr_to_q4k(input: &Path, output: &Path) -> Result<usize> {
use crate::bundle::MappedFile;
let mapped = MappedFile::open(input).map_err(|e| AprenderError::FormatError {
message: format!("mmap '{}' failed: {e}", input.display()),
})?;
#[cfg(unix)]
{
let _ = mapped.advise_sequential();
}
let reader =
AprV2ReaderRef::from_bytes(mapped.as_slice()).map_err(|e| AprenderError::FormatError {
message: format!("APR v2 parse of '{}' failed: {e:?}", input.display()),
})?;
let names: Vec<String> = reader
.tensor_names()
.iter()
.map(|s| (*s).to_string())
.collect();
let param_count: u64 = names
.iter()
.filter_map(|n| reader.get_tensor(n))
.map(|e| e.element_count() as u64)
.sum();
let metadata = build_streaming_q4k_metadata(reader.metadata(), param_count);
let mut writer =
AprV2StreamingWriter::new(metadata).map_err(|e| AprenderError::FormatError {
message: format!("streaming writer init failed: {e:?}"),
})?;
for name in &names {
let entry = reader
.get_tensor(name)
.ok_or_else(|| AprenderError::FormatError {
message: format!("tensor '{name}' missing from index"),
})?;
let shape = entry.shape.clone();
let f32_data =
reader
.get_tensor_as_f32(name)
.ok_or_else(|| AprenderError::FormatError {
message: format!(
"failed to dequantize tensor '{name}' (dtype {:?})",
entry.dtype
),
})?;
validate_tensor_values(name, &f32_data)?;
if should_quantize_tensor(name, &shape, f32_data.len()) {
let q4k_bytes = quantize_q4_k_matrix(&f32_data, &shape);
drop(f32_data);
writer
.add_q4k_raw_tensor(name.clone(), shape, &q4k_bytes)
.map_err(|e| AprenderError::FormatError {
message: format!("write q4k '{name}' failed: {e:?}"),
})?;
} else {
writer
.add_f32_tensor(name.clone(), shape, &f32_data)
.map_err(|e| AprenderError::FormatError {
message: format!("write f32 '{name}' failed: {e:?}"),
})?;
}
}
writer
.finalize(output)
.map_err(|e| AprenderError::FormatError {
message: format!("finalize '{}' failed: {e:?}", output.display()),
})?;
Ok(names.len())
}
fn build_streaming_q4k_metadata(
source: &crate::format::v2::AprV2Metadata,
param_count: u64,
) -> crate::format::v2::AprV2Metadata {
let mut meta = source.clone();
meta.quantization = Some(QuantizationMetadata {
quant_type: "q4_k".to_string(),
bits: 4,
block_size: Some(256),
symmetric: false,
});
meta.param_count = param_count;
meta.total_size = 0;
if meta.original_format.is_none() {
meta.original_format = Some("apr".to_string());
}
meta
}
pub(crate) fn qualifies_for_streaming_q4k(path: &Path) -> bool {
use crate::format::rosetta::FormatType;
let threshold = effective_streaming_threshold();
let size = fs::metadata(path).map(|m| m.len()).unwrap_or(0);
if size < threshold {
return false;
}
matches!(FormatType::from_magic(path), Ok(FormatType::Apr))
}
fn effective_streaming_threshold() -> u64 {
#[cfg(test)]
{
let t = STREAMING_THRESHOLD_TEST_OVERRIDE.load(std::sync::atomic::Ordering::Relaxed);
if t != u64::MAX {
return t;
}
}
std::env::var("APR_STREAMING_THRESHOLD")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(STREAMING_THRESHOLD_BYTES)
}
#[cfg(test)]
pub(crate) static STREAMING_THRESHOLD_TEST_OVERRIDE: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(u64::MAX);
#[cfg(test)]
pub(crate) static STREAMING_THRESHOLD_TEST_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub fn streaming_quantize_peak_estimate(path: &Path) -> Option<u64> {
use crate::bundle::MappedFile;
use crate::format::v2::AprV2ReaderRef;
let mapped = MappedFile::open(path).ok()?;
let reader = AprV2ReaderRef::from_bytes(mapped.as_slice()).ok()?;
reader
.tensor_names()
.iter()
.filter_map(|n| reader.get_tensor(n))
.map(|e| {
let n = e.element_count() as u64;
let f32_bytes = n.saturating_mul(4);
let q4k_bytes = n.saturating_mul(9).div_ceil(16);
f32_bytes.saturating_add(q4k_bytes)
})
.max()
}