use std::{
fmt::Debug,
ops::Deref,
};
use burn::module::{
Param,
ParamId,
Parameter,
};
#[derive(Debug, Clone, PartialEq)]
pub struct ParamDesc<T>
where
T: Debug + Clone + Send + PartialEq,
{
param_id: ParamId,
data: T,
}
impl<T> Deref for ParamDesc<T>
where
T: Debug + Clone + Send + PartialEq,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<T> AsRef<T> for ParamDesc<T>
where
T: Debug + Clone + Send + PartialEq,
{
fn as_ref(&self) -> &T {
&self.data
}
}
impl<T> ParamDesc<T>
where
T: Debug + Clone + Send + PartialEq,
{
pub fn new(
param_id: ParamId,
param: T,
) -> Self {
Self {
param_id,
data: param,
}
}
pub fn param_id(&self) -> ParamId {
self.param_id
}
}
impl<T, D> From<&Param<T>> for ParamDesc<D>
where
T: Parameter,
D: for<'a> From<&'a T> + Debug + Clone + Send + PartialEq + 'static,
{
fn from(param: &Param<T>) -> Self {
Self::new(param.id, D::from(¶m.val()))
}
}
#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use burn::{
nn::LinearConfig,
prelude::Shape,
tensor::DType,
};
use super::{
super::{
TensorDesc,
TensorKindDesc,
},
*,
};
use crate::support::testing::SetupTestBackend;
type B = SetupTestBackend;
#[test]
fn test_from_param() {
let device = Default::default();
let linear = LinearConfig::new(2, 3).init::<B>(&device);
let param = linear.weight;
let param_desc: ParamDesc<TensorDesc> = (¶m).into();
assert_eq!(param_desc.param_id(), param.id);
assert_eq!(param_desc.kind(), TensorKindDesc::Float);
assert_eq!(param_desc.dtype(), DType::F32);
assert_eq!(param_desc.shape(), &Shape::new([2, 3]));
assert_eq!(param_desc.rank(), 2);
assert_eq!(param_desc.size_estimate(), DType::F32.size() * 2 * 3);
}
}