#[allow(clippy::wildcard_imports)]
use super::*;
#[derive(Debug, Clone, Default)]
pub struct ModelExtra {
pub tree_data: Option<TreeData>,
pub layer_data: Option<Vec<LayerData>>,
pub centroids: Option<AlignedVec<f32>>,
pub metadata: std::collections::HashMap<String, Vec<u8>>,
}
impl ModelExtra {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_tree(mut self, tree: TreeData) -> Self {
self.tree_data = Some(tree);
self
}
#[must_use]
pub fn with_layers(mut self, layers: Vec<LayerData>) -> Self {
self.layer_data = Some(layers);
self
}
#[must_use]
pub fn with_centroids(mut self, centroids: AlignedVec<f32>) -> Self {
self.centroids = Some(centroids);
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: Vec<u8>) -> Self {
self.metadata.insert(key.into(), value);
self
}
#[must_use]
pub fn size_bytes(&self) -> usize {
let tree_size = self.tree_data.as_ref().map_or(0, TreeData::size_bytes);
let layer_size: usize = self
.layer_data
.as_ref()
.map_or(0, |layers| layers.iter().map(LayerData::size_bytes).sum());
let centroid_size = self.centroids.as_ref().map_or(0, AlignedVec::size_bytes);
let metadata_size: usize = self.metadata.values().map(Vec::len).sum();
tree_size + layer_size + centroid_size + metadata_size
}
}
#[derive(Debug, Clone)]
pub struct TreeData {
pub feature_indices: Vec<u16>,
pub thresholds: Vec<f32>,
pub left_children: Vec<i32>,
pub right_children: Vec<i32>,
pub leaf_values: Vec<f32>,
}
impl TreeData {
#[must_use]
pub fn new() -> Self {
Self {
feature_indices: Vec::new(),
thresholds: Vec::new(),
left_children: Vec::new(),
right_children: Vec::new(),
leaf_values: Vec::new(),
}
}
#[must_use]
pub fn n_nodes(&self) -> usize {
self.thresholds.len()
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.feature_indices.len() * 2
+ self.thresholds.len() * 4
+ self.left_children.len() * 4
+ self.right_children.len() * 4
+ self.leaf_values.len() * 4
}
}
impl Default for TreeData {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LayerData {
pub layer_type: LayerType,
pub input_dim: u32,
pub output_dim: u32,
pub weights: Option<AlignedVec<f32>>,
pub biases: Option<AlignedVec<f32>>,
}
impl LayerData {
#[must_use]
pub fn dense(input_dim: u32, output_dim: u32) -> Self {
Self {
layer_type: LayerType::Dense,
input_dim,
output_dim,
weights: None,
biases: None,
}
}
#[must_use]
pub fn with_weights(mut self, weights: AlignedVec<f32>) -> Self {
self.weights = Some(weights);
self
}
#[must_use]
pub fn with_biases(mut self, biases: AlignedVec<f32>) -> Self {
self.biases = Some(biases);
self
}
#[must_use]
pub fn size_bytes(&self) -> usize {
let weights_size = self.weights.as_ref().map_or(0, AlignedVec::size_bytes);
let biases_size = self.biases.as_ref().map_or(0, AlignedVec::size_bytes);
weights_size + biases_size + 12 }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerType {
Dense,
ReLU,
Sigmoid,
Tanh,
Softmax,
Dropout,
BatchNorm,
}
#[derive(Debug, Clone)]
pub enum NativeModelError {
ParamCountMismatch { declared: usize, actual: usize },
InvalidParameter { index: usize, value: f32 },
InvalidBias { index: usize, value: f32 },
FeatureMismatch { expected: usize, got: usize },
MissingParams,
AlignmentError { ptr: usize, required: usize },
}
impl std::fmt::Display for NativeModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ParamCountMismatch { declared, actual } => {
write!(
f,
"Parameter count mismatch: declared {declared}, actual {actual}"
)
}
Self::InvalidParameter { index, value } => {
write!(f, "Invalid parameter at index {index}: {value}")
}
Self::InvalidBias { index, value } => {
write!(f, "Invalid bias at index {index}: {value}")
}
Self::FeatureMismatch { expected, got } => {
write!(f, "Feature mismatch: expected {expected}, got {got}")
}
Self::MissingParams => write!(f, "Missing model parameters"),
Self::AlignmentError { ptr, required } => {
write!(
f,
"Alignment error: ptr 0x{ptr:x} not aligned to {required}"
)
}
}
}
}
impl std::error::Error for NativeModelError {}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;