use crate::{TokenizedInput, Tokenizer};
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxExportConfig {
pub model_name: String,
pub model_version: i64,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub max_sequence_length: usize,
pub vocab_size: usize,
pub include_attention_mask: bool,
pub include_token_type_ids: bool,
pub pad_token_id: i64,
pub unk_token_id: i64,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub opset_version: i64,
}
impl Default for OnnxExportConfig {
fn default() -> Self {
Self {
model_name: "tokenizer".to_string(),
model_version: 1,
producer_name: "TrustformeRS".to_string(),
producer_version: "1.0.0".to_string(),
domain: "ai.onnx".to_string(),
max_sequence_length: 512,
vocab_size: 50000,
include_attention_mask: true,
include_token_type_ids: false,
pad_token_id: 0,
unk_token_id: 1,
bos_token_id: None,
eos_token_id: None,
opset_version: 15,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum OnnxDataType {
Int32,
Int64,
Float32,
Float64,
String,
Bool,
}
impl OnnxDataType {
pub fn to_onnx_enum(&self) -> i32 {
match self {
OnnxDataType::Int32 => 6,
OnnxDataType::Int64 => 7,
OnnxDataType::Float32 => 1,
OnnxDataType::Float64 => 11,
OnnxDataType::String => 8,
OnnxDataType::Bool => 9,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxTensorInfo {
pub name: String,
pub data_type: OnnxDataType,
pub shape: Vec<i64>,
pub doc_string: Option<String>,
}
impl OnnxTensorInfo {
pub fn new(name: String, data_type: OnnxDataType, shape: Vec<i64>) -> Self {
Self {
name,
data_type,
shape,
doc_string: None,
}
}
pub fn with_doc(mut self, doc: String) -> Self {
self.doc_string = Some(doc);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxNode {
pub name: String,
pub op_type: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: HashMap<String, OnnxAttribute>,
pub doc_string: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OnnxAttribute {
Int(i64),
Float(f32),
String(String),
Ints(Vec<i64>),
Floats(Vec<f32>),
Strings(Vec<String>),
Tensor(OnnxTensorData),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxTensorData {
pub name: String,
pub data_type: OnnxDataType,
pub shape: Vec<i64>,
pub raw_data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxModel {
pub metadata: OnnxModelMetadata,
pub inputs: Vec<OnnxTensorInfo>,
pub outputs: Vec<OnnxTensorInfo>,
pub nodes: Vec<OnnxNode>,
pub initializers: Vec<OnnxTensorData>,
pub value_infos: Vec<OnnxTensorInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxModelMetadata {
pub name: String,
pub version: i64,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub opset_version: i64,
pub doc_string: Option<String>,
pub metadata_props: HashMap<String, String>,
}
pub struct OnnxTokenizerExporter<T: Tokenizer> {
tokenizer: Arc<T>,
config: OnnxExportConfig,
}
impl<T: Tokenizer> OnnxTokenizerExporter<T> {
pub fn new(tokenizer: T, config: OnnxExportConfig) -> Self {
Self {
tokenizer: Arc::new(tokenizer),
config,
}
}
pub fn from_tokenizer(tokenizer: T) -> Self {
Self::new(tokenizer, OnnxExportConfig::default())
}
pub fn export(&self) -> Result<OnnxModel> {
let metadata = self.create_metadata();
let (inputs, outputs) = self.create_io_tensors();
let nodes = self.create_computation_graph()?;
let initializers = self.create_initializers()?;
Ok(OnnxModel {
metadata,
inputs,
outputs,
nodes,
initializers,
value_infos: Vec::new(), })
}
fn create_metadata(&self) -> OnnxModelMetadata {
let mut metadata_props = HashMap::new();
metadata_props.insert(
"max_sequence_length".to_string(),
self.config.max_sequence_length.to_string(),
);
metadata_props.insert("vocab_size".to_string(), self.config.vocab_size.to_string());
metadata_props.insert(
"pad_token_id".to_string(),
self.config.pad_token_id.to_string(),
);
metadata_props.insert(
"unk_token_id".to_string(),
self.config.unk_token_id.to_string(),
);
if let Some(bos_id) = self.config.bos_token_id {
metadata_props.insert("bos_token_id".to_string(), bos_id.to_string());
}
if let Some(eos_id) = self.config.eos_token_id {
metadata_props.insert("eos_token_id".to_string(), eos_id.to_string());
}
OnnxModelMetadata {
name: self.config.model_name.clone(),
version: self.config.model_version,
producer_name: self.config.producer_name.clone(),
producer_version: self.config.producer_version.clone(),
domain: self.config.domain.clone(),
opset_version: self.config.opset_version,
doc_string: Some("ONNX tokenizer model exported from TrustformeRS".to_string()),
metadata_props,
}
}
fn create_io_tensors(&self) -> (Vec<OnnxTensorInfo>, Vec<OnnxTensorInfo>) {
let inputs = vec![OnnxTensorInfo::new(
"input_text".to_string(),
OnnxDataType::String,
vec![-1], )
.with_doc("Input text strings to tokenize".to_string())];
let mut outputs = vec![OnnxTensorInfo::new(
"input_ids".to_string(),
OnnxDataType::Int64,
vec![-1, self.config.max_sequence_length as i64], )
.with_doc("Token IDs for input sequences".to_string())];
if self.config.include_attention_mask {
outputs.push(
OnnxTensorInfo::new(
"attention_mask".to_string(),
OnnxDataType::Int64,
vec![-1, self.config.max_sequence_length as i64],
)
.with_doc("Attention mask indicating real vs padding tokens".to_string()),
);
}
if self.config.include_token_type_ids {
outputs.push(
OnnxTensorInfo::new(
"token_type_ids".to_string(),
OnnxDataType::Int64,
vec![-1, self.config.max_sequence_length as i64],
)
.with_doc("Token type IDs for sequence pair tasks".to_string()),
);
}
(inputs, outputs)
}
fn create_computation_graph(&self) -> Result<Vec<OnnxNode>> {
let mut nodes = Vec::new();
let mut tokenize_attrs = HashMap::new();
tokenize_attrs.insert(
"max_length".to_string(),
OnnxAttribute::Int(self.config.max_sequence_length as i64),
);
tokenize_attrs.insert(
"pad_token_id".to_string(),
OnnxAttribute::Int(self.config.pad_token_id),
);
tokenize_attrs.insert(
"unk_token_id".to_string(),
OnnxAttribute::Int(self.config.unk_token_id),
);
if let Some(bos_id) = self.config.bos_token_id {
tokenize_attrs.insert("bos_token_id".to_string(), OnnxAttribute::Int(bos_id));
}
if let Some(eos_id) = self.config.eos_token_id {
tokenize_attrs.insert("eos_token_id".to_string(), OnnxAttribute::Int(eos_id));
}
let mut tokenize_outputs = vec!["input_ids".to_string()];
if self.config.include_attention_mask {
tokenize_outputs.push("attention_mask".to_string());
}
if self.config.include_token_type_ids {
tokenize_outputs.push("token_type_ids".to_string());
}
nodes.push(OnnxNode {
name: "tokenize".to_string(),
op_type: "TrustformeRSTokenizer".to_string(), inputs: vec!["input_text".to_string(), "vocab_tensor".to_string()],
outputs: tokenize_outputs,
attributes: tokenize_attrs,
doc_string: Some("Main tokenization operation".to_string()),
});
Ok(nodes)
}
fn create_initializers(&self) -> Result<Vec<OnnxTensorData>> {
let mut initializers = Vec::new();
let vocab_data = self.create_vocab_tensor()?;
initializers.push(vocab_data);
if let Ok(merge_data) = self.create_merge_tensor() {
initializers.push(merge_data);
}
Ok(initializers)
}
fn create_vocab_tensor(&self) -> Result<OnnxTensorData> {
let vocab = self.tokenizer.get_vocab();
let mut sorted_vocab: Vec<(String, u32)> = vocab.into_iter().collect();
sorted_vocab.sort_by_key(|(_, id)| *id);
let vocab_size = sorted_vocab.len();
let mut vocab_data = Vec::new();
for (token, _) in sorted_vocab {
vocab_data.extend(token.as_bytes());
vocab_data.push(0); }
let expected_size = self.config.vocab_size;
if vocab_size < expected_size {
for i in vocab_size..expected_size {
let padding_token = format!("[PAD_{}]", i);
vocab_data.extend(padding_token.as_bytes());
vocab_data.push(0);
}
}
Ok(OnnxTensorData {
name: "vocab_tensor".to_string(),
data_type: OnnxDataType::String,
shape: vec![expected_size as i64],
raw_data: vocab_data,
})
}
fn create_merge_tensor(&self) -> Result<OnnxTensorData> {
Ok(OnnxTensorData {
name: "merge_tensor".to_string(),
data_type: OnnxDataType::String,
shape: vec![0, 2], raw_data: Vec::new(),
})
}
pub fn export_to_bytes(&self) -> Result<Vec<u8>> {
let model = self.export()?;
serde_json::to_vec_pretty(&model)
.map_err(|e| anyhow!("Failed to serialize ONNX model: {}", e))
}
pub fn save_to_file(&self, path: &str) -> Result<()> {
let model_bytes = self.export_to_bytes()?;
std::fs::write(path, model_bytes)
.map_err(|e| anyhow!("Failed to write ONNX model to file: {}", e))
}
pub fn tokenizer(&self) -> &T {
&self.tokenizer
}
pub fn config(&self) -> &OnnxExportConfig {
&self.config
}
}
pub struct OnnxTokenizerRuntime {
model_path: String,
#[allow(dead_code)]
session_options: OnnxSessionOptions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxSessionOptions {
pub num_threads: Option<usize>,
pub use_gpu: bool,
pub gpu_device_id: Option<i32>,
pub optimization_level: OnnxOptimizationLevel,
pub enable_mem_pattern: bool,
}
impl Default for OnnxSessionOptions {
fn default() -> Self {
Self {
num_threads: None,
use_gpu: false,
gpu_device_id: None,
optimization_level: OnnxOptimizationLevel::All,
enable_mem_pattern: true,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum OnnxOptimizationLevel {
None,
Basic,
Extended,
All,
}
impl OnnxTokenizerRuntime {
pub fn new(model_path: String, options: OnnxSessionOptions) -> Self {
Self {
model_path,
session_options: options,
}
}
pub fn from_file(model_path: String) -> Self {
Self::new(model_path, OnnxSessionOptions::default())
}
pub fn tokenize(&self, texts: &[String]) -> Result<Vec<TokenizedInput>> {
let mut results = Vec::new();
for text in texts {
let tokenized = self.simulate_onnx_tokenization(text)?;
results.push(tokenized);
}
Ok(results)
}
fn simulate_onnx_tokenization(&self, text: &str) -> Result<TokenizedInput> {
let cleaned_text = self.preprocess_text(text);
let mut input_ids = Vec::new();
let mut offset_mapping = Vec::new();
let words: Vec<&str> = cleaned_text.split_whitespace().collect();
let mut current_offset = 0;
for word in words {
while current_offset < text.len() {
match text.chars().nth(current_offset) {
Some(c) if c.is_whitespace() => current_offset += 1,
_ => break,
}
}
let word_start = current_offset;
let subwords = self.simulate_subword_tokenization(word);
for subword in subwords {
let token_id = self.simulate_vocab_lookup(&subword);
input_ids.push(token_id);
let char_end = current_offset + subword.len();
offset_mapping.push(Some((current_offset, char_end)));
current_offset = char_end;
}
current_offset = word_start + word.len();
}
let mut final_ids = Vec::new();
let mut final_offsets = Vec::new();
final_ids.push(101); final_offsets.push((0, 0));
final_ids.extend(input_ids);
final_offsets.extend(offset_mapping.into_iter().map(|opt| opt.unwrap_or((0, 0))));
final_ids.push(102); final_offsets.push((0, 0));
let seq_len = final_ids.len();
let attention_mask = vec![1u8; seq_len];
let mut special_tokens_mask = vec![0u8; seq_len];
special_tokens_mask[0] = 1; special_tokens_mask[seq_len - 1] = 1;
Ok(TokenizedInput {
input_ids: final_ids,
attention_mask,
token_type_ids: Some(vec![0u32; seq_len]), special_tokens_mask: Some(special_tokens_mask),
offset_mapping: Some(final_offsets),
overflowing_tokens: None,
})
}
fn preprocess_text(&self, text: &str) -> String {
text.trim()
.chars()
.map(|c| if c.is_control() && c != '\n' && c != '\r' && c != '\t' { ' ' } else { c })
.collect::<String>()
.split_whitespace()
.collect::<Vec<&str>>()
.join(" ")
}
fn simulate_subword_tokenization(&self, word: &str) -> Vec<String> {
if word.is_empty() {
return vec![];
}
let mut subwords = Vec::new();
let chars: Vec<char> = word.chars().collect();
let mut i = 0;
while i < chars.len() {
let max_len = (chars.len() - i).min(8); let mut best_len = 1;
for len in (2..=max_len).rev() {
let subword: String = chars[i..i + len].iter().collect();
if self.simulate_vocab_contains(&subword) {
best_len = len;
break;
}
}
let subword: String = chars[i..i + best_len].iter().collect();
if i > 0 {
subwords.push(format!("##{}", subword));
} else {
subwords.push(subword);
}
i += best_len;
}
subwords
}
fn simulate_vocab_lookup(&self, token: &str) -> u32 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
let hash = hasher.finish();
let vocab_size = 30000; (hash % (vocab_size - 1000)) as u32 + 1000 }
fn simulate_vocab_contains(&self, token: &str) -> bool {
if token.len() <= 3 {
return true; }
let common_patterns = [
"##ing", "##ed", "##er", "##ly", "##tion", "##ness", "##able",
];
if common_patterns.iter().any(|&pattern| token.contains(pattern)) {
return true;
}
let hash = token.chars().map(|c| c as u32).sum::<u32>();
hash % 10 < 7
}
pub fn get_metadata(&self) -> Result<HashMap<String, String>> {
let mut metadata = HashMap::new();
metadata.insert("model_path".to_string(), self.model_path.clone());
metadata.insert("framework".to_string(), "ONNX Runtime".to_string());
Ok(metadata)
}
pub fn get_input_specs(&self) -> Result<Vec<OnnxTensorInfo>> {
Ok(vec![OnnxTensorInfo::new(
"input_text".to_string(),
OnnxDataType::String,
vec![-1],
)])
}
pub fn get_output_specs(&self) -> Result<Vec<OnnxTensorInfo>> {
Ok(vec![
OnnxTensorInfo::new("input_ids".to_string(), OnnxDataType::Int64, vec![-1, -1]),
OnnxTensorInfo::new(
"attention_mask".to_string(),
OnnxDataType::Int64,
vec![-1, -1],
),
])
}
}
pub struct OnnxUtils;
impl OnnxUtils {
pub fn validate_model(model: &OnnxModel) -> Result<()> {
if model.inputs.is_empty() {
return Err(anyhow!("Model must have at least one input"));
}
if model.outputs.is_empty() {
return Err(anyhow!("Model must have at least one output"));
}
for node in &model.nodes {
for input in &node.inputs {
if !model.inputs.iter().any(|i| &i.name == input)
&& !model.initializers.iter().any(|i| &i.name == input)
&& !model.nodes.iter().any(|n| n.outputs.contains(input))
{
return Err(anyhow!(
"Node {} has unconnected input: {}",
node.name,
input
));
}
}
}
Ok(())
}
pub fn model_to_string(model: &OnnxModel) -> String {
let mut result = String::new();
result.push_str(&format!("ONNX Model: {}\n", model.metadata.name));
result.push_str(&format!("Version: {}\n", model.metadata.version));
result.push_str(&format!(
"Producer: {} {}\n",
model.metadata.producer_name, model.metadata.producer_version
));
result.push_str("\nInputs:\n");
for input in &model.inputs {
result.push_str(&format!(
" {} [{:?}] {:?}\n",
input.name, input.shape, input.data_type
));
}
result.push_str("\nOutputs:\n");
for output in &model.outputs {
result.push_str(&format!(
" {} [{:?}] {:?}\n",
output.name, output.shape, output.data_type
));
}
result.push_str("\nNodes:\n");
for node in &model.nodes {
result.push_str(&format!(
" {} ({}): {:?} -> {:?}\n",
node.name, node.op_type, node.inputs, node.outputs
));
}
result
}
pub fn estimate_model_size(model: &OnnxModel) -> usize {
let mut size = 0;
for init in &model.initializers {
size += init.raw_data.len();
}
size += model.nodes.len() * 1024;
size
}
pub fn suggest_optimizations(model: &OnnxModel) -> Vec<String> {
let mut suggestions = Vec::new();
if model.nodes.len() > 100 {
suggestions.push("Consider model pruning for large models".to_string());
}
let total_initializer_size: usize =
model.initializers.iter().map(|i| i.raw_data.len()).sum();
if total_initializer_size > 100 * 1024 * 1024 {
suggestions.push("Consider quantization to reduce model size".to_string());
}
suggestions
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::char::CharTokenizer;
use std::collections::HashMap;
fn create_test_char_tokenizer() -> CharTokenizer {
let mut vocab = HashMap::new();
vocab.insert("[PAD]".to_string(), 0);
vocab.insert("[UNK]".to_string(), 1);
vocab.insert("[CLS]".to_string(), 2);
vocab.insert("[SEP]".to_string(), 3);
vocab.insert("h".to_string(), 4);
vocab.insert("e".to_string(), 5);
vocab.insert("l".to_string(), 6);
vocab.insert("o".to_string(), 7);
vocab.insert("w".to_string(), 8);
vocab.insert("r".to_string(), 9);
vocab.insert("d".to_string(), 10);
vocab.insert(" ".to_string(), 11);
vocab.insert("t".to_string(), 12);
vocab.insert("s".to_string(), 13);
CharTokenizer::new(vocab)
}
#[test]
fn test_onnx_export_config() {
let config = OnnxExportConfig::default();
assert_eq!(config.model_name, "tokenizer");
assert_eq!(config.max_sequence_length, 512);
assert!(config.include_attention_mask);
}
#[test]
fn test_onnx_tensor_info() {
let tensor_info = OnnxTensorInfo::new(
"test_tensor".to_string(),
OnnxDataType::Int64,
vec![-1, 512],
)
.with_doc("Test tensor documentation".to_string());
assert_eq!(tensor_info.name, "test_tensor");
assert_eq!(tensor_info.data_type.to_onnx_enum(), 7); assert_eq!(tensor_info.shape, vec![-1, 512]);
assert!(tensor_info.doc_string.is_some());
}
#[test]
fn test_onnx_exporter_creation() {
let tokenizer = create_test_char_tokenizer();
let exporter = OnnxTokenizerExporter::from_tokenizer(tokenizer);
assert_eq!(exporter.config().model_name, "tokenizer");
assert_eq!(exporter.config().max_sequence_length, 512);
}
#[test]
fn test_onnx_model_export() {
let tokenizer = create_test_char_tokenizer();
let exporter = OnnxTokenizerExporter::from_tokenizer(tokenizer);
let model = exporter.export().expect("Operation failed in test");
assert_eq!(model.metadata.name, "tokenizer");
assert!(!model.inputs.is_empty());
assert!(!model.outputs.is_empty());
}
#[test]
fn test_onnx_model_serialization() {
let tokenizer = create_test_char_tokenizer();
let exporter = OnnxTokenizerExporter::from_tokenizer(tokenizer);
let model_bytes = exporter.export_to_bytes().expect("Operation failed in test");
assert!(!model_bytes.is_empty());
}
#[test]
fn test_onnx_runtime_creation() {
let runtime = OnnxTokenizerRuntime::from_file("test_model.onnx".to_string());
let input_specs = runtime.get_input_specs().expect("Operation failed in test");
assert!(!input_specs.is_empty());
assert_eq!(input_specs[0].name, "input_text");
}
#[test]
fn test_onnx_utils_validation() {
let metadata = OnnxModelMetadata {
name: "test".to_string(),
version: 1,
producer_name: "test".to_string(),
producer_version: "1.0".to_string(),
domain: "test".to_string(),
opset_version: 15,
doc_string: None,
metadata_props: HashMap::new(),
};
let model = OnnxModel {
metadata,
inputs: vec![OnnxTensorInfo::new(
"input".to_string(),
OnnxDataType::String,
vec![-1],
)],
outputs: vec![OnnxTensorInfo::new(
"output".to_string(),
OnnxDataType::Int64,
vec![-1, -1],
)],
nodes: Vec::new(),
initializers: Vec::new(),
value_infos: Vec::new(),
};
assert!(OnnxUtils::validate_model(&model).is_ok());
}
#[test]
fn test_onnx_model_to_string() {
let metadata = OnnxModelMetadata {
name: "test_model".to_string(),
version: 1,
producer_name: "TrustformeRS".to_string(),
producer_version: "1.0.0".to_string(),
domain: "ai.onnx".to_string(),
opset_version: 15,
doc_string: None,
metadata_props: HashMap::new(),
};
let model = OnnxModel {
metadata,
inputs: vec![OnnxTensorInfo::new(
"input".to_string(),
OnnxDataType::String,
vec![-1],
)],
outputs: vec![OnnxTensorInfo::new(
"output".to_string(),
OnnxDataType::Int64,
vec![-1, -1],
)],
nodes: Vec::new(),
initializers: Vec::new(),
value_infos: Vec::new(),
};
let model_str = OnnxUtils::model_to_string(&model);
assert!(model_str.contains("test_model"));
assert!(model_str.contains("TrustformeRS"));
}
}