use crate::lstm_error::Error;
use crate::math_helper;
use crate::provider::LstmDataV1Marker;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::str;
use icu_provider::DataPayload;
use ndarray::{Array1, Array2, ArrayBase, Dim, ViewRepr};
use zerovec::ule::AsULE;
#[cfg(feature = "lstm-grapheme")]
use unicode_segmentation::UnicodeSegmentation;
pub struct Lstm<'l> {
data: &'l DataPayload<LstmDataV1Marker>,
mat1: Array2<f32>,
mat2: Array2<f32>,
mat3: Array2<f32>,
mat4: Array1<f32>,
mat5: Array2<f32>,
mat6: Array2<f32>,
mat7: Array1<f32>,
mat8: Array2<f32>,
mat9: Array1<f32>,
}
impl<'l> Lstm<'l> {
pub fn try_new(data: &'l DataPayload<LstmDataV1Marker>) -> Result<Self, Error> {
if data.get().dic.len() > core::i16::MAX as usize {
return Err(Error::Limit);
}
#[cfg(feature = "lstm-grapheme")]
if !data.get().model.contains("_codepoints_") && !data.get().model.contains("_graphclust_")
{
return Err(Error::Syntax);
}
#[cfg(not(feature = "lstm-grapheme"))]
if !data.get().model.contains("_codepoints_") {
return Err(Error::Syntax);
}
let mat1 = data.get().mat1.as_ndarray2()?;
let mat2 = data.get().mat2.as_ndarray2()?;
let mat3 = data.get().mat3.as_ndarray2()?;
let mat4 = data.get().mat4.as_ndarray1()?;
let mat5 = data.get().mat5.as_ndarray2()?;
let mat6 = data.get().mat6.as_ndarray2()?;
let mat7 = data.get().mat7.as_ndarray1()?;
let mat8 = data.get().mat8.as_ndarray2()?;
let mat9 = data.get().mat9.as_ndarray1()?;
let embedd_dim = mat1.shape()[1];
let hunits = mat3.shape()[0];
if mat2.shape() != [embedd_dim, 4 * hunits]
|| mat3.shape() != [hunits, 4 * hunits]
|| mat4.shape() != [4 * hunits]
|| mat5.shape() != [embedd_dim, 4 * hunits]
|| mat6.shape() != [hunits, 4 * hunits]
|| mat7.shape() != [4 * hunits]
|| mat8.shape() != [2 * hunits, 4]
|| mat9.shape() != [4]
{
return Err(Error::DimensionMismatch);
}
Ok(Self {
data,
mat1,
mat2,
mat3,
mat4,
mat5,
mat6,
mat7,
mat8,
mat9,
})
}
#[allow(dead_code)]
pub fn get_model_name(&self) -> &str {
&self.data.get().model
}
fn compute_bies(&self, arr: Array1<f32>) -> Result<char, Error> {
let ind = math_helper::max_arr1(arr.view());
match ind {
0 => Ok('b'),
1 => Ok('i'),
2 => Ok('e'),
3 => Ok('s'),
_ => Err(Error::Syntax),
}
}
fn return_id(&self, g: &str) -> i16 {
let id = self.data.get().dic.get(g);
if let Some(id) = id {
i16::from_unaligned(*id)
} else {
self.data.get().dic.len() as i16
}
}
fn compute_hc(
&self,
x_t: ArrayBase<ViewRepr<&f32>, Dim<[usize; 1]>>,
h_tm1: &Array1<f32>,
c_tm1: &Array1<f32>,
warr: ArrayBase<ViewRepr<&f32>, Dim<[usize; 2]>>,
uarr: ArrayBase<ViewRepr<&f32>, Dim<[usize; 2]>>,
barr: ArrayBase<ViewRepr<&f32>, Dim<[usize; 1]>>,
) -> (Array1<f32>, Array1<f32>) {
let s_t = x_t.dot(&warr) + h_tm1.dot(&uarr) + barr;
let hunits = uarr.shape()[0];
let i = math_helper::sigmoid_arr1(s_t.slice(ndarray::s![..hunits]));
let f = math_helper::sigmoid_arr1(s_t.slice(ndarray::s![hunits..2 * hunits]));
let _c = math_helper::tanh_arr1(s_t.slice(ndarray::s![2 * hunits..3 * hunits]));
let o = math_helper::sigmoid_arr1(s_t.slice(ndarray::s![3 * hunits..]));
let c_t = i * _c + f * c_tm1;
let h_t = o * math_helper::tanh_arr1(c_t.view());
(h_t, c_t)
}
pub fn word_segmenter(&self, input: &str) -> String {
let input_seq: Vec<i16> = if self.data.get().model.contains("_codepoints_") {
input
.chars()
.map(|c| self.return_id(&c.to_string()))
.collect()
} else {
#[cfg(feature = "lstm-grapheme")]
{
UnicodeSegmentation::graphemes(input, true)
.map(|s| self.return_id(s))
.collect()
}
#[cfg(not(feature = "lstm-grapheme"))]
{
panic!("Unreachable")
}
};
let input_seq_len = input_seq.len();
let hunits = self.mat3.shape()[0];
let mut c_fw = Array1::<f32>::zeros(hunits);
let mut h_fw = Array1::<f32>::zeros(hunits);
let mut all_h_fw = Array2::<f32>::zeros((input_seq_len, hunits));
for (i, g_id) in input_seq.iter().enumerate() {
let x_t = self.mat1.slice(ndarray::s![*g_id as isize, ..]);
let (new_h, new_c) = self.compute_hc(
x_t,
&h_fw,
&c_fw,
self.mat2.view(),
self.mat3.view(),
self.mat4.view(),
);
h_fw = new_h;
c_fw = new_c;
all_h_fw = math_helper::change_row(all_h_fw, i, &h_fw);
}
let mut c_bw = Array1::<f32>::zeros(hunits);
let mut h_bw = Array1::<f32>::zeros(hunits);
let mut all_h_bw = Array2::<f32>::zeros((input_seq_len, hunits));
for (i, g_id) in input_seq.iter().rev().enumerate() {
let x_t = self.mat1.slice(ndarray::s![*g_id as isize, ..]);
let (new_h, new_c) = self.compute_hc(
x_t,
&h_bw,
&c_bw,
self.mat5.view(),
self.mat6.view(),
self.mat7.view(),
);
h_bw = new_h;
c_bw = new_c;
all_h_bw = math_helper::change_row(all_h_bw, input_seq_len - 1 - i, &h_bw);
}
let timew = self.mat8.view();
let timeb = self.mat9.view();
let mut bies = String::from("");
for i in 0..input_seq_len {
let curr_fw = all_h_fw.slice(ndarray::s![i, ..]);
let curr_bw = all_h_bw.slice(ndarray::s![i, ..]);
let concat_lstm = math_helper::concatenate_arr1(curr_fw, curr_bw);
let curr_est = concat_lstm.dot(&timew) + timeb;
let probs = math_helper::softmax(curr_est);
bies.push(self.compute_bies(probs).unwrap());
}
bies
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;
#[derive(PartialEq, Debug, Serialize, Deserialize)]
pub struct TestCase {
pub unseg: String,
pub expected_bies: String,
pub true_bies: String,
}
#[derive(PartialEq, Debug, Serialize, Deserialize)]
pub struct TestTextData {
pub testcases: Vec<TestCase>,
}
#[derive(Debug)]
pub struct TestText {
pub data: TestTextData,
}
impl TestText {
pub fn new(data: TestTextData) -> Self {
Self { data }
}
}
fn load_lstm_data(filename: &str) -> DataPayload<LstmDataV1Marker> {
DataPayload::<LstmDataV1Marker>::try_from_rc_buffer_badly(
std::fs::read(filename)
.expect("File can read to end")
.into(),
|bytes| serde_json::from_slice(bytes),
)
.expect("JSON syntax error")
}
fn load_test_text(filename: &str) -> TestTextData {
let file = File::open(filename).expect("File should be present");
let reader = BufReader::new(file);
serde_json::from_reader(reader).expect("JSON syntax error")
}
#[test]
#[ignore = "dic entries of graphclust data aren't sorted"]
#[cfg(feature = "lstm-grapheme")]
fn test_model_loading() {
let filename = "tests/testdata/Thai_graphclust_exclusive_model4_heavy/weights.json";
let lstm_data = load_lstm_data(filename);
let lstm = Lstm::try_new(&lstm_data).unwrap();
assert_eq!(
lstm.get_model_name(),
String::from("Thai_graphclust_exclusive_model4_heavy")
);
}
#[test]
fn segment_file_by_lstm() {
let embedding: &str = "codepoints";
let mut model_filename = "tests/testdata/Thai_".to_owned();
model_filename.push_str(embedding);
model_filename.push_str("_exclusive_model4_heavy/weights.json");
let lstm_data = load_lstm_data(&model_filename);
let lstm = Lstm::try_new(&lstm_data).unwrap();
let mut test_text_filename = "tests/testdata/test_text_".to_owned();
test_text_filename.push_str(embedding);
test_text_filename.push_str(".json");
let test_text_data = load_test_text(&test_text_filename);
let test_text = TestText::new(test_text_data);
for test_case in test_text.data.testcases {
let lstm_output = lstm.word_segmenter(&test_case.unseg);
println!("Test case : {}", test_case.unseg);
println!("Expected bies : {}", test_case.expected_bies);
println!("Estimated bies : {}", lstm_output);
println!("True bies : {}", test_case.true_bies);
println!("****************************************************");
assert_eq!(test_case.expected_bies, lstm_output);
}
}
}