fn insert_usize_meta(
custom: &mut std::collections::HashMap<String, serde_json::Value>,
key: &str,
value: Option<usize>,
) {
if let Some(v) = value {
custom.insert(
key.to_string(),
serde_json::Value::Number(serde_json::Number::from(v)),
);
}
}
fn insert_model_config_metadata(
cfg: &GgufModelConfig,
custom: &mut std::collections::HashMap<String, serde_json::Value>,
) {
if let Some(arch) = &cfg.architecture {
custom.insert(
"model.architecture".to_string(),
serde_json::Value::String(arch.clone()),
);
}
insert_usize_meta(custom, "model.hidden_size", cfg.hidden_size);
insert_usize_meta(custom, "model.num_layers", cfg.num_layers);
insert_usize_meta(custom, "model.num_heads", cfg.num_heads);
insert_usize_meta(custom, "model.num_kv_heads", cfg.num_kv_heads);
insert_usize_meta(custom, "model.vocab_size", cfg.vocab_size);
insert_usize_meta(custom, "model.intermediate_size", cfg.intermediate_size);
insert_usize_meta(
custom,
"model.max_position_embeddings",
cfg.max_position_embeddings,
);
if let Some(rope_theta) = cfg.rope_theta {
custom.insert(
"model.rope_theta".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(f64::from(rope_theta))
.unwrap_or_else(|| serde_json::Number::from(10000u64)),
),
);
}
if let Some(rms_eps) = cfg.rms_norm_eps {
custom.insert(
"model.rms_norm_eps".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(f64::from(rms_eps))
.unwrap_or_else(|| serde_json::Number::from(0u64)),
),
);
}
insert_usize_meta(custom, "model.rope_type", cfg.rope_type.map(|v| v as usize));
}
fn map_gguf_dtype(dtype: u32, tensor_name: &str) -> Result<TensorDType> {
match dtype {
0 => Ok(TensorDType::F32),
1 => Ok(TensorDType::F16),
12 => Ok(TensorDType::Q4K),
14 => Ok(TensorDType::Q6K),
2 | 3 | 6 | 8 | 13 => {
let (dtype_name, suggestion) = match dtype {
2 => ("Q4_0", "q4_k"),
3 => ("Q4_1", "q4_k"),
6 => ("Q5_0", "q6_k"),
8 => ("Q8_0", "q6_k"),
13 => ("Q5_K", "q6_k"),
_ => unreachable!("outer arm guarantees dtype is 2|3|6|8|13"),
};
Err(AprenderError::FormatError {
message: format!(
"GGUF tensor '{tensor_name}' uses {dtype_name} quantization which APR cannot \
represent exactly. Import requires exact format preservation. \
Use `apr convert --quantize {suggestion}` to convert to a supported format."
),
})
}
7 | 9 => Err(AprenderError::FormatError {
message: format!(
"GGUF dtype {dtype} (Q5_1/Q8_1) for tensor '{tensor_name}' not yet supported. \
Cannot store raw bytes - would violate LAYOUT-002 mandate."
),
}),
_ => Err(AprenderError::FormatError {
message: format!(
"Unsupported GGUF dtype {dtype} for tensor '{tensor_name}'. \
Cannot store raw bytes - would violate LAYOUT-002 mandate."
),
}),
}
}
#[provable_contracts_macros::requires(!tensors.is_empty())]
#[provable_contracts_macros::requires(tokenizer.is_some())]
#[provable_contracts_macros::requires(
model_config.map_or(false, |c| c.hidden_size.is_some() && c.num_layers.is_some() && c.num_heads.is_some() && c.vocab_size.is_some())
)]
pub(crate) fn write_apr_file_raw(
tensors: &BTreeMap<String, GgufRawTensor>,
output: &Path,
_options: &ImportOptions,
tokenizer: Option<&GgufTokenizer>,
model_config: Option<&GgufModelConfig>,
) -> Result<()> {
let (tensors, has_tied_embeddings) = resolve_tied_embeddings(tensors);
let param_count: u64 = tensors
.values()
.map(|t| t.shape.iter().product::<usize>() as u64)
.sum();
let tensor_shapes: serde_json::Map<String, serde_json::Value> = tensors
.iter()
.map(|(name, tensor)| {
let shape_array: Vec<serde_json::Value> = tensor
.shape
.iter()
.map(|&dim| serde_json::Value::Number(serde_json::Number::from(dim as u64)))
.collect();
(name.clone(), serde_json::Value::Array(shape_array))
})
.collect();
let mut custom = std::collections::HashMap::new();
custom.insert(
"tensor_shapes".to_string(),
serde_json::Value::Object(tensor_shapes),
);
if let Some(tok) = tokenizer {
insert_tokenizer_metadata(tok, &mut custom);
}
if let Some(cfg) = model_config {
insert_model_config_metadata(cfg, &mut custom);
}
if has_tied_embeddings {
custom.insert("tied_embeddings".to_string(), serde_json::Value::Bool(true));
}
let metadata = AprV2Metadata {
model_type: model_config
.and_then(|c| c.architecture.clone())
.unwrap_or_else(|| "qwen2".to_string()),
name: model_config.and_then(|c| c.architecture.clone()),
description: Some("GGUF Q4_K model imported with native quantization".to_string()),
author: None,
license: None,
data_source: None,
data_license: None,
version: Some("1.0.0".to_string()),
source: None,
original_format: Some("gguf".to_string()),
created_at: None,
total_size: 0, param_count,
quantization: None, sharding: None,
chat_template: tokenizer.and_then(|t| t.chat_template.clone()),
chat_format: None,
special_tokens: None,
architecture: model_config.and_then(|c| c.architecture.clone()),
hidden_size: model_config.and_then(|c| c.hidden_size),
num_layers: model_config.and_then(|c| c.num_layers),
num_heads: model_config.and_then(|c| c.num_heads),
num_kv_heads: model_config.and_then(|c| c.num_kv_heads),
vocab_size: model_config.and_then(|c| c.vocab_size),
intermediate_size: model_config.and_then(|c| c.intermediate_size),
max_position_embeddings: model_config.and_then(|c| c.max_position_embeddings),
rope_theta: model_config.and_then(|c| c.rope_theta),
rope_type: model_config.and_then(|c| c.rope_type),
rms_norm_eps: model_config.and_then(|c| c.rms_norm_eps),
head_dim: model_config.and_then(|c| c.head_dim),
num_experts: model_config.and_then(|c| c.num_experts),
num_experts_per_tok: model_config.and_then(|c| c.num_experts_per_tok),
moe_intermediate_size: model_config.and_then(|c| c.moe_intermediate_size),
custom,
};
let mut writer = AprV2Writer::new(metadata);
for (name, tensor) in tensors.iter() {
let apr_dtype = map_gguf_dtype(tensor.dtype, name)?;
writer.add_tensor(name, apr_dtype, tensor.shape.clone(), tensor.data.clone());
}
let bytes = writer.write().map_err(|e| AprenderError::FormatError {
message: format!("Failed to serialize APR format: {e}"),
})?;
let mut file = fs::File::create(output).map_err(|e| AprenderError::FormatError {
message: format!("Failed to create output file: {e}"),
})?;
file.write_all(&bytes)
.map_err(|e| AprenderError::FormatError {
message: format!("Failed to write APR file: {e}"),
})?;
Ok(())
}