find_simdoc/cosine.rs
1//! Searcher for all pairs of similar documents in the Cosine space.
2use std::sync::Mutex;
3
4use crate::errors::{FindSimdocError, Result};
5use crate::feature::{FeatureConfig, FeatureExtractor};
6use crate::lsh::simhash::SimHasher;
7use crate::tfidf::{Idf, Tf};
8
9use all_pairs_hamming::chunked_join::ChunkedJoiner;
10use rand::{RngCore, SeedableRng};
11use rayon::prelude::*;
12
13/// Searcher for all pairs of similar documents in the Cosine space.
14///
15/// # Approach
16///
17/// The search steps consist of
18///
19/// 1. Extracts features from documents,
20/// where a feature is a tfidf-weighted vector representation of character or word ngrams.
21/// 2. Convert the features into binary sketches through the [simplified simhash](https://dl.acm.org/doi/10.1145/1242572.1242592).
22/// 3. Search for similar sketches in the Hamming space using [`ChunkedJoiner`].
23///
24/// # Examples
25///
26/// ```
27/// use find_simdoc::tfidf::{Idf, Tf};
28/// use find_simdoc::CosineSearcher;
29///
30/// let documents = vec![
31/// "Welcome to Jimbocho, the town of books and curry!",
32/// "Welcome to Jimbocho, the city of books and curry!",
33/// "We welcome you to Jimbocho, the town of books and curry.",
34/// "Welcome to the town of books and curry, Jimbocho!",
35/// ];
36///
37/// // Creates a searcher for word unigrams (with random seed value 42).
38/// let searcher = CosineSearcher::new(1, Some(' '), Some(42)).unwrap();
39/// // Creates a term frequency (TF) weighter.
40/// let tf = Tf::new();
41/// // Creates a inverse document frequency (IDF) weighter.
42/// let idf = Idf::new()
43/// .build(documents.iter().clone(), searcher.config())
44/// .unwrap();
45/// // Builds the database of binary sketches converted from input documents,
46/// let searcher = searcher
47/// // with the TF weighter and
48/// .tf(Some(tf))
49/// // the IDF weighter,
50/// .idf(Some(idf))
51/// // where binary sketches are in the Hamming space of 10*64 dimensions.
52/// .build_sketches_in_parallel(documents.iter(), 10)
53/// .unwrap();
54///
55/// // Searches all similar pairs within radius 0.25.
56/// let results = searcher.search_similar_pairs(0.25);
57/// // A result consists of the left-side id, the right-side id, and their distance.
58/// assert_eq!(results, vec![(0, 1, 0.1296875), (0, 3, 0.24375)]);
59/// ```
60pub struct CosineSearcher {
61 config: FeatureConfig,
62 hasher: SimHasher,
63 tf: Option<Tf>,
64 idf: Option<Idf<u64>>,
65 joiner: Option<ChunkedJoiner<u64>>,
66 shows_progress: bool,
67}
68
69impl CosineSearcher {
70 /// Creates an instance.
71 ///
72 /// # Arguments
73 ///
74 /// * `window_size` - Window size for w-shingling in feature extraction (must be more than 0).
75 /// * `delimiter` - Delimiter for recognizing words as tokens in feature extraction.
76 /// If `None`, characters are used for tokens.
77 /// * `seed` - Seed value for random values.
78 pub fn new(window_size: usize, delimiter: Option<char>, seed: Option<u64>) -> Result<Self> {
79 let seed = seed.unwrap_or_else(rand::random::<u64>);
80 let mut seeder = rand_xoshiro::SplitMix64::seed_from_u64(seed);
81 let config = FeatureConfig::new(window_size, delimiter, seeder.next_u64())?;
82 let hasher = SimHasher::new(seeder.next_u64());
83 Ok(Self {
84 config,
85 hasher,
86 tf: None,
87 idf: None,
88 joiner: None,
89 shows_progress: false,
90 })
91 }
92
93 /// Shows the progress via the standard error output?
94 pub const fn shows_progress(mut self, yes: bool) -> Self {
95 self.shows_progress = yes;
96 self
97 }
98
99 /// Sets the scheme of TF weighting.
100 #[allow(clippy::missing_const_for_fn)]
101 pub fn tf(mut self, tf: Option<Tf>) -> Self {
102 self.tf = tf;
103 self
104 }
105
106 /// Sets the scheme of IDF weighting.
107 #[allow(clippy::missing_const_for_fn)]
108 pub fn idf(mut self, idf: Option<Idf<u64>>) -> Self {
109 self.idf = idf;
110 self
111 }
112
113 /// Builds the database of sketches from input documents.
114 ///
115 /// # Arguments
116 ///
117 /// * `documents` - List of documents (must not include an empty string).
118 /// * `num_chunks` - Number of chunks of sketches, indicating that
119 /// the number of dimensions in the Hamming space is `num_chunks*64`.
120 pub fn build_sketches<I, D>(mut self, documents: I, num_chunks: usize) -> Result<Self>
121 where
122 I: IntoIterator<Item = D>,
123 D: AsRef<str>,
124 {
125 let mut joiner = ChunkedJoiner::<u64>::new(num_chunks).shows_progress(self.shows_progress);
126 let extractor = FeatureExtractor::new(&self.config);
127
128 let mut feature = vec![];
129 for (i, doc) in documents.into_iter().enumerate() {
130 if self.shows_progress && (i + 1) % 10000 == 0 {
131 eprintln!("Processed {} documents...", i + 1);
132 }
133 let doc = doc.as_ref();
134 if doc.is_empty() {
135 return Err(FindSimdocError::input("Input document must not be empty."));
136 }
137 extractor.extract_with_weights(doc, &mut feature);
138 if let Some(tf) = self.tf.as_ref() {
139 tf.tf(&mut feature);
140 }
141 if let Some(idf) = self.idf.as_ref() {
142 for (term, weight) in feature.iter_mut() {
143 *weight *= idf.idf(*term);
144 }
145 }
146 joiner.add(self.hasher.iter(&feature)).unwrap();
147 }
148 self.joiner = Some(joiner);
149 Ok(self)
150 }
151
152 /// Builds the database of sketches from input documents in parallel.
153 ///
154 /// # Arguments
155 ///
156 /// * `documents` - List of documents (must not include an empty string).
157 /// * `num_chunks` - Number of chunks of sketches, indicating that
158 /// the number of dimensions in the Hamming space is `num_chunks*64`.
159 ///
160 /// # Notes
161 ///
162 /// The progress is not printed even if `shows_progress = true`.
163 pub fn build_sketches_in_parallel<I, D>(
164 mut self,
165 documents: I,
166 num_chunks: usize,
167 ) -> Result<Self>
168 where
169 I: Iterator<Item = D> + Send,
170 D: AsRef<str> + Send,
171 {
172 let extractor = FeatureExtractor::new(&self.config);
173 #[allow(clippy::mutex_atomic)]
174 let processed = Mutex::new(0usize);
175 let mut sketches: Vec<_> = documents
176 .into_iter()
177 .enumerate()
178 .par_bridge()
179 .map(|(i, doc)| {
180 #[allow(clippy::mutex_atomic)]
181 {
182 // Mutex::lock also locks eprintln.
183 let mut cnt = processed.lock().unwrap();
184 *cnt += 1;
185 if self.shows_progress && *cnt % 10000 == 0 {
186 eprintln!("Processed {} documents...", *cnt);
187 }
188 }
189 let doc = doc.as_ref();
190 // TODO: Returns the error value (but I dont know the manner).
191 assert!(!doc.is_empty(), "Input document must not be empty.");
192 let mut feature = vec![];
193 extractor.extract_with_weights(doc, &mut feature);
194 if let Some(tf) = self.tf.as_ref() {
195 tf.tf(&mut feature);
196 }
197 if let Some(idf) = self.idf.as_ref() {
198 for (term, weight) in feature.iter_mut() {
199 *weight *= idf.idf(*term);
200 }
201 }
202 let mut gen = self.hasher.iter(&feature);
203 let sketch: Vec<_> = (0..num_chunks).map(|_| gen.next().unwrap()).collect();
204 (i, sketch)
205 })
206 .collect();
207 sketches.par_sort_by_key(|&(i, _)| i);
208
209 let mut joiner = ChunkedJoiner::<u64>::new(num_chunks).shows_progress(self.shows_progress);
210 for (_, sketch) in sketches {
211 joiner.add(sketch).unwrap();
212 }
213 self.joiner = Some(joiner);
214 Ok(self)
215 }
216
217 /// Searches for all pairs of similar documents within an input radius, returning
218 /// triplets of the left-side id, the right-side id, and their distance.
219 pub fn search_similar_pairs(&self, radius: f64) -> Vec<(usize, usize, f64)> {
220 self.joiner.as_ref().unwrap().similar_pairs(radius)
221 }
222
223 /// Gets the number of input documents.
224 pub fn len(&self) -> usize {
225 self.joiner
226 .as_ref()
227 .map_or(0, |joiner| joiner.num_sketches())
228 }
229
230 /// Checks if the database is empty.
231 pub fn is_empty(&self) -> bool {
232 self.len() == 0
233 }
234
235 /// Gets the memory usage in bytes.
236 pub fn memory_in_bytes(&self) -> usize {
237 self.joiner
238 .as_ref()
239 .map_or(0, |joiner| joiner.memory_in_bytes())
240 }
241
242 /// Gets the configure of feature extraction.
243 pub const fn config(&self) -> &FeatureConfig {
244 &self.config
245 }
246}