autd3_modulation_audio_file/
csv.rs

1use autd3_core::derive::*;
2
3use std::{fmt::Debug, fs::File, path::Path};
4
5use crate::error::AudioFileError;
6
7/// The option of [`Csv`].
8#[derive(Debug, Clone)]
9pub struct CsvOption {
10    /// The delimiter of CSV file.
11    pub delimiter: u8,
12    /// Whether the CSV file has headers.
13    pub has_headers: bool,
14}
15
16impl Default for CsvOption {
17    fn default() -> Self {
18        Self {
19            delimiter: b',',
20            has_headers: false,
21        }
22    }
23}
24
25/// [`Modulation`] from CSV data.
26#[derive(Modulation, Debug, Clone)]
27pub struct Csv {
28    sampling_config: SamplingConfig,
29    buffer: Vec<u8>,
30}
31
32impl Csv {
33    /// Create a new [`Csv`].
34    pub fn new<P, Config>(
35        path: P,
36        sampling_config: Config,
37        option: CsvOption,
38    ) -> Result<Self, AudioFileError>
39    where
40        P: AsRef<Path> + Clone + Debug,
41        Config: Into<SamplingConfig> + Debug + Copy,
42    {
43        let f = File::open(&path)?;
44        let mut rdr = csv::ReaderBuilder::new()
45            .has_headers(option.has_headers)
46            .delimiter(option.delimiter)
47            .from_reader(f);
48        let buffer = rdr
49            .records()
50            .map(|r| {
51                let record = r?;
52                csv::Result::Ok(
53                    record
54                        .iter()
55                        .map(|x| x.trim().to_owned())
56                        .collect::<Vec<_>>(),
57                )
58            })
59            .collect::<csv::Result<Vec<_>>>()?
60            .into_iter()
61            .flatten()
62            .map(|s| s.parse::<u8>())
63            .collect::<Result<Vec<u8>, _>>()?;
64        Ok(Self {
65            sampling_config: sampling_config.into(),
66            buffer,
67        })
68    }
69
70    /// Write a [`Modulation`] into a writer as CSV format.
71    pub fn write<Writer: std::io::Write, M: Modulation>(
72        m: M,
73        writer: Writer,
74        option: CsvOption,
75    ) -> Result<(), AudioFileError> {
76        let sample_rate = m.sampling_config().freq()?.hz();
77        let buffer = m.calc(&FirmwareLimits {
78            mod_buf_size_max: u32::MAX,
79            ..FirmwareLimits::unused()
80        })?;
81        let mut writer = csv::WriterBuilder::new()
82            .delimiter(option.delimiter)
83            .from_writer(writer);
84        if option.has_headers {
85            writer.write_record(&[format!("Buffer (sampling rate = {sample_rate} Hz)")])?;
86        }
87        buffer
88            .into_iter()
89            .try_for_each(|b| writer.write_record(&[b.to_string()]))?;
90        Ok(())
91    }
92}
93
94impl Modulation for Csv {
95    fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
96        Ok(self.buffer)
97    }
98
99    fn sampling_config(&self) -> SamplingConfig {
100        self.sampling_config
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use autd3_core::common::{Freq, Hz};
107
108    use super::*;
109    use std::io::Write;
110
111    fn create_csv(path: impl AsRef<Path>, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
112        let mut f = File::create(path)?;
113        data.iter().try_for_each(|d| writeln!(f, "{d}"))?;
114        Ok(())
115    }
116
117    #[rstest::rstest]
118    #[case(vec![0xFF, 0x7F, 0x00], 4000. * Hz)]
119    fn new(
120        #[case] data: Vec<u8>,
121        #[case] sample_rate: Freq<f32>,
122    ) -> Result<(), Box<dyn std::error::Error>> {
123        let dir = tempfile::tempdir().unwrap();
124        let path = dir.path().join("tmp.csv");
125        create_csv(&path, &data)?;
126
127        let m = Csv::new(path, sample_rate, CsvOption::default())?;
128        assert_eq!(sample_rate.hz(), m.sampling_config().freq()?.hz());
129        assert_eq!(data, *m.calc(&FirmwareLimits::unused())?);
130
131        Ok(())
132    }
133
134    #[rstest::rstest]
135    #[case("Buffer (sampling rate = 4000 Hz)\n0\n128\n255\n", true)]
136    #[case("0\n128\n255\n", false)]
137    fn write(
138        #[case] expect: &str,
139        #[case] has_headers: bool,
140    ) -> Result<(), Box<dyn std::error::Error>> {
141        #[derive(Clone)]
142        struct TestMod {
143            data: Vec<u8>,
144            rate: f32,
145        }
146        impl Modulation for TestMod {
147            fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
148                Ok(self.data)
149            }
150            fn sampling_config(&self) -> SamplingConfig {
151                SamplingConfig::new(self.rate * Hz)
152            }
153        }
154
155        let m = TestMod {
156            data: vec![0u8, 128u8, 255u8],
157            rate: 4000.0,
158        };
159        let mut wtr = Vec::new();
160        Csv::write(
161            m,
162            &mut wtr,
163            CsvOption {
164                delimiter: b',',
165                has_headers,
166            },
167        )?;
168
169        assert_eq!(expect, String::from_utf8(wtr)?);
170
171        Ok(())
172    }
173
174    #[test]
175    fn not_exists() -> Result<(), Box<dyn std::error::Error>> {
176        assert!(
177            Csv::new(
178                Path::new("not_exists.csv"),
179                4000. * Hz,
180                CsvOption::default(),
181            )
182            .is_err()
183        );
184        Ok(())
185    }
186}