Skip to main content

piper_plus/
input.rs

1//! JSONL 入力パーサー (Python infer_onnx.py 互換)
2
3use serde::Deserialize;
4
5use crate::engine::SynthesisRequest;
6use crate::error::PiperError;
7
8/// JSONL の1行を表す構造体
9#[derive(Debug, Deserialize)]
10pub struct JsonlUtterance {
11    pub phoneme_ids: Vec<i64>,
12
13    #[serde(default)]
14    pub speaker_id: Option<i64>,
15
16    #[serde(default)]
17    pub language_id: Option<i64>,
18
19    #[serde(default)]
20    pub prosody_features: Option<Vec<Option<ProsodyFeatureJson>>>,
21
22    /// 出力ファイル名のヒント
23    #[serde(default)]
24    pub output_file: Option<String>,
25}
26
27#[derive(Debug, Deserialize)]
28pub struct ProsodyFeatureJson {
29    pub a1: i32,
30    pub a2: i32,
31    pub a3: i32,
32}
33
34impl JsonlUtterance {
35    /// JSONL 行をパース
36    pub fn parse(line: &str) -> Result<Self, PiperError> {
37        serde_json::from_str(line).map_err(PiperError::from)
38    }
39
40    /// SynthesisRequest に変換 (move semantics — self を消費して clone を回避)
41    pub fn to_request(self, noise_scale: f32, length_scale: f32, noise_w: f32) -> SynthesisRequest {
42        let prosody_features = self.prosody_features.map(|features| {
43            features
44                .iter()
45                .map(|f| match f {
46                    Some(pf) => [pf.a1, pf.a2, pf.a3],
47                    None => [0, 0, 0],
48                })
49                .collect()
50        });
51
52        SynthesisRequest {
53            phoneme_ids: self.phoneme_ids,
54            prosody_features,
55            speaker_id: self.speaker_id,
56            language_id: self.language_id,
57            noise_scale,
58            length_scale,
59            noise_w,
60        }
61    }
62}
63
64/// stdin から JSONL 行を読み込むイテレータ
65pub struct JsonlReader<R: std::io::BufRead> {
66    reader: R,
67    line_buf: String,
68}
69
70impl<R: std::io::BufRead> JsonlReader<R> {
71    pub fn new(reader: R) -> Self {
72        Self {
73            reader,
74            line_buf: String::new(),
75        }
76    }
77}
78
79impl<R: std::io::BufRead> Iterator for JsonlReader<R> {
80    type Item = Result<JsonlUtterance, PiperError>;
81
82    fn next(&mut self) -> Option<Self::Item> {
83        loop {
84            self.line_buf.clear();
85            match self.reader.read_line(&mut self.line_buf) {
86                Ok(0) => return None, // EOF
87                Ok(_) => {
88                    let trimmed = self.line_buf.trim();
89                    if trimmed.is_empty() {
90                        continue; // skip empty lines without recursion
91                    }
92                    return Some(JsonlUtterance::parse(trimmed));
93                }
94                Err(e) => return Some(Err(PiperError::AudioOutput(e))),
95            }
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_parse_minimal_jsonl() {
106        let line = r#"{"phoneme_ids": [1, 2, 3]}"#;
107        let utt = JsonlUtterance::parse(line).unwrap();
108        assert_eq!(utt.phoneme_ids, vec![1, 2, 3]);
109        assert!(utt.speaker_id.is_none());
110        assert!(utt.prosody_features.is_none());
111    }
112
113    #[test]
114    fn test_parse_full_jsonl() {
115        let line = r#"{"phoneme_ids": [1, 2], "speaker_id": 5, "prosody_features": [{"a1": -2, "a2": 1, "a3": 5}, null]}"#;
116        let utt = JsonlUtterance::parse(line).unwrap();
117        assert_eq!(utt.speaker_id, Some(5));
118        let pf = utt.prosody_features.as_ref().unwrap();
119        assert_eq!(pf.len(), 2);
120        assert_eq!(pf[0].as_ref().unwrap().a1, -2);
121        assert!(pf[1].is_none());
122    }
123
124    #[test]
125    fn test_to_request_defaults() {
126        let line = r#"{"phoneme_ids": [1, 2, 3]}"#;
127        let utt = JsonlUtterance::parse(line).unwrap();
128        let req = utt.to_request(0.667, 1.0, 0.8);
129        assert_eq!(req.noise_scale, 0.667);
130        assert_eq!(req.length_scale, 1.0);
131        assert!(req.speaker_id.is_none());
132    }
133
134    #[test]
135    fn test_jsonl_reader() {
136        let input = "{ \"phoneme_ids\": [1] }\n{ \"phoneme_ids\": [2, 3] }\n";
137        let reader = JsonlReader::new(input.as_bytes());
138        let results: Vec<_> = reader.collect();
139        assert_eq!(results.len(), 2);
140        assert_eq!(results[0].as_ref().unwrap().phoneme_ids, vec![1]);
141        assert_eq!(results[1].as_ref().unwrap().phoneme_ids, vec![2, 3]);
142    }
143}