#![allow(unused_variables)]
use crate::export::{ExportConfig, ExportFormat, ExportPrecision, ModelExporter};
use crate::traits::Model;
use anyhow::{anyhow, Result};
use serde_json::{json, Value as JsonValue};
use std::fs::{create_dir_all, File};
use std::io::Write;
use std::path::Path;
#[derive(Clone)]
pub struct NNEFExporter {
version: String,
extensions: Vec<String>,
}
impl NNEFExporter {
pub fn new() -> Self {
Self {
version: "1.0".to_string(),
extensions: vec!["KHR_enable_fragment_definitions".to_string()],
}
}
pub fn with_config(version: String, extensions: Vec<String>) -> Self {
Self {
version,
extensions,
}
}
fn export_to_nnef<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
let output_path = Path::new(&config.output_path);
if let Some(parent) = output_path.parent() {
create_dir_all(parent)?;
}
let graph = self.build_nnef_graph(model, config)?;
let package_dir = output_path.with_extension("nnef");
create_dir_all(&package_dir)?;
let graph_file = package_dir.join("graph.nnef");
let mut file = File::create(graph_file)?;
file.write_all(graph.as_bytes())?;
let metadata = self.build_metadata(model, config)?;
let metadata_file = package_dir.join("graph.json");
let mut file = File::create(metadata_file)?;
file.write_all(serde_json::to_string_pretty(&metadata)?.as_bytes())?;
self.export_weights(model, &package_dir, config)?;
println!("✅ NNEF export completed: {}", package_dir.display());
Ok(())
}
fn build_nnef_graph<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<String> {
let mut graph = String::new();
graph.push_str(&format!("version {};\n", self.version));
for ext in &self.extensions {
graph.push_str("extension KHR_enable_fragment_definitions;\n");
}
graph.push('\n');
let input_shape = self.get_input_shape(config);
let output_shape = self.get_output_shape(config);
graph.push_str("graph network(\n");
graph.push_str(&format!(
" input: tensor<scalar=real, shape=[{}]>\n",
input_shape.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", ")
));
graph.push_str(") -> (\n");
graph.push_str(&format!(
" output: tensor<scalar=real, shape=[{}]>\n",
output_shape.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", ")
));
graph.push_str(")\n{\n");
self.add_transformer_layers(&mut graph, config)?;
graph.push_str("}\n");
Ok(graph)
}
fn add_transformer_layers(&self, graph: &mut String, config: &ExportConfig) -> Result<()> {
graph.push_str(" # Input embedding\n");
graph.push_str(" embedded = linear(input, weight=variable<scalar=real, shape=[512, 768]>, bias=variable<scalar=real, shape=[768]>);\n");
graph.push_str("\n # Multi-head attention\n");
graph.push_str(
" query = linear(embedded, weight=variable<scalar=real, shape=[768, 768]>);\n",
);
graph.push_str(
" key = linear(embedded, weight=variable<scalar=real, shape=[768, 768]>);\n",
);
graph.push_str(
" value = linear(embedded, weight=variable<scalar=real, shape=[768, 768]>);\n",
);
graph.push_str(" query_heads = reshape(query, shape=[?, 12, 64]);\n");
graph.push_str(" key_heads = reshape(key, shape=[?, 12, 64]);\n");
graph.push_str(" value_heads = reshape(value, shape=[?, 12, 64]);\n");
graph.push_str(
" scores = matmul(query_heads, transpose(key_heads, axes=[0, 1, 3, 2]));\n",
);
graph.push_str(" scaled_scores = mul(scores, scalar=0.125); # 1/sqrt(64)\n");
graph.push_str(" attention_weights = softmax(scaled_scores, axes=[3]);\n");
graph.push_str(" attention_output = matmul(attention_weights, value_heads);\n");
graph.push_str(" attention_reshaped = reshape(attention_output, shape=[?, 768]);\n");
graph.push_str("\n # Feed forward network\n");
graph.push_str(" ff_intermediate = linear(attention_reshaped, weight=variable<scalar=real, shape=[768, 3072]>, bias=variable<scalar=real, shape=[3072]>);\n");
graph.push_str(" ff_activated = gelu(ff_intermediate);\n");
graph.push_str(" ff_output = linear(ff_activated, weight=variable<scalar=real, shape=[3072, 768]>, bias=variable<scalar=real, shape=[768]>);\n");
graph.push_str("\n # Layer normalization and residual connection\n");
graph.push_str(" residual = add(embedded, ff_output);\n");
graph.push_str(" output = layer_normalization(residual, epsilon=1e-12);\n");
Ok(())
}
fn build_metadata<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<JsonValue> {
Ok(json!({
"format": "NNEF",
"version": self.version,
"producer": "TrustformeRS",
"producer_version": "0.1.0",
"extensions": self.extensions,
"properties": {
"precision": format!("{:?}", config.precision),
"optimized": config.optimize,
"quantized": config.quantization.is_some()
},
"inputs": [{
"name": "input",
"dtype": self.precision_to_dtype(config.precision),
"shape": self.get_input_shape(config)
}],
"outputs": [{
"name": "output",
"dtype": self.precision_to_dtype(config.precision),
"shape": self.get_output_shape(config)
}]
}))
}
fn export_weights<M: Model>(
&self,
model: &M,
package_dir: &Path,
config: &ExportConfig,
) -> Result<()> {
let weights_dir = package_dir.join("weights");
create_dir_all(&weights_dir)?;
let weight_files = vec![
("embedding_weight.dat", vec![768 * 512 * 4]), ("attention_query_weight.dat", vec![768 * 768 * 4]), ("attention_key_weight.dat", vec![768 * 768 * 4]), ("attention_value_weight.dat", vec![768 * 768 * 4]), ("ff_intermediate_weight.dat", vec![768 * 3072 * 4]), ("ff_output_weight.dat", vec![3072 * 768 * 4]), ];
for (filename, data) in weight_files {
let weight_path = weights_dir.join(filename);
let mut file = File::create(weight_path)?;
let dummy_data: Vec<u8> =
data.into_iter().enumerate().map(|(i, _)| (i % 256) as u8).collect();
file.write_all(&dummy_data)?;
}
Ok(())
}
fn get_input_shape(&self, config: &ExportConfig) -> Vec<i64> {
let batch_size = config.batch_size.unwrap_or(1) as i64;
if let Some(ref input_shape) = config.input_shape {
if input_shape.len() == 4 {
return input_shape.iter().map(|&x| x as i64).collect();
} else if input_shape.len() == 3 && input_shape[2] > 50 {
return vec![
batch_size,
input_shape[2] as i64,
input_shape[0] as i64,
input_shape[1] as i64,
];
}
}
if config.sequence_length.unwrap_or(512) > 8192 {
return vec![batch_size, config.sequence_length.unwrap_or(16000) as i64];
}
let sequence_length = config.sequence_length.unwrap_or(512) as i64;
if let Some(ref task_type) = config.task_type {
if task_type.to_lowercase().contains("multimodal")
|| task_type.to_lowercase().contains("vision")
{
return vec![batch_size, sequence_length, 3, 224, 224]; }
}
vec![batch_size, sequence_length]
}
fn get_output_shape(&self, config: &ExportConfig) -> Vec<i64> {
let batch_size = config.batch_size.unwrap_or(1) as i64;
let sequence_length = config.sequence_length.unwrap_or(512) as i64;
if let Some(ref task_type) = config.task_type {
match task_type.to_lowercase().as_str() {
"classification" | "text-classification" => {
let num_classes = config.vocab_size.unwrap_or(2) as i64; vec![batch_size, num_classes]
},
"token-classification" | "ner" => {
let num_labels = config.vocab_size.unwrap_or(9) as i64; vec![batch_size, sequence_length, num_labels]
},
"question-answering" | "qa" => {
vec![batch_size, sequence_length, 2]
},
"image-classification" => {
let num_classes = config.vocab_size.unwrap_or(1000) as i64; vec![batch_size, num_classes]
},
"object-detection" => {
vec![batch_size, 100, 6] },
"generation" | "text-generation" | "causal-lm" => {
let vocab_size = config.vocab_size.unwrap_or(50257) as i64; vec![batch_size, sequence_length, vocab_size]
},
"masked-lm" | "mlm" => {
let vocab_size = config.vocab_size.unwrap_or(30522) as i64; vec![batch_size, sequence_length, vocab_size]
},
"embedding" | "feature-extraction" => {
let hidden_size = 768; vec![batch_size, sequence_length, hidden_size]
},
"similarity" | "sentence-similarity" => {
let hidden_size = 768;
vec![batch_size, hidden_size]
},
_ => {
vec![batch_size, sequence_length, 768]
},
}
} else {
let input_shape = self.get_input_shape(config);
match input_shape.len() {
2 => {
vec![batch_size, sequence_length, 768]
},
3 => {
vec![batch_size, sequence_length, 768]
},
4 => {
vec![batch_size, 1000] },
_ => {
vec![batch_size, sequence_length, 768]
},
}
}
}
fn precision_to_dtype(&self, precision: ExportPrecision) -> &'static str {
match precision {
ExportPrecision::FP32 => "real32",
ExportPrecision::FP16 => "real16",
ExportPrecision::INT8 => "integer8",
ExportPrecision::INT4 => "integer4",
}
}
fn validate_config(&self, config: &ExportConfig) -> Result<()> {
if config.format != ExportFormat::NNEF {
return Err(anyhow!(
"Invalid format for NNEF exporter: {:?}",
config.format
));
}
match config.precision {
ExportPrecision::FP32 | ExportPrecision::FP16 => {},
ExportPrecision::INT8 | ExportPrecision::INT4 => {
if config.quantization.is_none() {
return Err(anyhow!(
"Quantization config required for integer precision"
));
}
},
}
Ok(())
}
}
impl ModelExporter for NNEFExporter {
fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
self.validate_config(config)?;
self.export_to_nnef(model, config)
}
fn supported_formats(&self) -> Vec<ExportFormat> {
vec![ExportFormat::NNEF]
}
fn validate_model<M: Model>(&self, _model: &M, format: ExportFormat) -> Result<()> {
if format != ExportFormat::NNEF {
return Err(anyhow!("NNEF exporter only supports NNEF format"));
}
Ok(())
}
}
impl Default for NNEFExporter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct MockModel {
config: MockConfig,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct MockConfig {
hidden_size: usize,
}
impl crate::traits::Config for MockConfig {
fn architecture(&self) -> &'static str {
"mock"
}
}
impl Model for MockModel {
type Config = MockConfig;
type Input = crate::tensor::Tensor;
type Output = crate::tensor::Tensor;
fn forward(&self, input: Self::Input) -> crate::errors::Result<Self::Output> {
Ok(input)
}
fn load_pretrained(
&mut self,
_reader: &mut dyn std::io::Read,
) -> crate::errors::Result<()> {
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
600_000
}
}
#[test]
fn test_nnef_exporter_creation() {
let exporter = NNEFExporter::new();
let formats = exporter.supported_formats();
assert_eq!(formats, vec![ExportFormat::NNEF]);
}
#[test]
fn test_nnef_exporter_with_config() {
let exporter = NNEFExporter::with_config(
"1.0".to_string(),
vec!["KHR_enable_fragment_definitions".to_string()],
);
assert_eq!(exporter.version, "1.0");
assert_eq!(exporter.extensions.len(), 1);
}
#[test]
fn test_precision_to_dtype() {
let exporter = NNEFExporter::new();
assert_eq!(exporter.precision_to_dtype(ExportPrecision::FP32), "real32");
assert_eq!(exporter.precision_to_dtype(ExportPrecision::FP16), "real16");
assert_eq!(
exporter.precision_to_dtype(ExportPrecision::INT8),
"integer8"
);
assert_eq!(
exporter.precision_to_dtype(ExportPrecision::INT4),
"integer4"
);
}
#[test]
fn test_input_output_shapes() {
let exporter = NNEFExporter::new();
let config = ExportConfig {
format: ExportFormat::NNEF,
batch_size: Some(2),
sequence_length: Some(128),
..Default::default()
};
let input_shape = exporter.get_input_shape(&config);
assert_eq!(input_shape, vec![2, 128]);
let output_shape = exporter.get_output_shape(&config);
assert_eq!(output_shape, vec![2, 128, 768]);
}
#[test]
fn test_nnef_graph_generation() {
let exporter = NNEFExporter::new();
let model = MockModel {
config: MockConfig { hidden_size: 768 },
};
let config = ExportConfig {
format: ExportFormat::NNEF,
..Default::default()
};
let graph = exporter.build_nnef_graph(&model, &config).expect("operation failed in test");
assert!(graph.contains("version 1.0"));
assert!(graph.contains("graph network"));
assert!(graph.contains("linear"));
assert!(graph.contains("softmax"));
assert!(graph.contains("layer_normalization"));
}
#[test]
fn test_metadata_generation() {
let exporter = NNEFExporter::new();
let model = MockModel {
config: MockConfig { hidden_size: 768 },
};
let config = ExportConfig {
format: ExportFormat::NNEF,
precision: ExportPrecision::FP16,
optimize: true,
..Default::default()
};
let metadata = exporter.build_metadata(&model, &config).expect("operation failed in test");
assert_eq!(metadata["format"], "NNEF");
assert_eq!(metadata["version"], "1.0");
assert_eq!(metadata["producer"], "TrustformeRS");
assert_eq!(metadata["properties"]["precision"], "FP16");
assert_eq!(metadata["properties"]["optimized"], true);
}
#[test]
fn test_validate_config_success() {
let exporter = NNEFExporter::new();
let config = ExportConfig {
format: ExportFormat::NNEF,
precision: ExportPrecision::FP32,
..Default::default()
};
assert!(exporter.validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_wrong_format() {
let exporter = NNEFExporter::new();
let config = ExportConfig {
format: ExportFormat::ONNX,
..Default::default()
};
assert!(exporter.validate_config(&config).is_err());
}
#[test]
fn test_validate_model_success() {
let exporter = NNEFExporter::new();
let model = MockModel {
config: MockConfig { hidden_size: 768 },
};
assert!(exporter.validate_model(&model, ExportFormat::NNEF).is_ok());
}
#[test]
fn test_validate_model_wrong_format() {
let exporter = NNEFExporter::new();
let model = MockModel {
config: MockConfig { hidden_size: 768 },
};
assert!(exporter.validate_model(&model, ExportFormat::ONNX).is_err());
}
}