Skip to main content

maolan_generate/
lib.rs

1use anyhow::{Context, Result, anyhow, bail};
2use sentencepiece::SentencePieceProcessor;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use std::ffi::OsString;
5use std::io::{Read, Write};
6use std::path::{Path, PathBuf};
7
8pub mod heartcodec;
9pub mod heartmula_runtime;
10
11pub const DEFAULT_MAX_PROMPT_TOKENS: usize = 128;
12pub const DEFAULT_CFG_SCALE: f32 = 1.5;
13pub const IPC_MODE_ENV: &str = "MAOLAN_BURN_SOCKETPAIR";
14
15pub fn stderr_logging_enabled() -> bool {
16    std::env::var_os(IPC_MODE_ENV).is_none()
17}
18
19#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
20#[serde(rename_all = "lowercase")]
21pub enum BackendChoice {
22    Cpu,
23    #[default]
24    Vulkan,
25}
26
27#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
28pub enum ModelChoice {
29    #[serde(rename = "happy-new-year")]
30    #[default]
31    HappyNewYear,
32    #[serde(rename = "RL")]
33    Rl,
34}
35
36#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
37pub struct GenerateRequest {
38    #[serde(default)]
39    pub model: ModelChoice,
40    pub prompt: String,
41    #[serde(default)]
42    pub model_dir: Option<PathBuf>,
43    #[serde(default = "default_output_path")]
44    pub output_path: PathBuf,
45    #[serde(default)]
46    pub inspect_only: bool,
47    pub backend: BackendChoice,
48    pub cfg_scale: f32,
49    #[serde(alias = "seconds_total", alias = "max_audio_length_ms")]
50    pub length: usize,
51    /// ODE steps for HeartMula flow matching (lower = faster, 10 = default)
52    #[serde(default = "default_ode_steps")]
53    pub ode_steps: usize,
54    /// Lyrics prompt (alias for prompt)
55    #[serde(default)]
56    pub lyrics: Option<String>,
57    /// Tags / style prompt
58    #[serde(default)]
59    pub tags: Option<String>,
60    /// Top-k sampling for HeartMula token generation
61    #[serde(default = "default_topk")]
62    pub topk: usize,
63    /// Sampling temperature for HeartMula token generation
64    #[serde(default = "default_temperature")]
65    pub temperature: f32,
66    /// Decode an existing frames JSON instead of generating tokens
67    #[serde(default)]
68    pub decode_only: bool,
69    /// Input frames JSON for decode-only mode
70    #[serde(default)]
71    pub frames_json: Option<PathBuf>,
72    /// Number of worker threads to use for decode-only CPU decoding
73    #[serde(default)]
74    pub decode_threads: Option<usize>,
75    /// Seed for deterministic HeartCodec decoder latent initialization
76    #[serde(default)]
77    pub decoder_seed: u64,
78}
79
80fn default_ode_steps() -> usize {
81    10
82}
83
84fn default_topk() -> usize {
85    50
86}
87
88fn default_temperature() -> f32 {
89    1.0
90}
91
92pub type CliOptions = GenerateRequest;
93
94#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
95pub struct GenerateResponseHeader {
96    pub backend: BackendChoice,
97    pub channels: usize,
98    pub frames: usize,
99    pub guidance_scale: f32,
100    pub prompt_tokens: i64,
101    pub sample_rate_hz: u32,
102    pub length: usize,
103    pub steps: usize,
104}
105
106#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
107pub struct GenerateError {
108    pub error: String,
109}
110
111/// Progress update message sent during generation
112#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
113pub struct GenerateProgress {
114    pub phase: String,
115    pub progress: f32,
116    pub operation: String,
117}
118
119fn default_output_path() -> PathBuf {
120    PathBuf::from("output.wav")
121}
122
123pub fn help_text() -> &'static str {
124    "\
125maolan-generate
126
127Usage:
128  maolan-generate [options] <prompt-or-lyrics>
129
130Options:
131  --model <happy-new-year|RL>
132  --model-dir <path>
133  --output <path>
134  --inspect
135  --backend <cpu|vulkan>       Select the runtime backend
136  --lyrics <text>          Prompt / lyrics (positional argument also accepted)
137  --tags <text>            Style tags for HeartMula
138  --cfg-scale <float>      CFG scale (1.0=no guidance, 2.0=weak, 6.0=strong)
139  --length <int>           HeartMula: output length in milliseconds
140  --topk <int>             HeartMula: top-k sampling (default: 50)
141  --temperature <float>    HeartMula: sampling temperature (default: 1.0)
142  --ode-steps <int>        HeartMula: flow matching steps (5=fast, 10=default, 20=best)
143  --decoder-seed <int>     Seed for deterministic HeartCodec decoder latents
144  --decode-only            Decode an existing frames JSON instead of generating tokens
145  --frames-json <path>     Frames JSON input for --decode-only
146  --decode-threads <int>    Number of worker threads for decode-only CPU decoding
147  -h, --help
148"
149}
150
151pub fn parse_options(args: impl IntoIterator<Item = OsString>) -> Result<CliOptions> {
152    let mut args = args.into_iter();
153    let _program = args.next();
154    let mut prompt = None;
155    let mut model_dir = None;
156    let mut output_path = default_output_path();
157    let mut inspect_only = false;
158    let mut model = ModelChoice::HappyNewYear;
159    let mut backend = BackendChoice::Vulkan;
160    let mut cfg_scale = DEFAULT_CFG_SCALE;
161    let mut length = 6_000_usize;
162    let mut ode_steps = 10_usize;
163    let mut lyrics = None;
164    let mut tags = None;
165    let mut topk = default_topk();
166    let mut temperature = default_temperature();
167    let mut decode_only = false;
168    let mut frames_json = None;
169    let mut decode_threads = None;
170    let mut decoder_seed = 0_u64;
171
172    while let Some(arg) = args.next() {
173        let arg = arg
174            .into_string()
175            .map_err(|_| anyhow!("arguments must be valid UTF-8"))?;
176
177        if matches!(arg.as_str(), "-h" | "--help") {
178            bail!(help_text());
179        }
180
181        if arg == "--backend" {
182            let value = args
183                .next()
184                .ok_or_else(|| anyhow!("missing value after --backend"))?
185                .into_string()
186                .map_err(|_| anyhow!("backend value must be valid UTF-8"))?;
187            backend = match value.as_str() {
188                "cpu" => BackendChoice::Cpu,
189                "vulkan" => BackendChoice::Vulkan,
190                _ => bail!("unsupported backend '{value}', expected one of: cpu, vulkan"),
191            };
192            continue;
193        }
194
195        if arg == "--model-dir" {
196            model_dir = Some(PathBuf::from(
197                args.next()
198                    .ok_or_else(|| anyhow!("missing value after --model-dir"))?,
199            ));
200            continue;
201        }
202
203        if arg == "--output" {
204            output_path = PathBuf::from(
205                args.next()
206                    .ok_or_else(|| anyhow!("missing value after --output"))?,
207            );
208            continue;
209        }
210
211        if arg == "--lyrics" {
212            let value = args
213                .next()
214                .ok_or_else(|| anyhow!("missing value after --lyrics"))?
215                .into_string()
216                .map_err(|_| anyhow!("lyrics value must be valid UTF-8"))?;
217            lyrics = Some(value);
218            continue;
219        }
220
221        if arg == "--tags" {
222            let value = args
223                .next()
224                .ok_or_else(|| anyhow!("missing value after --tags"))?
225                .into_string()
226                .map_err(|_| anyhow!("tags value must be valid UTF-8"))?;
227            tags = Some(value);
228            continue;
229        }
230
231        if arg == "--length" {
232            let value = args
233                .next()
234                .ok_or_else(|| anyhow!("missing value after --length"))?
235                .into_string()
236                .map_err(|_| anyhow!("length value must be valid UTF-8"))?;
237            length = value
238                .parse::<usize>()
239                .map_err(|_| anyhow!("length must be a whole number"))?;
240            continue;
241        }
242
243        if arg == "--topk" {
244            let value = args
245                .next()
246                .ok_or_else(|| anyhow!("missing value after --topk"))?
247                .into_string()
248                .map_err(|_| anyhow!("topk value must be valid UTF-8"))?;
249            topk = value
250                .parse::<usize>()
251                .map_err(|_| anyhow!("topk must be a whole number"))?;
252            if topk == 0 {
253                bail!("topk must be greater than zero");
254            }
255            continue;
256        }
257
258        if arg == "--temperature" {
259            let value = args
260                .next()
261                .ok_or_else(|| anyhow!("missing value after --temperature"))?
262                .into_string()
263                .map_err(|_| anyhow!("temperature value must be valid UTF-8"))?;
264            temperature = value
265                .parse::<f32>()
266                .map_err(|_| anyhow!("temperature must be a number"))?;
267            if !temperature.is_finite() || temperature < 0.0 {
268                bail!("temperature must be a finite non-negative number");
269            }
270            continue;
271        }
272
273        if arg == "--inspect" {
274            inspect_only = true;
275            continue;
276        }
277
278        if arg == "--decode-only" {
279            decode_only = true;
280            continue;
281        }
282
283        if arg == "--frames-json" {
284            frames_json =
285                Some(PathBuf::from(args.next().ok_or_else(|| {
286                    anyhow!("missing value after --frames-json")
287                })?));
288            continue;
289        }
290
291        if arg == "--decode-threads" {
292            let value = args
293                .next()
294                .ok_or_else(|| anyhow!("missing value after --decode-threads"))?
295                .into_string()
296                .map_err(|_| anyhow!("decode-threads value must be valid UTF-8"))?;
297            decode_threads = Some(
298                value
299                    .parse::<usize>()
300                    .map_err(|_| anyhow!("decode-threads must be a whole number"))?,
301            );
302            continue;
303        }
304
305        if arg == "--decoder-seed" {
306            let value = args
307                .next()
308                .ok_or_else(|| anyhow!("missing value after --decoder-seed"))?
309                .into_string()
310                .map_err(|_| anyhow!("decoder-seed value must be valid UTF-8"))?;
311            decoder_seed = value
312                .parse::<u64>()
313                .map_err(|_| anyhow!("decoder-seed must be a whole number"))?;
314            continue;
315        }
316
317        if arg == "--model" {
318            let value = args
319                .next()
320                .ok_or_else(|| anyhow!("missing value after --model"))?
321                .into_string()
322                .map_err(|_| anyhow!("model value must be valid UTF-8"))?;
323            model = match value.as_str() {
324                "happy-new-year" => ModelChoice::HappyNewYear,
325                "RL" => ModelChoice::Rl,
326                _ => {
327                    bail!("unsupported model '{value}', expected one of: happy-new-year, RL")
328                }
329            };
330            continue;
331        }
332
333        if arg == "--cfg-scale" {
334            let value = args
335                .next()
336                .ok_or_else(|| anyhow!("missing value after --cfg-scale"))?
337                .into_string()
338                .map_err(|_| anyhow!("cfg-scale value must be valid UTF-8"))?;
339            cfg_scale = value
340                .parse::<f32>()
341                .map_err(|_| anyhow!("cfg-scale must be a number"))?;
342            if !cfg_scale.is_finite() || cfg_scale < 0.0 {
343                bail!("cfg-scale must be a finite non-negative number");
344            }
345            continue;
346        }
347
348        if arg == "--ode-steps" {
349            let value = args
350                .next()
351                .ok_or_else(|| anyhow!("missing value after --ode-steps"))?
352                .into_string()
353                .map_err(|_| anyhow!("ode-steps value must be valid UTF-8"))?;
354            ode_steps = value
355                .parse::<usize>()
356                .map_err(|_| anyhow!("ode-steps must be a whole number"))?;
357            if ode_steps == 0 || ode_steps > 50 {
358                bail!("ode-steps must be between 1 and 50");
359            }
360            continue;
361        }
362
363        if prompt.is_some() {
364            bail!("expected exactly one positional argument: the prompt");
365        }
366        prompt = Some(arg);
367    }
368
369    let prompt = if decode_only {
370        prompt.unwrap_or_default()
371    } else if let Some(lyrics) = lyrics {
372        lyrics
373    } else {
374        prompt.ok_or_else(|| {
375            anyhow!("missing prompt argument; provide a positional argument or --lyrics")
376        })?
377    };
378    let trimmed = prompt.trim();
379
380    if !decode_only && trimmed.is_empty() {
381        bail!("prompt argument cannot be empty");
382    }
383
384    validate_options(CliOptions {
385        model,
386        prompt: trimmed.to_owned(),
387        model_dir,
388        output_path,
389        inspect_only,
390        backend,
391        cfg_scale,
392        length,
393        ode_steps,
394        lyrics: None,
395        tags,
396        topk,
397        temperature,
398        decode_only,
399        frames_json,
400        decode_threads,
401        decoder_seed,
402    })
403}
404
405pub fn validate_options(mut options: CliOptions) -> Result<CliOptions> {
406    let prompt = options.prompt.trim();
407    if prompt.is_empty() && !options.decode_only {
408        bail!("prompt argument cannot be empty");
409    }
410    options.prompt = prompt.to_owned();
411
412    options.tags = options
413        .tags
414        .as_deref()
415        .map(str::trim)
416        .filter(|value| !value.is_empty())
417        .map(str::to_owned);
418
419    options.model_dir = options
420        .model_dir
421        .as_deref()
422        .map(Path::new)
423        .map(Path::to_path_buf);
424
425    if !options.cfg_scale.is_finite() || options.cfg_scale < 0.0 {
426        bail!("cfg-scale must be a finite non-negative number");
427    }
428    if options.length == 0 {
429        bail!("length must be greater than zero");
430    }
431    if options.output_path.as_os_str().is_empty() {
432        bail!("output path cannot be empty");
433    }
434    if options.decode_only && options.frames_json.is_none() {
435        bail!("--decode-only requires --frames-json");
436    }
437    if options.frames_json.is_some() && !options.decode_only {
438        bail!("--frames-json can only be used with --decode-only");
439    }
440    if let Some(threads) = options.decode_threads
441        && threads == 0
442    {
443        bail!("--decode-threads must be greater than zero");
444    }
445
446    Ok(options)
447}
448
449pub fn read_ipc_message<T: DeserializeOwned>(reader: &mut impl Read) -> Result<T> {
450    let mut len_bytes = [0_u8; 8];
451    reader
452        .read_exact(&mut len_bytes)
453        .context("failed to read IPC message length")?;
454    let len = u64::from_le_bytes(len_bytes);
455    let len = usize::try_from(len).context("IPC message length is too large")?;
456    let mut payload = vec![0_u8; len];
457    reader
458        .read_exact(&mut payload)
459        .context("failed to read IPC message payload")?;
460    serde_json::from_slice(&payload).context("failed to decode IPC JSON message")
461}
462
463pub fn write_ipc_message<T: Serialize>(writer: &mut impl Write, value: &T) -> Result<()> {
464    let payload = serde_json::to_vec(value).context("failed to encode IPC JSON message")?;
465    let len = u64::try_from(payload.len()).context("IPC payload is too large")?;
466    writer
467        .write_all(&len.to_le_bytes())
468        .context("failed to write IPC message length")?;
469    writer
470        .write_all(&payload)
471        .context("failed to write IPC message payload")?;
472    writer.flush().context("failed to flush IPC JSON message")?;
473    Ok(())
474}
475
476pub fn write_ipc_bytes(writer: &mut impl Write, bytes: &[u8]) -> Result<()> {
477    let len = u64::try_from(bytes.len()).context("IPC byte payload is too large")?;
478    writer
479        .write_all(&len.to_le_bytes())
480        .context("failed to write IPC byte length")?;
481    writer
482        .write_all(bytes)
483        .context("failed to write IPC byte payload")?;
484    writer.flush().context("failed to flush IPC byte payload")?;
485    Ok(())
486}
487
488pub fn tokenizer_path() -> PathBuf {
489    Path::new(env!("CARGO_MANIFEST_DIR"))
490        .join("assets")
491        .join("t5-base-spiece.model")
492}
493
494pub fn load_tokenizer() -> Result<SentencePieceProcessor> {
495    SentencePieceProcessor::open(tokenizer_path())
496        .context("failed to open the bundled T5 sentencepiece model")
497}
498
499pub fn encode_prompt(
500    tokenizer: &SentencePieceProcessor,
501    prompt: &str,
502    max_tokens: usize,
503) -> Result<(Vec<i64>, Vec<i64>)> {
504    let mut token_ids = Vec::with_capacity(max_tokens);
505
506    if let Some(bos_id) = tokenizer.bos_id() {
507        token_ids.push(i64::from(bos_id));
508    }
509
510    for piece in tokenizer
511        .encode(prompt)
512        .context("failed to tokenize prompt")?
513    {
514        if token_ids.len() >= max_tokens {
515            break;
516        }
517        token_ids.push(i64::from(piece.id));
518    }
519
520    if token_ids.len() < max_tokens
521        && let Some(eos_id) = tokenizer.eos_id()
522    {
523        token_ids.push(i64::from(eos_id));
524    }
525
526    if token_ids.len() > max_tokens {
527        token_ids.truncate(max_tokens);
528    }
529
530    let attention_len = token_ids.len();
531    let mut attention_mask = vec![1_i64; attention_len];
532    token_ids.resize(max_tokens, 0);
533    attention_mask.resize(max_tokens, 0);
534
535    Ok((token_ids, attention_mask))
536}
537
538#[cfg(test)]
539mod tests {
540    use super::{BackendChoice, DEFAULT_MAX_PROMPT_TOKENS, ModelChoice, parse_options};
541    use std::ffi::OsString;
542
543    #[test]
544    fn parses_single_prompt_argument() {
545        let args = [OsString::from("generate"), OsString::from("warm tape hiss")];
546        let options = parse_options(args).expect("options should parse");
547        assert_eq!(options.prompt, "warm tape hiss");
548        assert_eq!(options.model, ModelChoice::HappyNewYear);
549        assert_eq!(options.backend, BackendChoice::Vulkan);
550        assert_eq!(options.cfg_scale, 1.5);
551        assert_eq!(options.length, 6_000);
552    }
553
554    #[test]
555    fn trims_surrounding_whitespace() {
556        let args = [
557            OsString::from("generate"),
558            OsString::from("  foley footsteps  "),
559        ];
560        let options = parse_options(args).expect("options should parse");
561        assert_eq!(options.prompt, "foley footsteps");
562    }
563
564    #[test]
565    fn rejects_missing_prompt() {
566        let args = [OsString::from("generate")];
567        assert!(parse_options(args).is_err());
568    }
569
570    #[test]
571    fn parses_backend_flag_after_prompt() {
572        let args = [
573            OsString::from("generate"),
574            OsString::from("warm tape hiss"),
575            OsString::from("--backend"),
576            OsString::from("vulkan"),
577        ];
578        let options = parse_options(args).expect("options should parse");
579        assert_eq!(options.backend, BackendChoice::Vulkan);
580    }
581
582    #[test]
583    fn parses_model_flag() {
584        let args = [
585            OsString::from("generate"),
586            OsString::from("--model"),
587            OsString::from("happy-new-year"),
588            OsString::from("verse and chorus"),
589        ];
590        let options = parse_options(args).expect("options should parse");
591        assert_eq!(options.model, ModelChoice::HappyNewYear);
592    }
593
594    #[test]
595    fn parses_rl_model_flag() {
596        let args = [
597            OsString::from("generate"),
598            OsString::from("--model"),
599            OsString::from("RL"),
600            OsString::from("verse and chorus"),
601        ];
602        let options = parse_options(args).expect("options should parse");
603        assert_eq!(options.model, ModelChoice::Rl);
604    }
605
606    #[test]
607    fn parses_tags_cfg_and_length() {
608        let args = [
609            OsString::from("generate"),
610            OsString::from("--tags"),
611            OsString::from("warm tape hiss"),
612            OsString::from("--cfg-scale"),
613            OsString::from("4.5"),
614            OsString::from("--ode-steps"),
615            OsString::from("20"),
616            OsString::from("--length"),
617            OsString::from("8000"),
618            OsString::from("verse and chorus"),
619        ];
620        let options = parse_options(args).expect("options should parse");
621        assert_eq!(options.cfg_scale, 4.5);
622        assert_eq!(options.ode_steps, 20);
623        assert_eq!(options.length, 8_000);
624    }
625
626    #[test]
627    fn parses_decode_only_without_prompt() {
628        let args = [
629            OsString::from("generate"),
630            OsString::from("--decode-only"),
631            OsString::from("--frames-json"),
632            OsString::from("/tmp/frames.json"),
633        ];
634        let options = parse_options(args).expect("options should parse");
635        assert!(options.decode_only);
636        assert_eq!(
637            options.frames_json.as_deref(),
638            Some(std::path::Path::new("/tmp/frames.json"))
639        );
640        assert!(options.prompt.is_empty());
641    }
642
643    #[test]
644    fn parses_decode_threads() {
645        let args = [
646            OsString::from("generate"),
647            OsString::from("--decode-only"),
648            OsString::from("--frames-json"),
649            OsString::from("/tmp/frames.json"),
650            OsString::from("--decode-threads"),
651            OsString::from("8"),
652        ];
653        let options = parse_options(args).expect("options should parse");
654        assert_eq!(options.decode_threads, Some(8));
655    }
656
657    const _: () = assert!(DEFAULT_MAX_PROMPT_TOKENS == 128);
658
659    #[test]
660    fn parses_cpu_backend_flag() {
661        let args = [
662            OsString::from("generate"),
663            OsString::from("--backend"),
664            OsString::from("cpu"),
665            OsString::from("test prompt"),
666        ];
667        let options = parse_options(args).expect("options should parse");
668        assert_eq!(options.backend, BackendChoice::Cpu);
669    }
670
671    #[test]
672    fn rejects_invalid_backend() {
673        let args = [
674            OsString::from("generate"),
675            OsString::from("--backend"),
676            OsString::from("invalid"),
677            OsString::from("test prompt"),
678        ];
679        assert!(parse_options(args).is_err());
680    }
681
682    #[test]
683    fn parses_cfg_scale_validation() {
684        let args = [
685            OsString::from("generate"),
686            OsString::from("--cfg-scale"),
687            OsString::from("2.5"),
688            OsString::from("test prompt"),
689        ];
690        let options = parse_options(args).expect("options should parse");
691        assert_eq!(options.cfg_scale, 2.5);
692    }
693
694    #[test]
695    fn rejects_negative_cfg_scale() {
696        let args = [
697            OsString::from("generate"),
698            OsString::from("--cfg-scale"),
699            OsString::from("-1.0"),
700            OsString::from("test prompt"),
701        ];
702        assert!(parse_options(args).is_err());
703    }
704
705    #[test]
706    fn rejects_invalid_cfg_scale() {
707        let args = [
708            OsString::from("generate"),
709            OsString::from("--cfg-scale"),
710            OsString::from("not-a-number"),
711            OsString::from("test prompt"),
712        ];
713        assert!(parse_options(args).is_err());
714    }
715
716    #[test]
717    fn parses_temperature() {
718        let args = [
719            OsString::from("generate"),
720            OsString::from("--temperature"),
721            OsString::from("0.8"),
722            OsString::from("test prompt"),
723        ];
724        let options = parse_options(args).expect("options should parse");
725        assert_eq!(options.temperature, 0.8);
726    }
727
728    #[test]
729    fn rejects_negative_temperature() {
730        let args = [
731            OsString::from("generate"),
732            OsString::from("--temperature"),
733            OsString::from("-0.5"),
734            OsString::from("test prompt"),
735        ];
736        assert!(parse_options(args).is_err());
737    }
738
739    #[test]
740    fn parses_topk() {
741        let args = [
742            OsString::from("generate"),
743            OsString::from("--topk"),
744            OsString::from("25"),
745            OsString::from("test prompt"),
746        ];
747        let options = parse_options(args).expect("options should parse");
748        assert_eq!(options.topk, 25);
749    }
750
751    #[test]
752    fn rejects_zero_topk() {
753        let args = [
754            OsString::from("generate"),
755            OsString::from("--topk"),
756            OsString::from("0"),
757            OsString::from("test prompt"),
758        ];
759        assert!(parse_options(args).is_err());
760    }
761
762    #[test]
763    fn parses_ode_steps() {
764        let args = [
765            OsString::from("generate"),
766            OsString::from("--ode-steps"),
767            OsString::from("15"),
768            OsString::from("test prompt"),
769        ];
770        let options = parse_options(args).expect("options should parse");
771        assert_eq!(options.ode_steps, 15);
772    }
773
774    #[test]
775    fn rejects_zero_ode_steps() {
776        let args = [
777            OsString::from("generate"),
778            OsString::from("--ode-steps"),
779            OsString::from("0"),
780            OsString::from("test prompt"),
781        ];
782        assert!(parse_options(args).is_err());
783    }
784
785    #[test]
786    fn rejects_too_many_ode_steps() {
787        let args = [
788            OsString::from("generate"),
789            OsString::from("--ode-steps"),
790            OsString::from("51"),
791            OsString::from("test prompt"),
792        ];
793        assert!(parse_options(args).is_err());
794    }
795
796    #[test]
797    fn parses_output_path() {
798        let args = [
799            OsString::from("generate"),
800            OsString::from("--output"),
801            OsString::from("/tmp/output.wav"),
802            OsString::from("test prompt"),
803        ];
804        let options = parse_options(args).expect("options should parse");
805        assert_eq!(
806            options.output_path,
807            std::path::PathBuf::from("/tmp/output.wav")
808        );
809    }
810
811    #[test]
812    fn parses_model_dir() {
813        let args = [
814            OsString::from("generate"),
815            OsString::from("--model-dir"),
816            OsString::from("/tmp/models"),
817            OsString::from("test prompt"),
818        ];
819        let options = parse_options(args).expect("options should parse");
820        assert_eq!(
821            options.model_dir,
822            Some(std::path::PathBuf::from("/tmp/models"))
823        );
824    }
825
826    #[test]
827    fn parses_decoder_seed() {
828        let args = [
829            OsString::from("generate"),
830            OsString::from("--decoder-seed"),
831            OsString::from("42"),
832            OsString::from("test prompt"),
833        ];
834        let options = parse_options(args).expect("options should parse");
835        assert_eq!(options.decoder_seed, 42);
836    }
837
838    #[test]
839    fn parses_lyrics_alias() {
840        let args = [
841            OsString::from("generate"),
842            OsString::from("--lyrics"),
843            OsString::from("custom lyrics text"),
844        ];
845        let options = parse_options(args).expect("options should parse");
846        assert_eq!(options.prompt, "custom lyrics text");
847    }
848
849    #[test]
850    fn parses_inspect_flag() {
851        let args = [
852            OsString::from("generate"),
853            OsString::from("--inspect"),
854            OsString::from("test prompt"),
855        ];
856        let options = parse_options(args).expect("options should parse");
857        assert!(options.inspect_only);
858    }
859
860    #[test]
861    fn rejects_multiple_positional_args() {
862        let args = [
863            OsString::from("generate"),
864            OsString::from("first prompt"),
865            OsString::from("second prompt"),
866        ];
867        assert!(parse_options(args).is_err());
868    }
869
870    #[test]
871    fn rejects_empty_prompt() {
872        let args = [OsString::from("generate"), OsString::from("   ")];
873        assert!(parse_options(args).is_err());
874    }
875
876    #[test]
877    fn validate_options_trims_prompt() {
878        let options = super::CliOptions {
879            model: ModelChoice::HappyNewYear,
880            prompt: "  test prompt  ".to_owned(),
881            model_dir: None,
882            output_path: std::path::PathBuf::from("output.wav"),
883            inspect_only: false,
884            backend: BackendChoice::Vulkan,
885            cfg_scale: 1.5,
886            length: 6000,
887            ode_steps: 10,
888            lyrics: None,
889            tags: None,
890            topk: 50,
891            temperature: 1.0,
892            decode_only: false,
893            frames_json: None,
894            decode_threads: None,
895            decoder_seed: 0,
896        };
897        let validated = super::validate_options(options).expect("validation should pass");
898        assert_eq!(validated.prompt, "test prompt");
899    }
900
901    #[test]
902    fn validate_options_rejects_empty_output_path() {
903        let options = super::CliOptions {
904            model: ModelChoice::HappyNewYear,
905            prompt: "test".to_owned(),
906            model_dir: None,
907            output_path: std::path::PathBuf::from(""),
908            inspect_only: false,
909            backend: BackendChoice::Vulkan,
910            cfg_scale: 1.5,
911            length: 6000,
912            ode_steps: 10,
913            lyrics: None,
914            tags: None,
915            topk: 50,
916            temperature: 1.0,
917            decode_only: false,
918            frames_json: None,
919            decode_threads: None,
920            decoder_seed: 0,
921        };
922        assert!(super::validate_options(options).is_err());
923    }
924
925    #[test]
926    fn validate_options_rejects_zero_length() {
927        let options = super::CliOptions {
928            model: ModelChoice::HappyNewYear,
929            prompt: "test".to_owned(),
930            model_dir: None,
931            output_path: std::path::PathBuf::from("output.wav"),
932            inspect_only: false,
933            backend: BackendChoice::Vulkan,
934            cfg_scale: 1.5,
935            length: 0,
936            ode_steps: 10,
937            lyrics: None,
938            tags: None,
939            topk: 50,
940            temperature: 1.0,
941            decode_only: false,
942            frames_json: None,
943            decode_threads: None,
944            decoder_seed: 0,
945        };
946        assert!(super::validate_options(options).is_err());
947    }
948
949    #[test]
950    fn validate_options_rejects_decode_only_without_frames() {
951        let options = super::CliOptions {
952            model: ModelChoice::HappyNewYear,
953            prompt: "".to_owned(),
954            model_dir: None,
955            output_path: std::path::PathBuf::from("output.wav"),
956            inspect_only: false,
957            backend: BackendChoice::Vulkan,
958            cfg_scale: 1.5,
959            length: 6000,
960            ode_steps: 10,
961            lyrics: None,
962            tags: None,
963            topk: 50,
964            temperature: 1.0,
965            decode_only: true,
966            frames_json: None,
967            decode_threads: None,
968            decoder_seed: 0,
969        };
970        assert!(super::validate_options(options).is_err());
971    }
972
973    #[test]
974    fn validate_options_rejects_zero_decode_threads() {
975        let options = super::CliOptions {
976            model: ModelChoice::HappyNewYear,
977            prompt: "test".to_owned(),
978            model_dir: None,
979            output_path: std::path::PathBuf::from("output.wav"),
980            inspect_only: false,
981            backend: BackendChoice::Vulkan,
982            cfg_scale: 1.5,
983            length: 6000,
984            ode_steps: 10,
985            lyrics: None,
986            tags: None,
987            topk: 50,
988            temperature: 1.0,
989            decode_only: false,
990            frames_json: None,
991            decode_threads: Some(0),
992            decoder_seed: 0,
993        };
994        assert!(super::validate_options(options).is_err());
995    }
996
997    #[test]
998    fn validate_options_trims_tags() {
999        let options = super::CliOptions {
1000            model: ModelChoice::HappyNewYear,
1001            prompt: "test".to_owned(),
1002            model_dir: None,
1003            output_path: std::path::PathBuf::from("output.wav"),
1004            inspect_only: false,
1005            backend: BackendChoice::Vulkan,
1006            cfg_scale: 1.5,
1007            length: 6000,
1008            ode_steps: 10,
1009            lyrics: None,
1010            tags: Some("  tag1, tag2  ".to_owned()),
1011            topk: 50,
1012            temperature: 1.0,
1013            decode_only: false,
1014            frames_json: None,
1015            decode_threads: None,
1016            decoder_seed: 0,
1017        };
1018        let validated = super::validate_options(options).expect("validation should pass");
1019        assert_eq!(validated.tags, Some("tag1, tag2".to_owned()));
1020    }
1021
1022    #[test]
1023    fn validate_options_filters_empty_tags() {
1024        let options = super::CliOptions {
1025            model: ModelChoice::HappyNewYear,
1026            prompt: "test".to_owned(),
1027            model_dir: None,
1028            output_path: std::path::PathBuf::from("output.wav"),
1029            inspect_only: false,
1030            backend: BackendChoice::Vulkan,
1031            cfg_scale: 1.5,
1032            length: 6000,
1033            ode_steps: 10,
1034            lyrics: None,
1035            tags: Some("   ".to_owned()),
1036            topk: 50,
1037            temperature: 1.0,
1038            decode_only: false,
1039            frames_json: None,
1040            decode_threads: None,
1041            decoder_seed: 0,
1042        };
1043        let validated = super::validate_options(options).expect("validation should pass");
1044        assert_eq!(validated.tags, None);
1045    }
1046
1047    #[test]
1048    fn default_output_path_is_output_wav() {
1049        let args = [OsString::from("generate"), OsString::from("test prompt")];
1050        let options = parse_options(args).expect("options should parse");
1051        assert_eq!(options.output_path, std::path::PathBuf::from("output.wav"));
1052    }
1053
1054    #[test]
1055    fn default_length_is_6000() {
1056        let args = [OsString::from("generate"), OsString::from("test prompt")];
1057        let options = parse_options(args).expect("options should parse");
1058        assert_eq!(options.length, 6000);
1059    }
1060
1061    #[test]
1062    fn default_ode_steps_is_10() {
1063        let args = [OsString::from("generate"), OsString::from("test prompt")];
1064        let options = parse_options(args).expect("options should parse");
1065        assert_eq!(options.ode_steps, 10);
1066    }
1067
1068    #[test]
1069    fn default_topk_is_50() {
1070        let args = [OsString::from("generate"), OsString::from("test prompt")];
1071        let options = parse_options(args).expect("options should parse");
1072        assert_eq!(options.topk, 50);
1073    }
1074
1075    #[test]
1076    fn default_temperature_is_1() {
1077        let args = [OsString::from("generate"), OsString::from("test prompt")];
1078        let options = parse_options(args).expect("options should parse");
1079        assert_eq!(options.temperature, 1.0);
1080    }
1081
1082    #[test]
1083    fn default_cfg_scale_is_1_5() {
1084        let args = [OsString::from("generate"), OsString::from("test prompt")];
1085        let options = parse_options(args).expect("options should parse");
1086        assert_eq!(options.cfg_scale, 1.5);
1087    }
1088
1089    #[test]
1090    fn default_decoder_seed_is_0() {
1091        let args = [OsString::from("generate"), OsString::from("test prompt")];
1092        let options = parse_options(args).expect("options should parse");
1093        assert_eq!(options.decoder_seed, 0);
1094    }
1095
1096    #[test]
1097    fn help_text_contains_usage() {
1098        let help = super::help_text();
1099        assert!(help.contains("maolan-generate"));
1100        assert!(help.contains("Usage:"));
1101        assert!(help.contains("Options:"));
1102    }
1103
1104    #[test]
1105    fn stderr_logging_disabled_in_ipc_mode() {
1106        let _ = super::stderr_logging_enabled();
1107    }
1108
1109    #[test]
1110    fn write_and_read_ipc_message_roundtrip() {
1111        use super::{read_ipc_message, write_ipc_message};
1112        use std::io::Cursor;
1113
1114        let original = super::GenerateResponseHeader {
1115            backend: BackendChoice::Cpu,
1116            channels: 2,
1117            frames: 48000,
1118            guidance_scale: 2.0,
1119            prompt_tokens: 10,
1120            sample_rate_hz: 48000,
1121            length: 6000,
1122            steps: 10,
1123        };
1124
1125        let mut buffer = Vec::new();
1126        write_ipc_message(&mut buffer, &original).expect("write should succeed");
1127
1128        let mut cursor = Cursor::new(buffer);
1129        let decoded: super::GenerateResponseHeader =
1130            read_ipc_message(&mut cursor).expect("read should succeed");
1131
1132        assert_eq!(decoded.backend, original.backend);
1133        assert_eq!(decoded.channels, original.channels);
1134        assert_eq!(decoded.frames, original.frames);
1135        assert_eq!(decoded.guidance_scale, original.guidance_scale);
1136        assert_eq!(decoded.prompt_tokens, original.prompt_tokens);
1137        assert_eq!(decoded.sample_rate_hz, original.sample_rate_hz);
1138        assert_eq!(decoded.length, original.length);
1139        assert_eq!(decoded.steps, original.steps);
1140    }
1141
1142    #[test]
1143    fn write_and_read_ipc_progress_roundtrip() {
1144        use super::{read_ipc_message, write_ipc_message};
1145        use std::io::Cursor;
1146
1147        let original = super::GenerateProgress {
1148            phase: "generator".to_owned(),
1149            progress: 0.5,
1150            operation: "Processing".to_owned(),
1151        };
1152
1153        let mut buffer = Vec::new();
1154        write_ipc_message(&mut buffer, &original).expect("write should succeed");
1155
1156        let mut cursor = Cursor::new(buffer);
1157        let decoded: super::GenerateProgress =
1158            read_ipc_message(&mut cursor).expect("read should succeed");
1159
1160        assert_eq!(decoded.phase, original.phase);
1161        assert_eq!(decoded.progress, original.progress);
1162        assert_eq!(decoded.operation, original.operation);
1163    }
1164
1165    #[test]
1166    fn write_and_read_ipc_error_roundtrip() {
1167        use super::{read_ipc_message, write_ipc_message};
1168        use std::io::Cursor;
1169
1170        let original = super::GenerateError {
1171            error: "Test error message".to_owned(),
1172        };
1173
1174        let mut buffer = Vec::new();
1175        write_ipc_message(&mut buffer, &original).expect("write should succeed");
1176
1177        let mut cursor = Cursor::new(buffer);
1178        let decoded: super::GenerateError =
1179            read_ipc_message(&mut cursor).expect("read should succeed");
1180
1181        assert_eq!(decoded.error, original.error);
1182    }
1183
1184    #[test]
1185    fn write_ipc_bytes_roundtrip() {
1186        use super::write_ipc_bytes;
1187        use std::io::Cursor;
1188
1189        let original = b"Hello, World!";
1190
1191        let mut buffer = Vec::new();
1192        write_ipc_bytes(&mut buffer, original).expect("write should succeed");
1193
1194        let mut cursor = Cursor::new(buffer);
1195        let mut len_bytes = [0_u8; 8];
1196        std::io::Read::read_exact(&mut cursor, &mut len_bytes).expect("read length should succeed");
1197        let len = u64::from_le_bytes(len_bytes) as usize;
1198        assert_eq!(len, original.len());
1199
1200        let mut payload = vec![0_u8; len];
1201        std::io::Read::read_exact(&mut cursor, &mut payload).expect("read payload should succeed");
1202        assert_eq!(&payload[..], &original[..]);
1203    }
1204
1205    #[test]
1206    fn read_ipc_message_fails_on_truncated_data() {
1207        use super::read_ipc_message;
1208        use std::io::Cursor;
1209
1210        let len_bytes = 100_u64.to_le_bytes();
1211        let buffer = len_bytes.to_vec();
1212
1213        let mut cursor = Cursor::new(buffer);
1214        let result: Result<super::GenerateResponseHeader, _> = read_ipc_message(&mut cursor);
1215        assert!(result.is_err());
1216    }
1217
1218    #[test]
1219    fn read_ipc_message_fails_on_invalid_json() {
1220        use super::read_ipc_message;
1221        use std::io::Cursor;
1222
1223        let payload = b"not valid json";
1224        let len_bytes = (payload.len() as u64).to_le_bytes();
1225        let mut buffer = Vec::new();
1226        buffer.extend_from_slice(&len_bytes);
1227        buffer.extend_from_slice(payload);
1228
1229        let mut cursor = Cursor::new(buffer);
1230        let result: Result<super::GenerateResponseHeader, _> = read_ipc_message(&mut cursor);
1231        assert!(result.is_err());
1232    }
1233
1234    #[test]
1235    fn serialize_generate_request() {
1236        let request = super::GenerateRequest {
1237            model: ModelChoice::Rl,
1238            prompt: "test prompt".to_owned(),
1239            model_dir: Some(std::path::PathBuf::from("/tmp/models")),
1240            output_path: std::path::PathBuf::from("/tmp/output.wav"),
1241            inspect_only: true,
1242            backend: BackendChoice::Cpu,
1243            cfg_scale: 2.5,
1244            length: 8000,
1245            ode_steps: 15,
1246            lyrics: Some("lyrics text".to_owned()),
1247            tags: Some("tag1,tag2".to_owned()),
1248            topk: 25,
1249            temperature: 0.8,
1250            decode_only: false,
1251            frames_json: None,
1252            decode_threads: Some(4),
1253            decoder_seed: 42,
1254        };
1255
1256        let json = serde_json::to_string(&request).expect("serialization should succeed");
1257        assert!(json.contains("test prompt"));
1258        assert!(json.contains("cpu"));
1259        assert!(json.contains("RL"));
1260    }
1261
1262    #[test]
1263    fn deserialize_generate_request() {
1264        let json = r#"{
1265            "model": "RL",
1266            "prompt": "test prompt",
1267            "output_path": "/tmp/output.wav",
1268            "backend": "cpu",
1269            "cfg_scale": 2.5,
1270            "length": 8000,
1271            "ode_steps": 15,
1272            "topk": 25,
1273            "temperature": 0.8,
1274            "decoder_seed": 42
1275        }"#;
1276
1277        let request: super::GenerateRequest =
1278            serde_json::from_str(json).expect("deserialization should succeed");
1279        assert_eq!(request.model, ModelChoice::Rl);
1280        assert_eq!(request.prompt, "test prompt");
1281        assert_eq!(request.backend, BackendChoice::Cpu);
1282        assert_eq!(request.cfg_scale, 2.5);
1283        assert_eq!(request.length, 8000);
1284        assert_eq!(request.ode_steps, 15);
1285        assert_eq!(request.topk, 25);
1286        assert_eq!(request.temperature, 0.8);
1287        assert_eq!(request.decoder_seed, 42);
1288    }
1289
1290    #[test]
1291    fn deserialize_generate_request_with_aliases() {
1292        let json1 =
1293            r#"{"prompt": "test", "backend": "cpu", "cfg_scale": 1.5, "seconds_total": 5000}"#;
1294        let request1: super::GenerateRequest =
1295            serde_json::from_str(json1).expect("deserialization should succeed");
1296        assert_eq!(request1.length, 5000);
1297
1298        let json2 = r#"{"prompt": "test", "backend": "cpu", "cfg_scale": 1.5, "max_audio_length_ms": 7000}"#;
1299        let request2: super::GenerateRequest =
1300            serde_json::from_str(json2).expect("deserialization should succeed");
1301        assert_eq!(request2.length, 7000);
1302    }
1303
1304    #[test]
1305    fn backend_choice_default_is_vulkan() {
1306        let default: BackendChoice = Default::default();
1307        assert_eq!(default, BackendChoice::Vulkan);
1308    }
1309
1310    #[test]
1311    fn model_choice_default_is_happy_new_year() {
1312        let default: ModelChoice = Default::default();
1313        assert_eq!(default, ModelChoice::HappyNewYear);
1314    }
1315
1316    #[test]
1317    fn default_output_path_function() {
1318        let path = super::default_output_path();
1319        assert_eq!(path, std::path::PathBuf::from("output.wav"));
1320    }
1321
1322    #[test]
1323    fn tokenizer_path_returns_valid_path() {
1324        let path = super::tokenizer_path();
1325        assert!(path.to_string_lossy().contains("t5-base-spiece.model"));
1326    }
1327}