1use crate::banned::BANNED;
2use crate::buffer_proxy_iterator::BufferProxyIterator;
3use crate::mtch::*;
4use crate::replacements::REPLACEMENTS;
5use crate::trie::*;
6use crate::Set;
7use crate::{is_whitespace, Replacements, Type};
8use std::iter::Filter;
9use std::mem;
10use std::ops::Deref;
11use std::ops::RangeInclusive;
12use std::str::Chars;
13use unicode_normalization::{Decompositions, Recompositions, UnicodeNormalization};
14
15pub struct Censor<I: Iterator<Item = char>> {
20    buffer: BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>>,
23    options: Options,
24    inline: InlineState,
25    allocated: AllocatedState,
26}
27
28struct Options {
29    trie: &'static Trie,
30    replacements: &'static Replacements,
31    ignore_false_positives: bool,
33    ignore_self_censoring: bool,
34    censor_first_character_threshold: Type,
35    censor_replacement: char,
37    censor_threshold: Type,
38}
39
40impl Default for Options {
41    fn default() -> Self {
42        Self {
43            trie: &*TRIE,
44            replacements: &*REPLACEMENTS,
45            ignore_false_positives: false,
47            ignore_self_censoring: false,
48            censor_first_character_threshold: Type::OFFENSIVE & Type::SEVERE,
49            censor_replacement: '*',
51            censor_threshold: Default::default(),
52        }
53    }
54}
55
56struct InlineState {
57    separate: bool,
59    last_pos: usize,
61    typ: Type,
63    uppercase: u8,
65    repetitions: u8,
66    last: Option<char>,
67    gibberish: u8,
68    replacements: u8,
69    self_censoring: u8,
71    safe: bool,
73    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
74    match_ptrs: usize,
75    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
76    total_matches: usize,
77    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
78    total_match_characters: usize,
79    space_appended: bool,
81    done: bool,
83}
84
85impl Default for InlineState {
86    fn default() -> Self {
87        Self {
88            separate: true,
90            typ: Type::NONE,
92            uppercase: 0,
93            repetitions: 0,
94            last: None,
95            gibberish: 0,
96            replacements: 0,
97            self_censoring: 0,
98            safe: false,
99            space_appended: false,
100            done: false,
101            last_pos: usize::MAX,
102            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
103            match_ptrs: 0,
104            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
105            total_matches: 0,
106            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
107            total_match_characters: 0,
108        }
109    }
110}
111
112#[derive(Default)]
113struct AllocatedState {
114    matches: Set<Match>,
116    matches_tmp: Set<Match>,
118    pending_commit: Vec<Match>,
120    #[cfg(feature = "trace_full")]
121    detections: crate::Map<String, usize>,
122}
123
124impl AllocatedState {
125    fn clear(&mut self) {
126        let Self {
127            matches,
128            matches_tmp,
129            pending_commit,
130            #[cfg(feature = "trace_full")]
131            detections,
132        } = self;
133        matches.clear();
134        matches_tmp.clear();
135        pending_commit.clear();
136        #[cfg(feature = "trace_full")]
137        detections.clear();
138    }
139}
140
141impl<'a> Censor<Chars<'a>> {
142    pub fn from_str(s: &'a str) -> Self {
144        Self::new(s.chars())
145    }
146}
147
148impl<I: Iterator<Item = char>> Censor<I> {
149    pub fn new(text: I) -> Self {
151        Self {
152            buffer: Self::buffer_from(text),
153            options: Default::default(),
154            inline: Default::default(),
155            allocated: Default::default(),
156        }
157    }
158
159    fn buffer_from(
160        text: I,
161    ) -> BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>> {
162        fn filter_char(c: &char) -> bool {
165            use finl_unicode::categories::{CharacterCategories, MinorCategory};
166            let category = c.get_minor_category();
167            let nok = matches!(
168                category,
169                MinorCategory::Cn | MinorCategory::Co | MinorCategory::Mn
170            );
171
172            !(nok || BANNED.deref().deref().contains(*c))
173        }
174
175        BufferProxyIterator::new(
176            text
177                .nfd()
179                .filter(filter_char as fn(&char) -> bool)
180                .nfc(),
181        )
182    }
183
184    pub fn reset(&mut self, text: I) {
187        self.inline = Default::default();
188        self.allocated.clear();
189        self.buffer = Self::buffer_from(text);
190    }
191
192    pub fn with_trie(&mut self, trie: &'static Trie) -> &mut Self {
194        self.options.trie = trie;
195        self
196    }
197
198    pub fn with_replacements(&mut self, replacements: &'static Replacements) -> &mut Self {
200        self.options.replacements = replacements;
201        self
202    }
203
204    pub fn with_censor_threshold(&mut self, censor_threshold: Type) -> &mut Self {
211        self.options.censor_threshold = censor_threshold;
212        self
213    }
214
215    pub fn with_ignore_false_positives(&mut self, ignore_false_positives: bool) -> &mut Self {
220        self.options.ignore_false_positives = ignore_false_positives;
221        self
222    }
223
224    pub fn with_ignore_self_censoring(&mut self, ignore_self_censoring: bool) -> &mut Self {
234        self.options.ignore_self_censoring = ignore_self_censoring;
235        self
236    }
237
238    pub fn with_censor_first_character_threshold(
243        &mut self,
244        censor_first_character_threshold: Type,
245    ) -> &mut Self {
246        self.options.censor_first_character_threshold = censor_first_character_threshold;
247        self
248    }
249
250    pub fn with_censor_replacement(&mut self, censor_replacement: char) -> &mut Self {
263        self.options.censor_replacement = censor_replacement;
264        self
265    }
266
267    #[cfg(feature = "find_false_positives")]
269    pub fn with_separate(&mut self, separate: bool) -> &mut Self {
270        self.inline.separate = separate;
271        self
272    }
273
274    pub fn censor(&mut self) -> String {
286        assert!(
287            !self.buffer.index().is_some(),
288            "censor must be called before any other form of processing"
289        );
290        self.collect()
291    }
292
293    pub fn analyze(&mut self) -> Type {
297        self.ensure_done();
298        self.analysis()
299    }
300
301    pub fn censor_and_analyze(&mut self) -> (String, Type) {
303        let censored = self.censor();
305        (censored, self.analysis())
307    }
308
309    fn analysis(&self) -> Type {
311        self.inline.typ | self.safe_self_censoring_and_spam_detection()
312    }
313
314    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
315    pub fn match_ptrs(&self) -> usize {
316        self.inline.match_ptrs
317    }
318
319    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
320    pub fn total_matches(&self) -> usize {
321        self.inline.total_matches
322    }
323
324    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
325    pub fn total_match_characters(&self) -> usize {
326        self.inline.total_match_characters
327    }
328
329    #[cfg(feature = "trace_full")]
330    pub fn detections(&self) -> &crate::Map<String, usize> {
331        &self.allocated.detections
332    }
333
334    fn ensure_done(&mut self) {
335        if !self.inline.done {
336            for _ in self {}
337        }
338    }
339
340    fn safe_self_censoring_and_spam_detection(&self) -> Type {
341        let safe = if self.inline.safe && self.inline.repetitions < 4 {
342            Type::SAFE
343        } else {
344            Type::NONE
345        };
346
347        if self.inline.last_pos < 6 {
348            return safe;
350        }
351
352        let total = self
355            .inline
356            .last_pos
357            .saturating_add(6)
358            .min(u16::MAX as usize) as u16;
359
360        let spam = self
362            .inline
363            .uppercase
364            .max(self.inline.repetitions)
365            .max(self.inline.gibberish / 2)
366            .max(self.inline.replacements) as u16;
367
368        let percent_spam = 100 * spam / total;
370        let percent_self_censoring = 100 * self.inline.self_censoring as u16 / total;
371
372        let spam = if percent_spam >= 70 && self.inline.last_pos >= 20 {
374            Type::SPAM & Type::SEVERE
375        } else if percent_spam >= 50 && self.inline.last_pos >= 10 {
376            Type::SPAM & Type::MODERATE
377        } else if percent_spam >= 30 {
378            Type::SPAM & Type::MILD
379        } else {
380            Type::NONE
381        };
382
383        let self_censoring = if !self.options.ignore_self_censoring && percent_self_censoring > 20 {
385            Type::PROFANE & Type::MILD
386        } else {
387            Type::NONE
388        };
389
390        safe | spam | self_censoring
391    }
392}
393
394impl<I: Iterator<Item = char>> Iterator for Censor<I> {
395    type Item = char;
396
397    fn next(&mut self) -> Option<Self::Item> {
399        while let Some(raw_c) = self.buffer.next().or_else(|| {
400            if self.inline.space_appended {
401                None
402            } else {
403                self.inline.space_appended = true;
404                Some(' ')
405            }
406        }) {
407            if !self.inline.space_appended && raw_c != '!' && raw_c != '.' && raw_c != '?' {
408                self.inline.safe = false;
410            }
411
412            let pos = self.buffer.index();
413
414            self.inline.uppercase = self
415                .inline
416                .uppercase
417                .saturating_add(raw_c.is_uppercase() as u8);
418
419            let skippable = !raw_c.is_alphabetic() || is_whitespace(raw_c);
420            let replacement = self.options.replacements.get(raw_c);
421
422            #[cfg(feature = "trace")]
423            println!(
424                "Read '{}', skippable={}, replacing with={:?}",
425                raw_c, skippable, replacement
426            );
427
428            const BLOCK_ELEMENTS: RangeInclusive<char> = '\u{2580}'..='\u{259F}';
429
430            if (!self.inline.separate || self.inline.last == Some(self.options.censor_replacement))
431                && (raw_c == self.options.censor_replacement || BLOCK_ELEMENTS.contains(&raw_c))
432            {
433                self.inline.self_censoring = self.inline.self_censoring.saturating_add(1);
435            }
436
437            if let Some(last) = self.inline.last {
438                if raw_c == last {
439                    self.inline.repetitions = self.inline.repetitions.saturating_add(1);
440                }
441
442                fn is_gibberish(c: char) -> bool {
444                    matches!(c, 'a' | 's' | 'd' | 'f' | 'j' | 'k' | 'l' | ';')
445                }
446
447                if is_gibberish(raw_c) && is_gibberish(last) {
449                    self.inline.gibberish = self.inline.gibberish.saturating_add(1);
450                }
451            }
452
453            if let Some(pos) = pos {
454                if !(skippable
459                    && replacement.is_none()
460                    && !self.options.trie.root.children.contains_key(&raw_c))
461                {
462                    let begin_camel_case_word = raw_c.is_ascii_uppercase()
463                        && self
464                            .inline
465                            .last
466                            .map(|c| !c.is_ascii_uppercase())
467                            .unwrap_or(false);
468
469                    self.allocated.matches.insert(Match {
471                        node: &self.options.trie.root,
472                        start: pos, end: usize::MAX, last: 0 as char, begin_separate: self.inline.separate || begin_camel_case_word,
476                        end_separate: false, spaces: 0,
478                        skipped: 0,
479                        replacements: 0,
480                        repetitions: 0,
481                        low_confidence_replacements: 0,
482                    });
483                }
484            }
485
486            self.inline.separate = skippable;
487
488            if self.inline.separate {
489                for pending in self.allocated.pending_commit.iter_mut() {
490                    if pending.end == self.inline.last_pos {
491                        pending.end_separate = true;
492                    }
493                }
494            }
495
496            let mut drain_start: Option<usize> = None;
497            let mut safety_end = usize::MAX;
498            let mut replacement_counted = false;
499            let raw_c_lower = raw_c.to_lowercase().next().unwrap();
500
501            mem::swap(&mut self.allocated.matches, &mut self.allocated.matches_tmp);
502            for c in replacement
503                .map(|a| a.as_str())
504                .unwrap_or(&&*raw_c.encode_utf8(&mut [0; 4]))
505                .chars()
506            {
507                let benign_replacement = c == raw_c || c == raw_c_lower;
509
510                let countable_replacement = !(replacement_counted
512                    || benign_replacement
513                    || raw_c.is_ascii_alphabetic()
514                    || (raw_c.is_ascii_digit()
515                        && self
516                            .inline
517                            .last
518                            .map(|l| l.is_ascii_digit())
519                            .unwrap_or(false)));
520
521                if countable_replacement {
522                    self.inline.replacements = self.inline.replacements.saturating_add(1);
523                    replacement_counted = true;
524                }
525
526                #[cfg(feature = "trace")]
527                println!(
528                    " - Replacement '{}', benign={}, countable={}",
529                    c, benign_replacement, countable_replacement
530                );
531
532                let ignore_sep = matches!(c, '-' | '\'' | '\n' | '\r');
541
542                for m in self.allocated.matches_tmp.iter() {
543                    let m = m.clone();
544
545                    if m.low_confidence_replacements > 5
546                        || m.skipped > 5
547                        || (m.node.word && m.repetitions > 20)
548                    {
549                        #[cfg(feature = "trace")]
550                        println!("throwing out low confidence match: \"{}\"", m.node.trace);
551                        }
553
554                    safety_end = safety_end.min(m.start);
555
556                    #[cfg(feature = "trace")]
557                    println!(
558                        "  - Consider match \"{}\" with spaces={}, replacements={}",
559                        m.node.trace, m.spaces, m.replacements
560                    );
561
562                    if (skippable || c == m.last || Some(c) == m.node.last)
563                        && m.start != pos.unwrap_or(0)
564                    {
565                        let new_space = matches!(c, ' ' | '.' | ',' | ':' | ';' | '…' | '(' | ')' | '_' | '-')
570                            && m.node.last != Some(' ');
571                        let new_repetition: bool = !new_space && c == m.last;
572                        let new_skip = !new_space && skippable && !ignore_sep && !new_repetition;
573                        let new_replacement = c == m.last && raw_c != c && !new_repetition;
575                        let new_low_confidence_replacement =
576                            new_replacement && raw_c.is_ascii_digit();
577
578                        let undo_m = Match {
579                            spaces: m.spaces.saturating_add(new_space as u8),
580                            skipped: m.skipped.saturating_add(new_skip as u8),
581                            replacements: m.replacements.saturating_add(new_replacement as u8),
582                            low_confidence_replacements: m
583                                .low_confidence_replacements
584                                .saturating_add(new_low_confidence_replacement as u8),
585                            repetitions: m.repetitions.saturating_add(new_repetition as u8),
586                            last: c,
587                            ..m
588                        };
589                        #[cfg(feature = "trace")]
590                        println!("    (keep with last={}, node last={:?}, spaces={}, skip={}, repl={}, repet={})", undo_m.last, undo_m.node.last, undo_m.spaces, undo_m.skipped, undo_m.replacements, undo_m.repetitions);
591
592                        if let Some(existing) = self.allocated.matches.get(&undo_m) {
593                            let replacement = existing.combine(&undo_m);
594                            self.allocated.matches.replace(replacement);
595                        } else {
596                            self.allocated.matches.insert(undo_m);
597                        }
598                    }
599
600                    if let Some(next) = m.node.children.get(&c) {
601                        let new_replacement = !benign_replacement && (c != raw_c) && c != ' ';
602                        let new_low_confidence_replacement =
603                            new_replacement && raw_c.is_ascii_digit();
604                        let new_space =
605                            !new_replacement && (raw_c != c && self.inline.separate && c != '\'');
606
607                        let next_m = Match {
608                            node: next,
609                            spaces: m.spaces.saturating_add(new_space as u8),
610                            replacements: m.replacements.saturating_add(new_replacement as u8),
611                            low_confidence_replacements: m
612                                .low_confidence_replacements
613                                .saturating_add(new_low_confidence_replacement as u8),
614                            last: c,
615                            ..m
616                        };
617
618                        #[cfg(feature = "trace")]
619                        println!(
620                            "     - Next is \"{}\", with spaces={}, replacements={}",
621                            next.trace, next_m.spaces, next_m.replacements
622                        );
623
624                        if next.word {
625                            if next_m.node.typ.is(Type::SAFE)
626                                && next_m.start == 0
627                                && next_m.spaces == 0
628                                && next_m.skipped == 0
629                                && next_m.replacements == 0
630                                && !self.options.ignore_false_positives
631                            {
632                                #[cfg(feature = "trace")]
634                                println!("found safe word: {}", next_m.node.trace);
635                                self.inline.safe = true;
636                            }
637
638                            if next_m.node.typ.is(Type::ANY) {
654                                self.allocated.pending_commit.push(Match {
655                                    end: pos.unwrap(),
656                                    ..next_m
657                                });
658                            } else if next_m.spaces == 0
659                                && next_m.skipped == 0
660                                && next_m.replacements == 0
661                                && next_m.repetitions == 0 && !self.options.ignore_false_positives
663                            {
664                                #[cfg(feature = "trace")]
666                                println!("Found false positive {}", next_m.node.trace);
667                                drain_start = Some(
668                                    drain_start
669                                        .map(|start| start.min(next_m.start))
670                                        .unwrap_or(next_m.start),
671                                );
672                            }
673                        }
674
675                        if let Some(existing) = self.allocated.matches.get(&next_m) {
676                            let replacement = existing.combine(&next_m);
677                            self.allocated.matches.replace(replacement);
678                        } else {
679                            self.allocated.matches.insert(next_m);
680                        }
681                    }
682                }
683            }
684            self.allocated.matches_tmp.clear();
685            self.inline.last = Some(raw_c);
686            if let Some(pos) = pos {
687                self.inline.last_pos = pos;
688            }
689
690            let spy = &mut self.buffer;
691            let options = &self.options;
692            let inline = &mut self.inline;
693            let pending_commit = &mut self.allocated.pending_commit;
694            #[cfg(feature = "trace_full")]
695            let detections = &mut self.allocated.detections;
696
697            pending_commit.retain(|pending| {
698                #[cfg(feature = "trace")]
699                println!("Consider whether to cancel pending commit {} with start={} against drain_start={:?}", pending.node.trace, pending.start, drain_start);
700
701                if let Some(start) = drain_start {
703                    if pending.start >= start {
704                        #[cfg(feature = "trace")]
705                        println!("Cancelled {}", pending.node.trace);
706                        return false;
707                    }
708                }
709
710                if pending.end < safety_end {
712                    if pending.commit(
713                        &mut inline.typ,
714                        spy,
715                        options.censor_threshold,
716                        options.censor_first_character_threshold,
717                        options.censor_replacement,
718                    ) {
719                        #[cfg(any(feature = "find_false_positives", feature = "trace"))]
720                        {
721                            inline.match_ptrs ^= pending.node as *const _ as usize;
722                            inline.total_matches += 1;
723                            inline.total_match_characters += pending.end - pending.start;
724                            #[cfg(feature = "trace_full")]
725                            {
726                                *detections.entry(pending.node.trace.clone()).or_default() += 1;
727                            }
728                        }
729                    }
730                    return false;
731                }
732
733                true
736            });
737
738            if let Some(spy_next_index) = self.buffer.spy_next_index() {
740                let mut safe_until = spy_next_index < safety_end;
742
743                for pending in &self.allocated.pending_commit {
745                    if pending.start <= spy_next_index {
746                        safe_until = false;
747                        break;
748                    }
749                }
750                if safe_until {
751                    return self.buffer.spy_next();
752                }
753            }
754        }
755
756        let residual = mem::take(&mut self.allocated.pending_commit);
757        #[cfg(feature = "trace")]
758        if !residual.is_empty() {
759            println!("{} residuals", residual.len());
760        }
761        for pending in residual {
762            if pending.commit(
763                &mut self.inline.typ,
764                &mut self.buffer,
765                self.options.censor_threshold,
766                self.options.censor_first_character_threshold,
767                self.options.censor_replacement,
768            ) {
769                #[cfg(any(feature = "find_false_positives", feature = "trace"))]
770                {
771                    self.inline.match_ptrs ^= pending.node as *const _ as usize;
772                    self.inline.total_matches += 1;
773                    self.inline.total_match_characters += pending.end - pending.start;
774                    #[cfg(feature = "trace_full")]
775                    {
776                        *self
777                            .allocated
778                            .detections
779                            .entry(pending.node.trace.clone())
780                            .or_default() += 1;
781                    }
782                }
783            }
784        }
785
786        if let Some(c) = self.buffer.spy_next() {
787            return Some(c);
788        }
789
790        self.inline.done = true;
791
792        None
793    }
794}
795
796pub trait CensorStr: Sized {
798    fn censor(self) -> String;
800
801    fn is_inappropriate(self) -> bool {
803        self.is(Type::INAPPROPRIATE)
804    }
805
806    fn is(self, threshold: Type) -> bool;
808
809    fn isnt(self, threshold: Type) -> bool {
811        !self.is(threshold)
812    }
813}
814
815impl CensorStr for &str {
816    fn censor(self) -> String {
817        if should_skip_censor(self) {
818            self.to_owned()
819        } else {
820            Censor::new(self.chars()).censor()
821        }
822    }
823
824    fn is(self, threshold: Type) -> bool {
825        Censor::from_str(self).analyze().is(threshold)
826    }
827}
828
829pub trait CensorIter {
831    type Iterator: Iterator<Item = char>;
832
833    fn censor(self) -> Self::Iterator;
836}
837
838impl<I: Iterator<Item = char> + Clone> CensorIter for I {
839    type Iterator = Censor<I>;
840
841    fn censor(self) -> Self::Iterator {
844        Censor::new(self)
845    }
846}
847
848pub(crate) fn should_skip_censor(string: &str) -> bool {
851    let mut some_special = false;
852    for c in string.chars() {
853        use finl_unicode::categories::CharacterCategories;
854        if ('\u{0900}'..='\u{097F}').contains(&c) {
856            some_special = true;
857        } else if !(c.is_whitespace() || c.is_separator()) {
858            return false;
859        }
860    }
861    some_special
862}
863
864#[cfg(feature = "customize")]
883#[deprecated = "Use the equivalent Trie::customize_default().set(word, typ) or the safe API Censor::with_trie"]
884pub unsafe fn add_word(word: &str, typ: Type) {
885    Trie::customize_default().set(word, typ)
886}
887
888#[cfg(test)]
889mod tests {
890    #![allow(unused_imports)]
891
892    extern crate test;
893    use crate::censor::should_skip_censor;
894    use crate::{Censor, CensorIter, CensorStr, Trie, Type};
895    use bitflags::_core::ops::Not;
896    use rand::prelude::ThreadRng;
897    use rand::{thread_rng, Rng};
898    use serial_test::serial;
899    use std::fs::File;
900    use std::io::BufReader;
901    use std::time::{Duration, Instant};
902    use test::Bencher;
903
904    #[test]
905    #[serial]
906    fn short_replacement() {
907        "99".isnt(Type::PROFANE);
908        "900".isnt(Type::PROFANE);
909        "kkk".is(Type::OFFENSIVE);
910    }
911
912    #[test]
913    #[serial]
914    fn unicode_whitespace() {
915        assert!("fu\u{1160}ck".is(Type::PROFANE));
916        assert!(!"fu\u{1161}ck".is(Type::PROFANE));
917    }
918
919    #[test]
920    #[serial]
921    fn unicode_abuse() {
922        let mut rng = thread_rng();
923
924        fn random_string(rng: &mut ThreadRng, len: usize) -> String {
925            rng.sample_iter::<char, _>(rand::distributions::Standard)
926                .take(len)
927                .collect()
928        }
929
930        for _ in 0..10 {
931            let input = random_string(&mut rng, 100);
932            let censored = input.censor();
933
934            assert!(censored.len() < input.len() / 2);
936
937            println!("{} -> {}", input, censored);
938        }
939    }
940
941    #[allow(dead_code)]
942    fn find_detection(text: &str) {
943        let holistic = Censor::from_str(text).analyze();
944
945        if holistic & Type::SPAM.not() != Type::NONE {
946            println!("{}", text);
947
948            let mut start = 0;
950            let mut end = text.chars().count();
951
952            while start < end
953                && Censor::new(text.chars().skip(start).take(end - start))
954                    .analyze()
955                    .is(Type::ANY)
956            {
957                start += 1;
958            }
959            start = start.saturating_sub(1);
960            while start < end
961                && Censor::new(text.chars().skip(start).take(end - start))
962                    .analyze()
963                    .is(Type::ANY)
964            {
965                end -= 1;
966            }
967            end += 1;
968            for _ in 0..start {
969                print!("-");
970            }
971            for _ in start..end {
972                print!("^");
973            }
974            print!(" ");
975            println!(
976                "(\"{}\" is {:?})",
977                text.chars()
978                    .skip(start)
979                    .take(end - start)
980                    .collect::<String>(),
981                holistic
982            );
983        } else {
984            println!("{} ({:?})", text, holistic);
985        }
986    }
987
988    #[test]
989    #[serial]
990    fn curated() {
991        let mut cases: Vec<(&str, bool, Option<bool>)> = vec![("", false, Some(false))];
992        cases.extend(
993            include_str!("test_positive.txt")
994                .split('\n')
995                .filter(|l| !l.is_empty())
996                .map(|l| (l, true, Some(false))),
997        );
998        cases.extend(
999            include_str!("test_negative.txt")
1000                .split('\n')
1001                .filter(|l| !l.is_empty())
1002                .map(|l| (l, false, None)),
1003        );
1004        cases.extend(
1005            include_str!("safe.txt")
1006                .split('\n')
1007                .filter(|l| !l.is_empty() && !l.starts_with('#'))
1008                .map(|l| (l, false, Some(true))),
1009        );
1010        cases.extend(
1011            include_str!("test_safe.txt")
1012                .split('\n')
1013                .filter(|l| !l.is_empty())
1014                .map(|l| (l, false, Some(true))),
1015        );
1016
1017        let mut failures = Vec::new();
1018
1019        for (case, any_truth, safe_truth) in cases {
1020            let typ = Censor::from_str(case).analyze();
1026            let any = typ.is(Type::ANY);
1027            let safe = typ.is(Type::SAFE);
1028
1029            if any != any_truth {
1033                find_detection(case);
1034                failures.push(format!("FAIL: Predicted {:?} for: \"{}\"", typ, case));
1035            } else if !any_truth {
1036                let censored = case.censor();
1038                if case != censored {
1039                    failures.push(format!("Censored: : \"{case}\" -> {censored}"))
1040                }
1041            }
1042            if let Some(safe_truth) = safe_truth {
1043                if safe != safe_truth {
1044                    failures.push(format!("FAIL: Predicted safe={} for: \"{}\"", safe, case));
1045                }
1046            }
1047        }
1048
1049        if !failures.is_empty() {
1050            for failure in failures {
1051                println!("{failure}");
1052            }
1053            panic!();
1054        }
1055    }
1056
1057    #[test]
1058    #[serial]
1059    fn censor() {
1060        let censored = Censor::from_str("HELLO fučk Shit nudes WORLD!")
1061            .with_censor_replacement('#')
1062            .with_censor_first_character_threshold(Type::SEXUAL & Type::SEVERE)
1063            .censor();
1064
1065        assert_eq!(censored, "HELLO f### S### ##### WORLD!");
1066
1067        assert_eq!("fcking coward".censor(), "f***** coward");
1069
1070        let censored = Censor::from_str("卍")
1071            .with_censor_first_character_threshold(Type::NONE)
1072            .censor();
1073
1074        assert_eq!(censored, "*");
1075    }
1076
1077    #[test]
1078    #[serial]
1079    fn bidirectional() {
1080        assert_eq!("an toidi", "an \u{202e}toidi".censor());
1082    }
1083
1084    #[test]
1085    #[serial]
1086    fn analyze() {
1087        let analysis = Censor::from_str("HELLO fuck shit WORLD!").analyze();
1088
1089        assert_ne!(analysis, Type::NONE);
1090        assert!(analysis.is(Type::INAPPROPRIATE));
1091        assert!(analysis.is(Type::PROFANE));
1092        assert!(analysis.isnt(Type::SEXUAL & Type::SEVERE));
1093        assert!(analysis.isnt(Type::OFFENSIVE));
1094        assert!(analysis.isnt(Type::MEAN));
1095    }
1096
1097    #[test]
1099    #[serial]
1100    fn apis() {
1101        "abcd".censor();
1102        String::from("abcd").censor();
1103        let _ = "abcd".chars().censor().collect::<String>();
1104        let (_, _) = Censor::new("abcd".chars())
1105            .with_censor_replacement('?')
1106            .censor_and_analyze();
1107        let mut censor = Censor::from_str("abcd");
1108        let _ = censor.censor();
1109        let _ = censor.analyze();
1110        let (_, _) = Censor::from_str("HELLO crap WORLD!").censor_and_analyze();
1111    }
1112
1113    #[test]
1114    #[serial]
1115    fn levels() {
1116        assert!("poo".is(Type::PROFANE & Type::MILD));
1117        assert!("poo".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1118        assert!("poo".isnt(Type::PROFANE & Type::MODERATE));
1119        assert!("poo".isnt(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1120        assert!("poo".isnt(Type::PROFANE & Type::SEVERE));
1121        assert!("arse".is(Type::PROFANE & Type::MODERATE));
1122        assert!("arse".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1123        assert!("arse".is(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1124        assert!("arse".isnt(Type::PROFANE & Type::MILD));
1125        assert!("arse".isnt(Type::PROFANE & Type::SEVERE));
1126        assert!("i hope you die".is(Type::MEAN & Type::SEVERE));
1127        assert!("i hope you die".is(Type::MEAN & Type::MILD_OR_HIGHER));
1128        assert!("i hope you die".is(Type::MEAN & Type::MODERATE_OR_HIGHER));
1129        assert!("i hope you die".isnt(Type::MEAN & Type::MILD));
1130        assert!("i hope you die".isnt(Type::MEAN & Type::MODERATE));
1131        assert!("You said your mother only smiled on her TV show".isnt(
1132            Type::PROFANE
1133                | Type::OFFENSIVE
1134                | Type::SEXUAL & Type::MODERATE_OR_HIGHER
1135                | Type::MEAN & Type::SEVERE
1136        ));
1137    }
1138
1139    #[test]
1140    #[serial]
1141    fn repetitions_non_safe() {
1142        assert!("hello".is(Type::SAFE));
1143        assert!("helllo".is(Type::SAFE));
1144        assert!("hellllllllo".isnt(Type::SAFE));
1145    }
1146
1147    #[test]
1148    #[serial]
1149    #[cfg(not(debug_assertions))]
1150    fn accuracy() {
1151        fn rustrict(s: &str) -> bool {
1152            s.is(Type::ANY)
1153        }
1154
1155        #[allow(dead_code)]
1156        fn rustrict_old(s: &str) -> bool {
1157            rustrict_old::CensorStr::is(s, rustrict_old::Type::ANY)
1158        }
1159
1160        fn censor(s: &str) -> bool {
1161            use censor_crate::*;
1162            let filter = Standard + Sex + Zealous;
1163            filter.check(s)
1164        }
1165
1166        let mut stfu_filter = stfu_crate::types::OwnedFilter::default();
1167        use stfu_crate::word_lists::severity::{MILD, SEVERE, STRONG};
1168        stfu_filter.add_slice(&MILD);
1169        stfu_filter.add_slice(&STRONG);
1170        stfu_filter.add_slice(&SEVERE);
1171
1172        let stfu = |s: &str| -> bool { stfu_filter.filter_string(s).is_some() };
1173
1174        fn profane_rs(s: &str) -> bool {
1175            profane_rs_crate::contains_profanity(s, false)
1176        }
1177
1178        println!("| Crate | Accuracy | Positive Accuracy | Negative Accuracy | Time |");
1179        println!("|-------|----------|-------------------|-------------------|------|");
1180        print_accuracy(
1181            "https://crates.io/crates/rustrict",
1182            rustrict,
1183            false, Some(rustrict_old as fn(&str) -> bool).filter(|_| std::env::var("COMPARE").is_ok()),
1185        );
1186        print_accuracy("https://crates.io/crates/censor", censor, false, None);
1187        print_accuracy("https://crates.io/crates/stfu", stfu, false, None);
1188        print_accuracy(
1189            "https://crates.io/crates/profane-rs",
1190            profane_rs,
1191            false,
1192            None,
1193        );
1194    }
1195
1196    #[allow(dead_code)]
1197    fn print_accuracy(
1198        link: &str,
1199        checker: impl Fn(&str) -> bool,
1200        find_detections: bool,
1201        compare_to: Option<fn(&str) -> bool>,
1202    ) {
1203        let start = Instant::now();
1204        let (total, positive, negative) = accuracy_of(checker, find_detections, compare_to);
1205        println!(
1206            "| [{}]({}) | {:.2}% | {:.2}% | {:.2}% | {:.2}s |",
1207            link.split('/').last().unwrap(),
1208            link,
1209            total * 100.0,
1210            positive * 100.0,
1211            negative * 100.0,
1212            start.elapsed().as_secs()
1213        );
1214    }
1215
1216    #[allow(dead_code)]
1217    fn accuracy_of(
1218        checker: impl Fn(&str) -> bool,
1219        find_detections: bool,
1220        compare_to: Option<fn(&str) -> bool>,
1221    ) -> (f32, f32, f32) {
1222        let file = File::open("test.csv").unwrap();
1223        let reader = BufReader::new(file);
1224        let mut csv = csv::Reader::from_reader(reader);
1225
1226        let mut correct_positive = 0;
1227        let mut correct_negative = 0;
1228        let mut total_positive = 0;
1229        let mut total_negative = 0;
1230
1231        for line in csv.records().take(100000) {
1232            let record = line.unwrap();
1233            let truth = record[0].parse::<i8>().unwrap() == 1;
1234            let text = &record[1];
1235            let prediction = checker(text);
1236            if prediction == truth {
1238                if truth {
1239                    correct_positive += 1;
1240                } else {
1241                    correct_negative += 1;
1242                }
1243            } else if find_detections && text.len() < 100 {
1244                println!("{}: {}", truth, text);
1245                if prediction {
1246                    find_detection(text);
1247                }
1248            }
1249            if let Some(checker) = compare_to {
1250                let compare_prediction = checker(text);
1251                if prediction != compare_prediction && text.len() < 100 {
1252                    println!("COMPARISON: On \"{}\", output {} instead", text, prediction);
1253                }
1254            }
1255            if truth {
1256                total_positive += 1;
1257            } else {
1258                total_negative += 1;
1259            }
1260        }
1261
1262        (
1263            (correct_positive + correct_negative) as f32 / (total_positive + total_negative) as f32,
1264            correct_positive as f32 / total_positive as f32,
1265            correct_negative as f32 / total_negative as f32,
1266        )
1267    }
1268
1269    #[test]
1270    #[serial]
1271    fn devanagari() {
1272        println!("f\u{0900}u\u{0900}c\u{0900}k");
1273        const TEST: &'static str = "हत्यारा मकसहूद भाई तुम बड़ा मस्त काम करती।";
1274        assert!(should_skip_censor(TEST));
1275        assert_eq!(TEST, TEST.censor());
1276    }
1277
1278    #[test]
1279    #[serial]
1280    fn pancakes() {
1281        assert_eq!(
1282            "🥞",
1283            std::str::from_utf8(&[240, 159, 165, 158]).unwrap().censor()
1284        );
1285    }
1286
1287    #[test]
1288    #[serial]
1289    fn bandwidth() {
1290        let file = File::open("test.csv").unwrap();
1291        let total_len = file.metadata().unwrap().len() as usize;
1292        let reader = BufReader::new(file);
1293        let mut csv = csv::Reader::from_reader(reader);
1294
1295        let mut text = String::with_capacity(total_len);
1296
1297        for line in csv.records().take(100000) {
1298            let record = line.unwrap();
1299            text.push_str(&record[1]);
1300        }
1301
1302        for power in 1..16 {
1303            let len = 2usize.pow(power);
1304
1305            if len > text.len() {
1306                break;
1307            }
1308
1309            let now = Instant::now();
1310
1311            let (_, _) = Censor::from_str(&text[0..len]).censor_and_analyze();
1312
1313            let elapsed = now.elapsed();
1314
1315            println!(
1316                "{}, {}, {}",
1317                len,
1318                elapsed.as_secs_f32(),
1319                len as f32 / elapsed.as_secs_f32() / 1000.0 / 1000.0
1320            );
1321        }
1322    }
1323
1324    #[cfg(feature = "customize")]
1325    #[test]
1326    #[serial]
1327    #[allow(deprecated)]
1328    fn customize() {
1329        use crate::add_word;
1330
1331        let test_profanity = "thisisafakeprofanityfortesting";
1332        let test_profanity_issue_7 = "плохоеслово";
1333        let test_safe = "thisisafakesafewordfortesting";
1334
1335        unsafe {
1337            add_word(test_profanity, Type::PROFANE & Type::SEVERE);
1338            add_word(test_profanity_issue_7, Type::PROFANE & Type::SEVERE);
1339            add_word(test_safe, Type::SAFE);
1340        }
1341
1342        assert!(test_profanity.is(Type::PROFANE & Type::SEVERE));
1343        assert!(test_profanity_issue_7.is(Type::PROFANE & Type::SEVERE));
1344        assert!(test_safe.is(Type::SAFE));
1345
1346        unsafe {
1347            add_word(test_profanity, Type::NONE);
1348        }
1349
1350        assert!(test_profanity.isnt(Type::PROFANE));
1351    }
1352
1353    #[cfg(feature = "serde")]
1354    #[test]
1355    #[serial]
1356    fn serde() {
1357        let large = Trie::default();
1358        let bc = bincode::serialize(&large).unwrap();
1359        let json = serde_json::to_string(&large).unwrap();
1360        println!("large bincode {}, large json {}", bc.len(), json.len());
1361
1362        let mut trie = Trie::new();
1363        trie.set("squeak", Type::SPAM & Type::MILD);
1364        trie.set("squirrel", Type::SAFE);
1365
1366        let bc = bincode::serialize(&trie).unwrap();
1367        println!("smol bincode (len {}): {bc:?}", bc.len());
1368        let json = serde_json::to_string(&trie).unwrap();
1369        println!("smol json (len {}): {json}", json.len());
1370    }
1371
1372    #[allow(soft_unstable)]
1373    #[bench]
1374    fn bench_is_inappropriate(b: &mut Bencher) {
1375        b.iter(|| test::black_box("hello fuck world shit").is_inappropriate());
1376    }
1377
1378    #[allow(soft_unstable)]
1379    #[bench]
1380    fn bench_is_inappropriate_long(b: &mut Bencher) {
1381        b.iter(|| test::black_box("hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit").is_inappropriate());
1382    }
1383
1384    #[allow(soft_unstable)]
1385    #[bench]
1386    fn bench_censor(b: &mut Bencher) {
1387        b.iter(|| test::black_box("hello fuck world shit").censor());
1388    }
1389}