use super::reader::{GgufReader, GgufTensorMeta};
use super::types::{
padding_for_alignment, write_metadata_kv, GgufHeader, GgufValue, GGUF_DEFAULT_ALIGNMENT,
GGUF_VERSION,
};
use crate::error::{AprenderError, Result};
use std::collections::HashSet;
use std::fs::File;
use std::io::{self, BufWriter, Write};
use std::path::{Path, PathBuf};
fn invalid(msg: String) -> AprenderError {
AprenderError::Io(io::Error::new(io::ErrorKind::InvalidData, msg))
}
fn io_err(e: io::Error) -> AprenderError {
AprenderError::Io(io::Error::new(e.kind(), e.to_string()))
}
fn write_string(buf: &mut Vec<u8>, s: &str) {
buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
struct TensorPlan {
name: String,
dims: Vec<u64>,
dtype: u32,
part: usize,
abs_start: usize,
abs_end: usize,
}
pub fn merge_gguf_shards(parts: &[PathBuf], output: &Path) -> Result<()> {
if parts.len() < 2 {
return Err(invalid(format!(
"merge_gguf_shards needs >= 2 parts, got {}",
parts.len()
)));
}
let mut plans: Vec<TensorPlan> = Vec::new();
let mut metadata: Vec<(String, GgufValue)> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for (pi, path) in parts.iter().enumerate() {
let reader = if pi == 0 {
GgufReader::from_file_full(path)?
} else {
GgufReader::from_file(path)?
};
if pi == 0 {
for (k, v) in &reader.metadata {
if !k.starts_with("split.") && k != "general.alignment" {
metadata.push((k.clone(), v.clone()));
}
}
}
let mut metas: Vec<&GgufTensorMeta> = reader.tensors.iter().collect();
metas.sort_by_key(|t| t.offset);
let section_len = reader.data.len().saturating_sub(reader.data_offset);
for (j, m) in metas.iter().enumerate() {
let start = m.offset as usize;
let end = if j + 1 < metas.len() {
metas[j + 1].offset as usize
} else {
section_len
};
let abs_start = reader.data_offset.saturating_add(start);
let abs_end = reader.data_offset.saturating_add(end);
if end < start || abs_end > reader.data.len() {
return Err(invalid(format!(
"corrupt tensor offsets in shard {}",
path.display()
)));
}
if !seen.insert(m.name.clone()) {
return Err(invalid(format!(
"duplicate tensor '{}' across shards (corrupt or non-disjoint split)",
m.name
)));
}
plans.push(TensorPlan {
name: m.name.clone(),
dims: m.dims.clone(),
dtype: m.dtype,
part: pi,
abs_start,
abs_end,
});
}
}
let mut head: Vec<u8> = Vec::new();
GgufHeader {
version: GGUF_VERSION,
tensor_count: plans.len() as u64,
metadata_kv_count: metadata.len() as u64,
}
.write_to(&mut head)?;
for (k, v) in &metadata {
write_metadata_kv(&mut head, k, v)?;
}
let mut running: u64 = 0;
for t in &plans {
write_string(&mut head, &t.name);
head.extend_from_slice(&(t.dims.len() as u32).to_le_bytes());
for d in &t.dims {
head.extend_from_slice(&d.to_le_bytes());
}
head.extend_from_slice(&t.dtype.to_le_bytes());
head.extend_from_slice(&running.to_le_bytes());
let len = (t.abs_end - t.abs_start) as u64;
running = running.saturating_add(len);
running = running
.saturating_add(padding_for_alignment(running as usize, GGUF_DEFAULT_ALIGNMENT) as u64);
}
let file = File::create(output).map_err(io_err)?;
let mut w = BufWriter::new(file);
w.write_all(&head).map_err(io_err)?;
let header_pad = padding_for_alignment(head.len(), GGUF_DEFAULT_ALIGNMENT);
if header_pad > 0 {
w.write_all(&vec![0u8; header_pad]).map_err(io_err)?;
}
let mut loaded: Option<(usize, GgufReader)> = None;
for t in &plans {
let reload = loaded.as_ref().map_or(true, |(pi, _)| *pi != t.part);
if reload {
loaded = Some((t.part, GgufReader::from_file(&parts[t.part])?));
}
let reader = &loaded.as_ref().expect("part just loaded").1;
if t.abs_end > reader.data.len() {
return Err(invalid(format!(
"shard {} shorter than expected on re-read",
parts[t.part].display()
)));
}
let block = &reader.data[t.abs_start..t.abs_end];
w.write_all(block).map_err(io_err)?;
let pad = padding_for_alignment(block.len(), GGUF_DEFAULT_ALIGNMENT);
if pad > 0 {
w.write_all(&vec![0u8; pad]).map_err(io_err)?;
}
}
w.flush().map_err(io_err)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::format::gguf::types::{export_tensors_to_gguf, GgmlType, GgufTensor};
fn write_part(path: &Path, tensors: &[GgufTensor], meta: &[(String, GgufValue)]) {
let mut buf = Vec::new();
export_tensors_to_gguf(&mut buf, tensors, meta).expect("export part");
std::fs::write(path, &buf).expect("write part");
}
fn tmpdir(tag: &str) -> PathBuf {
let d = std::env::temp_dir().join(format!("apr-merge-{}-{}", tag, std::process::id()));
std::fs::create_dir_all(&d).expect("mkdir");
d
}
#[test]
fn merge_two_part_roundtrip() {
let dir = tmpdir("rt");
let p0 = dir.join("model-00001-of-00002.gguf");
let p1 = dir.join("model-00002-of-00002.gguf");
let merged = dir.join("model.gguf");
let a_data = vec![1u8; 16];
let b_data = vec![2u8; 36];
let a = GgufTensor {
name: "blk.0.weight".into(),
shape: vec![4],
dtype: GgmlType::F32,
data: a_data.clone(),
};
let b = GgufTensor {
name: "blk.1.weight".into(),
shape: vec![64],
dtype: GgmlType::Q4_0,
data: b_data.clone(),
};
write_part(
&p0,
&[a],
&[
(
"general.architecture".into(),
GgufValue::String("llama".into()),
),
("split.no".into(), GgufValue::Uint16(0)),
("split.count".into(), GgufValue::Uint16(2)),
],
);
write_part(
&p1,
&[b],
&[
("split.no".into(), GgufValue::Uint16(1)),
("split.count".into(), GgufValue::Uint16(2)),
],
);
merge_gguf_shards(&[p0, p1], &merged).expect("merge");
let r = GgufReader::from_file_full(&merged).expect("re-read merged");
let names: Vec<&str> = r.tensors.iter().map(|t| t.name.as_str()).collect();
assert!(
names.contains(&"blk.0.weight") && names.contains(&"blk.1.weight"),
"merged file must contain tensors from BOTH parts, got {names:?}"
);
assert!(
!r.metadata.keys().any(|k| k.starts_with("split.")),
"split.* metadata must be stripped"
);
assert!(
r.metadata.contains_key("general.architecture"),
"general.* metadata must be preserved"
);
for (name, want) in [("blk.0.weight", &a_data), ("blk.1.weight", &b_data)] {
let m = r
.tensors
.iter()
.find(|t| t.name == name)
.unwrap_or_else(|| panic!("tensor {name} missing"));
let start = r.data_offset + m.offset as usize;
assert_eq!(
&r.data[start..start + want.len()],
want.as_slice(),
"tensor {name} bytes must survive merge"
);
}
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn merge_preserves_nonwhitelisted_arch_metadata() {
let dir = tmpdir("gemma");
let p0 = dir.join("model-00001-of-00002.gguf");
let p1 = dir.join("model-00002-of-00002.gguf");
let merged = dir.join("model.gguf");
let t0 = GgufTensor {
name: "blk.0.weight".into(),
shape: vec![4],
dtype: GgmlType::F32,
data: vec![7u8; 16],
};
let t1 = GgufTensor {
name: "blk.1.weight".into(),
shape: vec![4],
dtype: GgmlType::F32,
data: vec![9u8; 16],
};
write_part(
&p0,
&[t0],
&[
(
"general.architecture".into(),
GgufValue::String("gemma".into()),
),
("gemma.embedding_length".into(), GgufValue::Uint32(2048)),
("gemma.block_count".into(), GgufValue::Uint32(18)),
("gemma.attention.head_count".into(), GgufValue::Uint32(8)),
("split.no".into(), GgufValue::Uint16(0)),
("split.count".into(), GgufValue::Uint16(2)),
],
);
write_part(&p1, &[t1], &[("split.no".into(), GgufValue::Uint16(1))]);
merge_gguf_shards(&[p0, p1], &merged).expect("merge");
let r = GgufReader::from_file_full(&merged).expect("re-read merged");
for key in [
"gemma.embedding_length",
"gemma.block_count",
"gemma.attention.head_count",
] {
assert!(
r.metadata.contains_key(key),
"merged gemma model must retain {key}; got {:?}",
r.metadata.keys().collect::<Vec<_>>()
);
}
assert!(!r.metadata.keys().any(|k| k.starts_with("split.")));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn merge_rejects_duplicate_tensor_names() {
let dir = tmpdir("dup");
let p0 = dir.join("model-00001-of-00002.gguf");
let p1 = dir.join("model-00002-of-00002.gguf");
let merged = dir.join("model.gguf");
let dup = |fill: u8| GgufTensor {
name: "blk.0.weight".into(),
shape: vec![4],
dtype: GgmlType::F32,
data: vec![fill; 16],
};
write_part(
&p0,
&[dup(1)],
&[(
"general.architecture".into(),
GgufValue::String("llama".into()),
)],
);
write_part(&p1, &[dup(2)], &[]);
let res = merge_gguf_shards(&[p0, p1], &merged);
assert!(
res.is_err(),
"duplicate tensor name across shards must be rejected"
);
std::fs::remove_dir_all(&dir).ok();
}
}