use super::{ExportConfig, ExportFormat, ExportPrecision, ModelExporter};
use crate::traits::Model;
use anyhow::{anyhow, Result};
use byteorder::{LittleEndian, WriteBytesExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufWriter, Write};
const GGML_MAGIC: u32 = 0x67676d6c; const GGML_VERSION: u32 = 1;
#[derive(Debug, Clone, Copy)]
pub enum GGMLType {
F32 = 0,
F16 = 1,
Q4_0 = 2,
Q4_1 = 3,
Q5_0 = 6,
Q5_1 = 7,
Q8_0 = 8,
Q8_1 = 9,
Q2K = 10,
Q3K = 11,
Q4K = 12,
Q5K = 13,
Q6K = 14,
Q8K = 15,
}
impl GGMLType {
fn from_precision(precision: ExportPrecision) -> Self {
match precision {
ExportPrecision::FP32 => GGMLType::F32,
ExportPrecision::FP16 => GGMLType::F16,
ExportPrecision::INT8 => GGMLType::Q8_0,
ExportPrecision::INT4 => GGMLType::Q4_0,
}
}
#[allow(dead_code)]
fn element_size(&self) -> usize {
match self {
GGMLType::F32 => 4,
GGMLType::F16 => 2,
GGMLType::Q4_0 => 2, GGMLType::Q4_1 => 2,
GGMLType::Q5_0 => 3,
GGMLType::Q5_1 => 3,
GGMLType::Q8_0 => 1,
GGMLType::Q8_1 => 1,
GGMLType::Q2K => 1,
GGMLType::Q3K => 1,
GGMLType::Q4K => 1,
GGMLType::Q5K => 1,
GGMLType::Q6K => 1,
GGMLType::Q8K => 1,
}
}
}
#[derive(Debug)]
pub struct GGMLTensor {
pub name: String,
pub tensor_type: GGMLType,
pub dimensions: Vec<usize>,
pub data: Vec<u8>,
}
#[derive(Debug)]
pub struct GGMLModel {
pub magic: u32,
pub version: u32,
pub vocab_size: usize,
pub context_length: usize,
pub embedding_length: usize,
pub head_count: usize,
pub head_count_kv: usize,
pub layer_count: usize,
pub rope_dimension_count: usize,
pub file_type: u32,
pub tensors: Vec<GGMLTensor>,
pub vocab: HashMap<String, u32>,
}
#[derive(Clone)]
pub struct GGMLExporter {
quantization_enabled: bool,
}
impl Default for GGMLExporter {
fn default() -> Self {
Self::new()
}
}
impl GGMLExporter {
pub fn new() -> Self {
Self {
quantization_enabled: false,
}
}
pub fn with_quantization(mut self, enabled: bool) -> Self {
self.quantization_enabled = enabled;
self
}
fn create_ggml_model<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<GGMLModel> {
let vocab_size = 50257; let context_length = config.sequence_length.unwrap_or(2048);
let embedding_length = 768; let head_count = 12; let head_count_kv = head_count; let layer_count = 12; let rope_dimension_count = embedding_length / head_count;
let file_type = match config.precision {
ExportPrecision::FP32 => 0,
ExportPrecision::FP16 => 1,
ExportPrecision::INT8 => 8,
ExportPrecision::INT4 => 2,
};
let mut tensors = Vec::new();
self.convert_model_weights(model, &mut tensors, config)?;
let mut vocab = HashMap::new();
for i in 0..vocab_size {
vocab.insert(format!("token_{}", i), i as u32);
}
Ok(GGMLModel {
magic: GGML_MAGIC,
version: GGML_VERSION,
vocab_size,
context_length,
embedding_length,
head_count,
head_count_kv,
layer_count,
rope_dimension_count,
file_type,
tensors,
vocab,
})
}
fn convert_model_weights<M: Model>(
&self,
_model: &M,
tensors: &mut Vec<GGMLTensor>,
config: &ExportConfig,
) -> Result<()> {
let tensor_type = GGMLType::from_precision(config.precision);
self.add_tensor(
tensors,
"token_embd.weight",
tensor_type,
vec![50257, 768], &self.generate_dummy_weights(50257 * 768, tensor_type)?,
);
for layer_idx in 0..12 {
let layer_prefix = format!("blk.{}", layer_idx);
self.add_tensor(
tensors,
&format!("{}.attn_q.weight", layer_prefix),
tensor_type,
vec![768, 768],
&self.generate_dummy_weights(768 * 768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.attn_k.weight", layer_prefix),
tensor_type,
vec![768, 768],
&self.generate_dummy_weights(768 * 768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.attn_v.weight", layer_prefix),
tensor_type,
vec![768, 768],
&self.generate_dummy_weights(768 * 768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.attn_output.weight", layer_prefix),
tensor_type,
vec![768, 768],
&self.generate_dummy_weights(768 * 768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.ffn_up.weight", layer_prefix),
tensor_type,
vec![768, 3072],
&self.generate_dummy_weights(768 * 3072, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.ffn_down.weight", layer_prefix),
tensor_type,
vec![3072, 768],
&self.generate_dummy_weights(3072 * 768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.attn_norm.weight", layer_prefix),
tensor_type,
vec![768],
&self.generate_dummy_weights(768, tensor_type)?,
);
self.add_tensor(
tensors,
&format!("{}.ffn_norm.weight", layer_prefix),
tensor_type,
vec![768],
&self.generate_dummy_weights(768, tensor_type)?,
);
}
self.add_tensor(
tensors,
"norm.weight",
tensor_type,
vec![768],
&self.generate_dummy_weights(768, tensor_type)?,
);
self.add_tensor(
tensors,
"output.weight",
tensor_type,
vec![768, 50257],
&self.generate_dummy_weights(768 * 50257, tensor_type)?,
);
Ok(())
}
fn add_tensor(
&self,
tensors: &mut Vec<GGMLTensor>,
name: &str,
tensor_type: GGMLType,
dimensions: Vec<usize>,
data: &[u8],
) {
tensors.push(GGMLTensor {
name: name.to_string(),
tensor_type,
dimensions,
data: data.to_vec(),
});
}
fn generate_dummy_weights(&self, size: usize, tensor_type: GGMLType) -> Result<Vec<u8>> {
let mut data = Vec::new();
use scirs2_core::random::*;
let mut rng = thread_rng();
match tensor_type {
GGMLType::F32 => {
for _ in 0..size {
let val = if rng.random::<f32>() < 0.5 {
let std_dev = (2.0 / (size as f32).sqrt()).sqrt();
rng.random_range(-3.0 * std_dev..3.0 * std_dev)
} else {
let std_dev = (2.0 / size as f32).sqrt();
rng.random_range(-3.0 * std_dev..3.0 * std_dev)
};
data.extend_from_slice(&val.to_le_bytes());
}
},
GGMLType::F16 => {
for _ in 0..size {
let std_dev = (2.0 / (size as f32).sqrt()).sqrt();
let val = rng.random_range(-2.0 * std_dev..2.0 * std_dev);
let f16_val = half::f16::from_f32(val.clamp(-65504.0, 65504.0)); data.extend_from_slice(&f16_val.to_le_bytes());
}
},
GGMLType::Q8_0 => {
let block_size = 32;
let num_blocks = size.div_ceil(block_size);
for _ in 0..num_blocks {
let scale = rng.random_range(0.001..0.1f32);
data.extend_from_slice(&scale.to_le_bytes());
for _ in 0..block_size {
let normalized_val = rng.random_range(-1.0..1.0f32);
let quantized =
(normalized_val / scale * 127.0).round().clamp(-128.0, 127.0) as i8;
data.push(quantized as u8);
}
}
},
GGMLType::Q4_0 => {
for i in (0..size).step_by(2) {
let val1 = (i as f32 * 0.001).sin();
let val2 = ((i + 1) as f32 * 0.001).sin();
let q1 = (val1 * 7.0).round().clamp(-8.0, 7.0) as i8;
let q2 = (val2 * 7.0).round().clamp(-8.0, 7.0) as i8;
let packed = ((q1 & 0xF) | ((q2 & 0xF) << 4)) as u8;
data.push(packed);
}
},
_ => {
return Err(anyhow!("Unsupported tensor type: {:?}", tensor_type));
},
}
Ok(data)
}
fn serialize_ggml_model(&self, model: &GGMLModel, output_path: &str) -> Result<()> {
let file = File::create(format!("{}.ggml", output_path))?;
let mut writer = BufWriter::new(file);
writer.write_u32::<LittleEndian>(model.magic)?;
writer.write_u32::<LittleEndian>(model.version)?;
writer.write_u32::<LittleEndian>(model.vocab_size as u32)?;
writer.write_u32::<LittleEndian>(model.context_length as u32)?;
writer.write_u32::<LittleEndian>(model.embedding_length as u32)?;
writer.write_u32::<LittleEndian>(model.head_count as u32)?;
writer.write_u32::<LittleEndian>(model.head_count_kv as u32)?;
writer.write_u32::<LittleEndian>(model.layer_count as u32)?;
writer.write_u32::<LittleEndian>(model.rope_dimension_count as u32)?;
writer.write_u32::<LittleEndian>(model.file_type)?;
writer.write_u32::<LittleEndian>(model.vocab.len() as u32)?;
for (token, id) in &model.vocab {
writer.write_u32::<LittleEndian>(*id)?;
writer.write_u32::<LittleEndian>(token.len() as u32)?;
writer.write_all(token.as_bytes())?;
}
writer.write_u32::<LittleEndian>(model.tensors.len() as u32)?;
for tensor in &model.tensors {
writer.write_u32::<LittleEndian>(tensor.name.len() as u32)?;
writer.write_all(tensor.name.as_bytes())?;
writer.write_u32::<LittleEndian>(tensor.dimensions.len() as u32)?;
for &dim in &tensor.dimensions {
writer.write_u32::<LittleEndian>(dim as u32)?;
}
writer.write_u32::<LittleEndian>(tensor.tensor_type as u32)?;
writer.write_u32::<LittleEndian>(tensor.data.len() as u32)?;
writer.write_all(&tensor.data)?;
}
writer.flush()?;
Ok(())
}
}
impl ModelExporter for GGMLExporter {
fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
if config.format != ExportFormat::GGML {
return Err(anyhow!("GGMLExporter only supports GGML format"));
}
let ggml_model = self.create_ggml_model(model, config)?;
self.serialize_ggml_model(&ggml_model, &config.output_path)?;
println!("Model exported to {}.ggml", config.output_path);
Ok(())
}
fn supported_formats(&self) -> Vec<ExportFormat> {
vec![ExportFormat::GGML]
}
fn validate_model<M: Model>(&self, _model: &M, format: ExportFormat) -> Result<()> {
if format != ExportFormat::GGML {
return Err(anyhow!("GGMLExporter only supports GGML format"));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ggml_exporter_creation() {
let exporter = GGMLExporter::new();
assert!(!exporter.quantization_enabled);
let exporter_quant = exporter.with_quantization(true);
assert!(exporter_quant.quantization_enabled);
}
#[test]
fn test_ggml_type_conversion() {
assert_eq!(GGMLType::from_precision(ExportPrecision::FP32) as u32, 0);
assert_eq!(GGMLType::from_precision(ExportPrecision::FP16) as u32, 1);
assert_eq!(GGMLType::from_precision(ExportPrecision::INT8) as u32, 8);
assert_eq!(GGMLType::from_precision(ExportPrecision::INT4) as u32, 2);
}
#[test]
fn test_ggml_type_element_size() {
assert_eq!(GGMLType::F32.element_size(), 4);
assert_eq!(GGMLType::F16.element_size(), 2);
assert_eq!(GGMLType::Q8_0.element_size(), 1);
assert_eq!(GGMLType::Q4_0.element_size(), 2);
}
#[test]
fn test_supported_formats() {
let exporter = GGMLExporter::new();
let formats = exporter.supported_formats();
assert_eq!(formats.len(), 1);
assert_eq!(formats[0], ExportFormat::GGML);
}
#[test]
fn test_ggml_constants() {
assert_eq!(GGML_MAGIC, 0x67676d6c);
assert_eq!(GGML_VERSION, 1);
}
#[test]
fn test_dummy_weight_generation() {
let exporter = GGMLExporter::new();
let f32_weights = exporter
.generate_dummy_weights(10, GGMLType::F32)
.expect("operation failed in test");
assert_eq!(f32_weights.len(), 10 * 4);
let f16_weights = exporter
.generate_dummy_weights(10, GGMLType::F16)
.expect("operation failed in test");
assert_eq!(f16_weights.len(), 10 * 2);
let q8_weights = exporter
.generate_dummy_weights(10, GGMLType::Q8_0)
.expect("operation failed in test");
assert_eq!(q8_weights.len(), 36);
let q4_weights = exporter
.generate_dummy_weights(10, GGMLType::Q4_0)
.expect("operation failed in test");
assert_eq!(q4_weights.len(), 5); }
}