1use std::path::PathBuf;
9use std::sync::atomic::Ordering;
10use std::thread;
11use std::time::Duration;
12
13use anyhow::{bail, Context, Result};
14use chrono::Utc;
15use clap::Parser;
16
17use crate::voice::capture::{
18 install_ctrl_c_handler, run_capture, CaptureOpts, CaptureSummary, TerminationReason,
19};
20use crate::voice::models::SPEAKER_WESPEAKER_EN;
21use crate::voice::{
22 captures_dir, speaker_file, CpalAudioSource, EnrolledSpeaker, WespeakerEmbedder,
23};
24
25pub const DEFAULT_IDLE_AFTER_SECS: u32 = 2;
27
28pub const DEFAULT_MAX_SECS: u32 = 30;
30
31#[derive(Parser)]
38pub struct EnrollCommand {
39 #[arg(long, default_value = "default")]
42 pub name: String,
43
44 #[arg(long, default_value_t = DEFAULT_IDLE_AFTER_SECS)]
46 pub idle_after: u32,
47
48 #[arg(long, default_value_t = DEFAULT_MAX_SECS)]
52 pub max_secs: u32,
53
54 #[arg(long)]
56 pub device: Option<String>,
57
58 #[arg(long)]
62 pub speaker_model: Option<PathBuf>,
63
64 #[arg(long)]
66 pub force: bool,
67}
68
69impl EnrollCommand {
70 pub fn execute(self) -> Result<()> {
72 let speaker_model = resolve_speaker_model_path(self.speaker_model.as_deref())?;
73 let dest = speaker_file(&self.name)?;
74 if dest.is_file() && !self.force {
75 bail!(
76 "speaker {} already enrolled at {}; pass --force to overwrite",
77 self.name,
78 dest.display()
79 );
80 }
81
82 let captures = captures_dir()?;
85 std::fs::create_dir_all(&captures)
86 .with_context(|| format!("create captures dir {}", captures.display()))?;
87 let timestamp = Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
88 let tmp_wav = captures.join(format!(".enroll-{timestamp}.wav"));
89
90 let stop = install_ctrl_c_handler()?;
91 if self.max_secs > 0 {
95 let stop = stop.clone();
96 let deadline_secs = u64::from(self.max_secs);
97 thread::spawn(move || {
98 thread::sleep(Duration::from_secs(deadline_secs));
99 stop.store(true, Ordering::Relaxed);
100 });
101 }
102
103 let source = CpalAudioSource::new(self.device.as_deref())?;
104 eprintln!(
105 "Recording enrolment for {} to {} (idle-after: {}s, max: {}s, Ctrl-C to stop)…",
106 self.name,
107 tmp_wav.display(),
108 self.idle_after,
109 self.max_secs,
110 );
111 let opts = CaptureOpts::new(&tmp_wav, self.idle_after);
112 let summary = run_capture(source, opts, stop)?;
113 print_capture_summary(&summary);
114
115 let result = embed_and_save(&self.name, &speaker_model, &tmp_wav, &dest);
117
118 let _ = std::fs::remove_file(&tmp_wav);
121
122 result?;
123 eprintln!(
124 "Enrolled speaker {} ({} dim) -> {}",
125 self.name,
126 embed_dim_hint(),
127 dest.display()
128 );
129 Ok(())
130 }
131}
132
133fn embed_and_save(
134 name: &str,
135 speaker_model: &std::path::Path,
136 tmp_wav: &std::path::Path,
137 dest: &std::path::Path,
138) -> Result<()> {
139 let pcm = read_wav_16k_mono_i16(tmp_wav)?;
140 let embedder = WespeakerEmbedder::new(speaker_model)?;
141 let vector = embedder.embed(&pcm)?;
142 let enrolled = EnrolledSpeaker {
143 name: name.to_string(),
144 model: SPEAKER_WESPEAKER_EN.variant.to_string(),
145 dim: vector.len(),
146 vector,
147 samples_used: 1,
148 enrolled_at: Utc::now(),
149 };
150 enrolled.save(dest)
151}
152
153fn read_wav_16k_mono_i16(path: &std::path::Path) -> Result<Vec<i16>> {
154 let mut reader = hound::WavReader::open(path)
155 .with_context(|| format!("open enrolment WAV at {}", path.display()))?;
156 let spec = reader.spec();
157 if spec.sample_rate != 16_000 || spec.channels != 1 {
158 bail!(
159 "enrolment WAV at {} must be 16 kHz mono (got {} Hz, {} channels)",
160 path.display(),
161 spec.sample_rate,
162 spec.channels
163 );
164 }
165 let samples: Vec<i16> = reader
166 .samples::<i16>()
167 .collect::<Result<Vec<_>, _>>()
168 .context("decode enrolment WAV samples")?;
169 Ok(samples)
170}
171
172fn resolve_speaker_model_path(override_path: Option<&std::path::Path>) -> Result<PathBuf> {
173 let dir = SPEAKER_WESPEAKER_EN.resolve_dir(override_path)?;
174 SPEAKER_WESPEAKER_EN.ensure_present(&dir)?;
175 Ok(dir.join(SPEAKER_WESPEAKER_EN.required_files[0]))
176}
177
178fn print_capture_summary(summary: &CaptureSummary) {
179 eprintln!("{}", format_capture_summary(summary));
180}
181
182fn format_capture_summary(summary: &CaptureSummary) -> String {
183 let reason = match summary.terminated_by {
184 TerminationReason::Idle => "silence threshold reached",
185 TerminationReason::SourceExhausted => "audio source ended",
186 TerminationReason::Signal => "Ctrl-C or max-secs deadline",
187 };
188 let seconds = samples_to_seconds(summary.samples_written);
189 format!(
190 "Captured {seconds:.2}s ({} samples; {} trimmed; stopped: {reason})",
191 summary.samples_written, summary.trimmed_samples,
192 )
193}
194
195fn samples_to_seconds(samples: u64) -> f64 {
196 samples as f64 / f64::from(crate::voice::wav::TARGET_SAMPLE_RATE)
197}
198
199const fn embed_dim_hint() -> usize {
200 256
203}
204
205#[cfg(test)]
206#[allow(clippy::unwrap_used, clippy::expect_used)]
207mod tests {
208 use super::*;
209 use clap::Parser;
210
211 #[derive(Parser)]
212 struct TestCli {
213 #[command(flatten)]
214 enroll: EnrollCommand,
215 }
216
217 #[test]
218 fn parses_defaults() {
219 let cli = TestCli::try_parse_from(["test"]).unwrap();
220 assert_eq!(cli.enroll.name, "default");
221 assert_eq!(cli.enroll.idle_after, DEFAULT_IDLE_AFTER_SECS);
222 assert_eq!(cli.enroll.max_secs, DEFAULT_MAX_SECS);
223 assert!(cli.enroll.device.is_none());
224 assert!(cli.enroll.speaker_model.is_none());
225 assert!(!cli.enroll.force);
226 }
227
228 #[test]
229 fn parses_all_flags() {
230 let cli = TestCli::try_parse_from([
231 "test",
232 "--name",
233 "jky",
234 "--idle-after",
235 "3",
236 "--max-secs",
237 "20",
238 "--device",
239 "Built-in Mic",
240 "--speaker-model",
241 "/opt/wespeaker.onnx",
242 "--force",
243 ])
244 .unwrap();
245 assert_eq!(cli.enroll.name, "jky");
246 assert_eq!(cli.enroll.idle_after, 3);
247 assert_eq!(cli.enroll.max_secs, 20);
248 assert_eq!(cli.enroll.device.as_deref(), Some("Built-in Mic"));
249 assert_eq!(
250 cli.enroll.speaker_model.as_deref().and_then(|p| p.to_str()),
251 Some("/opt/wespeaker.onnx")
252 );
253 assert!(cli.enroll.force);
254 }
255
256 #[test]
257 fn parses_max_secs_zero_disables_cap() {
258 let cli = TestCli::try_parse_from(["test", "--max-secs", "0"]).unwrap();
259 assert_eq!(cli.enroll.max_secs, 0);
260 }
261
262 #[test]
263 fn rejects_negative_idle_after() {
264 let r = TestCli::try_parse_from(["test", "--idle-after", "-1"]);
265 assert!(r.is_err());
266 }
267
268 #[test]
269 fn rejects_negative_max_secs() {
270 let r = TestCli::try_parse_from(["test", "--max-secs", "-1"]);
271 assert!(r.is_err());
272 }
273
274 #[test]
275 fn resolve_speaker_model_path_errors_with_install_hint_when_dir_empty() {
276 let tmp = tempfile::TempDir::new().unwrap();
277 let Err(err) = resolve_speaker_model_path(Some(tmp.path())) else {
278 panic!("empty model dir should fail the ensure_present check");
279 };
280 let msg = format!("{err:#}");
281 assert!(msg.contains("no Speaker model found"), "got: {msg}");
282 assert!(msg.contains("--variant speaker-wespeaker-en"), "got: {msg}");
283 }
284
285 #[test]
286 fn resolve_speaker_model_path_returns_onnx_file_when_present() {
287 let tmp = tempfile::TempDir::new().unwrap();
288 let onnx = tmp.path().join(SPEAKER_WESPEAKER_EN.required_files[0]);
289 std::fs::write(&onnx, b"placeholder").unwrap();
290 let resolved = resolve_speaker_model_path(Some(tmp.path())).unwrap();
291 assert_eq!(resolved, onnx);
292 }
293
294 #[test]
295 fn read_wav_16k_mono_i16_rejects_wrong_format() {
296 let tmp = tempfile::TempDir::new().unwrap();
297 let path = tmp.path().join("bad.wav");
298 let spec = hound::WavSpec {
299 channels: 2,
300 sample_rate: 44_100,
301 bits_per_sample: 16,
302 sample_format: hound::SampleFormat::Int,
303 };
304 let mut writer = hound::WavWriter::create(&path, spec).unwrap();
305 writer.write_sample(0_i16).unwrap();
306 writer.write_sample(0_i16).unwrap();
307 writer.finalize().unwrap();
308 let Err(err) = read_wav_16k_mono_i16(&path) else {
309 panic!("stereo @ 44.1k should fail");
310 };
311 let msg = format!("{err:#}");
312 assert!(msg.contains("must be 16 kHz mono"), "got: {msg}");
313 }
314
315 fn summary(reason: TerminationReason, written: u64, trimmed: u64) -> CaptureSummary {
316 CaptureSummary {
317 output: std::path::PathBuf::from("/tmp/out.wav"),
318 samples_written: written,
319 trimmed_samples: trimmed,
320 terminated_by: reason,
321 }
322 }
323
324 #[test]
325 fn format_capture_summary_idle_termination_mentions_silence() {
326 let s = format_capture_summary(&summary(TerminationReason::Idle, 16_000, 3_200));
327 assert!(s.contains("silence threshold reached"));
328 assert!(
329 s.contains("1.00s"),
330 "16000 samples @ 16kHz = 1.00s; got: {s}"
331 );
332 assert!(s.contains("16000 samples"));
333 assert!(s.contains("3200 trimmed"));
334 }
335
336 #[test]
337 fn format_capture_summary_signal_termination_mentions_ctrl_c_or_deadline() {
338 let s = format_capture_summary(&summary(TerminationReason::Signal, 48_000, 0));
339 assert!(s.contains("Ctrl-C or max-secs deadline"));
340 assert!(
341 s.contains("3.00s"),
342 "48000 samples @ 16kHz = 3.00s; got: {s}"
343 );
344 }
345
346 #[test]
347 fn format_capture_summary_source_exhausted_mentions_source() {
348 let s = format_capture_summary(&summary(TerminationReason::SourceExhausted, 8_000, 0));
349 assert!(s.contains("audio source ended"));
350 assert!(s.contains("0.50s"));
351 }
352
353 #[test]
354 fn samples_to_seconds_round_trips_at_16k() {
355 assert!((samples_to_seconds(0) - 0.0).abs() < f64::EPSILON);
356 assert!((samples_to_seconds(16_000) - 1.0).abs() < f64::EPSILON);
357 assert!((samples_to_seconds(8_000) - 0.5).abs() < f64::EPSILON);
358 }
359
360 #[test]
361 fn embed_dim_hint_is_256() {
362 assert_eq!(embed_dim_hint(), 256);
363 }
364
365 #[test]
366 fn read_wav_16k_mono_i16_decodes_ok_when_format_matches() {
367 let tmp = tempfile::TempDir::new().unwrap();
368 let path = tmp.path().join("ok.wav");
369 let spec = hound::WavSpec {
370 channels: 1,
371 sample_rate: 16_000,
372 bits_per_sample: 16,
373 sample_format: hound::SampleFormat::Int,
374 };
375 let mut writer = hound::WavWriter::create(&path, spec).unwrap();
376 for s in [100_i16, 200, 300, 400] {
377 writer.write_sample(s).unwrap();
378 }
379 writer.finalize().unwrap();
380 let samples = read_wav_16k_mono_i16(&path).unwrap();
381 assert_eq!(samples, vec![100, 200, 300, 400]);
382 }
383}