1use crate::error::{IoError, IoResult};
44use scirs2_core::ndarray::Array1;
45use serde::{Deserialize, Serialize};
46use std::path::Path;
47use std::time::{Duration, SystemTime};
48use tokio::fs::File;
49use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
50use tracing::{debug, info};
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
54pub enum RecorderFormat {
55 #[default]
57 Binary,
58 Json,
60 Csv,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RecorderConfig {
67 pub path: String,
69
70 #[serde(default)]
72 pub format: RecorderFormat,
73
74 #[serde(default = "default_sample_rate")]
76 pub sample_rate: f32,
77
78 #[serde(default = "default_channels")]
80 pub channels: usize,
81
82 #[serde(default = "default_buffer_size")]
84 pub buffer_size: usize,
85
86 #[serde(default = "default_true")]
88 pub record_timestamps: bool,
89
90 #[serde(default)]
92 pub metadata: std::collections::HashMap<String, String>,
93}
94
95fn default_sample_rate() -> f32 {
96 44100.0
97}
98
99fn default_channels() -> usize {
100 1
101}
102
103fn default_buffer_size() -> usize {
104 1024
105}
106
107fn default_true() -> bool {
108 true
109}
110
111impl Default for RecorderConfig {
112 fn default() -> Self {
113 Self {
114 path: String::new(),
115 format: RecorderFormat::Binary,
116 sample_rate: 44100.0,
117 channels: 1,
118 buffer_size: 1024,
119 record_timestamps: true,
120 metadata: std::collections::HashMap::new(),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RecordedFrame {
128 pub samples: Vec<f32>,
130
131 pub timestamp: Option<f64>,
133
134 pub frame_number: usize,
136}
137
138pub struct StreamRecorder {
140 config: RecorderConfig,
141 writer: BufWriter<File>,
142 frame_count: usize,
143 start_time: SystemTime,
144 total_samples: usize,
145}
146
147impl StreamRecorder {
148 pub async fn new(config: RecorderConfig) -> IoResult<Self> {
150 let file = File::create(&config.path)
151 .await
152 .map_err(|e| IoError::WriteFailed(format!("Failed to create recording: {}", e)))?;
153
154 let mut writer = BufWriter::new(file);
155
156 match config.format {
158 RecorderFormat::Binary => {
159 writer
161 .write_all(b"ZHREC001")
162 .await
163 .map_err(|e| IoError::WriteFailed(format!("Failed to write header: {}", e)))?;
164
165 let config_json = serde_json::to_string(&config).map_err(|e| {
167 IoError::WriteFailed(format!("Failed to serialize config: {}", e))
168 })?;
169 let config_len = config_json.len() as u32;
170 writer
171 .write_all(&config_len.to_le_bytes())
172 .await
173 .map_err(|e| {
174 IoError::WriteFailed(format!("Failed to write config length: {}", e))
175 })?;
176 writer
177 .write_all(config_json.as_bytes())
178 .await
179 .map_err(|e| IoError::WriteFailed(format!("Failed to write config: {}", e)))?;
180 }
181 RecorderFormat::Json => {
182 let header = serde_json::json!({
184 "format": "kizzasi-io-recording",
185 "version": "1.0",
186 "config": config,
187 "frames": []
188 });
189 let header_str = serde_json::to_string_pretty(&header).map_err(|e| {
190 IoError::WriteFailed(format!("Failed to write JSON header: {}", e))
191 })?;
192 writer
193 .write_all(header_str.as_bytes())
194 .await
195 .map_err(|e| IoError::WriteFailed(format!("Failed to write header: {}", e)))?;
196 }
197 RecorderFormat::Csv => {
198 let header = if config.record_timestamps {
200 "frame,timestamp,samples\n"
201 } else {
202 "frame,samples\n"
203 };
204 writer.write_all(header.as_bytes()).await.map_err(|e| {
205 IoError::WriteFailed(format!("Failed to write CSV header: {}", e))
206 })?;
207 }
208 }
209
210 info!("Stream recorder created: {:?}", config.path);
211
212 Ok(Self {
213 config,
214 writer,
215 frame_count: 0,
216 start_time: SystemTime::now(),
217 total_samples: 0,
218 })
219 }
220
221 pub async fn record_samples(
223 &mut self,
224 samples: &[f32],
225 timestamp: Option<f64>,
226 ) -> IoResult<()> {
227 let frame = RecordedFrame {
228 samples: samples.to_vec(),
229 timestamp: timestamp.or_else(|| {
230 if self.config.record_timestamps {
231 Some(
232 self.start_time
233 .elapsed()
234 .unwrap_or(Duration::ZERO)
235 .as_secs_f64(),
236 )
237 } else {
238 None
239 }
240 }),
241 frame_number: self.frame_count,
242 };
243
244 self.write_frame(&frame).await?;
245 self.frame_count += 1;
246 self.total_samples += samples.len();
247
248 Ok(())
249 }
250
251 pub async fn record_array(&mut self, samples: &Array1<f32>) -> IoResult<()> {
253 let vec: Vec<f32> = samples.to_vec();
254 self.record_samples(&vec, None).await
255 }
256
257 async fn write_frame(&mut self, frame: &RecordedFrame) -> IoResult<()> {
259 match self.config.format {
260 RecorderFormat::Binary => {
261 let sample_count = frame.samples.len() as u32;
263 self.writer
264 .write_all(&sample_count.to_le_bytes())
265 .await
266 .map_err(|e| IoError::WriteFailed(format!("Failed to write frame: {}", e)))?;
267
268 if let Some(ts) = frame.timestamp {
269 self.writer
270 .write_all(&ts.to_le_bytes())
271 .await
272 .map_err(|e| {
273 IoError::WriteFailed(format!("Failed to write timestamp: {}", e))
274 })?;
275 }
276
277 for &sample in &frame.samples {
279 self.writer
280 .write_all(&sample.to_le_bytes())
281 .await
282 .map_err(|e| {
283 IoError::WriteFailed(format!("Failed to write sample: {}", e))
284 })?;
285 }
286 }
287 RecorderFormat::Json => {
288 let json = serde_json::to_string(&frame).map_err(|e| {
289 IoError::WriteFailed(format!("Failed to serialize frame: {}", e))
290 })?;
291 self.writer.write_all(json.as_bytes()).await.map_err(|e| {
292 IoError::WriteFailed(format!("Failed to write JSON frame: {}", e))
293 })?;
294 self.writer
295 .write_all(b"\n")
296 .await
297 .map_err(|e| IoError::WriteFailed(format!("Failed to write newline: {}", e)))?;
298 }
299 RecorderFormat::Csv => {
300 let samples_str = frame
301 .samples
302 .iter()
303 .map(|s| s.to_string())
304 .collect::<Vec<_>>()
305 .join(";");
306
307 let line = if let Some(ts) = frame.timestamp {
308 format!("{},{},{}\n", frame.frame_number, ts, samples_str)
309 } else {
310 format!("{},{}\n", frame.frame_number, samples_str)
311 };
312
313 self.writer.write_all(line.as_bytes()).await.map_err(|e| {
314 IoError::WriteFailed(format!("Failed to write CSV line: {}", e))
315 })?;
316 }
317 }
318
319 debug!(
320 "Recorded frame {}: {} samples",
321 frame.frame_number,
322 frame.samples.len()
323 );
324
325 Ok(())
326 }
327
328 pub async fn finalize(mut self) -> IoResult<()> {
330 self.writer
331 .flush()
332 .await
333 .map_err(|e| IoError::WriteFailed(format!("Failed to flush recording: {}", e)))?;
334
335 info!(
336 "Recording finalized: {} frames, {} samples",
337 self.frame_count, self.total_samples
338 );
339
340 Ok(())
341 }
342
343 pub fn frame_count(&self) -> usize {
345 self.frame_count
346 }
347
348 pub fn total_samples(&self) -> usize {
350 self.total_samples
351 }
352
353 pub async fn create_player(&self) -> IoResult<StreamPlayer> {
355 StreamPlayer::new(&self.config.path).await
356 }
357}
358
359pub struct StreamPlayer {
361 config: RecorderConfig,
362 reader: BufReader<File>,
363 frame_count: usize,
364 format: RecorderFormat,
365}
366
367impl StreamPlayer {
368 pub async fn new<P: AsRef<Path>>(path: P) -> IoResult<Self> {
370 let file = File::open(path)
371 .await
372 .map_err(|e| IoError::ReadFailed(format!("Failed to open recording: {}", e)))?;
373
374 let mut reader = BufReader::new(file);
375
376 let mut magic = [0u8; 8];
378 reader
379 .read_exact(&mut magic)
380 .await
381 .map_err(|e| IoError::ReadFailed(format!("Failed to read magic: {}", e)))?;
382
383 let (format, config) = if &magic == b"ZHREC001" {
384 let mut len_bytes = [0u8; 4];
386 reader
387 .read_exact(&mut len_bytes)
388 .await
389 .map_err(|e| IoError::ReadFailed(format!("Failed to read config length: {}", e)))?;
390 let config_len = u32::from_le_bytes(len_bytes) as usize;
391
392 let mut config_bytes = vec![0u8; config_len];
393 reader
394 .read_exact(&mut config_bytes)
395 .await
396 .map_err(|e| IoError::ReadFailed(format!("Failed to read config: {}", e)))?;
397
398 let config: RecorderConfig = serde_json::from_slice(&config_bytes)
399 .map_err(|e| IoError::ReadFailed(format!("Failed to parse config: {}", e)))?;
400
401 (RecorderFormat::Binary, config)
402 } else {
403 return Err(IoError::ReadFailed(
405 "Non-binary format playback not yet implemented".into(),
406 ));
407 };
408
409 info!("Stream player opened: {:?}", format);
410
411 Ok(Self {
412 config,
413 reader,
414 frame_count: 0,
415 format,
416 })
417 }
418
419 pub async fn next_frame(&mut self) -> IoResult<Option<RecordedFrame>> {
421 match self.format {
422 RecorderFormat::Binary => {
423 let mut count_bytes = [0u8; 4];
425 match self.reader.read_exact(&mut count_bytes).await {
426 Ok(_) => {}
427 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
428 Err(e) => {
429 return Err(IoError::ReadFailed(format!(
430 "Failed to read frame count: {}",
431 e
432 )))
433 }
434 }
435
436 let sample_count = u32::from_le_bytes(count_bytes) as usize;
437
438 let timestamp = if self.config.record_timestamps {
440 let mut ts_bytes = [0u8; 8];
441 self.reader.read_exact(&mut ts_bytes).await.map_err(|e| {
442 IoError::ReadFailed(format!("Failed to read timestamp: {}", e))
443 })?;
444 Some(f64::from_le_bytes(ts_bytes))
445 } else {
446 None
447 };
448
449 let mut samples = Vec::with_capacity(sample_count);
451 for _ in 0..sample_count {
452 let mut sample_bytes = [0u8; 4];
453 self.reader
454 .read_exact(&mut sample_bytes)
455 .await
456 .map_err(|e| {
457 IoError::ReadFailed(format!("Failed to read sample: {}", e))
458 })?;
459 samples.push(f32::from_le_bytes(sample_bytes));
460 }
461
462 let frame = RecordedFrame {
463 samples,
464 timestamp,
465 frame_number: self.frame_count,
466 };
467
468 self.frame_count += 1;
469 debug!("Read frame {}", frame.frame_number);
470
471 Ok(Some(frame))
472 }
473 _ => Err(IoError::ReadFailed("Format not supported yet".into())),
474 }
475 }
476
477 pub async fn seek_to_frame(&mut self, _frame_number: usize) -> IoResult<()> {
479 Err(IoError::ReadFailed("Seeking not yet implemented".into()))
481 }
482
483 pub fn config(&self) -> &RecorderConfig {
485 &self.config
486 }
487
488 pub fn frame_number(&self) -> usize {
490 self.frame_count
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use std::env;
498
499 #[tokio::test]
500 async fn test_recorder_binary() {
501 let temp_dir = env::temp_dir();
502 let path = temp_dir.join("test_recording.bin");
503
504 let config = RecorderConfig {
505 path: path.to_string_lossy().to_string(),
506 format: RecorderFormat::Binary,
507 sample_rate: 44100.0,
508 channels: 1,
509 buffer_size: 1024,
510 record_timestamps: true,
511 metadata: std::collections::HashMap::new(),
512 };
513
514 let mut recorder = StreamRecorder::new(config).await.unwrap();
516 recorder
517 .record_samples(&[1.0, 2.0, 3.0], None)
518 .await
519 .unwrap();
520 recorder
521 .record_samples(&[4.0, 5.0, 6.0], None)
522 .await
523 .unwrap();
524 recorder.finalize().await.unwrap();
525
526 let mut player = StreamPlayer::new(&path).await.unwrap();
528
529 let frame1 = player.next_frame().await.unwrap().unwrap();
530 assert_eq!(frame1.samples, vec![1.0, 2.0, 3.0]);
531 assert_eq!(frame1.frame_number, 0);
532
533 let frame2 = player.next_frame().await.unwrap().unwrap();
534 assert_eq!(frame2.samples, vec![4.0, 5.0, 6.0]);
535 assert_eq!(frame2.frame_number, 1);
536
537 assert!(player.next_frame().await.unwrap().is_none());
538
539 std::fs::remove_file(path).ok();
541 }
542
543 #[tokio::test]
544 async fn test_recorder_array() {
545 let temp_dir = env::temp_dir();
546 let path = temp_dir.join("test_recording_array.bin");
547
548 let config = RecorderConfig {
549 path: path.to_string_lossy().to_string(),
550 format: RecorderFormat::Binary,
551 sample_rate: 48000.0,
552 channels: 2,
553 buffer_size: 512,
554 record_timestamps: false,
555 metadata: std::collections::HashMap::new(),
556 };
557
558 let mut recorder = StreamRecorder::new(config).await.unwrap();
559
560 let samples = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
561 recorder.record_array(&samples).await.unwrap();
562 recorder.finalize().await.unwrap();
563
564 std::fs::remove_file(path).ok();
566 }
567}