1use anyhow::{Context, Result, anyhow, bail};
2use sentencepiece::SentencePieceProcessor;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use std::ffi::OsString;
5use std::io::{Read, Write};
6use std::path::{Path, PathBuf};
7
8pub mod heartcodec;
9pub mod heartmula_runtime;
10
11pub const DEFAULT_MAX_PROMPT_TOKENS: usize = 128;
12pub const DEFAULT_CFG_SCALE: f32 = 1.5;
13pub const IPC_MODE_ENV: &str = "MAOLAN_BURN_SOCKETPAIR";
14
15pub fn stderr_logging_enabled() -> bool {
16 std::env::var_os(IPC_MODE_ENV).is_none()
17}
18
19#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
20#[serde(rename_all = "lowercase")]
21pub enum BackendChoice {
22 Cpu,
23 #[default]
24 Vulkan,
25}
26
27#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
28pub enum ModelChoice {
29 #[serde(rename = "happy-new-year")]
30 #[default]
31 HappyNewYear,
32 #[serde(rename = "RL")]
33 Rl,
34}
35
36#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
37pub struct GenerateRequest {
38 #[serde(default)]
39 pub model: ModelChoice,
40 pub prompt: String,
41 #[serde(default)]
42 pub model_dir: Option<PathBuf>,
43 #[serde(default = "default_output_path")]
44 pub output_path: PathBuf,
45 #[serde(default)]
46 pub inspect_only: bool,
47 pub backend: BackendChoice,
48 pub cfg_scale: f32,
49 #[serde(alias = "seconds_total", alias = "max_audio_length_ms")]
50 pub length: usize,
51 #[serde(default = "default_ode_steps")]
53 pub ode_steps: usize,
54 #[serde(default)]
56 pub lyrics: Option<String>,
57 #[serde(default)]
59 pub tags: Option<String>,
60 #[serde(default = "default_topk")]
62 pub topk: usize,
63 #[serde(default = "default_temperature")]
65 pub temperature: f32,
66 #[serde(default)]
68 pub decode_only: bool,
69 #[serde(default)]
71 pub frames_json: Option<PathBuf>,
72 #[serde(default)]
74 pub decode_threads: Option<usize>,
75 #[serde(default)]
77 pub decoder_seed: u64,
78}
79
80fn default_ode_steps() -> usize {
81 10
82}
83
84fn default_topk() -> usize {
85 50
86}
87
88fn default_temperature() -> f32 {
89 1.0
90}
91
92pub type CliOptions = GenerateRequest;
93
94#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
95pub struct GenerateResponseHeader {
96 pub backend: BackendChoice,
97 pub channels: usize,
98 pub frames: usize,
99 pub guidance_scale: f32,
100 pub prompt_tokens: i64,
101 pub sample_rate_hz: u32,
102 pub length: usize,
103 pub steps: usize,
104}
105
106#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
107pub struct GenerateError {
108 pub error: String,
109}
110
111#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
113pub struct GenerateProgress {
114 pub phase: String,
115 pub progress: f32,
116 pub operation: String,
117}
118
119fn default_output_path() -> PathBuf {
120 PathBuf::from("output.wav")
121}
122
123pub fn help_text() -> &'static str {
124 "\
125maolan-generate
126
127Usage:
128 maolan-generate [options] <prompt-or-lyrics>
129
130Options:
131 --model <happy-new-year|RL>
132 --model-dir <path>
133 --output <path>
134 --inspect
135 --backend <cpu|vulkan> Select the runtime backend
136 --lyrics <text> Prompt / lyrics (positional argument also accepted)
137 --tags <text> Style tags for HeartMula
138 --cfg-scale <float> CFG scale (1.0=no guidance, 2.0=weak, 6.0=strong)
139 --length <int> HeartMula: output length in milliseconds
140 --topk <int> HeartMula: top-k sampling (default: 50)
141 --temperature <float> HeartMula: sampling temperature (default: 1.0)
142 --ode-steps <int> HeartMula: flow matching steps (5=fast, 10=default, 20=best)
143 --decoder-seed <int> Seed for deterministic HeartCodec decoder latents
144 --decode-only Decode an existing frames JSON instead of generating tokens
145 --frames-json <path> Frames JSON input for --decode-only
146 --decode-threads <int> Number of worker threads for decode-only CPU decoding
147 -h, --help
148"
149}
150
151pub fn parse_options(args: impl IntoIterator<Item = OsString>) -> Result<CliOptions> {
152 let mut args = args.into_iter();
153 let _program = args.next();
154 let mut prompt = None;
155 let mut model_dir = None;
156 let mut output_path = default_output_path();
157 let mut inspect_only = false;
158 let mut model = ModelChoice::HappyNewYear;
159 let mut backend = BackendChoice::Vulkan;
160 let mut cfg_scale = DEFAULT_CFG_SCALE;
161 let mut length = 6_000_usize;
162 let mut ode_steps = 10_usize;
163 let mut lyrics = None;
164 let mut tags = None;
165 let mut topk = default_topk();
166 let mut temperature = default_temperature();
167 let mut decode_only = false;
168 let mut frames_json = None;
169 let mut decode_threads = None;
170 let mut decoder_seed = 0_u64;
171
172 while let Some(arg) = args.next() {
173 let arg = arg
174 .into_string()
175 .map_err(|_| anyhow!("arguments must be valid UTF-8"))?;
176
177 if matches!(arg.as_str(), "-h" | "--help") {
178 bail!(help_text());
179 }
180
181 if arg == "--backend" {
182 let value = args
183 .next()
184 .ok_or_else(|| anyhow!("missing value after --backend"))?
185 .into_string()
186 .map_err(|_| anyhow!("backend value must be valid UTF-8"))?;
187 backend = match value.as_str() {
188 "cpu" => BackendChoice::Cpu,
189 "vulkan" => BackendChoice::Vulkan,
190 _ => bail!("unsupported backend '{value}', expected one of: cpu, vulkan"),
191 };
192 continue;
193 }
194
195 if arg == "--model-dir" {
196 model_dir = Some(PathBuf::from(
197 args.next()
198 .ok_or_else(|| anyhow!("missing value after --model-dir"))?,
199 ));
200 continue;
201 }
202
203 if arg == "--output" {
204 output_path = PathBuf::from(
205 args.next()
206 .ok_or_else(|| anyhow!("missing value after --output"))?,
207 );
208 continue;
209 }
210
211 if arg == "--lyrics" {
212 let value = args
213 .next()
214 .ok_or_else(|| anyhow!("missing value after --lyrics"))?
215 .into_string()
216 .map_err(|_| anyhow!("lyrics value must be valid UTF-8"))?;
217 lyrics = Some(value);
218 continue;
219 }
220
221 if arg == "--tags" {
222 let value = args
223 .next()
224 .ok_or_else(|| anyhow!("missing value after --tags"))?
225 .into_string()
226 .map_err(|_| anyhow!("tags value must be valid UTF-8"))?;
227 tags = Some(value);
228 continue;
229 }
230
231 if arg == "--length" {
232 let value = args
233 .next()
234 .ok_or_else(|| anyhow!("missing value after --length"))?
235 .into_string()
236 .map_err(|_| anyhow!("length value must be valid UTF-8"))?;
237 length = value
238 .parse::<usize>()
239 .map_err(|_| anyhow!("length must be a whole number"))?;
240 continue;
241 }
242
243 if arg == "--topk" {
244 let value = args
245 .next()
246 .ok_or_else(|| anyhow!("missing value after --topk"))?
247 .into_string()
248 .map_err(|_| anyhow!("topk value must be valid UTF-8"))?;
249 topk = value
250 .parse::<usize>()
251 .map_err(|_| anyhow!("topk must be a whole number"))?;
252 if topk == 0 {
253 bail!("topk must be greater than zero");
254 }
255 continue;
256 }
257
258 if arg == "--temperature" {
259 let value = args
260 .next()
261 .ok_or_else(|| anyhow!("missing value after --temperature"))?
262 .into_string()
263 .map_err(|_| anyhow!("temperature value must be valid UTF-8"))?;
264 temperature = value
265 .parse::<f32>()
266 .map_err(|_| anyhow!("temperature must be a number"))?;
267 if !temperature.is_finite() || temperature < 0.0 {
268 bail!("temperature must be a finite non-negative number");
269 }
270 continue;
271 }
272
273 if arg == "--inspect" {
274 inspect_only = true;
275 continue;
276 }
277
278 if arg == "--decode-only" {
279 decode_only = true;
280 continue;
281 }
282
283 if arg == "--frames-json" {
284 frames_json =
285 Some(PathBuf::from(args.next().ok_or_else(|| {
286 anyhow!("missing value after --frames-json")
287 })?));
288 continue;
289 }
290
291 if arg == "--decode-threads" {
292 let value = args
293 .next()
294 .ok_or_else(|| anyhow!("missing value after --decode-threads"))?
295 .into_string()
296 .map_err(|_| anyhow!("decode-threads value must be valid UTF-8"))?;
297 decode_threads = Some(
298 value
299 .parse::<usize>()
300 .map_err(|_| anyhow!("decode-threads must be a whole number"))?,
301 );
302 continue;
303 }
304
305 if arg == "--decoder-seed" {
306 let value = args
307 .next()
308 .ok_or_else(|| anyhow!("missing value after --decoder-seed"))?
309 .into_string()
310 .map_err(|_| anyhow!("decoder-seed value must be valid UTF-8"))?;
311 decoder_seed = value
312 .parse::<u64>()
313 .map_err(|_| anyhow!("decoder-seed must be a whole number"))?;
314 continue;
315 }
316
317 if arg == "--model" {
318 let value = args
319 .next()
320 .ok_or_else(|| anyhow!("missing value after --model"))?
321 .into_string()
322 .map_err(|_| anyhow!("model value must be valid UTF-8"))?;
323 model = match value.as_str() {
324 "happy-new-year" => ModelChoice::HappyNewYear,
325 "RL" => ModelChoice::Rl,
326 _ => {
327 bail!("unsupported model '{value}', expected one of: happy-new-year, RL")
328 }
329 };
330 continue;
331 }
332
333 if arg == "--cfg-scale" {
334 let value = args
335 .next()
336 .ok_or_else(|| anyhow!("missing value after --cfg-scale"))?
337 .into_string()
338 .map_err(|_| anyhow!("cfg-scale value must be valid UTF-8"))?;
339 cfg_scale = value
340 .parse::<f32>()
341 .map_err(|_| anyhow!("cfg-scale must be a number"))?;
342 if !cfg_scale.is_finite() || cfg_scale < 0.0 {
343 bail!("cfg-scale must be a finite non-negative number");
344 }
345 continue;
346 }
347
348 if arg == "--ode-steps" {
349 let value = args
350 .next()
351 .ok_or_else(|| anyhow!("missing value after --ode-steps"))?
352 .into_string()
353 .map_err(|_| anyhow!("ode-steps value must be valid UTF-8"))?;
354 ode_steps = value
355 .parse::<usize>()
356 .map_err(|_| anyhow!("ode-steps must be a whole number"))?;
357 if ode_steps == 0 || ode_steps > 50 {
358 bail!("ode-steps must be between 1 and 50");
359 }
360 continue;
361 }
362
363 if prompt.is_some() {
364 bail!("expected exactly one positional argument: the prompt");
365 }
366 prompt = Some(arg);
367 }
368
369 let prompt = if decode_only {
370 prompt.unwrap_or_default()
371 } else if let Some(lyrics) = lyrics {
372 lyrics
373 } else {
374 prompt.ok_or_else(|| {
375 anyhow!("missing prompt argument; provide a positional argument or --lyrics")
376 })?
377 };
378 let trimmed = prompt.trim();
379
380 if !decode_only && trimmed.is_empty() {
381 bail!("prompt argument cannot be empty");
382 }
383
384 validate_options(CliOptions {
385 model,
386 prompt: trimmed.to_owned(),
387 model_dir,
388 output_path,
389 inspect_only,
390 backend,
391 cfg_scale,
392 length,
393 ode_steps,
394 lyrics: None,
395 tags,
396 topk,
397 temperature,
398 decode_only,
399 frames_json,
400 decode_threads,
401 decoder_seed,
402 })
403}
404
405pub fn validate_options(mut options: CliOptions) -> Result<CliOptions> {
406 let prompt = options.prompt.trim();
407 if prompt.is_empty() && !options.decode_only {
408 bail!("prompt argument cannot be empty");
409 }
410 options.prompt = prompt.to_owned();
411
412 options.tags = options
413 .tags
414 .as_deref()
415 .map(str::trim)
416 .filter(|value| !value.is_empty())
417 .map(str::to_owned);
418
419 options.model_dir = options
420 .model_dir
421 .as_deref()
422 .map(Path::new)
423 .map(Path::to_path_buf);
424
425 if !options.cfg_scale.is_finite() || options.cfg_scale < 0.0 {
426 bail!("cfg-scale must be a finite non-negative number");
427 }
428 if options.length == 0 {
429 bail!("length must be greater than zero");
430 }
431 if options.output_path.as_os_str().is_empty() {
432 bail!("output path cannot be empty");
433 }
434 if options.decode_only && options.frames_json.is_none() {
435 bail!("--decode-only requires --frames-json");
436 }
437 if options.frames_json.is_some() && !options.decode_only {
438 bail!("--frames-json can only be used with --decode-only");
439 }
440 if let Some(threads) = options.decode_threads
441 && threads == 0
442 {
443 bail!("--decode-threads must be greater than zero");
444 }
445
446 Ok(options)
447}
448
449pub fn read_ipc_message<T: DeserializeOwned>(reader: &mut impl Read) -> Result<T> {
450 let mut len_bytes = [0_u8; 8];
451 reader
452 .read_exact(&mut len_bytes)
453 .context("failed to read IPC message length")?;
454 let len = u64::from_le_bytes(len_bytes);
455 let len = usize::try_from(len).context("IPC message length is too large")?;
456 let mut payload = vec![0_u8; len];
457 reader
458 .read_exact(&mut payload)
459 .context("failed to read IPC message payload")?;
460 serde_json::from_slice(&payload).context("failed to decode IPC JSON message")
461}
462
463pub fn write_ipc_message<T: Serialize>(writer: &mut impl Write, value: &T) -> Result<()> {
464 let payload = serde_json::to_vec(value).context("failed to encode IPC JSON message")?;
465 let len = u64::try_from(payload.len()).context("IPC payload is too large")?;
466 writer
467 .write_all(&len.to_le_bytes())
468 .context("failed to write IPC message length")?;
469 writer
470 .write_all(&payload)
471 .context("failed to write IPC message payload")?;
472 writer.flush().context("failed to flush IPC JSON message")?;
473 Ok(())
474}
475
476pub fn write_ipc_bytes(writer: &mut impl Write, bytes: &[u8]) -> Result<()> {
477 let len = u64::try_from(bytes.len()).context("IPC byte payload is too large")?;
478 writer
479 .write_all(&len.to_le_bytes())
480 .context("failed to write IPC byte length")?;
481 writer
482 .write_all(bytes)
483 .context("failed to write IPC byte payload")?;
484 writer.flush().context("failed to flush IPC byte payload")?;
485 Ok(())
486}
487
488pub fn tokenizer_path() -> PathBuf {
489 Path::new(env!("CARGO_MANIFEST_DIR"))
490 .join("assets")
491 .join("t5-base-spiece.model")
492}
493
494pub fn load_tokenizer() -> Result<SentencePieceProcessor> {
495 SentencePieceProcessor::open(tokenizer_path())
496 .context("failed to open the bundled T5 sentencepiece model")
497}
498
499pub fn encode_prompt(
500 tokenizer: &SentencePieceProcessor,
501 prompt: &str,
502 max_tokens: usize,
503) -> Result<(Vec<i64>, Vec<i64>)> {
504 let mut token_ids = Vec::with_capacity(max_tokens);
505
506 if let Some(bos_id) = tokenizer.bos_id() {
507 token_ids.push(i64::from(bos_id));
508 }
509
510 for piece in tokenizer
511 .encode(prompt)
512 .context("failed to tokenize prompt")?
513 {
514 if token_ids.len() >= max_tokens {
515 break;
516 }
517 token_ids.push(i64::from(piece.id));
518 }
519
520 if token_ids.len() < max_tokens
521 && let Some(eos_id) = tokenizer.eos_id()
522 {
523 token_ids.push(i64::from(eos_id));
524 }
525
526 if token_ids.len() > max_tokens {
527 token_ids.truncate(max_tokens);
528 }
529
530 let attention_len = token_ids.len();
531 let mut attention_mask = vec![1_i64; attention_len];
532 token_ids.resize(max_tokens, 0);
533 attention_mask.resize(max_tokens, 0);
534
535 Ok((token_ids, attention_mask))
536}
537
538#[cfg(test)]
539mod tests {
540 use super::{BackendChoice, DEFAULT_MAX_PROMPT_TOKENS, ModelChoice, parse_options};
541 use std::ffi::OsString;
542
543 #[test]
544 fn parses_single_prompt_argument() {
545 let args = [OsString::from("generate"), OsString::from("warm tape hiss")];
546 let options = parse_options(args).expect("options should parse");
547 assert_eq!(options.prompt, "warm tape hiss");
548 assert_eq!(options.model, ModelChoice::HappyNewYear);
549 assert_eq!(options.backend, BackendChoice::Vulkan);
550 assert_eq!(options.cfg_scale, 1.5);
551 assert_eq!(options.length, 6_000);
552 }
553
554 #[test]
555 fn trims_surrounding_whitespace() {
556 let args = [
557 OsString::from("generate"),
558 OsString::from(" foley footsteps "),
559 ];
560 let options = parse_options(args).expect("options should parse");
561 assert_eq!(options.prompt, "foley footsteps");
562 }
563
564 #[test]
565 fn rejects_missing_prompt() {
566 let args = [OsString::from("generate")];
567 assert!(parse_options(args).is_err());
568 }
569
570 #[test]
571 fn parses_backend_flag_after_prompt() {
572 let args = [
573 OsString::from("generate"),
574 OsString::from("warm tape hiss"),
575 OsString::from("--backend"),
576 OsString::from("vulkan"),
577 ];
578 let options = parse_options(args).expect("options should parse");
579 assert_eq!(options.backend, BackendChoice::Vulkan);
580 }
581
582 #[test]
583 fn parses_model_flag() {
584 let args = [
585 OsString::from("generate"),
586 OsString::from("--model"),
587 OsString::from("happy-new-year"),
588 OsString::from("verse and chorus"),
589 ];
590 let options = parse_options(args).expect("options should parse");
591 assert_eq!(options.model, ModelChoice::HappyNewYear);
592 }
593
594 #[test]
595 fn parses_rl_model_flag() {
596 let args = [
597 OsString::from("generate"),
598 OsString::from("--model"),
599 OsString::from("RL"),
600 OsString::from("verse and chorus"),
601 ];
602 let options = parse_options(args).expect("options should parse");
603 assert_eq!(options.model, ModelChoice::Rl);
604 }
605
606 #[test]
607 fn parses_tags_cfg_and_length() {
608 let args = [
609 OsString::from("generate"),
610 OsString::from("--tags"),
611 OsString::from("warm tape hiss"),
612 OsString::from("--cfg-scale"),
613 OsString::from("4.5"),
614 OsString::from("--ode-steps"),
615 OsString::from("20"),
616 OsString::from("--length"),
617 OsString::from("8000"),
618 OsString::from("verse and chorus"),
619 ];
620 let options = parse_options(args).expect("options should parse");
621 assert_eq!(options.cfg_scale, 4.5);
622 assert_eq!(options.ode_steps, 20);
623 assert_eq!(options.length, 8_000);
624 }
625
626 #[test]
627 fn parses_decode_only_without_prompt() {
628 let args = [
629 OsString::from("generate"),
630 OsString::from("--decode-only"),
631 OsString::from("--frames-json"),
632 OsString::from("/tmp/frames.json"),
633 ];
634 let options = parse_options(args).expect("options should parse");
635 assert!(options.decode_only);
636 assert_eq!(
637 options.frames_json.as_deref(),
638 Some(std::path::Path::new("/tmp/frames.json"))
639 );
640 assert!(options.prompt.is_empty());
641 }
642
643 #[test]
644 fn parses_decode_threads() {
645 let args = [
646 OsString::from("generate"),
647 OsString::from("--decode-only"),
648 OsString::from("--frames-json"),
649 OsString::from("/tmp/frames.json"),
650 OsString::from("--decode-threads"),
651 OsString::from("8"),
652 ];
653 let options = parse_options(args).expect("options should parse");
654 assert_eq!(options.decode_threads, Some(8));
655 }
656
657 const _: () = assert!(DEFAULT_MAX_PROMPT_TOKENS == 128);
658
659 #[test]
660 fn parses_cpu_backend_flag() {
661 let args = [
662 OsString::from("generate"),
663 OsString::from("--backend"),
664 OsString::from("cpu"),
665 OsString::from("test prompt"),
666 ];
667 let options = parse_options(args).expect("options should parse");
668 assert_eq!(options.backend, BackendChoice::Cpu);
669 }
670
671 #[test]
672 fn rejects_invalid_backend() {
673 let args = [
674 OsString::from("generate"),
675 OsString::from("--backend"),
676 OsString::from("invalid"),
677 OsString::from("test prompt"),
678 ];
679 assert!(parse_options(args).is_err());
680 }
681
682 #[test]
683 fn parses_cfg_scale_validation() {
684 let args = [
685 OsString::from("generate"),
686 OsString::from("--cfg-scale"),
687 OsString::from("2.5"),
688 OsString::from("test prompt"),
689 ];
690 let options = parse_options(args).expect("options should parse");
691 assert_eq!(options.cfg_scale, 2.5);
692 }
693
694 #[test]
695 fn rejects_negative_cfg_scale() {
696 let args = [
697 OsString::from("generate"),
698 OsString::from("--cfg-scale"),
699 OsString::from("-1.0"),
700 OsString::from("test prompt"),
701 ];
702 assert!(parse_options(args).is_err());
703 }
704
705 #[test]
706 fn rejects_invalid_cfg_scale() {
707 let args = [
708 OsString::from("generate"),
709 OsString::from("--cfg-scale"),
710 OsString::from("not-a-number"),
711 OsString::from("test prompt"),
712 ];
713 assert!(parse_options(args).is_err());
714 }
715
716 #[test]
717 fn parses_temperature() {
718 let args = [
719 OsString::from("generate"),
720 OsString::from("--temperature"),
721 OsString::from("0.8"),
722 OsString::from("test prompt"),
723 ];
724 let options = parse_options(args).expect("options should parse");
725 assert_eq!(options.temperature, 0.8);
726 }
727
728 #[test]
729 fn rejects_negative_temperature() {
730 let args = [
731 OsString::from("generate"),
732 OsString::from("--temperature"),
733 OsString::from("-0.5"),
734 OsString::from("test prompt"),
735 ];
736 assert!(parse_options(args).is_err());
737 }
738
739 #[test]
740 fn parses_topk() {
741 let args = [
742 OsString::from("generate"),
743 OsString::from("--topk"),
744 OsString::from("25"),
745 OsString::from("test prompt"),
746 ];
747 let options = parse_options(args).expect("options should parse");
748 assert_eq!(options.topk, 25);
749 }
750
751 #[test]
752 fn rejects_zero_topk() {
753 let args = [
754 OsString::from("generate"),
755 OsString::from("--topk"),
756 OsString::from("0"),
757 OsString::from("test prompt"),
758 ];
759 assert!(parse_options(args).is_err());
760 }
761
762 #[test]
763 fn parses_ode_steps() {
764 let args = [
765 OsString::from("generate"),
766 OsString::from("--ode-steps"),
767 OsString::from("15"),
768 OsString::from("test prompt"),
769 ];
770 let options = parse_options(args).expect("options should parse");
771 assert_eq!(options.ode_steps, 15);
772 }
773
774 #[test]
775 fn rejects_zero_ode_steps() {
776 let args = [
777 OsString::from("generate"),
778 OsString::from("--ode-steps"),
779 OsString::from("0"),
780 OsString::from("test prompt"),
781 ];
782 assert!(parse_options(args).is_err());
783 }
784
785 #[test]
786 fn rejects_too_many_ode_steps() {
787 let args = [
788 OsString::from("generate"),
789 OsString::from("--ode-steps"),
790 OsString::from("51"),
791 OsString::from("test prompt"),
792 ];
793 assert!(parse_options(args).is_err());
794 }
795
796 #[test]
797 fn parses_output_path() {
798 let args = [
799 OsString::from("generate"),
800 OsString::from("--output"),
801 OsString::from("/tmp/output.wav"),
802 OsString::from("test prompt"),
803 ];
804 let options = parse_options(args).expect("options should parse");
805 assert_eq!(
806 options.output_path,
807 std::path::PathBuf::from("/tmp/output.wav")
808 );
809 }
810
811 #[test]
812 fn parses_model_dir() {
813 let args = [
814 OsString::from("generate"),
815 OsString::from("--model-dir"),
816 OsString::from("/tmp/models"),
817 OsString::from("test prompt"),
818 ];
819 let options = parse_options(args).expect("options should parse");
820 assert_eq!(
821 options.model_dir,
822 Some(std::path::PathBuf::from("/tmp/models"))
823 );
824 }
825
826 #[test]
827 fn parses_decoder_seed() {
828 let args = [
829 OsString::from("generate"),
830 OsString::from("--decoder-seed"),
831 OsString::from("42"),
832 OsString::from("test prompt"),
833 ];
834 let options = parse_options(args).expect("options should parse");
835 assert_eq!(options.decoder_seed, 42);
836 }
837
838 #[test]
839 fn parses_lyrics_alias() {
840 let args = [
841 OsString::from("generate"),
842 OsString::from("--lyrics"),
843 OsString::from("custom lyrics text"),
844 ];
845 let options = parse_options(args).expect("options should parse");
846 assert_eq!(options.prompt, "custom lyrics text");
847 }
848
849 #[test]
850 fn parses_inspect_flag() {
851 let args = [
852 OsString::from("generate"),
853 OsString::from("--inspect"),
854 OsString::from("test prompt"),
855 ];
856 let options = parse_options(args).expect("options should parse");
857 assert!(options.inspect_only);
858 }
859
860 #[test]
861 fn rejects_multiple_positional_args() {
862 let args = [
863 OsString::from("generate"),
864 OsString::from("first prompt"),
865 OsString::from("second prompt"),
866 ];
867 assert!(parse_options(args).is_err());
868 }
869
870 #[test]
871 fn rejects_empty_prompt() {
872 let args = [OsString::from("generate"), OsString::from(" ")];
873 assert!(parse_options(args).is_err());
874 }
875
876 #[test]
877 fn validate_options_trims_prompt() {
878 let options = super::CliOptions {
879 model: ModelChoice::HappyNewYear,
880 prompt: " test prompt ".to_owned(),
881 model_dir: None,
882 output_path: std::path::PathBuf::from("output.wav"),
883 inspect_only: false,
884 backend: BackendChoice::Vulkan,
885 cfg_scale: 1.5,
886 length: 6000,
887 ode_steps: 10,
888 lyrics: None,
889 tags: None,
890 topk: 50,
891 temperature: 1.0,
892 decode_only: false,
893 frames_json: None,
894 decode_threads: None,
895 decoder_seed: 0,
896 };
897 let validated = super::validate_options(options).expect("validation should pass");
898 assert_eq!(validated.prompt, "test prompt");
899 }
900
901 #[test]
902 fn validate_options_rejects_empty_output_path() {
903 let options = super::CliOptions {
904 model: ModelChoice::HappyNewYear,
905 prompt: "test".to_owned(),
906 model_dir: None,
907 output_path: std::path::PathBuf::from(""),
908 inspect_only: false,
909 backend: BackendChoice::Vulkan,
910 cfg_scale: 1.5,
911 length: 6000,
912 ode_steps: 10,
913 lyrics: None,
914 tags: None,
915 topk: 50,
916 temperature: 1.0,
917 decode_only: false,
918 frames_json: None,
919 decode_threads: None,
920 decoder_seed: 0,
921 };
922 assert!(super::validate_options(options).is_err());
923 }
924
925 #[test]
926 fn validate_options_rejects_zero_length() {
927 let options = super::CliOptions {
928 model: ModelChoice::HappyNewYear,
929 prompt: "test".to_owned(),
930 model_dir: None,
931 output_path: std::path::PathBuf::from("output.wav"),
932 inspect_only: false,
933 backend: BackendChoice::Vulkan,
934 cfg_scale: 1.5,
935 length: 0,
936 ode_steps: 10,
937 lyrics: None,
938 tags: None,
939 topk: 50,
940 temperature: 1.0,
941 decode_only: false,
942 frames_json: None,
943 decode_threads: None,
944 decoder_seed: 0,
945 };
946 assert!(super::validate_options(options).is_err());
947 }
948
949 #[test]
950 fn validate_options_rejects_decode_only_without_frames() {
951 let options = super::CliOptions {
952 model: ModelChoice::HappyNewYear,
953 prompt: "".to_owned(),
954 model_dir: None,
955 output_path: std::path::PathBuf::from("output.wav"),
956 inspect_only: false,
957 backend: BackendChoice::Vulkan,
958 cfg_scale: 1.5,
959 length: 6000,
960 ode_steps: 10,
961 lyrics: None,
962 tags: None,
963 topk: 50,
964 temperature: 1.0,
965 decode_only: true,
966 frames_json: None,
967 decode_threads: None,
968 decoder_seed: 0,
969 };
970 assert!(super::validate_options(options).is_err());
971 }
972
973 #[test]
974 fn validate_options_rejects_zero_decode_threads() {
975 let options = super::CliOptions {
976 model: ModelChoice::HappyNewYear,
977 prompt: "test".to_owned(),
978 model_dir: None,
979 output_path: std::path::PathBuf::from("output.wav"),
980 inspect_only: false,
981 backend: BackendChoice::Vulkan,
982 cfg_scale: 1.5,
983 length: 6000,
984 ode_steps: 10,
985 lyrics: None,
986 tags: None,
987 topk: 50,
988 temperature: 1.0,
989 decode_only: false,
990 frames_json: None,
991 decode_threads: Some(0),
992 decoder_seed: 0,
993 };
994 assert!(super::validate_options(options).is_err());
995 }
996
997 #[test]
998 fn validate_options_trims_tags() {
999 let options = super::CliOptions {
1000 model: ModelChoice::HappyNewYear,
1001 prompt: "test".to_owned(),
1002 model_dir: None,
1003 output_path: std::path::PathBuf::from("output.wav"),
1004 inspect_only: false,
1005 backend: BackendChoice::Vulkan,
1006 cfg_scale: 1.5,
1007 length: 6000,
1008 ode_steps: 10,
1009 lyrics: None,
1010 tags: Some(" tag1, tag2 ".to_owned()),
1011 topk: 50,
1012 temperature: 1.0,
1013 decode_only: false,
1014 frames_json: None,
1015 decode_threads: None,
1016 decoder_seed: 0,
1017 };
1018 let validated = super::validate_options(options).expect("validation should pass");
1019 assert_eq!(validated.tags, Some("tag1, tag2".to_owned()));
1020 }
1021
1022 #[test]
1023 fn validate_options_filters_empty_tags() {
1024 let options = super::CliOptions {
1025 model: ModelChoice::HappyNewYear,
1026 prompt: "test".to_owned(),
1027 model_dir: None,
1028 output_path: std::path::PathBuf::from("output.wav"),
1029 inspect_only: false,
1030 backend: BackendChoice::Vulkan,
1031 cfg_scale: 1.5,
1032 length: 6000,
1033 ode_steps: 10,
1034 lyrics: None,
1035 tags: Some(" ".to_owned()),
1036 topk: 50,
1037 temperature: 1.0,
1038 decode_only: false,
1039 frames_json: None,
1040 decode_threads: None,
1041 decoder_seed: 0,
1042 };
1043 let validated = super::validate_options(options).expect("validation should pass");
1044 assert_eq!(validated.tags, None);
1045 }
1046
1047 #[test]
1048 fn default_output_path_is_output_wav() {
1049 let args = [OsString::from("generate"), OsString::from("test prompt")];
1050 let options = parse_options(args).expect("options should parse");
1051 assert_eq!(options.output_path, std::path::PathBuf::from("output.wav"));
1052 }
1053
1054 #[test]
1055 fn default_length_is_6000() {
1056 let args = [OsString::from("generate"), OsString::from("test prompt")];
1057 let options = parse_options(args).expect("options should parse");
1058 assert_eq!(options.length, 6000);
1059 }
1060
1061 #[test]
1062 fn default_ode_steps_is_10() {
1063 let args = [OsString::from("generate"), OsString::from("test prompt")];
1064 let options = parse_options(args).expect("options should parse");
1065 assert_eq!(options.ode_steps, 10);
1066 }
1067
1068 #[test]
1069 fn default_topk_is_50() {
1070 let args = [OsString::from("generate"), OsString::from("test prompt")];
1071 let options = parse_options(args).expect("options should parse");
1072 assert_eq!(options.topk, 50);
1073 }
1074
1075 #[test]
1076 fn default_temperature_is_1() {
1077 let args = [OsString::from("generate"), OsString::from("test prompt")];
1078 let options = parse_options(args).expect("options should parse");
1079 assert_eq!(options.temperature, 1.0);
1080 }
1081
1082 #[test]
1083 fn default_cfg_scale_is_1_5() {
1084 let args = [OsString::from("generate"), OsString::from("test prompt")];
1085 let options = parse_options(args).expect("options should parse");
1086 assert_eq!(options.cfg_scale, 1.5);
1087 }
1088
1089 #[test]
1090 fn default_decoder_seed_is_0() {
1091 let args = [OsString::from("generate"), OsString::from("test prompt")];
1092 let options = parse_options(args).expect("options should parse");
1093 assert_eq!(options.decoder_seed, 0);
1094 }
1095
1096 #[test]
1097 fn help_text_contains_usage() {
1098 let help = super::help_text();
1099 assert!(help.contains("maolan-generate"));
1100 assert!(help.contains("Usage:"));
1101 assert!(help.contains("Options:"));
1102 }
1103
1104 #[test]
1105 fn stderr_logging_disabled_in_ipc_mode() {
1106 let _ = super::stderr_logging_enabled();
1107 }
1108
1109 #[test]
1110 fn write_and_read_ipc_message_roundtrip() {
1111 use super::{read_ipc_message, write_ipc_message};
1112 use std::io::Cursor;
1113
1114 let original = super::GenerateResponseHeader {
1115 backend: BackendChoice::Cpu,
1116 channels: 2,
1117 frames: 48000,
1118 guidance_scale: 2.0,
1119 prompt_tokens: 10,
1120 sample_rate_hz: 48000,
1121 length: 6000,
1122 steps: 10,
1123 };
1124
1125 let mut buffer = Vec::new();
1126 write_ipc_message(&mut buffer, &original).expect("write should succeed");
1127
1128 let mut cursor = Cursor::new(buffer);
1129 let decoded: super::GenerateResponseHeader =
1130 read_ipc_message(&mut cursor).expect("read should succeed");
1131
1132 assert_eq!(decoded.backend, original.backend);
1133 assert_eq!(decoded.channels, original.channels);
1134 assert_eq!(decoded.frames, original.frames);
1135 assert_eq!(decoded.guidance_scale, original.guidance_scale);
1136 assert_eq!(decoded.prompt_tokens, original.prompt_tokens);
1137 assert_eq!(decoded.sample_rate_hz, original.sample_rate_hz);
1138 assert_eq!(decoded.length, original.length);
1139 assert_eq!(decoded.steps, original.steps);
1140 }
1141
1142 #[test]
1143 fn write_and_read_ipc_progress_roundtrip() {
1144 use super::{read_ipc_message, write_ipc_message};
1145 use std::io::Cursor;
1146
1147 let original = super::GenerateProgress {
1148 phase: "generator".to_owned(),
1149 progress: 0.5,
1150 operation: "Processing".to_owned(),
1151 };
1152
1153 let mut buffer = Vec::new();
1154 write_ipc_message(&mut buffer, &original).expect("write should succeed");
1155
1156 let mut cursor = Cursor::new(buffer);
1157 let decoded: super::GenerateProgress =
1158 read_ipc_message(&mut cursor).expect("read should succeed");
1159
1160 assert_eq!(decoded.phase, original.phase);
1161 assert_eq!(decoded.progress, original.progress);
1162 assert_eq!(decoded.operation, original.operation);
1163 }
1164
1165 #[test]
1166 fn write_and_read_ipc_error_roundtrip() {
1167 use super::{read_ipc_message, write_ipc_message};
1168 use std::io::Cursor;
1169
1170 let original = super::GenerateError {
1171 error: "Test error message".to_owned(),
1172 };
1173
1174 let mut buffer = Vec::new();
1175 write_ipc_message(&mut buffer, &original).expect("write should succeed");
1176
1177 let mut cursor = Cursor::new(buffer);
1178 let decoded: super::GenerateError =
1179 read_ipc_message(&mut cursor).expect("read should succeed");
1180
1181 assert_eq!(decoded.error, original.error);
1182 }
1183
1184 #[test]
1185 fn write_ipc_bytes_roundtrip() {
1186 use super::write_ipc_bytes;
1187 use std::io::Cursor;
1188
1189 let original = b"Hello, World!";
1190
1191 let mut buffer = Vec::new();
1192 write_ipc_bytes(&mut buffer, original).expect("write should succeed");
1193
1194 let mut cursor = Cursor::new(buffer);
1195 let mut len_bytes = [0_u8; 8];
1196 std::io::Read::read_exact(&mut cursor, &mut len_bytes).expect("read length should succeed");
1197 let len = u64::from_le_bytes(len_bytes) as usize;
1198 assert_eq!(len, original.len());
1199
1200 let mut payload = vec![0_u8; len];
1201 std::io::Read::read_exact(&mut cursor, &mut payload).expect("read payload should succeed");
1202 assert_eq!(&payload[..], &original[..]);
1203 }
1204
1205 #[test]
1206 fn read_ipc_message_fails_on_truncated_data() {
1207 use super::read_ipc_message;
1208 use std::io::Cursor;
1209
1210 let len_bytes = 100_u64.to_le_bytes();
1211 let buffer = len_bytes.to_vec();
1212
1213 let mut cursor = Cursor::new(buffer);
1214 let result: Result<super::GenerateResponseHeader, _> = read_ipc_message(&mut cursor);
1215 assert!(result.is_err());
1216 }
1217
1218 #[test]
1219 fn read_ipc_message_fails_on_invalid_json() {
1220 use super::read_ipc_message;
1221 use std::io::Cursor;
1222
1223 let payload = b"not valid json";
1224 let len_bytes = (payload.len() as u64).to_le_bytes();
1225 let mut buffer = Vec::new();
1226 buffer.extend_from_slice(&len_bytes);
1227 buffer.extend_from_slice(payload);
1228
1229 let mut cursor = Cursor::new(buffer);
1230 let result: Result<super::GenerateResponseHeader, _> = read_ipc_message(&mut cursor);
1231 assert!(result.is_err());
1232 }
1233
1234 #[test]
1235 fn serialize_generate_request() {
1236 let request = super::GenerateRequest {
1237 model: ModelChoice::Rl,
1238 prompt: "test prompt".to_owned(),
1239 model_dir: Some(std::path::PathBuf::from("/tmp/models")),
1240 output_path: std::path::PathBuf::from("/tmp/output.wav"),
1241 inspect_only: true,
1242 backend: BackendChoice::Cpu,
1243 cfg_scale: 2.5,
1244 length: 8000,
1245 ode_steps: 15,
1246 lyrics: Some("lyrics text".to_owned()),
1247 tags: Some("tag1,tag2".to_owned()),
1248 topk: 25,
1249 temperature: 0.8,
1250 decode_only: false,
1251 frames_json: None,
1252 decode_threads: Some(4),
1253 decoder_seed: 42,
1254 };
1255
1256 let json = serde_json::to_string(&request).expect("serialization should succeed");
1257 assert!(json.contains("test prompt"));
1258 assert!(json.contains("cpu"));
1259 assert!(json.contains("RL"));
1260 }
1261
1262 #[test]
1263 fn deserialize_generate_request() {
1264 let json = r#"{
1265 "model": "RL",
1266 "prompt": "test prompt",
1267 "output_path": "/tmp/output.wav",
1268 "backend": "cpu",
1269 "cfg_scale": 2.5,
1270 "length": 8000,
1271 "ode_steps": 15,
1272 "topk": 25,
1273 "temperature": 0.8,
1274 "decoder_seed": 42
1275 }"#;
1276
1277 let request: super::GenerateRequest =
1278 serde_json::from_str(json).expect("deserialization should succeed");
1279 assert_eq!(request.model, ModelChoice::Rl);
1280 assert_eq!(request.prompt, "test prompt");
1281 assert_eq!(request.backend, BackendChoice::Cpu);
1282 assert_eq!(request.cfg_scale, 2.5);
1283 assert_eq!(request.length, 8000);
1284 assert_eq!(request.ode_steps, 15);
1285 assert_eq!(request.topk, 25);
1286 assert_eq!(request.temperature, 0.8);
1287 assert_eq!(request.decoder_seed, 42);
1288 }
1289
1290 #[test]
1291 fn deserialize_generate_request_with_aliases() {
1292 let json1 =
1293 r#"{"prompt": "test", "backend": "cpu", "cfg_scale": 1.5, "seconds_total": 5000}"#;
1294 let request1: super::GenerateRequest =
1295 serde_json::from_str(json1).expect("deserialization should succeed");
1296 assert_eq!(request1.length, 5000);
1297
1298 let json2 = r#"{"prompt": "test", "backend": "cpu", "cfg_scale": 1.5, "max_audio_length_ms": 7000}"#;
1299 let request2: super::GenerateRequest =
1300 serde_json::from_str(json2).expect("deserialization should succeed");
1301 assert_eq!(request2.length, 7000);
1302 }
1303
1304 #[test]
1305 fn backend_choice_default_is_vulkan() {
1306 let default: BackendChoice = Default::default();
1307 assert_eq!(default, BackendChoice::Vulkan);
1308 }
1309
1310 #[test]
1311 fn model_choice_default_is_happy_new_year() {
1312 let default: ModelChoice = Default::default();
1313 assert_eq!(default, ModelChoice::HappyNewYear);
1314 }
1315
1316 #[test]
1317 fn default_output_path_function() {
1318 let path = super::default_output_path();
1319 assert_eq!(path, std::path::PathBuf::from("output.wav"));
1320 }
1321
1322 #[test]
1323 fn tokenizer_path_returns_valid_path() {
1324 let path = super::tokenizer_path();
1325 assert!(path.to_string_lossy().contains("t5-base-spiece.model"));
1326 }
1327}