use crate::domain::DomainId;
use crate::object::{Shape, Tensor};
#[derive(Debug, Clone)]
pub struct Parameter {
pub weight: Tensor<u16>,
pub m: Vec<f32>,
pub v: Vec<f32>,
pub step: u32,
}
impl Parameter {
pub fn new(weight: Tensor<u16>, m: Vec<f32>, v: Vec<f32>) -> Self {
let n = weight.data.len();
assert_eq!(m.len(), n, "Parameter m length mismatch");
assert_eq!(v.len(), n, "Parameter v length mismatch");
Self {
weight,
m,
v,
step: 0,
}
}
pub fn from_fp16_bits(weight_bits: Vec<u16>, shape: Shape) -> Self {
let n = weight_bits.len();
let tensor = Tensor::dense_cpu(DomainId::new("fp16"), shape, weight_bits);
Self::new(tensor, vec![0.0f32; n], vec![0.0f32; n])
}
pub fn from_f32(values: &[f32], shape: Shape) -> Self {
let bits: Vec<u16> = values.iter().map(|&v| f32_to_fp16_bits(v)).collect();
Self::from_fp16_bits(bits, shape)
}
pub fn len(&self) -> usize {
self.weight.data.len()
}
pub fn is_empty(&self) -> bool {
self.weight.data.is_empty()
}
pub fn weights_f32(&self) -> Vec<f32> {
self.weight
.data
.iter()
.map(|&b| fp16_bits_to_f32(b))
.collect()
}
}
pub fn f32_to_fp16_bits(value: f32) -> u16 {
crate::backend::f16_convert::f32_to_f16(value)
}
pub fn fp16_bits_to_f32(bits: u16) -> f32 {
crate::backend::f16_convert::f16_to_f32(bits)
}