chinese_ner/
lib.rs

1use std::error;
2use std::fmt;
3use std::fs::File;
4use std::io::prelude::*;
5use std::io::{self, BufReader};
6use std::path::Path;
7
8use jieba_rs::Jieba;
9
10#[derive(Debug)]
11pub enum Error {
12    Io(io::Error),
13    Crf(crfsuite::CrfError),
14}
15
16impl error::Error for Error {
17    fn description(&self) -> &str {
18        match *self {
19            Error::Io(_) => "I/O error",
20            Error::Crf(_) => "crfsuite error",
21        }
22    }
23}
24
25impl fmt::Display for Error {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        match *self {
28            Error::Io(ref err) => err.fmt(f),
29            Error::Crf(ref err) => err.fmt(f),
30        }
31    }
32}
33
34impl From<io::Error> for Error {
35    #[inline]
36    fn from(err: io::Error) -> Error {
37        Error::Io(err)
38    }
39}
40
41impl From<crfsuite::CrfError> for Error {
42    #[inline]
43    fn from(err: crfsuite::CrfError) -> Error {
44        Error::Crf(err)
45    }
46}
47
48#[derive(Debug)]
49pub struct ChineseNER {
50    model: crfsuite::Model,
51    segmentor: jieba_rs::Jieba,
52}
53
54#[cfg(feature = "bundled-model")]
55impl Default for ChineseNER {
56    fn default() -> ChineseNER {
57        ChineseNER::new()
58    }
59}
60
61#[derive(Debug, Clone, PartialEq)]
62pub struct NamedEntity<'a> {
63    pub word: Vec<&'a str>,
64    pub tag: Vec<&'a str>,
65    pub entity: Vec<(usize, usize, &'static str)>,
66}
67
68impl ChineseNER {
69    #[cfg(feature = "bundled-model")]
70    pub fn new() -> Self {
71        let model_bytes = include_bytes!("ner.model");
72        let model = crfsuite::Model::from_memory(&model_bytes[..]).expect("open model failed");
73        Self {
74            model,
75            segmentor: Jieba::new(),
76        }
77    }
78
79    pub fn from_model(model_path: &str) -> Result<Self, Error> {
80        let model = crfsuite::Model::from_file(model_path)?;
81        Ok(Self {
82            model,
83            segmentor: Jieba::new(),
84        })
85    }
86
87    pub fn predict<'a>(&'a self, sentence: &'a str) -> Result<NamedEntity<'a>, Error> {
88        use crfsuite::Attribute;
89
90        let mut tagger = self.model.tagger()?;
91        let (split_words, tags) = split_by_words(&self.segmentor, sentence);
92        let features = sent2features(&split_words);
93        let attributes: Vec<crfsuite::Item> = features
94            .into_iter()
95            .map(|x| {
96                x.into_iter()
97                    .map(|f| Attribute::new(f, 1.0))
98                    .collect::<crfsuite::Item>()
99            })
100            .collect();
101        let tag_result = tagger.tag(&attributes)?;
102        let mut is_tag = false;
103        let mut start_index = 0;
104        let mut entities = Vec::new();
105        for (index, tag) in tag_result.iter().enumerate() {
106            if !is_tag && tag.starts_with('B') {
107                start_index = index;
108                is_tag = true;
109            } else if is_tag && tag == "O" {
110                entities.push((start_index, index, get_tag_name(&tag_result[start_index])));
111                is_tag = false;
112            }
113        }
114        let words = tags.iter().map(|x| x.word).collect();
115        let tags = tags.iter().map(|x| x.tag).collect();
116        Ok(NamedEntity {
117            word: words,
118            tag: tags,
119            entity: entities,
120        })
121    }
122}
123
124fn get_tag_name(tag: &str) -> &'static str {
125    if tag.contains("PRO") {
126        "product_name"
127    } else if tag.contains("PER") {
128        "person_name"
129    } else if tag.contains("TIM") {
130        "time"
131    } else if tag.contains("ORG") {
132        "org_name"
133    } else if tag.contains("LOC") {
134        "location"
135    } else {
136        "unknown"
137    }
138}
139
140#[derive(Debug, PartialEq)]
141struct SplitWord<'a> {
142    word: &'a str,
143    status: &'static str,
144    tag: String,
145    entity_type: String,
146}
147
148fn split_by_words<'a>(
149    segmentor: &'a Jieba,
150    sentence: &'a str,
151) -> (Vec<SplitWord<'a>>, Vec<jieba_rs::Tag<'a>>) {
152    let mut words = Vec::new();
153    let mut char_indices = sentence.char_indices().map(|x| x.0).peekable();
154    while let Some(pos) = char_indices.next() {
155        if let Some(next_pos) = char_indices.peek() {
156            let word = &sentence[pos..*next_pos];
157            words.push(SplitWord {
158                word: word,
159                status: "",
160                tag: String::new(),
161                entity_type: String::new(),
162            });
163        } else {
164            let word = &sentence[pos..];
165            words.push(SplitWord {
166                word: word,
167                status: "",
168                tag: String::new(),
169                entity_type: String::new(),
170            });
171        }
172    }
173    let tags = segmentor.tag(sentence, true);
174    let mut index = 0;
175    for word_tag in &tags {
176        let char_count = word_tag.word.chars().count();
177        for i in 0..char_count {
178            let status = {
179                if char_count == 1 {
180                    "S"
181                } else if i == 0 {
182                    "B"
183                } else if i == char_count - 1 {
184                    "E"
185                } else {
186                    "I"
187                }
188            };
189            words[index].status = status;
190            words[index].tag = word_tag.tag.to_string();
191            index += 1;
192        }
193    }
194    (words, tags)
195}
196
197fn sent2features(split_words: &[SplitWord]) -> Vec<Vec<String>> {
198    let mut features = Vec::with_capacity(split_words.len());
199    for i in 0..split_words.len() {
200        features.push(word2features(split_words, i));
201    }
202    features
203}
204
205fn word2features(split_words: &[SplitWord], i: usize) -> Vec<String> {
206    let split_word = &split_words[i];
207    let word = split_word.word;
208    let is_digit = word.chars().all(|c| c.is_ascii_digit());
209    let mut features = vec![
210        "bias".to_string(),
211        format!("word={}", word),
212        format!("word.isdigit={}", if is_digit { "True" } else { "False" }),
213        format!("postag={}", split_word.tag),
214        format!("cuttag={}", split_word.status),
215    ];
216    if i > 0 {
217        let split_word1 = &split_words[i - 1];
218        features.push(format!("-1:word={}", split_word1.word));
219        features.push(format!("-1:postag={}", split_word1.tag));
220        features.push(format!("-1:cuttag={}", split_word1.status));
221    } else {
222        features.push("BOS".to_string());
223    }
224    if i < split_words.len() - 1 {
225        let split_word1 = &split_words[i + 1];
226        features.push(format!("+1:word={}", split_word1.word));
227        features.push(format!("+1:postag={}", split_word1.tag));
228        features.push(format!("+1:cuttag={}", split_word1.status));
229    } else {
230        features.push("EOS".to_string());
231    }
232    features
233}
234
235pub struct NERTrainer {
236    trainer: crfsuite::Trainer,
237    segmentor: jieba_rs::Jieba,
238    output_path: String,
239}
240
241impl NERTrainer {
242    pub fn new(output_path: &str) -> Self {
243        Self {
244            trainer: crfsuite::Trainer::new(true),
245            segmentor: Jieba::new(),
246            output_path: output_path.to_string(),
247        }
248    }
249
250    pub fn train<T: AsRef<Path>>(&mut self, dataset_path: T) -> Result<(), Error> {
251        let file = File::open(dataset_path)?;
252        let reader = BufReader::new(file);
253        let lines = reader.lines().collect::<Result<Vec<String>, _>>()?;
254        let mut x_train = Vec::new();
255        let mut y_train = Vec::new();
256        let mut words: Vec<SplitWord> = Vec::new();
257        for line in &lines {
258            if line.is_empty() {
259                let sentence: String = words.iter().map(|x| x.word).collect::<Vec<_>>().join("");
260                let tags = self.segmentor.tag(&sentence, true);
261                let mut index = 0;
262                for word_tag in tags {
263                    let char_count = word_tag.word.chars().count();
264                    for i in 0..char_count {
265                        let status = {
266                            if char_count == 1 {
267                                "S"
268                            } else if i == 0 {
269                                "B"
270                            } else if i == char_count - 1 {
271                                "E"
272                            } else {
273                                "I"
274                            }
275                        };
276                        words[index].status = status;
277                        words[index].tag = word_tag.tag.to_string();
278                        index += 1;
279                    }
280                }
281                x_train.push(sent2features(&words));
282                y_train.push(
283                    words
284                        .iter()
285                        .map(|x| x.entity_type.to_string())
286                        .collect::<Vec<_>>(),
287                );
288                words.clear();
289            } else {
290                let parts: Vec<&str> = line.split_ascii_whitespace().collect();
291                let word = &parts[0];
292                let entity_type = &parts[1];
293                words.push(SplitWord {
294                    word: word,
295                    status: "",
296                    tag: String::new(),
297                    entity_type: entity_type.to_string(),
298                });
299            }
300        }
301        self.trainer
302            .select(crfsuite::Algorithm::LBFGS, crfsuite::GraphicalModel::CRF1D)?;
303        for (features, yseq) in x_train.into_iter().zip(y_train) {
304            let xseq: Vec<crfsuite::Item> = features
305                .into_iter()
306                .map(|x| {
307                    x.into_iter()
308                        .map(|f| crfsuite::Attribute::new(f, 1.0))
309                        .collect::<crfsuite::Item>()
310                })
311                .collect();
312            self.trainer.append(&xseq, &yseq, 0)?;
313        }
314        self.trainer.train(&self.output_path, -1)?;
315        Ok(())
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use jieba_rs::Jieba;
323
324    #[test]
325    fn test_split_by_words() {
326        let jieba = Jieba::new();
327        let sentence = "洗衣机,国内掀起了大数据、云计算的热潮。仙鹤门地区。";
328        let (ret, _) = split_by_words(&jieba, sentence);
329        assert_eq!(
330            ret,
331            vec![
332                SplitWord {
333                    word: "洗",
334                    status: "B",
335                    tag: "n".to_string(),
336                    entity_type: String::new()
337                },
338                SplitWord {
339                    word: "衣",
340                    status: "I",
341                    tag: "n".to_string(),
342                    entity_type: String::new()
343                },
344                SplitWord {
345                    word: "机",
346                    status: "E",
347                    tag: "n".to_string(),
348                    entity_type: String::new()
349                },
350                SplitWord {
351                    word: ",",
352                    status: "S",
353                    tag: "x".to_string(),
354                    entity_type: String::new()
355                },
356                SplitWord {
357                    word: "国",
358                    status: "B",
359                    tag: "s".to_string(),
360                    entity_type: String::new()
361                },
362                SplitWord {
363                    word: "内",
364                    status: "E",
365                    tag: "s".to_string(),
366                    entity_type: String::new()
367                },
368                SplitWord {
369                    word: "掀",
370                    status: "B",
371                    tag: "v".to_string(),
372                    entity_type: String::new()
373                },
374                SplitWord {
375                    word: "起",
376                    status: "E",
377                    tag: "v".to_string(),
378                    entity_type: String::new()
379                },
380                SplitWord {
381                    word: "了",
382                    status: "S",
383                    tag: "ul".to_string(),
384                    entity_type: String::new()
385                },
386                SplitWord {
387                    word: "大",
388                    status: "S",
389                    tag: "a".to_string(),
390                    entity_type: String::new()
391                },
392                SplitWord {
393                    word: "数",
394                    status: "B",
395                    tag: "n".to_string(),
396                    entity_type: String::new()
397                },
398                SplitWord {
399                    word: "据",
400                    status: "E",
401                    tag: "n".to_string(),
402                    entity_type: String::new()
403                },
404                SplitWord {
405                    word: "、",
406                    status: "S",
407                    tag: "x".to_string(),
408                    entity_type: String::new()
409                },
410                SplitWord {
411                    word: "云",
412                    status: "S",
413                    tag: "ns".to_string(),
414                    entity_type: String::new()
415                },
416                SplitWord {
417                    word: "计",
418                    status: "B",
419                    tag: "v".to_string(),
420                    entity_type: String::new()
421                },
422                SplitWord {
423                    word: "算",
424                    status: "E",
425                    tag: "v".to_string(),
426                    entity_type: String::new()
427                },
428                SplitWord {
429                    word: "的",
430                    status: "S",
431                    tag: "uj".to_string(),
432                    entity_type: String::new()
433                },
434                SplitWord {
435                    word: "热",
436                    status: "B",
437                    tag: "n".to_string(),
438                    entity_type: String::new()
439                },
440                SplitWord {
441                    word: "潮",
442                    status: "E",
443                    tag: "n".to_string(),
444                    entity_type: String::new()
445                },
446                SplitWord {
447                    word: "。",
448                    status: "S",
449                    tag: "x".to_string(),
450                    entity_type: String::new()
451                },
452                SplitWord {
453                    word: "仙",
454                    status: "B",
455                    tag: "n".to_string(),
456                    entity_type: String::new()
457                },
458                SplitWord {
459                    word: "鹤",
460                    status: "E",
461                    tag: "n".to_string(),
462                    entity_type: String::new()
463                },
464                SplitWord {
465                    word: "门",
466                    status: "S",
467                    tag: "n".to_string(),
468                    entity_type: String::new()
469                },
470                SplitWord {
471                    word: "地",
472                    status: "B",
473                    tag: "n".to_string(),
474                    entity_type: String::new()
475                },
476                SplitWord {
477                    word: "区",
478                    status: "E",
479                    tag: "n".to_string(),
480                    entity_type: String::new()
481                },
482                SplitWord {
483                    word: "。",
484                    status: "S",
485                    tag: "x".to_string(),
486                    entity_type: String::new()
487                },
488            ]
489        );
490    }
491
492    #[cfg(feature = "bundled-model")]
493    #[test]
494    fn test_ner_predict() {
495        let ner = ChineseNER::new();
496        let sentence = "今天纽约的天气真好啊,京华大酒店的李白经理吃了一只北京烤鸭。";
497        let result = ner.predict(sentence).unwrap();
498        assert_eq!(
499            result.word,
500            vec![
501                "今天",
502                "纽约",
503                "的",
504                "天气",
505                "真好",
506                "啊",
507                ",",
508                "京华",
509                "大酒店",
510                "的",
511                "李白",
512                "经理",
513                "吃",
514                "了",
515                "一只",
516                "北京烤鸭",
517                "。"
518            ]
519        );
520        assert_eq!(
521            result.tag,
522            vec![
523                "t", "ns", "uj", "n", "d", "zg", "x", "nz", "n", "uj", "nr", "n", "v", "ul", "m",
524                "n", "x"
525            ]
526        );
527        assert_eq!(
528            result.entity,
529            vec![
530                (2, 4, "location"),
531                (11, 16, "org_name"),
532                (17, 19, "person_name"),
533                (25, 27, "location")
534            ]
535        );
536    }
537}