1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#[derive(Debug, PartialEq)]
pub enum SimilarityModel {
    TF_IDF,
    BM25{k1: f64, b: f64},
}


/// tf(term_frequency) = log(term_frequency) + 1.0
#[inline]
fn tf(term_frequency: u32) -> f64 {
    (term_frequency as f64).ln() + 1.0
}


/// idf(term_docs, total_docs) = log((total_docs + 1.0) / (term_docs + 1.0)) + 1.0
#[inline]
fn idf(term_docs: u64, total_docs: u64) -> f64 {
    ((total_docs as f64 + 1.0) / (term_docs as f64 + 1.0)).ln() + 1.0
}


impl SimilarityModel {
    pub fn score(&self, term_frequency: u32, length: u32, total_tokens: u64, total_docs: u64, total_docs_with_term: u64) -> f64 {
        match *self {
            SimilarityModel::TF_IDF => {
                let tf = tf(term_frequency);
                let idf = idf(total_docs_with_term, total_docs);

                tf * idf
            }
            SimilarityModel::BM25{k1, b} => {
                let tf = tf(term_frequency);
                let idf = idf(total_docs_with_term, total_docs);
                let average_length = (total_tokens as f64) / (total_docs as f64);

                idf * (k1 + 1.0) * (tf / (tf + (k1 * ((1.0 - b) + b * (length as f64).sqrt() / average_length.sqrt()))))
            }
        }
    }
}


#[cfg(test)]
mod tests {
    use super::SimilarityModel;

    #[test]
    fn test_tf_idf_higher_term_freq_increases_score() {
        let similarity = SimilarityModel::TF_IDF;

        assert!(similarity.score(2, 40, 100, 10, 5) > similarity.score(1, 40, 100, 10, 5));
    }

    #[test]
    fn test_tf_idf_lower_term_docs_increases_score() {
        let similarity = SimilarityModel::TF_IDF;

        assert!(similarity.score(1, 40, 100, 10, 5) > similarity.score(1, 40, 100, 10, 10));
    }

    #[test]
    fn test_tf_idf_field_length_doesnt_affect_score() {
        let similarity = SimilarityModel::TF_IDF;

        assert!(similarity.score(1, 100, 100, 20, 5) == similarity.score(1, 40, 100, 20, 5));
    }

    #[test]
    fn test_tf_idf_total_tokens_doesnt_affect_score() {
        let similarity = SimilarityModel::TF_IDF;

        assert!(similarity.score(1, 40, 1000, 20, 5) == similarity.score(1, 40, 100, 20, 5));
    }

    #[test]
    fn test_bm25_higher_term_freq_increases_score() {
        let similarity = SimilarityModel::BM25 {
            k1: 1.2,
            b: 0.75,
        };

        assert!(similarity.score(2, 40, 100, 10, 5) > similarity.score(1, 40, 100, 10, 5));
    }

    #[test]
    fn test_bm25_lower_term_docs_increases_score() {
        let similarity = SimilarityModel::BM25 {
            k1: 1.2,
            b: 0.75,
        };

        assert!(similarity.score(1, 40, 100, 10, 5) > similarity.score(1, 40, 100, 10, 10));
    }

    #[test]
    fn test_bm25_lower_field_length_increases_score() {
        let similarity = SimilarityModel::BM25 {
            k1: 1.2,
            b: 0.75,
        };

        assert!(similarity.score(1, 40, 100, 20, 5) > similarity.score(1, 100, 100, 20, 5));
    }

    #[test]
    fn test_bm25_higher_total_tokens_increases_score() {
        let similarity = SimilarityModel::BM25 {
            k1: 1.2,
            b: 0.75,
        };

        assert!(similarity.score(1, 40, 1000, 20, 5) > similarity.score(1, 40, 100, 20, 5));
    }
}