use std::str;
use crate::string_strategy::UnicodeStringStrategy;
use crate::symspell::{SymSpell, SymSpellBuilder, Verbosity};
use wasm_bindgen::prelude::*;
#[derive(serde::Serialize, serde::Deserialize)]
pub struct JSSuggestion {
term: String,
distance: i32,
count: i32,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct JSComposition {
segmented_string: String,
distance_sum: i32,
prob_log_sum: f32,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct InitParams {
max_edit_distance: i32,
prefix_length: i32,
count_threshold: i32,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct DictParams {
term_index: i32,
count_index: i32,
separator: String,
}
#[wasm_bindgen(js_name = SymSpell)]
pub struct JSSymSpell {
symspell: SymSpell<UnicodeStringStrategy>,
}
#[wasm_bindgen(js_class = SymSpell)]
impl JSSymSpell {
#[wasm_bindgen(constructor)]
pub fn new(parameters: &JsValue) -> Result<JSSymSpell, JsValue> {
let params: InitParams = serde_wasm_bindgen::from_value(parameters.clone())
.map_err(|_| JsValue::from("Unable to parse arguments"))?;
let symspell = SymSpellBuilder::default()
.max_dictionary_edit_distance(params.max_edit_distance as i64)
.prefix_length(params.prefix_length as i64)
.count_threshold(params.count_threshold as i64)
.build()
.map_err(|e| JsValue::from(e.to_string()))?;
Ok(JSSymSpell { symspell })
}
pub fn load_dictionary(&mut self, input: &[u8], args: &JsValue) -> Result<(), JsValue> {
let params: DictParams = serde_wasm_bindgen::from_value(args.clone())
.map_err(|_| JsValue::from("Unable to parse arguments"))?;
let corpus = str::from_utf8(input).map_err(|_| JsValue::from("Invalid UTF-8"))?;
for line in corpus.lines() {
self.symspell.load_dictionary_line(
&line,
params.term_index as i64,
params.count_index as i64,
¶ms.separator,
);
}
Ok(())
}
pub fn load_bigram_dictionary(&mut self, input: &[u8], args: &JsValue) -> Result<(), JsValue> {
let params: DictParams = serde_wasm_bindgen::from_value(args.clone())
.map_err(|_| JsValue::from("Unable to parse arguments"))?;
let corpus = str::from_utf8(input).map_err(|_| JsValue::from("Invalid UTF-8"))?;
for line in corpus.lines() {
self.symspell.load_bigram_dictionary_line(
&line,
params.term_index as i64,
params.count_index as i64,
¶ms.separator,
);
}
Ok(())
}
pub fn lookup_compound(
&self,
input: &str,
edit_distance: i32,
) -> Result<Vec<JsValue>, JsValue> {
let res = self.symspell.lookup_compound(input, edit_distance as i64);
Ok(res
.into_iter()
.map(|sugg| {
let temp = JSSuggestion {
term: sugg.term,
distance: sugg.distance as i32,
count: sugg.count as i32,
};
serde_wasm_bindgen::to_value(&temp).unwrap()
})
.collect())
}
pub fn lookup(
&self,
input: &str,
verbosity: i8,
max_edit_distance: i32,
) -> Result<Vec<JsValue>, JsValue> {
let sym_verbosity = match verbosity {
0 => Verbosity::Top,
1 => Verbosity::All,
2 => Verbosity::Closest,
_ => return Err(JsValue::from("Verbosity must be between 0 and 2")),
};
let res = self
.symspell
.lookup(&input, sym_verbosity, max_edit_distance as i64);
Ok(res
.into_iter()
.map(|sugg| {
let temp = JSSuggestion {
term: sugg.term,
distance: sugg.distance as i32,
count: sugg.count as i32,
};
serde_wasm_bindgen::to_value(&temp).unwrap()
})
.collect())
}
pub fn word_segmentation(
&self,
input: &str,
max_edit_distance: i32,
) -> Result<JsValue, JsValue> {
let seg = self
.symspell
.word_segmentation(input, max_edit_distance as i64);
let res = JSComposition {
segmented_string: seg.segmented_string,
distance_sum: seg.distance_sum as i32,
prob_log_sum: seg.prob_log_sum as f32,
};
Ok(serde_wasm_bindgen::to_value(&res).unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn test_sentence() {
let init_args = InitParams {
max_edit_distance: 2,
prefix_length: 7,
count_threshold: 1,
};
let mut speller =
JSSymSpell::new(&serde_wasm_bindgen::to_value(&init_args).unwrap()).unwrap();
let dict = "where 360468339\ninfo 352363058".as_bytes();
let dict_args = DictParams {
term_index: 0,
count_index: 1,
separator: String::from(" "),
};
speller
.load_dictionary(dict, &serde_wasm_bindgen::to_value(&dict_args).unwrap())
.unwrap();
let bigram_dict = "this is 1111\nwhere is 1234".as_bytes();
let bigram_dict_args = DictParams {
term_index: 0,
count_index: 2,
separator: String::from(" "),
};
speller
.load_bigram_dictionary(
bigram_dict,
&serde_wasm_bindgen::to_value(&bigram_dict_args).unwrap(),
)
.unwrap();
let sentence = "wher";
let expected = "where";
let result: JSSuggestion =
serde_wasm_bindgen::from_value(speller.lookup_compound(sentence, 1).unwrap()[0].clone())
.unwrap();
assert_eq!(result.term, expected);
let sentence = "whereinfo";
let expected = "where info";
let result: JSComposition = serde_wasm_bindgen::from_value(
speller.word_segmentation(sentence, 2).unwrap(),
)
.unwrap();
assert_eq!(result.segmented_string, expected);
}
}