ali_nls_sr/
lib.rs

1use ali_nls_drive::{
2    self,
3    error::ZError,
4    futures_channel::{self},
5    tokio::{self, time::sleep},
6    tokio_tungstenite::tungstenite::{http::Uri, Message},
7    AliNlsDrive,
8};
9use log::info;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::{
13    env::{self, VarError},
14    fs::File,
15    future,
16    io::{BufReader, Read},
17    path::Path,
18    str::FromStr,
19    sync::Arc,
20    time::Duration,
21};
22use uuid::Uuid;
23
24pub use ali_nls_drive::config::AliNlsConfig;
25
26pub struct AliNlsToSr {
27    drive: AliNlsDrive,
28}
29
30#[derive(Serialize)]
31struct NlsHeader {
32    message_id: String,
33    task_id: String,
34    namespace: String,
35    name: String,
36    appkey: String,
37}
38
39#[derive(Serialize)]
40struct Payload {
41    fomrat: String,
42    sample_rate: u32,
43    enable_intermediate_result: bool,
44    enable_punctuation_prediction: bool,
45    enable_inverse_text_normalization: bool,
46    enable_words: bool,
47}
48
49#[derive(Serialize)]
50struct CmdCont {
51    header: NlsHeader,
52    payload: Payload,
53}
54#[derive(PartialEq)]
55enum TransStep {
56    UploadFile,
57    TransProcessing,
58    TransOneSentenceEnd,
59    TransAllComplete,
60    Unknown,
61}
62
63impl AliNlsToSr {
64    /// necessary params get from env
65    /// * ALI_TOKEN
66    pub fn from(config: AliNlsConfig) -> Self {
67        let ret = dotenv::from_filename(".env_dev");
68        match ret {
69            Ok(v) => info!("found .env_dev file in {}! load env from it!", v.display()),
70            Err(_) => {}
71        };
72        Self {
73            drive: AliNlsDrive::new(config),
74        }
75    }
76
77    fn gen_taskid() -> String {
78        return Uuid::new_v4().to_string().replace("-", "");
79    }
80
81    fn get_token(&self) -> Result<String, VarError> {
82        return env::var("ALI_TOKEN");
83    }
84
85    fn handle_sr_resp(ret: &Value) -> TransStep {
86        let header = ret["header"].as_object().unwrap();
87        if let Some(statu) = header.get("status") {
88            let s = statu.as_i64().unwrap();
89            if s == 20000000 {
90                let proce_name = header.get("name").unwrap().as_str().unwrap();
91                let _ = match proce_name {
92                    "TranscriptionResultChanged" => {
93                        return TransStep::TransProcessing;
94                    }
95                    "TranscriptionStarted" => {
96                        return TransStep::UploadFile;
97                    }
98                    "SentenceEnd" => {
99                        return TransStep::TransOneSentenceEnd;
100                    }
101                    "TranscriptionCompleted" => {
102                        return TransStep::TransAllComplete;
103                    }
104                    &_ => {}
105                };
106                return TransStep::Unknown;
107            }
108        }
109        return TransStep::Unknown;
110    }
111
112    pub async fn sr_from_slicefile(&mut self, fpath: &Path) -> Result<Option<String>, ZError> {
113        let (ch_sender, ch_receive) = futures_channel::mpsc::unbounded();
114        //url
115        let sr_path = format!("/ws/v1?token={}", self.get_token()?);
116        let _ = &self.drive.config.host.push_str(&sr_path);
117        let uri = Uri::from_str(&self.drive.config.host).unwrap();
118        //client
119        self.drive.new_wscli(uri.to_string()).await?;
120        //shake params
121        let task_id = Arc::new(Self::gen_taskid().clone());
122        let app_key = Arc::new(self.drive.config.app_key.clone());
123        let cmd = Self::gen_req_val(
124            task_id.as_ref().to_string(),
125            app_key.as_ref().to_string(),
126            "StartTranscription".to_owned(),
127        );
128        let cont = json!(cmd).to_string();
129        let _ = &ch_sender.unbounded_send(Message::Text(cont))?;
130        //listen response
131        let mut ret_sr = SrResult::default();
132        let mut ret_jsonstr: String = String::from("");
133        self.drive
134            .run(ch_receive, |_c, msg| {
135                println!("msg is -->>msg={:?}", msg);
136                let ret: Value = serde_json::from_str(msg.unwrap().to_string().as_str())
137                    .expect("[ws]return msg convert to json failed!");
138                let s = Self::handle_sr_resp(&ret);
139                let _ = match s {
140                    TransStep::UploadFile => {
141                        //chk file open succ?
142                        let r = File::open(fpath).expect("Not found test file!");
143                        //clone outer var
144                        let sender_c = ch_sender.clone();
145                        let task_idr = task_id.as_ref().clone();
146                        let app_keyr = app_key.as_ref().clone();
147                        //slice upload
148                        let _: tokio::task::JoinHandle<()> = tokio::spawn(async move {
149                            let mut reader = BufReader::new(r);
150                            const CHUNK_SIZE: usize = 1024 * 10;
151                            let mut chunk_con = [0_u8; CHUNK_SIZE];
152                            loop {
153                                let chunk: &mut [u8] = &mut chunk_con;
154                                if reader.read_exact(chunk).is_ok() {
155                                    println!("send file slice:{}", CHUNK_SIZE.to_string());
156                                    let _ = &sender_c
157                                        .unbounded_send(Message::Binary(chunk.to_vec()))
158                                        .unwrap();
159                                    sleep(Duration::from_millis(100)).await;
160                                } else {
161                                    println!("slice upload finish!");
162                                    break;
163                                }
164                            }
165                            let cmd = Self::gen_req_val(
166                                task_idr,
167                                app_keyr,
168                                "StopTranscription".to_owned(),
169                            );
170                            let cont = json!(cmd).to_string();
171                            let _ = &sender_c.unbounded_send(Message::Text(cont));
172                        });
173                    }
174                    TransStep::Unknown => {}
175                    TransStep::TransProcessing => {}
176                    TransStep::TransOneSentenceEnd => {
177                        let line_fulltxt = ret
178                            .get("payload")
179                            .unwrap()
180                            .get("result")
181                            .unwrap()
182                            .to_string();
183                        let line_words: Vec<Value> = ret
184                            .get("payload")
185                            .unwrap()
186                            .get("words")
187                            .unwrap()
188                            .as_array()
189                            .unwrap()
190                            .to_vec();
191                        let sr_time = ret
192                            .get("payload")
193                            .unwrap()
194                            .get("time")
195                            .unwrap()
196                            .as_i64()
197                            .unwrap();
198                        for word in line_words {
199                            ret_sr.words.push(serde_json::from_value(word).unwrap());
200                        }
201                        ret_sr.full_txt += &line_fulltxt.replace("\"", "");
202                        ret_sr.total_time = sr_time;
203                    }
204                    TransStep::TransAllComplete => {
205                        ret_jsonstr = serde_json::to_string(&ret_sr).unwrap();
206                        return future::ready(None);
207                    }
208                };
209                future::ready(Some("".to_string()))
210            })
211            .await;
212        Ok(Some(ret_jsonstr))
213    }
214
215    fn gen_req_val(task_id: String, app_key: String, cmd: String) -> CmdCont {
216        CmdCont {
217            header: NlsHeader {
218                message_id: Uuid::new_v4().to_string().replace("-", ""),
219                task_id: task_id,
220                namespace: "SpeechTranscriber".to_owned(),
221                name: cmd,
222                appkey: app_key,
223            },
224            payload: Payload {
225                fomrat: "opus".to_owned(),
226                sample_rate: 16000,
227                enable_intermediate_result: false,
228                enable_punctuation_prediction: true,
229                enable_inverse_text_normalization: false,
230                enable_words: true,
231            },
232        }
233    }
234}
235
236#[derive(Serialize, Deserialize)]
237struct SrWordResult {
238    #[serde(alias = "startTime")]
239    start_time: i64,
240    #[serde(alias = "endTime")]
241    end_time: i64,
242    text: String,
243}
244
245#[derive(Serialize, Default)]
246struct SrResult {
247    full_txt: String,
248    words: Vec<SrWordResult>,
249    total_time: i64,
250}
251
252#[test]
253fn test_sr() {
254    use std::env;
255    use std::path::Path;
256    use tokio::runtime::Runtime;
257
258    Runtime::new().unwrap().block_on(async {
259        let mut c = AliNlsToSr::from(AliNlsConfig {
260            app_key: "FPwxKxga3cQ6B2Fs".to_owned(),
261            host: "wss://nls-gateway.cn-shanghai.aliyuncs.com".to_owned(),
262        });
263        let cur_p = &env::current_dir().unwrap();
264        let f = Path::new(cur_p).join("test").join("16000_2_s16le.wav");
265        let ret = c.sr_from_slicefile(f.as_path()).await;
266        match ret {
267            Ok(r) => {
268                if let Some(r_) = r {
269                    println!("json result is :{:?}", r_);
270                }
271            }
272            Err(e) => {
273                println!("[error]{}", e.to_string());
274            }
275        }
276    });
277}