Skip to main content

zk_audio/
builder.rs

1use crate::contracts::{AudioSink, MetricsCollector, ProcessorBuildRequest, ProcessorFactory};
2use crate::core::{AudioProfile, AudioResult, AudioSpec, CaptureDiagnostics, DelayEffectConfig};
3use crate::factory::ProfileProcessorFactory;
4use crate::metrics::LevelMetrics;
5use crate::mic_sim::{MicrophoneSimConfig, MicrophoneSimulatorFactory};
6use crate::pipeline::{AudioPipeline, WavSink};
7use std::collections::BTreeMap;
8use std::path::PathBuf;
9
10#[derive(Debug, Clone)]
11pub struct PipelineBuildRequest {
12    pub output_path: PathBuf,
13    pub profile: AudioProfile,
14    pub sample_rate: u32,
15    pub device_channels: u16,
16    pub device_name: Option<String>,
17    pub gain_db: f32,
18    pub limiter_threshold: f32,
19    pub high_pass_hz: f32,
20    pub noise_suppression_amount: f32,
21    pub noise_calibration_ms: u32,
22    pub delay_effect: Option<DelayEffectConfig>,
23    pub stage_overrides: BTreeMap<String, crate::core::ProcessorOverrideMode>,
24    pub microphone_sim: MicrophoneSimConfig,
25}
26
27pub struct NativePipelineBuilder<F = ProfileProcessorFactory> {
28    processor_factory: F,
29}
30
31impl NativePipelineBuilder<ProfileProcessorFactory> {
32    pub fn new() -> Self {
33        Self {
34            processor_factory: ProfileProcessorFactory::new(),
35        }
36    }
37}
38
39impl Default for NativePipelineBuilder<ProfileProcessorFactory> {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<F> NativePipelineBuilder<F>
46where
47    F: ProcessorFactory,
48{
49    #[cfg(test)]
50    pub fn with_factory(processor_factory: F) -> Self {
51        Self { processor_factory }
52    }
53
54    pub fn build(&self, request: PipelineBuildRequest) -> AudioResult<AudioPipeline> {
55        let output_spec = AudioSpec {
56            sample_rate: request.sample_rate,
57            channels: 1,
58        };
59        let processors = self
60            .processor_factory
61            .build_processors(ProcessorBuildRequest {
62                profile: request.profile,
63                gain_db: request.gain_db,
64                limiter_threshold: request.limiter_threshold,
65                high_pass_hz: request.high_pass_hz,
66                noise_suppression_amount: request.noise_suppression_amount,
67                noise_calibration_ms: request.noise_calibration_ms,
68                delay_effect: request.delay_effect,
69                stage_overrides: request.stage_overrides.clone(),
70            })?;
71        let processor_names = processors
72            .iter()
73            .map(|processor| processor.name().to_string())
74            .collect::<Vec<_>>();
75        let microphone_sim_processors =
76            MicrophoneSimulatorFactory::new().build_processors(request.microphone_sim)?;
77        let microphone_sim_processor_names = microphone_sim_processors
78            .iter()
79            .map(|processor| processor.name().to_string())
80            .collect::<Vec<_>>();
81        let microphone_sim_model = request
82            .microphone_sim
83            .active_model()
84            .map(|model| model.as_str().to_string());
85        let mut notes = request
86            .delay_effect
87            .map(|effect| vec![format!("delay_effect={}", effect.preset.as_str())])
88            .unwrap_or_default();
89        if let Some(model) = &microphone_sim_model {
90            notes.push(format!("microphone_sim={}", model));
91        }
92        notes.push("pipeline_order=microphone_sim->voice_processing->sink".to_string());
93
94        self.build_with_parts(
95            output_spec,
96            CaptureDiagnostics {
97                backend: "native".to_string(),
98                profile: request.profile.as_str().to_string(),
99                profile_base: Some(request.profile.as_str().to_string()),
100                device_name: request.device_name,
101                sample_rate: Some(request.sample_rate),
102                channels: Some(request.device_channels),
103                processor_names,
104                processor_stage_overrides: request
105                    .stage_overrides
106                    .iter()
107                    .map(|(stage, mode)| format!("{}={}", stage, mode.as_str()))
108                    .collect(),
109                resolved_delay_preset: request
110                    .delay_effect
111                    .map(|effect| effect.preset.as_str().to_string()),
112                microphone_sim_model,
113                microphone_sim_processor_names,
114                notes,
115                ..CaptureDiagnostics::default()
116            },
117            microphone_sim_processors,
118            processors,
119            Box::new(WavSink::create(&request.output_path, output_spec)?),
120            Box::new(LevelMetrics::default()),
121        )
122    }
123
124    fn build_with_parts(
125        &self,
126        spec: AudioSpec,
127        diagnostics: CaptureDiagnostics,
128        microphone_sim_processors: Vec<Box<dyn crate::mic_sim::contracts::MicrophoneSimProcessor>>,
129        processors: Vec<Box<dyn crate::contracts::AudioProcessor>>,
130        sink: Box<dyn AudioSink>,
131        metrics: Box<dyn MetricsCollector>,
132    ) -> AudioResult<AudioPipeline> {
133        AudioPipeline::new(
134            spec,
135            diagnostics,
136            microphone_sim_processors,
137            processors,
138            sink,
139            metrics,
140        )
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::contracts::AudioProcessor;
148    use crate::core::{AudioError, AudioFrame};
149    use std::sync::{Arc, Mutex};
150
151    struct RecordingFactory {
152        names: Arc<Mutex<Vec<String>>>,
153    }
154
155    impl ProcessorFactory for RecordingFactory {
156        fn build_processors(
157            &self,
158            request: ProcessorBuildRequest,
159        ) -> AudioResult<Vec<Box<dyn AudioProcessor>>> {
160            self.names
161                .lock()
162                .unwrap()
163                .push(request.profile.as_str().to_string());
164            Ok(vec![Box::new(NoopProcessor)])
165        }
166    }
167
168    struct NoopProcessor;
169
170    impl AudioProcessor for NoopProcessor {
171        fn name(&self) -> &'static str {
172            "noop"
173        }
174
175        fn prepare(&mut self, _spec: AudioSpec) -> AudioResult<()> {
176            Ok(())
177        }
178
179        fn process(&mut self, _frame: &mut AudioFrame) -> AudioResult<()> {
180            Ok(())
181        }
182    }
183
184    #[test]
185    fn builder_uses_factory_and_populates_diagnostics() {
186        let names = Arc::new(Mutex::new(Vec::new()));
187        let builder = NativePipelineBuilder::with_factory(RecordingFactory {
188            names: Arc::clone(&names),
189        });
190
191        let output_path = std::env::temp_dir().join("zk-listen-builder-test.wav");
192        let mut pipeline = builder
193            .build(PipelineBuildRequest {
194                output_path: output_path.clone(),
195                profile: AudioProfile::VoiceHvac,
196                sample_rate: 44_100,
197                device_channels: 2,
198                device_name: Some("Test Mic".to_string()),
199                gain_db: 2.0,
200                limiter_threshold: 0.92,
201                high_pass_hz: 100.0,
202                noise_suppression_amount: 0.5,
203                noise_calibration_ms: 350,
204                delay_effect: None,
205                stage_overrides: std::collections::BTreeMap::new(),
206                microphone_sim: MicrophoneSimConfig::default(),
207            })
208            .map_err(|err| AudioError::new(format!("builder failed: {}", err)))
209            .unwrap();
210
211        let diagnostics = pipeline.finalize(Some(10)).unwrap();
212        assert_eq!(diagnostics.profile, "voice_hvac");
213        assert_eq!(diagnostics.profile_base.as_deref(), Some("voice_hvac"));
214        assert_eq!(diagnostics.device_name.as_deref(), Some("Test Mic"));
215        assert_eq!(diagnostics.processor_names, vec!["noop"]);
216        assert!(diagnostics.processor_stage_overrides.is_empty());
217        assert!(diagnostics.resolved_delay_preset.is_none());
218        assert!(diagnostics.microphone_sim_model.is_none());
219        assert!(diagnostics.microphone_sim_processor_names.is_empty());
220        assert!(diagnostics
221            .notes
222            .iter()
223            .any(|note| note == "pipeline_order=microphone_sim->voice_processing->sink"));
224        assert_eq!(names.lock().unwrap().as_slice(), ["voice_hvac"]);
225
226        let _ = std::fs::remove_file(output_path);
227    }
228}