use crate::{
Dataset, HuggingfaceDatasetLoader, SqliteDataset,
transform::{Mapper, MapperDataset},
};
use hound::WavReader;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, FromRepr};
type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;
#[allow(missing_docs)]
#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]
pub enum SpeechCommandClass {
Yes = 0,
No = 1,
Up = 2,
Down = 3,
Left = 4,
Right = 5,
On = 6,
Off = 7,
Stop = 8,
Go = 9,
Zero = 10,
One = 11,
Two = 12,
Three = 13,
Four = 14,
Five = 15,
Six = 16,
Seven = 17,
Eight = 18,
Nine = 19,
Bed = 20,
Bird = 21,
Cat = 22,
Dog = 23,
Happy = 24,
House = 25,
Marvin = 26,
Sheila = 27,
Tree = 28,
Wow = 29,
Backward = 30,
Forward = 31,
Follow = 32,
Learn = 33,
Visual = 34,
Silence = 35,
Other = 36,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SpeechItemRaw {
pub audio_bytes: Vec<u8>,
pub label: usize,
pub is_unknown: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SpeechItem {
pub audio_samples: Vec<f32>,
pub sample_rate: usize,
pub label: SpeechCommandClass,
}
pub struct SpeechCommandsDataset {
dataset: MappedDataset,
}
impl SpeechCommandsDataset {
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<SpeechItemRaw> =
HuggingfaceDatasetLoader::new("speech_commands")
.with_subset("v0.02")
.dataset(split)
.unwrap();
let dataset = MapperDataset::new(dataset, ConvertSamples);
Self { dataset }
}
pub fn train() -> Self {
Self::new("train")
}
pub fn test() -> Self {
Self::new("test")
}
pub fn validation() -> Self {
Self::new("validation")
}
pub fn num_classes() -> usize {
SpeechCommandClass::COUNT
}
}
impl Dataset<SpeechItem> for SpeechCommandsDataset {
fn get(&self, index: usize) -> Option<SpeechItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
struct ConvertSamples;
impl ConvertSamples {
fn to_speechcommandclass(label: usize) -> SpeechCommandClass {
SpeechCommandClass::from_repr(label).unwrap()
}
fn to_audiosamples(bytes: &Vec<u8>) -> (Vec<f32>, usize) {
let reader = WavReader::new(bytes.as_slice()).unwrap();
let spec = reader.spec();
let max_value = (1 << (spec.bits_per_sample - 1)) as f32;
let sample_rate = spec.sample_rate as usize;
let audio_samples: Vec<f32> = reader
.into_samples::<i32>()
.filter_map(Result::ok)
.map(|sample| sample as f32 / max_value)
.collect();
(audio_samples, sample_rate)
}
}
impl Mapper<SpeechItemRaw, SpeechItem> for ConvertSamples {
fn map(&self, item: &SpeechItemRaw) -> SpeechItem {
let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes);
let label = Self::to_speechcommandclass(item.label);
SpeechItem {
audio_samples,
sample_rate,
label,
}
}
}