1use std::error;
2use std::fmt;
3use std::fs::File;
4use std::io::prelude::*;
5use std::io::{self, BufReader};
6use std::path::Path;
7
8use jieba_rs::Jieba;
9
10#[derive(Debug)]
11pub enum Error {
12 Io(io::Error),
13 Crf(crfsuite::CrfError),
14}
15
16impl error::Error for Error {
17 fn description(&self) -> &str {
18 match *self {
19 Error::Io(_) => "I/O error",
20 Error::Crf(_) => "crfsuite error",
21 }
22 }
23}
24
25impl fmt::Display for Error {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 match *self {
28 Error::Io(ref err) => err.fmt(f),
29 Error::Crf(ref err) => err.fmt(f),
30 }
31 }
32}
33
34impl From<io::Error> for Error {
35 #[inline]
36 fn from(err: io::Error) -> Error {
37 Error::Io(err)
38 }
39}
40
41impl From<crfsuite::CrfError> for Error {
42 #[inline]
43 fn from(err: crfsuite::CrfError) -> Error {
44 Error::Crf(err)
45 }
46}
47
48#[derive(Debug)]
49pub struct ChineseNER {
50 model: crfsuite::Model,
51 segmentor: jieba_rs::Jieba,
52}
53
54#[cfg(feature = "bundled-model")]
55impl Default for ChineseNER {
56 fn default() -> ChineseNER {
57 ChineseNER::new()
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
62pub struct NamedEntity<'a> {
63 pub word: Vec<&'a str>,
64 pub tag: Vec<&'a str>,
65 pub entity: Vec<(usize, usize, &'static str)>,
66}
67
68impl ChineseNER {
69 #[cfg(feature = "bundled-model")]
70 pub fn new() -> Self {
71 let model_bytes = include_bytes!("ner.model");
72 let model = crfsuite::Model::from_memory(&model_bytes[..]).expect("open model failed");
73 Self {
74 model,
75 segmentor: Jieba::new(),
76 }
77 }
78
79 pub fn from_model(model_path: &str) -> Result<Self, Error> {
80 let model = crfsuite::Model::from_file(model_path)?;
81 Ok(Self {
82 model,
83 segmentor: Jieba::new(),
84 })
85 }
86
87 pub fn predict<'a>(&'a self, sentence: &'a str) -> Result<NamedEntity<'a>, Error> {
88 use crfsuite::Attribute;
89
90 let mut tagger = self.model.tagger()?;
91 let (split_words, tags) = split_by_words(&self.segmentor, sentence);
92 let features = sent2features(&split_words);
93 let attributes: Vec<crfsuite::Item> = features
94 .into_iter()
95 .map(|x| {
96 x.into_iter()
97 .map(|f| Attribute::new(f, 1.0))
98 .collect::<crfsuite::Item>()
99 })
100 .collect();
101 let tag_result = tagger.tag(&attributes)?;
102 let mut is_tag = false;
103 let mut start_index = 0;
104 let mut entities = Vec::new();
105 for (index, tag) in tag_result.iter().enumerate() {
106 if !is_tag && tag.starts_with('B') {
107 start_index = index;
108 is_tag = true;
109 } else if is_tag && tag == "O" {
110 entities.push((start_index, index, get_tag_name(&tag_result[start_index])));
111 is_tag = false;
112 }
113 }
114 let words = tags.iter().map(|x| x.word).collect();
115 let tags = tags.iter().map(|x| x.tag).collect();
116 Ok(NamedEntity {
117 word: words,
118 tag: tags,
119 entity: entities,
120 })
121 }
122}
123
124fn get_tag_name(tag: &str) -> &'static str {
125 if tag.contains("PRO") {
126 "product_name"
127 } else if tag.contains("PER") {
128 "person_name"
129 } else if tag.contains("TIM") {
130 "time"
131 } else if tag.contains("ORG") {
132 "org_name"
133 } else if tag.contains("LOC") {
134 "location"
135 } else {
136 "unknown"
137 }
138}
139
140#[derive(Debug, PartialEq)]
141struct SplitWord<'a> {
142 word: &'a str,
143 status: &'static str,
144 tag: String,
145 entity_type: String,
146}
147
148fn split_by_words<'a>(
149 segmentor: &'a Jieba,
150 sentence: &'a str,
151) -> (Vec<SplitWord<'a>>, Vec<jieba_rs::Tag<'a>>) {
152 let mut words = Vec::new();
153 let mut char_indices = sentence.char_indices().map(|x| x.0).peekable();
154 while let Some(pos) = char_indices.next() {
155 if let Some(next_pos) = char_indices.peek() {
156 let word = &sentence[pos..*next_pos];
157 words.push(SplitWord {
158 word: word,
159 status: "",
160 tag: String::new(),
161 entity_type: String::new(),
162 });
163 } else {
164 let word = &sentence[pos..];
165 words.push(SplitWord {
166 word: word,
167 status: "",
168 tag: String::new(),
169 entity_type: String::new(),
170 });
171 }
172 }
173 let tags = segmentor.tag(sentence, true);
174 let mut index = 0;
175 for word_tag in &tags {
176 let char_count = word_tag.word.chars().count();
177 for i in 0..char_count {
178 let status = {
179 if char_count == 1 {
180 "S"
181 } else if i == 0 {
182 "B"
183 } else if i == char_count - 1 {
184 "E"
185 } else {
186 "I"
187 }
188 };
189 words[index].status = status;
190 words[index].tag = word_tag.tag.to_string();
191 index += 1;
192 }
193 }
194 (words, tags)
195}
196
197fn sent2features(split_words: &[SplitWord]) -> Vec<Vec<String>> {
198 let mut features = Vec::with_capacity(split_words.len());
199 for i in 0..split_words.len() {
200 features.push(word2features(split_words, i));
201 }
202 features
203}
204
205fn word2features(split_words: &[SplitWord], i: usize) -> Vec<String> {
206 let split_word = &split_words[i];
207 let word = split_word.word;
208 let is_digit = word.chars().all(|c| c.is_ascii_digit());
209 let mut features = vec![
210 "bias".to_string(),
211 format!("word={}", word),
212 format!("word.isdigit={}", if is_digit { "True" } else { "False" }),
213 format!("postag={}", split_word.tag),
214 format!("cuttag={}", split_word.status),
215 ];
216 if i > 0 {
217 let split_word1 = &split_words[i - 1];
218 features.push(format!("-1:word={}", split_word1.word));
219 features.push(format!("-1:postag={}", split_word1.tag));
220 features.push(format!("-1:cuttag={}", split_word1.status));
221 } else {
222 features.push("BOS".to_string());
223 }
224 if i < split_words.len() - 1 {
225 let split_word1 = &split_words[i + 1];
226 features.push(format!("+1:word={}", split_word1.word));
227 features.push(format!("+1:postag={}", split_word1.tag));
228 features.push(format!("+1:cuttag={}", split_word1.status));
229 } else {
230 features.push("EOS".to_string());
231 }
232 features
233}
234
235pub struct NERTrainer {
236 trainer: crfsuite::Trainer,
237 segmentor: jieba_rs::Jieba,
238 output_path: String,
239}
240
241impl NERTrainer {
242 pub fn new(output_path: &str) -> Self {
243 Self {
244 trainer: crfsuite::Trainer::new(true),
245 segmentor: Jieba::new(),
246 output_path: output_path.to_string(),
247 }
248 }
249
250 pub fn train<T: AsRef<Path>>(&mut self, dataset_path: T) -> Result<(), Error> {
251 let file = File::open(dataset_path)?;
252 let reader = BufReader::new(file);
253 let lines = reader.lines().collect::<Result<Vec<String>, _>>()?;
254 let mut x_train = Vec::new();
255 let mut y_train = Vec::new();
256 let mut words: Vec<SplitWord> = Vec::new();
257 for line in &lines {
258 if line.is_empty() {
259 let sentence: String = words.iter().map(|x| x.word).collect::<Vec<_>>().join("");
260 let tags = self.segmentor.tag(&sentence, true);
261 let mut index = 0;
262 for word_tag in tags {
263 let char_count = word_tag.word.chars().count();
264 for i in 0..char_count {
265 let status = {
266 if char_count == 1 {
267 "S"
268 } else if i == 0 {
269 "B"
270 } else if i == char_count - 1 {
271 "E"
272 } else {
273 "I"
274 }
275 };
276 words[index].status = status;
277 words[index].tag = word_tag.tag.to_string();
278 index += 1;
279 }
280 }
281 x_train.push(sent2features(&words));
282 y_train.push(
283 words
284 .iter()
285 .map(|x| x.entity_type.to_string())
286 .collect::<Vec<_>>(),
287 );
288 words.clear();
289 } else {
290 let parts: Vec<&str> = line.split_ascii_whitespace().collect();
291 let word = &parts[0];
292 let entity_type = &parts[1];
293 words.push(SplitWord {
294 word: word,
295 status: "",
296 tag: String::new(),
297 entity_type: entity_type.to_string(),
298 });
299 }
300 }
301 self.trainer
302 .select(crfsuite::Algorithm::LBFGS, crfsuite::GraphicalModel::CRF1D)?;
303 for (features, yseq) in x_train.into_iter().zip(y_train) {
304 let xseq: Vec<crfsuite::Item> = features
305 .into_iter()
306 .map(|x| {
307 x.into_iter()
308 .map(|f| crfsuite::Attribute::new(f, 1.0))
309 .collect::<crfsuite::Item>()
310 })
311 .collect();
312 self.trainer.append(&xseq, &yseq, 0)?;
313 }
314 self.trainer.train(&self.output_path, -1)?;
315 Ok(())
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use jieba_rs::Jieba;
323
324 #[test]
325 fn test_split_by_words() {
326 let jieba = Jieba::new();
327 let sentence = "洗衣机,国内掀起了大数据、云计算的热潮。仙鹤门地区。";
328 let (ret, _) = split_by_words(&jieba, sentence);
329 assert_eq!(
330 ret,
331 vec![
332 SplitWord {
333 word: "洗",
334 status: "B",
335 tag: "n".to_string(),
336 entity_type: String::new()
337 },
338 SplitWord {
339 word: "衣",
340 status: "I",
341 tag: "n".to_string(),
342 entity_type: String::new()
343 },
344 SplitWord {
345 word: "机",
346 status: "E",
347 tag: "n".to_string(),
348 entity_type: String::new()
349 },
350 SplitWord {
351 word: ",",
352 status: "S",
353 tag: "x".to_string(),
354 entity_type: String::new()
355 },
356 SplitWord {
357 word: "国",
358 status: "B",
359 tag: "s".to_string(),
360 entity_type: String::new()
361 },
362 SplitWord {
363 word: "内",
364 status: "E",
365 tag: "s".to_string(),
366 entity_type: String::new()
367 },
368 SplitWord {
369 word: "掀",
370 status: "B",
371 tag: "v".to_string(),
372 entity_type: String::new()
373 },
374 SplitWord {
375 word: "起",
376 status: "E",
377 tag: "v".to_string(),
378 entity_type: String::new()
379 },
380 SplitWord {
381 word: "了",
382 status: "S",
383 tag: "ul".to_string(),
384 entity_type: String::new()
385 },
386 SplitWord {
387 word: "大",
388 status: "S",
389 tag: "a".to_string(),
390 entity_type: String::new()
391 },
392 SplitWord {
393 word: "数",
394 status: "B",
395 tag: "n".to_string(),
396 entity_type: String::new()
397 },
398 SplitWord {
399 word: "据",
400 status: "E",
401 tag: "n".to_string(),
402 entity_type: String::new()
403 },
404 SplitWord {
405 word: "、",
406 status: "S",
407 tag: "x".to_string(),
408 entity_type: String::new()
409 },
410 SplitWord {
411 word: "云",
412 status: "S",
413 tag: "ns".to_string(),
414 entity_type: String::new()
415 },
416 SplitWord {
417 word: "计",
418 status: "B",
419 tag: "v".to_string(),
420 entity_type: String::new()
421 },
422 SplitWord {
423 word: "算",
424 status: "E",
425 tag: "v".to_string(),
426 entity_type: String::new()
427 },
428 SplitWord {
429 word: "的",
430 status: "S",
431 tag: "uj".to_string(),
432 entity_type: String::new()
433 },
434 SplitWord {
435 word: "热",
436 status: "B",
437 tag: "n".to_string(),
438 entity_type: String::new()
439 },
440 SplitWord {
441 word: "潮",
442 status: "E",
443 tag: "n".to_string(),
444 entity_type: String::new()
445 },
446 SplitWord {
447 word: "。",
448 status: "S",
449 tag: "x".to_string(),
450 entity_type: String::new()
451 },
452 SplitWord {
453 word: "仙",
454 status: "B",
455 tag: "n".to_string(),
456 entity_type: String::new()
457 },
458 SplitWord {
459 word: "鹤",
460 status: "E",
461 tag: "n".to_string(),
462 entity_type: String::new()
463 },
464 SplitWord {
465 word: "门",
466 status: "S",
467 tag: "n".to_string(),
468 entity_type: String::new()
469 },
470 SplitWord {
471 word: "地",
472 status: "B",
473 tag: "n".to_string(),
474 entity_type: String::new()
475 },
476 SplitWord {
477 word: "区",
478 status: "E",
479 tag: "n".to_string(),
480 entity_type: String::new()
481 },
482 SplitWord {
483 word: "。",
484 status: "S",
485 tag: "x".to_string(),
486 entity_type: String::new()
487 },
488 ]
489 );
490 }
491
492 #[cfg(feature = "bundled-model")]
493 #[test]
494 fn test_ner_predict() {
495 let ner = ChineseNER::new();
496 let sentence = "今天纽约的天气真好啊,京华大酒店的李白经理吃了一只北京烤鸭。";
497 let result = ner.predict(sentence).unwrap();
498 assert_eq!(
499 result.word,
500 vec![
501 "今天",
502 "纽约",
503 "的",
504 "天气",
505 "真好",
506 "啊",
507 ",",
508 "京华",
509 "大酒店",
510 "的",
511 "李白",
512 "经理",
513 "吃",
514 "了",
515 "一只",
516 "北京烤鸭",
517 "。"
518 ]
519 );
520 assert_eq!(
521 result.tag,
522 vec![
523 "t", "ns", "uj", "n", "d", "zg", "x", "nz", "n", "uj", "nr", "n", "v", "ul", "m",
524 "n", "x"
525 ]
526 );
527 assert_eq!(
528 result.entity,
529 vec![
530 (2, 4, "location"),
531 (11, 16, "org_name"),
532 (17, 19, "person_name"),
533 (25, 27, "location")
534 ]
535 );
536 }
537}