#![allow(clippy::needless_range_loop)]
use std::collections::HashMap;
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub struct Tensor {
pub shape: Vec<usize>,
pub data: Vec<f64>,
}
impl Tensor {
pub fn new(shape: Vec<usize>, data: Vec<f64>) -> Self {
let expected: usize = shape.iter().product();
assert_eq!(
data.len(),
expected,
"data length {} does not match shape {:?} (product {})",
data.len(),
shape,
expected
);
Tensor { shape, data }
}
pub fn zeros(shape: Vec<usize>) -> Self {
let n: usize = shape.iter().product();
Tensor {
shape,
data: vec![0.0; n],
}
}
pub fn numel(&self) -> usize {
self.data.len()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(8 + 8 * self.shape.len() + 8 * self.data.len());
buf.extend_from_slice(&(self.shape.len() as u64).to_le_bytes());
for &d in &self.shape {
buf.extend_from_slice(&(d as u64).to_le_bytes());
}
for &v in &self.data {
buf.extend_from_slice(&v.to_bits().to_le_bytes());
}
buf
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 8 {
return None;
}
let ndim = u64::from_le_bytes(bytes[0..8].try_into().ok()?) as usize;
let header_len = 8 + 8 * ndim;
if bytes.len() < header_len {
return None;
}
let mut shape = Vec::with_capacity(ndim);
for i in 0..ndim {
let off = 8 + 8 * i;
shape.push(u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?) as usize);
}
let n: usize = shape.iter().product();
if bytes.len() < header_len + 8 * n {
return None;
}
let mut data = Vec::with_capacity(n);
for i in 0..n {
let off = header_len + 8 * i;
let bits = u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?);
data.push(f64::from_bits(bits));
}
Some(Tensor { shape, data })
}
pub fn add(&self, other: &Tensor) -> Option<Tensor> {
if self.shape != other.shape {
return None;
}
let data = self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a + b)
.collect();
Some(Tensor {
shape: self.shape.clone(),
data,
})
}
pub fn scale(&self, s: f64) -> Tensor {
Tensor {
shape: self.shape.clone(),
data: self.data.iter().map(|v| v * s).collect(),
}
}
pub fn sum(&self) -> f64 {
self.data.iter().sum()
}
pub fn mean(&self) -> f64 {
if self.data.is_empty() {
return 0.0;
}
self.sum() / self.data.len() as f64
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct DenseLayer {
pub name: String,
pub weights: Tensor,
pub bias: Tensor,
pub activation: String,
}
impl DenseLayer {
pub fn new(
name: impl Into<String>,
in_features: usize,
out_features: usize,
activation: impl Into<String>,
) -> Self {
DenseLayer {
name: name.into(),
weights: Tensor::zeros(vec![out_features, in_features]),
bias: Tensor::zeros(vec![out_features]),
activation: activation.into(),
}
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
let in_feat = input.len();
let out_feat = self.bias.data.len();
let mut out = vec![0.0f64; out_feat];
for i in 0..out_feat {
let mut acc = self.bias.data[i];
for j in 0..in_feat.min(self.weights.data.len() / out_feat) {
acc += self.weights.data[i * in_feat + j] * input[j];
}
out[i] = apply_activation(acc, &self.activation);
}
out
}
pub fn param_count(&self) -> usize {
self.weights.numel() + self.bias.numel()
}
pub fn to_bytes(&self) -> Vec<u8> {
let name_bytes = self.name.as_bytes();
let act_bytes = self.activation.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(name_bytes);
buf.extend_from_slice(&(act_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(act_bytes);
let wb = self.weights.to_bytes();
buf.extend_from_slice(&(wb.len() as u64).to_le_bytes());
buf.extend_from_slice(&wb);
let bb = self.bias.to_bytes();
buf.extend_from_slice(&(bb.len() as u64).to_le_bytes());
buf.extend_from_slice(&bb);
buf
}
}
#[allow(dead_code)]
pub fn apply_activation(x: f64, activation: &str) -> f64 {
match activation {
"relu" => x.max(0.0),
"sigmoid" => 1.0 / (1.0 + (-x).exp()),
"tanh" => x.tanh(),
"softplus" => (1.0 + x.exp()).ln(),
"elu" => {
if x >= 0.0 {
x
} else {
x.exp() - 1.0
}
}
"leaky_relu" => {
if x >= 0.0 {
x
} else {
0.01 * x
}
}
_ => x, }
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct ModelWeights {
pub layers: Vec<DenseLayer>,
}
impl ModelWeights {
pub fn new() -> Self {
ModelWeights { layers: Vec::new() }
}
pub fn add_layer(&mut self, layer: DenseLayer) {
self.layers.push(layer);
}
pub fn get_layer(&self, name: &str) -> Option<&DenseLayer> {
self.layers.iter().find(|l| l.name == name)
}
pub fn total_params(&self) -> usize {
self.layers.iter().map(|l| l.param_count()).sum()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(self.layers.len() as u64).to_le_bytes());
for layer in &self.layers {
let lb = layer.to_bytes();
buf.extend_from_slice(&(lb.len() as u64).to_le_bytes());
buf.extend_from_slice(&lb);
}
buf
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct StateDict {
pub tensors: HashMap<String, Tensor>,
}
impl StateDict {
pub fn new() -> Self {
StateDict {
tensors: HashMap::new(),
}
}
pub fn insert(&mut self, key: impl Into<String>, tensor: Tensor) {
self.tensors.insert(key.into(), tensor);
}
pub fn get(&self, key: &str) -> Option<&Tensor> {
self.tensors.get(key)
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn total_params(&self) -> usize {
self.tensors.values().map(|t| t.numel()).sum()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(self.tensors.len() as u64).to_le_bytes());
let mut keys: Vec<&String> = self.tensors.keys().collect();
keys.sort(); for k in keys {
let kb = k.as_bytes();
buf.extend_from_slice(&(kb.len() as u64).to_le_bytes());
buf.extend_from_slice(kb);
let tb = self.tensors[k].to_bytes();
buf.extend_from_slice(&(tb.len() as u64).to_le_bytes());
buf.extend_from_slice(&tb);
}
buf
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
let mut pos = 0usize;
let n = read_u64(bytes, &mut pos)? as usize;
let mut dict = StateDict::new();
for _ in 0..n {
let klen = read_u64(bytes, &mut pos)? as usize;
if pos + klen > bytes.len() {
return None;
}
let key = String::from_utf8(bytes[pos..pos + klen].to_vec()).ok()?;
pos += klen;
let tlen = read_u64(bytes, &mut pos)? as usize;
if pos + tlen > bytes.len() {
return None;
}
let tensor = Tensor::from_bytes(&bytes[pos..pos + tlen])?;
pos += tlen;
dict.insert(key, tensor);
}
Some(dict)
}
}
fn read_u64(bytes: &[u8], pos: &mut usize) -> Option<u64> {
if *pos + 8 > bytes.len() {
return None;
}
let v = u64::from_le_bytes(bytes[*pos..*pos + 8].try_into().ok()?);
*pos += 8;
Some(v)
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct OnnxNode {
pub name: String,
pub op_type: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: HashMap<String, f64>,
}
impl OnnxNode {
pub fn new(
name: impl Into<String>,
op_type: impl Into<String>,
inputs: Vec<String>,
outputs: Vec<String>,
) -> Self {
OnnxNode {
name: name.into(),
op_type: op_type.into(),
inputs,
outputs,
attributes: HashMap::new(),
}
}
pub fn with_attr(mut self, key: impl Into<String>, value: f64) -> Self {
self.attributes.insert(key.into(), value);
self
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct OnnxLikeGraph {
pub nodes: Vec<OnnxNode>,
pub initializers: StateDict,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub name: String,
}
impl OnnxLikeGraph {
pub fn new(name: impl Into<String>) -> Self {
OnnxLikeGraph {
name: name.into(),
nodes: Vec::new(),
initializers: StateDict::new(),
inputs: Vec::new(),
outputs: Vec::new(),
}
}
pub fn add_node(&mut self, node: OnnxNode) {
self.nodes.push(node);
}
pub fn add_initializer(&mut self, name: impl Into<String>, tensor: Tensor) {
self.initializers.insert(name, tensor);
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn count_op(&self, op: &str) -> usize {
self.nodes.iter().filter(|n| n.op_type == op).count()
}
pub fn is_topologically_valid(&self) -> bool {
let mut available: std::collections::HashSet<&str> =
self.inputs.iter().map(|s| s.as_str()).collect();
for k in self.initializers.tensors.keys() {
available.insert(k.as_str());
}
for node in &self.nodes {
for inp in &node.inputs {
if !available.contains(inp.as_str()) {
return false;
}
}
for out in &node.outputs {
available.insert(out.as_str());
}
}
true
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct DataRow {
pub features: Vec<f64>,
pub label: Option<usize>,
}
impl DataRow {
pub fn labelled(features: Vec<f64>, label: usize) -> Self {
DataRow {
features,
label: Some(label),
}
}
pub fn unlabelled(features: Vec<f64>) -> Self {
DataRow {
features,
label: None,
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct Dataset {
pub rows: Vec<DataRow>,
pub feature_names: Vec<String>,
pub class_names: Vec<String>,
}
impl Dataset {
pub fn new() -> Self {
Dataset {
rows: Vec::new(),
feature_names: Vec::new(),
class_names: Vec::new(),
}
}
pub fn push(&mut self, row: DataRow) {
self.rows.push(row);
}
pub fn len(&self) -> usize {
self.rows.len()
}
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
pub fn num_features(&self) -> usize {
self.rows.first().map(|r| r.features.len()).unwrap_or(0)
}
pub fn shuffle(&mut self, seed: u64) {
let n = self.rows.len();
if n < 2 {
return;
}
let mut rng = LcgRng::new(seed);
for i in (1..n).rev() {
let j = rng.next_usize_below(i + 1);
self.rows.swap(i, j);
}
}
pub fn train_val_split(&self, val_fraction: f64) -> (Dataset, Dataset) {
let val_count = ((self.rows.len() as f64) * val_fraction.clamp(0.0, 1.0)) as usize;
let train_count = self.rows.len().saturating_sub(val_count);
let mut train = Dataset {
rows: self.rows[..train_count].to_vec(),
feature_names: self.feature_names.clone(),
class_names: self.class_names.clone(),
};
let mut val = Dataset {
rows: self.rows[train_count..].to_vec(),
feature_names: self.feature_names.clone(),
class_names: self.class_names.clone(),
};
let _ = &mut train;
let _ = &mut val;
(train, val)
}
pub fn feature_stats(&self) -> (Vec<f64>, Vec<f64>) {
let nf = self.num_features();
if nf == 0 || self.rows.is_empty() {
return (vec![], vec![]);
}
let n = self.rows.len() as f64;
let mut means = vec![0.0f64; nf];
for row in &self.rows {
for (k, &v) in row.features.iter().enumerate() {
means[k] += v;
}
}
for m in &mut means {
*m /= n;
}
let mut stds = vec![0.0f64; nf];
for row in &self.rows {
for (k, &v) in row.features.iter().enumerate() {
let d = v - means[k];
stds[k] += d * d;
}
}
for s in &mut stds {
*s = (*s / n).sqrt();
}
(means, stds)
}
}
#[allow(dead_code)]
struct LcgRng {
state: u64,
}
impl LcgRng {
fn new(seed: u64) -> Self {
LcgRng {
state: seed ^ 0x1234_5678_9abc_def0,
}
}
fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
self.state
}
fn next_usize_below(&mut self, n: usize) -> usize {
if n == 0 {
return 0;
}
(self.next_u64() % n as u64) as usize
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct NormalizationParams {
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub mins: Vec<f64>,
pub maxs: Vec<f64>,
}
impl NormalizationParams {
pub fn from_dataset(dataset: &Dataset) -> Self {
let (means, stds) = dataset.feature_stats();
let nf = means.len();
let mut mins = vec![f64::INFINITY; nf];
let mut maxs = vec![f64::NEG_INFINITY; nf];
for row in &dataset.rows {
for (k, &v) in row.features.iter().enumerate() {
if v < mins[k] {
mins[k] = v;
}
if v > maxs[k] {
maxs[k] = v;
}
}
}
NormalizationParams {
means,
stds,
mins,
maxs,
}
}
pub fn normalize_zscore(&self, features: &[f64]) -> Vec<f64> {
features
.iter()
.enumerate()
.map(|(k, &v)| {
let s = if k < self.stds.len() {
self.stds[k]
} else {
1.0
};
let m = if k < self.means.len() {
self.means[k]
} else {
0.0
};
if s.abs() < 1e-15 { 0.0 } else { (v - m) / s }
})
.collect()
}
pub fn normalize_minmax(&self, features: &[f64]) -> Vec<f64> {
features
.iter()
.enumerate()
.map(|(k, &v)| {
let mn = if k < self.mins.len() {
self.mins[k]
} else {
0.0
};
let mx = if k < self.maxs.len() {
self.maxs[k]
} else {
1.0
};
let range = mx - mn;
if range.abs() < 1e-15 {
0.0
} else {
(v - mn) / range
}
})
.collect()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
let write_vec = |buf: &mut Vec<u8>, v: &[f64]| {
buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
for &x in v {
buf.extend_from_slice(&x.to_bits().to_le_bytes());
}
};
write_vec(&mut buf, &self.means);
write_vec(&mut buf, &self.stds);
write_vec(&mut buf, &self.mins);
write_vec(&mut buf, &self.maxs);
buf
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct LabelEncoder {
pub classes: Vec<String>,
index: HashMap<String, usize>,
}
impl LabelEncoder {
pub fn new() -> Self {
LabelEncoder {
classes: Vec::new(),
index: HashMap::new(),
}
}
pub fn fit(mut class_names: Vec<String>) -> Self {
class_names.sort();
class_names.dedup();
let index = class_names
.iter()
.enumerate()
.map(|(i, s)| (s.clone(), i))
.collect();
LabelEncoder {
classes: class_names,
index,
}
}
pub fn encode(&self, name: &str) -> Option<usize> {
self.index.get(name).copied()
}
pub fn decode(&self, idx: usize) -> Option<&str> {
self.classes.get(idx).map(|s| s.as_str())
}
pub fn num_classes(&self) -> usize {
self.classes.len()
}
pub fn one_hot(&self, idx: usize) -> Vec<f64> {
let mut v = vec![0.0f64; self.num_classes()];
if idx < v.len() {
v[idx] = 1.0;
}
v
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct ConfusionMatrix {
pub num_classes: usize,
pub counts: Vec<u64>,
}
impl ConfusionMatrix {
pub fn new(num_classes: usize) -> Self {
ConfusionMatrix {
num_classes,
counts: vec![0; num_classes * num_classes],
}
}
pub fn record(&mut self, true_label: usize, predicted: usize) {
if true_label < self.num_classes && predicted < self.num_classes {
self.counts[true_label * self.num_classes + predicted] += 1;
}
}
pub fn accuracy(&self) -> f64 {
let total: u64 = self.counts.iter().sum();
if total == 0 {
return 0.0;
}
let correct: u64 = (0..self.num_classes)
.map(|i| self.counts[i * self.num_classes + i])
.sum();
correct as f64 / total as f64
}
pub fn precision(&self, class: usize) -> f64 {
if class >= self.num_classes {
return 0.0;
}
let tp = self.counts[class * self.num_classes + class] as f64;
let fp: f64 = (0..self.num_classes)
.filter(|&r| r != class)
.map(|r| self.counts[r * self.num_classes + class] as f64)
.sum();
if tp + fp < 1e-15 { 0.0 } else { tp / (tp + fp) }
}
pub fn recall(&self, class: usize) -> f64 {
if class >= self.num_classes {
return 0.0;
}
let tp = self.counts[class * self.num_classes + class] as f64;
let fn_: f64 = (0..self.num_classes)
.filter(|&c| c != class)
.map(|c| self.counts[class * self.num_classes + c] as f64)
.sum();
if tp + fn_ < 1e-15 {
0.0
} else {
tp / (tp + fn_)
}
}
pub fn f1(&self, class: usize) -> f64 {
let p = self.precision(class);
let r = self.recall(class);
if p + r < 1e-15 {
0.0
} else {
2.0 * p * r / (p + r)
}
}
pub fn to_csv(&self) -> String {
let mut s = String::new();
s.push_str("true\\pred");
for j in 0..self.num_classes {
s.push_str(&format!(",class_{j}"));
}
s.push('\n');
for i in 0..self.num_classes {
s.push_str(&format!("class_{i}"));
for j in 0..self.num_classes {
s.push_str(&format!(",{}", self.counts[i * self.num_classes + j]));
}
s.push('\n');
}
s
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct EpochRecord {
pub epoch: usize,
pub train_loss: f64,
pub val_loss: f64,
pub train_acc: f64,
pub val_acc: f64,
pub learning_rate: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct TrainingHistory {
pub records: Vec<EpochRecord>,
}
impl TrainingHistory {
pub fn new() -> Self {
TrainingHistory {
records: Vec::new(),
}
}
pub fn push(&mut self, record: EpochRecord) {
self.records.push(record);
}
pub fn num_epochs(&self) -> usize {
self.records.len()
}
pub fn best_val_acc(&self) -> Option<(usize, f64)> {
self.records
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.val_acc
.partial_cmp(&b.val_acc)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, r)| (i, r.val_acc))
}
pub fn best_val_loss(&self) -> Option<(usize, f64)> {
self.records
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.val_loss
.partial_cmp(&b.val_loss)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, r)| (i, r.val_loss))
}
pub fn to_csv(&self) -> String {
let mut s = String::from("epoch,train_loss,val_loss,train_acc,val_acc,lr\n");
for r in &self.records {
s.push_str(&format!(
"{},{:.6},{:.6},{:.6},{:.6},{:.8}\n",
r.epoch, r.train_loss, r.val_loss, r.train_acc, r.val_acc, r.learning_rate
));
}
s
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum HpValue {
Float(f64),
Bool(bool),
Str(String),
}
impl HpValue {
pub fn as_float(&self) -> Option<f64> {
if let HpValue::Float(v) = self {
Some(*v)
} else {
None
}
}
pub fn as_bool(&self) -> Option<bool> {
if let HpValue::Bool(v) = self {
Some(*v)
} else {
None
}
}
pub fn as_str(&self) -> Option<&str> {
if let HpValue::Str(s) = self {
Some(s.as_str())
} else {
None
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct HyperparamConfig {
pub params: HashMap<String, HpValue>,
}
impl HyperparamConfig {
pub fn new() -> Self {
HyperparamConfig {
params: HashMap::new(),
}
}
pub fn set_float(&mut self, key: impl Into<String>, value: f64) {
self.params.insert(key.into(), HpValue::Float(value));
}
pub fn set_bool(&mut self, key: impl Into<String>, value: bool) {
self.params.insert(key.into(), HpValue::Bool(value));
}
pub fn set_str(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.params.insert(key.into(), HpValue::Str(value.into()));
}
pub fn get_float(&self, key: &str) -> Option<f64> {
self.params.get(key)?.as_float()
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.params.get(key)?.as_bool()
}
pub fn get_str(&self, key: &str) -> Option<&str> {
self.params.get(key)?.as_str()
}
pub fn to_json(&self) -> String {
let mut parts: Vec<String> = Vec::new();
let mut keys: Vec<&String> = self.params.keys().collect();
keys.sort();
for k in keys {
let v_str = match &self.params[k] {
HpValue::Float(f) => format!("{f}"),
HpValue::Bool(b) => format!("{b}"),
HpValue::Str(s) => format!("\"{}\"", s.replace('"', "\\\"")),
};
parts.push(format!("\"{}\":{}", k.replace('"', "\\\""), v_str));
}
format!("{{{}}}", parts.join(","))
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct CheckpointMeta {
pub epoch: usize,
pub val_loss: f64,
pub val_acc: f64,
pub train_time_secs: f64,
pub architecture: String,
pub framework_version: String,
}
impl CheckpointMeta {
pub fn to_text(&self) -> String {
format!(
"epoch={}\nval_loss={:.8}\nval_acc={:.8}\ntrain_time_secs={:.3}\narchitecture={}\nframework_version={}\n",
self.epoch,
self.val_loss,
self.val_acc,
self.train_time_secs,
self.architecture,
self.framework_version
)
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct ModelCheckpoint {
pub state: StateDict,
pub meta: CheckpointMeta,
pub hparams: HyperparamConfig,
}
impl ModelCheckpoint {
pub fn new(state: StateDict, meta: CheckpointMeta, hparams: HyperparamConfig) -> Self {
ModelCheckpoint {
state,
meta,
hparams,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
let sb = self.state.to_bytes();
buf.extend_from_slice(&(sb.len() as u64).to_le_bytes());
buf.extend_from_slice(&sb);
let mt = self.meta.to_text();
let mb = mt.as_bytes();
buf.extend_from_slice(&(mb.len() as u64).to_le_bytes());
buf.extend_from_slice(mb);
let hp = self.hparams.to_json();
let hb = hp.as_bytes();
buf.extend_from_slice(&(hb.len() as u64).to_le_bytes());
buf.extend_from_slice(hb);
buf
}
pub fn byte_size(&self) -> usize {
self.to_bytes().len()
}
}
#[allow(dead_code)]
pub fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return vec![];
}
let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum < 1e-15 {
vec![1.0 / logits.len() as f64; logits.len()]
} else {
exps.iter().map(|e| e / sum).collect()
}
}
#[allow(dead_code)]
pub fn cross_entropy_loss(probs: &[f64], targets: &[f64]) -> f64 {
probs
.iter()
.zip(targets)
.map(|(&p, &t)| -t * (p.max(1e-15)).ln())
.sum()
}
#[allow(dead_code)]
pub fn argmax(values: &[f64]) -> usize {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
#[allow(dead_code)]
pub fn mse(predictions: &[f64], targets: &[f64]) -> f64 {
if predictions.is_empty() {
return 0.0;
}
let n = predictions.len().min(targets.len()) as f64;
predictions
.iter()
.zip(targets)
.map(|(&p, &t)| {
let d = p - t;
d * d
})
.sum::<f64>()
/ n
}
#[allow(dead_code)]
pub fn mae(predictions: &[f64], targets: &[f64]) -> f64 {
if predictions.is_empty() {
return 0.0;
}
let n = predictions.len().min(targets.len()) as f64;
predictions
.iter()
.zip(targets)
.map(|(&p, &t)| (p - t).abs())
.sum::<f64>()
/ n
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_new_shape_mismatch_panics() {
let result = std::panic::catch_unwind(|| Tensor::new(vec![2, 3], vec![0.0; 5]));
assert!(result.is_err());
}
#[test]
fn test_tensor_zeros() {
let t = Tensor::zeros(vec![3, 4]);
assert_eq!(t.numel(), 12);
assert!(t.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_tensor_numel() {
let t = Tensor::new(vec![2, 3], vec![1.0; 6]);
assert_eq!(t.numel(), 6);
assert_eq!(t.ndim(), 2);
}
#[test]
fn test_tensor_sum_mean() {
let t = Tensor::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
assert!((t.sum() - 10.0).abs() < 1e-12);
assert!((t.mean() - 2.5).abs() < 1e-12);
}
#[test]
fn test_tensor_scale() {
let t = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
let t2 = t.scale(2.0);
assert!((t2.data[1] - 4.0).abs() < 1e-12);
}
#[test]
fn test_tensor_add() {
let a = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
let b = Tensor::new(vec![3], vec![4.0, 5.0, 6.0]);
let c = a.add(&b).unwrap();
assert!((c.data[2] - 9.0).abs() < 1e-12);
}
#[test]
fn test_tensor_add_shape_mismatch() {
let a = Tensor::new(vec![2], vec![1.0, 2.0]);
let b = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
assert!(a.add(&b).is_none());
}
#[test]
fn test_tensor_roundtrip_bytes() {
let t = Tensor::new(vec![2, 3], vec![1.0, -2.5, 0.0, 3.125, 1e10, -1e-5]);
let bytes = t.to_bytes();
let t2 = Tensor::from_bytes(&bytes).unwrap();
assert_eq!(t2.shape, t.shape);
for (a, b) in t.data.iter().zip(&t2.data) {
assert!((a - b).abs() < 1e-15);
}
}
#[test]
fn test_tensor_from_bytes_empty_is_none() {
assert!(Tensor::from_bytes(&[]).is_none());
}
#[test]
fn test_dense_layer_param_count() {
let layer = DenseLayer::new("fc1", 4, 3, "relu");
assert_eq!(layer.param_count(), 15);
}
#[test]
fn test_dense_layer_forward_zero_weights() {
let layer = DenseLayer::new("fc", 3, 2, "linear");
let input = vec![1.0, 2.0, 3.0];
let out = layer.forward(&input);
assert_eq!(out.len(), 2);
for v in &out {
assert!(v.abs() < 1e-12);
}
}
#[test]
fn test_dense_layer_activation_relu() {
assert!((apply_activation(-5.0, "relu")).abs() < 1e-12);
assert!((apply_activation(3.0, "relu") - 3.0).abs() < 1e-12);
}
#[test]
fn test_dense_layer_activation_sigmoid() {
let v = apply_activation(0.0, "sigmoid");
assert!((v - 0.5).abs() < 1e-12);
}
#[test]
fn test_dense_layer_activation_tanh() {
let v = apply_activation(0.0, "tanh");
assert!(v.abs() < 1e-12);
}
#[test]
fn test_model_weights_add_and_get() {
let mut model = ModelWeights::new();
model.add_layer(DenseLayer::new("l1", 4, 8, "relu"));
model.add_layer(DenseLayer::new("l2", 8, 2, "sigmoid"));
assert_eq!(model.layers.len(), 2);
assert!(model.get_layer("l1").is_some());
assert!(model.get_layer("l3").is_none());
}
#[test]
fn test_model_weights_total_params() {
let mut model = ModelWeights::new();
model.add_layer(DenseLayer::new("l1", 4, 3, "relu")); model.add_layer(DenseLayer::new("l2", 3, 2, "linear")); assert_eq!(model.total_params(), 23);
}
#[test]
fn test_model_weights_to_bytes_nonempty() {
let mut model = ModelWeights::new();
model.add_layer(DenseLayer::new("l1", 2, 2, "relu"));
let bytes = model.to_bytes();
assert!(!bytes.is_empty());
}
#[test]
fn test_state_dict_insert_and_get() {
let mut sd = StateDict::new();
sd.insert("w1", Tensor::zeros(vec![4, 4]));
assert_eq!(sd.len(), 1);
assert_eq!(sd.get("w1").unwrap().numel(), 16);
}
#[test]
fn test_state_dict_total_params() {
let mut sd = StateDict::new();
sd.insert("a", Tensor::zeros(vec![3, 3]));
sd.insert("b", Tensor::zeros(vec![3]));
assert_eq!(sd.total_params(), 12);
}
#[test]
fn test_state_dict_roundtrip() {
let mut sd = StateDict::new();
sd.insert("w", Tensor::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]));
sd.insert("b", Tensor::new(vec![2], vec![0.5, -0.5]));
let bytes = sd.to_bytes();
let sd2 = StateDict::from_bytes(&bytes).unwrap();
assert_eq!(sd2.len(), 2);
let w = sd2.get("w").unwrap();
assert!((w.data[3] - 4.0).abs() < 1e-12);
}
#[test]
fn test_onnx_graph_node_count() {
let mut g = OnnxLikeGraph::new("test_model");
g.add_node(OnnxNode::new(
"n0",
"MatMul",
vec!["x".into(), "w0".into()],
vec!["h0".into()],
));
g.add_node(OnnxNode::new(
"n1",
"Relu",
vec!["h0".into()],
vec!["h1".into()],
));
assert_eq!(g.node_count(), 2);
assert_eq!(g.count_op("Relu"), 1);
}
#[test]
fn test_onnx_graph_topological_valid() {
let mut g = OnnxLikeGraph::new("model");
g.inputs.push("x".into());
g.add_initializer("w0", Tensor::zeros(vec![4, 4]));
g.add_node(OnnxNode::new(
"mm",
"MatMul",
vec!["x".into(), "w0".into()],
vec!["y".into()],
));
g.add_node(OnnxNode::new(
"act",
"Relu",
vec!["y".into()],
vec!["z".into()],
));
assert!(g.is_topologically_valid());
}
#[test]
fn test_onnx_graph_topological_invalid() {
let mut g = OnnxLikeGraph::new("model");
g.inputs.push("x".into());
g.add_node(OnnxNode::new(
"act",
"Relu",
vec!["undefined".into()],
vec!["z".into()],
));
assert!(!g.is_topologically_valid());
}
#[test]
fn test_dataset_len_and_features() {
let mut ds = Dataset::new();
ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
ds.push(DataRow::labelled(vec![3.0, 4.0], 1));
assert_eq!(ds.len(), 2);
assert_eq!(ds.num_features(), 2);
}
#[test]
fn test_dataset_shuffle_changes_order() {
let mut ds = Dataset::new();
for i in 0..20 {
ds.push(DataRow::labelled(vec![i as f64], 0));
}
let original: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
ds.shuffle(42);
let shuffled: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
assert_ne!(original, shuffled);
}
#[test]
fn test_dataset_train_val_split() {
let mut ds = Dataset::new();
for i in 0..100 {
ds.push(DataRow::labelled(vec![i as f64], 0));
}
let (train, val) = ds.train_val_split(0.2);
assert_eq!(train.len(), 80);
assert_eq!(val.len(), 20);
}
#[test]
fn test_dataset_feature_stats() {
let mut ds = Dataset::new();
ds.push(DataRow::labelled(vec![0.0, 10.0], 0));
ds.push(DataRow::labelled(vec![2.0, 10.0], 1));
let (means, _stds) = ds.feature_stats();
assert!((means[0] - 1.0).abs() < 1e-12);
assert!((means[1] - 10.0).abs() < 1e-12);
}
#[test]
fn test_normalization_zscore() {
let mut ds = Dataset::new();
ds.push(DataRow::labelled(vec![0.0], 0));
ds.push(DataRow::labelled(vec![2.0], 0));
let norm = NormalizationParams::from_dataset(&ds);
let z = norm.normalize_zscore(&[1.0]);
assert!(z[0].abs() < 1e-10);
}
#[test]
fn test_normalization_minmax() {
let mut ds = Dataset::new();
ds.push(DataRow::labelled(vec![0.0], 0));
ds.push(DataRow::labelled(vec![10.0], 0));
let norm = NormalizationParams::from_dataset(&ds);
let v = norm.normalize_minmax(&[5.0]);
assert!((v[0] - 0.5).abs() < 1e-12);
}
#[test]
fn test_normalization_bytes_nonempty() {
let mut ds = Dataset::new();
ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
let norm = NormalizationParams::from_dataset(&ds);
assert!(!norm.to_bytes().is_empty());
}
#[test]
fn test_label_encoder_fit_and_encode() {
let enc = LabelEncoder::fit(vec!["cat".into(), "dog".into(), "bird".into()]);
assert_eq!(enc.num_classes(), 3);
let i = enc.encode("dog").unwrap();
assert_eq!(enc.decode(i), Some("dog"));
}
#[test]
fn test_label_encoder_one_hot() {
let enc = LabelEncoder::fit(vec!["a".into(), "b".into(), "c".into()]);
let oh = enc.one_hot(enc.encode("b").unwrap());
assert_eq!(oh.iter().filter(|&&v| v == 1.0).count(), 1);
assert!((oh.iter().sum::<f64>() - 1.0).abs() < 1e-12);
}
#[test]
fn test_label_encoder_unknown_returns_none() {
let enc = LabelEncoder::fit(vec!["a".into()]);
assert!(enc.encode("z").is_none());
}
#[test]
fn test_confusion_matrix_accuracy() {
let mut cm = ConfusionMatrix::new(2);
cm.record(0, 0);
cm.record(0, 0);
cm.record(1, 1);
cm.record(1, 0); assert!((cm.accuracy() - 0.75).abs() < 1e-12);
}
#[test]
fn test_confusion_matrix_precision_recall() {
let mut cm = ConfusionMatrix::new(2);
cm.record(0, 0); cm.record(0, 1); cm.record(1, 0); cm.record(1, 1); let p = cm.precision(0);
let r = cm.recall(0);
assert!((p - 0.5).abs() < 1e-12);
assert!((r - 0.5).abs() < 1e-12);
}
#[test]
fn test_confusion_matrix_to_csv() {
let mut cm = ConfusionMatrix::new(2);
cm.record(0, 0);
cm.record(1, 1);
let csv = cm.to_csv();
assert!(csv.contains("class_0"));
assert!(csv.contains("class_1"));
}
#[test]
fn test_training_history_best_val_acc() {
let mut hist = TrainingHistory::new();
for e in 0..5 {
hist.push(EpochRecord {
epoch: e,
train_loss: 1.0 - e as f64 * 0.1,
val_loss: 1.0 - e as f64 * 0.08,
train_acc: e as f64 * 0.2,
val_acc: e as f64 * 0.18,
learning_rate: 0.001,
});
}
let (best_epoch, best_acc) = hist.best_val_acc().unwrap();
assert_eq!(best_epoch, 4);
assert!((best_acc - 0.72).abs() < 1e-10);
}
#[test]
fn test_training_history_to_csv() {
let mut hist = TrainingHistory::new();
hist.push(EpochRecord {
epoch: 0,
train_loss: 0.9,
val_loss: 0.85,
train_acc: 0.6,
val_acc: 0.62,
learning_rate: 0.01,
});
let csv = hist.to_csv();
assert!(csv.starts_with("epoch,"));
assert!(csv.contains("0,"));
}
#[test]
fn test_hyperparam_config_get_set() {
let mut cfg = HyperparamConfig::new();
cfg.set_float("lr", 0.001);
cfg.set_bool("dropout", true);
cfg.set_str("optimizer", "adam");
assert!((cfg.get_float("lr").unwrap() - 0.001).abs() < 1e-15);
assert!(cfg.get_bool("dropout").unwrap());
assert_eq!(cfg.get_str("optimizer").unwrap(), "adam");
}
#[test]
fn test_hyperparam_config_to_json() {
let mut cfg = HyperparamConfig::new();
cfg.set_float("lr", 0.01);
let json = cfg.to_json();
assert!(json.contains("lr"));
assert!(json.starts_with('{'));
assert!(json.ends_with('}'));
}
#[test]
fn test_checkpoint_byte_size_nonzero() {
let state = StateDict::new();
let meta = CheckpointMeta {
epoch: 10,
val_loss: 0.1,
val_acc: 0.95,
train_time_secs: 3600.0,
architecture: "MLP".into(),
framework_version: "0.1.0".into(),
};
let hparams = HyperparamConfig::new();
let ck = ModelCheckpoint::new(state, meta, hparams);
assert!(ck.byte_size() > 0);
}
#[test]
fn test_checkpoint_meta_to_text_contains_epoch() {
let meta = CheckpointMeta {
epoch: 42,
val_loss: 0.05,
val_acc: 0.98,
train_time_secs: 100.0,
architecture: "CNN".into(),
framework_version: "0.1.0".into(),
};
let text = meta.to_text();
assert!(text.contains("epoch=42"));
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0, 2.0, 3.0];
let probs = softmax(&logits);
let total: f64 = probs.iter().sum();
assert!((total - 1.0).abs() < 1e-12);
}
#[test]
fn test_softmax_max_has_highest_prob() {
let logits = vec![1.0, 5.0, 2.0];
let probs = softmax(&logits);
assert!(probs[1] > probs[0] && probs[1] > probs[2]);
}
#[test]
fn test_cross_entropy_perfect_prediction() {
let probs = vec![0.0, 1.0, 0.0];
let targets = vec![0.0, 1.0, 0.0];
let loss = cross_entropy_loss(&probs, &targets);
assert!(loss < 1e-10);
}
#[test]
fn test_argmax_basic() {
let v = vec![0.1, 0.7, 0.2];
assert_eq!(argmax(&v), 1);
}
#[test]
fn test_mse_zero() {
let p = vec![1.0, 2.0, 3.0];
let t = vec![1.0, 2.0, 3.0];
assert!(mse(&p, &t).abs() < 1e-12);
}
#[test]
fn test_mse_known() {
let p = vec![0.0, 0.0];
let t = vec![1.0, 1.0];
assert!((mse(&p, &t) - 1.0).abs() < 1e-12);
}
#[test]
fn test_mae_basic() {
let p = vec![0.0, 1.0, 2.0];
let t = vec![1.0, 1.0, 3.0];
let m = mae(&p, &t);
assert!((m - 2.0 / 3.0).abs() < 1e-12);
}
#[test]
fn test_apply_activation_leaky_relu() {
assert!((apply_activation(-1.0, "leaky_relu") - (-0.01)).abs() < 1e-12);
assert!((apply_activation(2.0, "leaky_relu") - 2.0).abs() < 1e-12);
}
#[test]
fn test_apply_activation_elu() {
let v = apply_activation(-1.0, "elu");
assert!(v < 0.0 && v > -1.0);
}
#[test]
fn test_lcg_rng_produces_different_values() {
let mut rng = LcgRng::new(1234);
let a = rng.next_u64();
let b = rng.next_u64();
assert_ne!(a, b);
}
}