use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::str::CharIndices;
use icu_locid::{locale, Locale};
use icu_provider::prelude::*;
use crate::complex::*;
use crate::indices::{Latin1Indices, Utf16Indices};
use crate::provider::*;
use crate::rule_segmenter::*;
pub type WordBreakIteratorUtf8<'l, 's> = RuleBreakIterator<'l, 's, WordBreakTypeUtf8>;
pub type WordBreakIteratorLatin1<'l, 's> = RuleBreakIterator<'l, 's, WordBreakTypeLatin1>;
pub type WordBreakIteratorUtf16<'l, 's> = RuleBreakIterator<'l, 's, WordBreakTypeUtf16>;
pub struct WordBreakSegmenter {
payload: DataPayload<WordBreakDataV1Marker>,
dictionary: Dictionary,
lstm: LstmPayloads,
}
impl WordBreakSegmenter {
#[cfg(feature = "lstm")]
pub fn try_new<D>(provider: &D) -> Result<Self, DataError>
where
D: DataProvider<WordBreakDataV1Marker>
+ DataProvider<UCharDictionaryBreakDataV1Marker>
+ DataProvider<LstmDataV1Marker>
+ ?Sized,
{
let payload = provider.load(Default::default())?.take_payload()?;
let cj = Self::load_dictionary(provider, locale!("ja")).ok();
let lstm = if cfg!(feature = "lstm") {
let burmese = Self::load_lstm(provider, locale!("my")).ok();
let khmer = Self::load_lstm(provider, locale!("lo")).ok();
let lao = Self::load_lstm(provider, locale!("lo")).ok();
let thai = Self::load_lstm(provider, locale!("th")).ok();
LstmPayloads {
burmese,
khmer,
lao,
thai,
}
} else {
LstmPayloads::default()
};
Ok(Self {
payload,
dictionary: Dictionary {
burmese: None,
khmer: None,
lao: None,
thai: None,
cj,
},
lstm,
})
}
#[cfg(not(feature = "lstm"))]
pub fn try_new<D>(provider: &D) -> Result<Self, DataError>
where
D: DataProvider<WordBreakDataV1Marker>
+ DataProvider<UCharDictionaryBreakDataV1Marker>
+ ?Sized,
{
let payload = provider.load(Default::default())?.take_payload()?;
let dictionary = if cfg!(feature = "lstm") {
let cj = Self::load_dictionary(provider, locale!("ja")).ok();
Dictionary {
burmese: None,
khmer: None,
lao: None,
thai: None,
cj,
}
} else {
let cj = Self::load_dictionary(provider, locale!("ja")).ok();
let burmese = Self::load_dictionary(provider, locale!("my")).ok();
let khmer = Self::load_dictionary(provider, locale!("km")).ok();
let lao = Self::load_dictionary(provider, locale!("lo")).ok();
let thai = Self::load_dictionary(provider, locale!("th")).ok();
Dictionary {
burmese,
khmer,
lao,
thai,
cj,
}
};
Ok(Self {
payload,
dictionary,
lstm: LstmPayloads::default(),
})
}
fn load_dictionary<D: DataProvider<UCharDictionaryBreakDataV1Marker> + ?Sized>(
provider: &D,
locale: Locale,
) -> Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataError> {
provider
.load(DataRequest {
locale: &DataLocale::from(locale),
metadata: Default::default(),
})?
.take_payload()
}
#[cfg(feature = "lstm")]
fn load_lstm<D: DataProvider<LstmDataV1Marker> + ?Sized>(
provider: &D,
locale: Locale,
) -> Result<DataPayload<LstmDataV1Marker>, DataError> {
provider
.load(DataRequest {
locale: &DataLocale::from(locale),
metadata: Default::default(),
})?
.take_payload()
}
pub fn segment_str<'l, 's>(&'l self, input: &'s str) -> WordBreakIteratorUtf8<'l, 's> {
WordBreakIteratorUtf8 {
iter: input.char_indices(),
len: input.len(),
current_pos_data: None,
result_cache: Vec::new(),
data: self.payload.get(),
dictionary: &self.dictionary,
lstm: &self.lstm,
}
}
pub fn segment_latin1<'l, 's>(&'l self, input: &'s [u8]) -> WordBreakIteratorLatin1<'l, 's> {
WordBreakIteratorLatin1 {
iter: Latin1Indices::new(input),
len: input.len(),
current_pos_data: None,
result_cache: Vec::new(),
data: self.payload.get(),
dictionary: &self.dictionary,
lstm: &self.lstm,
}
}
pub fn segment_utf16<'l, 's>(&'l self, input: &'s [u16]) -> WordBreakIteratorUtf16<'l, 's> {
WordBreakIteratorUtf16 {
iter: Utf16Indices::new(input),
len: input.len(),
current_pos_data: None,
result_cache: Vec::new(),
data: self.payload.get(),
dictionary: &self.dictionary,
lstm: &self.lstm,
}
}
}
pub struct WordBreakTypeUtf8;
impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf8 {
type IterAttr = CharIndices<'s>;
type CharType = char;
fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize {
iter.current_pos_data.unwrap().1.len_utf8()
}
fn handle_complex_language(
iter: &mut RuleBreakIterator<'l, 's, Self>,
left_codepoint: Self::CharType,
) -> Option<usize> {
let start_iter = iter.iter.clone();
let start_point = iter.current_pos_data;
let mut s = String::new();
s.push(left_codepoint);
loop {
s.push(iter.current_pos_data.unwrap().1);
iter.current_pos_data = iter.iter.next();
if iter.current_pos_data.is_none() {
break;
}
if iter.get_current_break_property() != iter.data.complex_property {
break;
}
}
iter.iter = start_iter;
iter.current_pos_data = start_point;
let breaks = complex_language_segment_str(iter.dictionary, iter.lstm, &s);
iter.result_cache = breaks;
let mut i = iter.current_pos_data.unwrap().1.len_utf8();
loop {
if i == *iter.result_cache.first().unwrap() {
iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect();
return Some(iter.current_pos_data.unwrap().0);
}
iter.current_pos_data = iter.iter.next();
if iter.current_pos_data.is_none() {
iter.result_cache.clear();
return Some(iter.len);
}
i += Self::get_current_position_character_len(iter);
}
}
}
pub struct WordBreakTypeLatin1;
impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeLatin1 {
type IterAttr = Latin1Indices<'s>;
type CharType = u8;
fn get_current_position_character_len(_: &RuleBreakIterator<Self>) -> usize {
panic!("not reachable")
}
fn handle_complex_language(
_: &mut RuleBreakIterator<'l, 's, Self>,
_: Self::CharType,
) -> Option<usize> {
panic!("not reachable")
}
}
pub struct WordBreakTypeUtf16;
impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf16 {
type IterAttr = Utf16Indices<'s>;
type CharType = u32;
fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize {
let ch = iter.current_pos_data.unwrap().1;
if ch >= 0x10000 {
2
} else {
1
}
}
fn handle_complex_language(
iter: &mut RuleBreakIterator<Self>,
left_codepoint: Self::CharType,
) -> Option<usize> {
let start_iter = iter.iter.clone();
let start_point = iter.current_pos_data;
let mut s = vec![left_codepoint as u16];
loop {
s.push(iter.current_pos_data.unwrap().1 as u16);
iter.current_pos_data = iter.iter.next();
if iter.current_pos_data.is_none() {
break;
}
if iter.get_current_break_property() != iter.data.complex_property {
break;
}
}
iter.iter = start_iter;
iter.current_pos_data = start_point;
let breaks = complex_language_segment_utf16(iter.dictionary, iter.lstm, &s);
let mut i = 1;
iter.result_cache = breaks;
loop {
if i == *iter.result_cache.first().unwrap() {
iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect();
return Some(iter.current_pos_data.unwrap().0);
}
iter.current_pos_data = iter.iter.next();
if iter.current_pos_data.is_none() {
iter.result_cache.clear();
return Some(iter.len);
}
i += 1;
}
}
}