use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CodebookSize(pub usize);
impl CodebookSize {
pub fn new(size: usize) -> Self {
Self(size)
}
pub fn get(self) -> usize {
self.0
}
}
impl fmt::Display for CodebookSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} codes", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EmbedDim(pub usize);
impl EmbedDim {
pub fn new(dim: usize) -> Self {
Self(dim)
}
pub fn get(self) -> usize {
self.0
}
}
impl fmt::Display for EmbedDim {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}D", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SignalLength(pub usize);
impl SignalLength {
pub fn new(len: usize) -> Self {
Self(len)
}
pub fn get(self) -> usize {
self.0
}
}
impl fmt::Display for SignalLength {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} samples", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BatchSize(pub usize);
impl BatchSize {
pub fn new(size: usize) -> Self {
Self(size)
}
pub fn get(self) -> usize {
self.0
}
}
impl fmt::Display for BatchSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "batch({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CodebookIndex(pub usize);
impl CodebookIndex {
pub fn new(idx: usize) -> Self {
Self(idx)
}
pub fn get(self) -> usize {
self.0
}
pub fn is_valid_for(&self, codebook_size: CodebookSize) -> bool {
self.0 < codebook_size.get()
}
}
impl fmt::Display for CodebookIndex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "code[{}]", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BitDepth(u8);
impl BitDepth {
pub fn new(bits: u8) -> Result<Self, String> {
if bits == 0 || bits > 16 {
return Err(format!("Bit depth must be 1-16, got {}", bits));
}
Ok(Self(bits))
}
pub fn bits_8() -> Self {
Self(8)
}
pub fn bits_16() -> Self {
Self(16)
}
pub fn get(self) -> u8 {
self.0
}
pub fn num_levels(self) -> usize {
1 << self.0
}
}
impl fmt::Display for BitDepth {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}-bit", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct LearningRate(pub f32);
impl LearningRate {
pub fn new(rate: f32) -> Result<Self, String> {
if rate <= 0.0 || !rate.is_finite() {
return Err(format!(
"Learning rate must be positive and finite, got {}",
rate
));
}
Ok(Self(rate))
}
pub fn default_rate() -> Self {
Self(0.001)
}
pub fn get(self) -> f32 {
self.0
}
}
impl fmt::Display for LearningRate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "lr={}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Epochs(pub usize);
impl Epochs {
pub fn new(epochs: usize) -> Self {
Self(epochs)
}
pub fn get(self) -> usize {
self.0
}
}
impl fmt::Display for Epochs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} epochs", self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_codebook_size() {
let size = CodebookSize::new(256);
assert_eq!(size.get(), 256);
assert_eq!(format!("{}", size), "256 codes");
}
#[test]
fn test_embed_dim() {
let dim = EmbedDim::new(128);
assert_eq!(dim.get(), 128);
assert_eq!(format!("{}", dim), "128D");
}
#[test]
fn test_codebook_index_validation() {
let idx = CodebookIndex::new(10);
let size = CodebookSize::new(256);
assert!(idx.is_valid_for(size));
let invalid_idx = CodebookIndex::new(300);
assert!(!invalid_idx.is_valid_for(size));
}
#[test]
fn test_bit_depth() {
let bd8 = BitDepth::new(8).unwrap();
assert_eq!(bd8.get(), 8);
assert_eq!(bd8.num_levels(), 256);
let bd16 = BitDepth::new(16).unwrap();
assert_eq!(bd16.num_levels(), 65536);
assert!(BitDepth::new(0).is_err());
assert!(BitDepth::new(17).is_err());
}
#[test]
fn test_learning_rate() {
let lr = LearningRate::new(0.01).unwrap();
assert_eq!(lr.get(), 0.01);
assert!(LearningRate::new(0.0).is_err());
assert!(LearningRate::new(-0.1).is_err());
assert!(LearningRate::new(f32::NAN).is_err());
}
#[test]
fn test_epochs() {
let epochs = Epochs::new(100);
assert_eq!(epochs.get(), 100);
assert_eq!(format!("{}", epochs), "100 epochs");
}
}