sif_embedding/sif.rs
1//! SIF: Smooth Inverse Frequency + Common Component Removal.
2use anyhow::{anyhow, Result};
3use ndarray::Array1;
4use ndarray::Array2;
5
6use crate::util;
7use crate::Float;
8use crate::SentenceEmbedder;
9use crate::WordEmbeddings;
10use crate::WordProbabilities;
11use crate::DEFAULT_N_SAMPLES_TO_FIT;
12use crate::DEFAULT_SEPARATOR;
13
14/// Default value of the SIF-weighting parameter `a`,
15/// following the original setting.
16pub const DEFAULT_PARAM_A: Float = 1e-3;
17
18/// Default value of the number of principal components to remove,
19/// following the original setting.
20pub const DEFAULT_N_COMPONENTS: usize = 1;
21
22const MODEL_MAGIC: &[u8] = b"sif_embedding::Sif 0.6\n";
23
24/// An implementation of SIF.
25///
26/// SIF is *Smooth Inverse Frequency* and *Common Component Removal*,
27/// simple but pewerful techniques for sentence embeddings described in the paper:
28/// Sanjeev Arora, Yingyu Liang, and Tengyu Ma,
29/// [A Simple but Tough-to-Beat Baseline for Sentence Embeddings](https://openreview.net/forum?id=SyK00v5xx),
30/// ICLR 2017.
31///
32/// # Brief description of API
33///
34/// The algorithm consists of two steps:
35///
36/// 1. Compute sentence embeddings with the SIF weighting.
37/// 2. Remove the common components from the sentence embeddings.
38///
39/// The common components are computed from input sentences.
40///
41/// Our API is designed to allow reuse of common components once computed
42/// because it is not always possible to obtain a sufficient number of sentences as queries to compute.
43///
44/// [`Sif::fit`] computes the common components from input sentences and returns a fitted instance of [`Sif`].
45/// [`Sif::embeddings`] computes sentence embeddings with the fitted components.
46///
47/// # Examples
48///
49/// ```
50/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
51/// use std::io::BufReader;
52///
53/// use finalfusion::compat::text::ReadText;
54/// use finalfusion::embeddings::Embeddings;
55/// use wordfreq::WordFreq;
56///
57/// use sif_embedding::{Sif, SentenceEmbedder};
58///
59/// // Loads word embeddings from a pretrained model.
60/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
61/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
62/// let word_embeddings = Embeddings::read_text(&mut reader)?;
63///
64/// // Loads word probabilities from a pretrained model.
65/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
66///
67/// // Prepares input sentences.
68/// let sentences = ["las vegas", "mega vegas"];
69///
70/// // Fits the model with input sentences.
71/// let model = Sif::new(&word_embeddings, &word_probs);
72/// let model = model.fit(&sentences)?;
73///
74/// // Computes sentence embeddings in shape (n, m),
75/// // where n is the number of sentences and m is the number of dimensions.
76/// let sent_embeddings = model.embeddings(sentences)?;
77/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
78/// # Ok(())
79/// # }
80/// ```
81///
82/// ## Only SIF weighting
83///
84/// If you want to apply only the SIF weighting to avoid the computation of common components,
85/// use [`Sif::with_parameters`] and set `n_components` to `0`.
86/// In this case, you can skip [`Sif::fit`] and directly perform [`Sif::embeddings`]
87/// because there is no parameter to fit
88/// (although the quality of the embeddings may be worse).
89///
90/// ```
91/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
92/// use std::io::BufReader;
93///
94/// use finalfusion::compat::text::ReadText;
95/// use finalfusion::embeddings::Embeddings;
96/// use wordfreq::WordFreq;
97///
98/// use sif_embedding::{Sif, SentenceEmbedder};
99///
100/// // Loads word embeddings from a pretrained model.
101/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
102/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
103/// let word_embeddings = Embeddings::read_text(&mut reader)?;
104///
105/// // Loads word probabilities from a pretrained model.
106/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
107///
108/// // When setting `n_components` to `0`, no common components are removed, and
109/// // the sentence embeddings can be computed without `fit`.
110/// let model = Sif::with_parameters(&word_embeddings, &word_probs, 1e-3, 0)?;
111/// let sent_embeddings = model.embeddings(["las vegas", "mega vegas"])?;
112/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
113/// # Ok(())
114/// # }
115/// ```
116///
117/// ## Serialization of fitted parameters
118///
119/// If you want to serialize and deserialize the fitted parameters,
120/// use [`Sif::serialize`] and [`Sif::deserialize`].
121///
122/// ```
123/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
124/// use std::io::BufReader;
125///
126/// use approx::assert_relative_eq;
127/// use finalfusion::compat::text::ReadText;
128/// use finalfusion::embeddings::Embeddings;
129/// use wordfreq::WordFreq;
130///
131/// use sif_embedding::{Sif, SentenceEmbedder};
132///
133/// // Loads word embeddings from a pretrained model.
134/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
135/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
136/// let word_embeddings = Embeddings::read_text(&mut reader)?;
137///
138/// // Loads word probabilities from a pretrained model.
139/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
140///
141/// // Prepares input sentences.
142/// let sentences = ["las vegas", "mega vegas"];
143///
144/// // Fits the model and computes sentence embeddings.
145/// let model = Sif::new(&word_embeddings, &word_probs);
146/// let model = model.fit(&sentences)?;
147/// let sent_embeddings = model.embeddings(&sentences)?;
148///
149/// // Serializes and deserializes the fitted parameters.
150/// let bytes = model.serialize()?;
151/// let other = Sif::deserialize(&bytes, &word_embeddings, &word_probs)?;
152/// let other_embeddings = other.embeddings(&sentences)?;
153/// assert_relative_eq!(sent_embeddings, other_embeddings);
154/// # Ok(())
155/// # }
156/// ```
157#[derive(Clone)]
158pub struct Sif<'w, 'p, W, P> {
159 word_embeddings: &'w W,
160 word_probs: &'p P,
161 param_a: Float,
162 n_components: usize,
163 common_components: Option<Array2<Float>>,
164 separator: char,
165 n_samples_to_fit: usize,
166}
167
168impl<'w, 'p, W, P> Sif<'w, 'p, W, P>
169where
170 W: WordEmbeddings,
171 P: WordProbabilities,
172{
173 /// Creates a new instance with default parameters defined by
174 /// [`DEFAULT_PARAM_A`] and [`DEFAULT_N_COMPONENTS`].
175 ///
176 /// # Arguments
177 ///
178 /// * `word_embeddings` - Word embeddings.
179 /// * `word_probs` - Word probabilities.
180 pub const fn new(word_embeddings: &'w W, word_probs: &'p P) -> Self {
181 Self {
182 word_embeddings,
183 word_probs,
184 param_a: DEFAULT_PARAM_A,
185 n_components: DEFAULT_N_COMPONENTS,
186 common_components: None,
187 separator: DEFAULT_SEPARATOR,
188 n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
189 }
190 }
191
192 /// Creates a new instance with manually specified parameters.
193 ///
194 /// # Arguments
195 ///
196 /// * `word_embeddings` - Word embeddings.
197 /// * `word_probs` - Word probabilities.
198 /// * `param_a` - A parameter `a` for SIF-weighting that should be positive.
199 /// * `n_components` - The number of principal components to remove.
200 ///
201 /// When setting `n_components` to `0`, no principal components are removed.
202 ///
203 /// # Errors
204 ///
205 /// Returns an error if `param_a` is not positive.
206 pub fn with_parameters(
207 word_embeddings: &'w W,
208 word_probs: &'p P,
209 param_a: Float,
210 n_components: usize,
211 ) -> Result<Self> {
212 if param_a <= 0. {
213 return Err(anyhow!("param_a must be positive."));
214 }
215 Ok(Self {
216 word_embeddings,
217 word_probs,
218 param_a,
219 n_components,
220 common_components: None,
221 separator: DEFAULT_SEPARATOR,
222 n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
223 })
224 }
225
226 /// Sets a separator for sentence segmentation (default: [`DEFAULT_SEPARATOR`]).
227 pub const fn separator(mut self, separator: char) -> Self {
228 self.separator = separator;
229 self
230 }
231
232 /// Sets the number of samples to fit the model (default: [`DEFAULT_N_SAMPLES_TO_FIT`]).
233 ///
234 /// # Errors
235 ///
236 /// Returns an error if `n_samples_to_fit` is 0.
237 pub fn n_samples_to_fit(mut self, n_samples_to_fit: usize) -> Result<Self> {
238 if n_samples_to_fit == 0 {
239 return Err(anyhow!("n_samples_to_fit must not be 0."));
240 }
241 self.n_samples_to_fit = n_samples_to_fit;
242 Ok(self)
243 }
244
245 /// Applies SIF-weighting.
246 /// (Lines 1--3 in Algorithm 1)
247 ///
248 /// # Complexities
249 ///
250 /// * Time complexity: `O(avg_num_words * embedding_size * num_sentences)`
251 /// * Space complexity: `O(embedding_size * num_sentences)`
252 fn weighted_embeddings<I, S>(&self, sentences: I) -> Array2<Float>
253 where
254 I: IntoIterator<Item = S>,
255 S: AsRef<str>,
256 {
257 let mut sent_embeddings = vec![];
258 let mut n_sentences = 0;
259 // O(num_words * embedding_size * num_sentences)
260 for sent in sentences {
261 let sent = sent.as_ref();
262 let mut n_words = 0;
263 let mut sent_embedding = Array1::zeros(self.embedding_size());
264 // O(avg_num_words * embedding_size)
265 for word in sent.split(self.separator) {
266 if let Some(word_embedding) = self.word_embeddings.embedding(word) {
267 let weight = self.param_a / (self.param_a + self.word_probs.probability(word));
268 sent_embedding += &(word_embedding.to_owned() * weight);
269 n_words += 1;
270 }
271 }
272 if n_words != 0 {
273 sent_embedding /= n_words as Float;
274 } else {
275 // If no parseable tokens, return a vector of a's
276 sent_embedding += self.param_a;
277 }
278 sent_embeddings.extend(sent_embedding.iter());
279 n_sentences += 1;
280 }
281 Array2::from_shape_vec((n_sentences, self.embedding_size()), sent_embeddings).unwrap()
282 }
283
284 /// Serializes the model.
285 pub fn serialize(&self) -> Result<Vec<u8>> {
286 let mut bytes = Vec::new();
287 bytes.extend_from_slice(MODEL_MAGIC);
288 bincode::serialize_into(&mut bytes, &self.param_a)?;
289 bincode::serialize_into(&mut bytes, &self.n_components)?;
290 bincode::serialize_into(&mut bytes, &self.common_components)?;
291 bincode::serialize_into(&mut bytes, &self.separator)?;
292 bincode::serialize_into(&mut bytes, &self.n_samples_to_fit)?;
293 Ok(bytes)
294 }
295
296 /// Deserializes the model.
297 ///
298 /// # Arguments
299 ///
300 /// * `bytes` - Byte sequence exported by [`Self::serialize`].
301 /// * `word_embeddings` - Word embeddings.
302 /// * `word_probs` - Word probabilities.
303 ///
304 /// `word_embeddings` and `word_probs` must be the same as those used in serialization.
305 pub fn deserialize(bytes: &[u8], word_embeddings: &'w W, word_probs: &'p P) -> Result<Self> {
306 if !bytes.starts_with(MODEL_MAGIC) {
307 return Err(anyhow!("The magic number of the input model mismatches."));
308 }
309 let mut bytes = &bytes[MODEL_MAGIC.len()..];
310 let param_a = bincode::deserialize_from(&mut bytes)?;
311 let n_components = bincode::deserialize_from(&mut bytes)?;
312 let common_components = bincode::deserialize_from(&mut bytes)?;
313 let separator = bincode::deserialize_from(&mut bytes)?;
314 let n_samples_to_fit = bincode::deserialize_from(&mut bytes)?;
315 Ok(Self {
316 word_embeddings,
317 word_probs,
318 param_a,
319 n_components,
320 common_components,
321 separator,
322 n_samples_to_fit,
323 })
324 }
325}
326
327impl<W, P> SentenceEmbedder for Sif<'_, '_, W, P>
328where
329 W: WordEmbeddings,
330 P: WordProbabilities,
331{
332 /// Returns the number of dimensions for sentence embeddings,
333 /// which is the same as the number of dimensions for word embeddings.
334 fn embedding_size(&self) -> usize {
335 self.word_embeddings.embedding_size()
336 }
337
338 /// Fits the model with input sentences.
339 ///
340 /// Sentences to fit are randomly sampled from `sentences` with [`Self::n_samples_to_fit`].
341 ///
342 /// If `n_components` is 0, does nothing and returns `self`.
343 ///
344 /// # Errors
345 ///
346 /// Returns an error if `sentences` is empty.
347 ///
348 /// # Complexities
349 ///
350 /// * Time complexity: `O(L*D*S + max(D,S)^3)`
351 /// * Space complexity: `O(D*S + max(D,S)^2)`
352 ///
353 /// where
354 ///
355 /// * `L` is the average number of words in a sentence.
356 /// * `D` is the number of dimensions for word embeddings (`embedding_size`).
357 /// * `S` is the number of sentences used to fit (`n_samples_to_fit`).
358 fn fit<S>(mut self, sentences: &[S]) -> Result<Self>
359 where
360 S: AsRef<str>,
361 {
362 if sentences.is_empty() {
363 return Err(anyhow!("Input sentences must not be empty."));
364 }
365 if self.n_components == 0 {
366 eprintln!("Warning: Nothing to fit since n_components is 0.");
367 return Ok(self);
368 }
369
370 // Time: O(n_samples_to_fit)
371 let sentences = util::sample_sentences(sentences, self.n_samples_to_fit);
372
373 // SIF-weighting.
374 //
375 // Time: O(avg_num_words * embedding_size * n_samples_to_fit)
376 // Space: O(embedding_size * n_samples_to_fit)
377 let sent_embeddings = self.weighted_embeddings(sentences);
378
379 // Common component removal.
380 //
381 // Time: O(max(embedding_size, n_samples_to_fit)^3)
382 // Space: O(max(embedding_size, n_samples_to_fit)^2)
383 let (_, common_components) =
384 util::principal_components(&sent_embeddings, self.n_components);
385 self.common_components = Some(common_components);
386
387 Ok(self)
388 }
389
390 /// Computes embeddings for input sentences using the fitted model.
391 ///
392 /// If `n_components` is 0, the fitting is not required.
393 ///
394 /// # Errors
395 ///
396 /// Returns an error if the model is not fitted.
397 ///
398 /// # Complexities
399 ///
400 /// * Time complexity: `O(L*D*N + C*D*N)`
401 /// * Space complexity: `O(D*N)`
402 ///
403 /// where
404 ///
405 /// * `L` is the average number of words in a sentence.
406 /// * `D` is the number of dimensions for word embeddings (`embedding_size`).
407 /// * `N` is the number of sentences (`sentences.len()`).
408 /// * `C` is the number of components to remove (`n_components`).
409 fn embeddings<I, S>(&self, sentences: I) -> Result<Array2<Float>>
410 where
411 I: IntoIterator<Item = S>,
412 S: AsRef<str>,
413 {
414 if self.n_components != 0 && self.common_components.is_none() {
415 return Err(anyhow!("The model is not fitted."));
416 }
417
418 // SIF-weighting.
419 //
420 // Time: O(avg_num_words * embedding_size * n_sentences)
421 // Space: O(embedding_size * n_sentences)
422 let sent_embeddings = self.weighted_embeddings(sentences);
423 if sent_embeddings.is_empty() {
424 return Ok(sent_embeddings);
425 }
426 if self.n_components == 0 {
427 return Ok(sent_embeddings);
428 }
429
430 // Common component removal.
431 //
432 // Time: O(embedding_size * n_sentences * n_components)
433 // Space: O(embedding_size * n_sentences)
434 let common_components = self.common_components.as_ref().unwrap();
435 let sent_embeddings =
436 util::remove_principal_components(&sent_embeddings, common_components, None);
437 Ok(sent_embeddings)
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 use approx::assert_relative_eq;
446 use ndarray::{arr1, CowArray, Ix1};
447
448 struct SimpleWordEmbeddings {}
449
450 impl WordEmbeddings for SimpleWordEmbeddings {
451 fn embedding(&self, word: &str) -> Option<CowArray<Float, Ix1>> {
452 match word {
453 "A" => Some(arr1(&[1., 2., 3.]).into()),
454 "BB" => Some(arr1(&[4., 5., 6.]).into()),
455 "CCC" => Some(arr1(&[7., 8., 9.]).into()),
456 "DDDD" => Some(arr1(&[10., 11., 12.]).into()),
457 _ => None,
458 }
459 }
460
461 fn embedding_size(&self) -> usize {
462 3
463 }
464 }
465
466 struct SimpleWordProbabilities {}
467
468 impl WordProbabilities for SimpleWordProbabilities {
469 fn probability(&self, word: &str) -> Float {
470 match word {
471 "A" => 0.6,
472 "BB" => 0.2,
473 "CCC" => 0.1,
474 "DDDD" => 0.1,
475 _ => 0.,
476 }
477 }
478
479 fn n_words(&self) -> usize {
480 4
481 }
482
483 fn entries(&self) -> Box<dyn Iterator<Item = (String, Float)> + '_> {
484 Box::new(
485 [("A", 0.6), ("BB", 0.2), ("CCC", 0.1), ("DDDD", 0.1)]
486 .iter()
487 .map(|&(word, prob)| (word.to_string(), prob)),
488 )
489 }
490 }
491
492 #[test]
493 fn test_basic() {
494 let word_embeddings = SimpleWordEmbeddings {};
495 let word_probs = SimpleWordProbabilities {};
496
497 let sif = Sif::new(&word_embeddings, &word_probs)
498 .fit(&["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
499 .unwrap();
500
501 let sent_embeddings = sif
502 .embeddings(["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
503 .unwrap();
504 assert_ne!(sent_embeddings, Array2::zeros((5, 3)));
505
506 let sent_embeddings = sif.embeddings(Vec::<&str>::new()).unwrap();
507 assert_eq!(sent_embeddings.shape(), &[0, 3]);
508
509 let sent_embeddings = sif.embeddings([""]).unwrap();
510 assert_ne!(sent_embeddings, Array2::zeros((1, 3)));
511 }
512
513 #[test]
514 fn test_separator() {
515 let word_embeddings = SimpleWordEmbeddings {};
516 let word_probs = SimpleWordProbabilities {};
517
518 let sentences_1 = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
519 let sentences_2 = &["A,BB,CCC,DDDD", "BB,CCC", "A,B,C", "Z", ""];
520
521 let sif = Sif::new(&word_embeddings, &word_probs);
522
523 let sif = sif.fit(sentences_1).unwrap();
524 let embeddings_1 = sif.embeddings(sentences_1).unwrap();
525
526 let sif = sif.separator(',');
527 let embeddings_2 = sif.embeddings(sentences_2).unwrap();
528
529 assert_relative_eq!(embeddings_1, embeddings_2);
530 }
531
532 #[test]
533 fn test_invalid_param_a() {
534 let word_embeddings = SimpleWordEmbeddings {};
535 let word_probs = SimpleWordProbabilities {};
536
537 let sif = Sif::with_parameters(&word_embeddings, &word_probs, 0., DEFAULT_N_COMPONENTS);
538 assert!(sif.is_err());
539 }
540
541 #[test]
542 fn test_no_fitted() {
543 let word_embeddings = SimpleWordEmbeddings {};
544 let word_probs = SimpleWordProbabilities {};
545
546 let sentences = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
547
548 let sif = Sif::new(&word_embeddings, &word_probs);
549 let embeddings = sif.embeddings(sentences);
550 assert!(embeddings.is_err());
551 }
552
553 #[test]
554 fn test_empty_fit() {
555 let word_embeddings = SimpleWordEmbeddings {};
556 let word_probs = SimpleWordProbabilities {};
557
558 let sif = Sif::new(&word_embeddings, &word_probs);
559 let sif = sif.fit(&Vec::<&str>::new());
560 assert!(sif.is_err());
561 }
562
563 #[test]
564 fn test_io() {
565 let word_embeddings = SimpleWordEmbeddings {};
566 let word_probs = SimpleWordProbabilities {};
567
568 let sentences = ["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
569 let model_a = Sif::new(&word_embeddings, &word_probs)
570 .fit(&sentences)
571 .unwrap();
572 let bytes = model_a.serialize().unwrap();
573 let model_b = Sif::deserialize(&bytes, &word_embeddings, &word_probs).unwrap();
574
575 let embeddings_a = model_a.embeddings(sentences).unwrap();
576 let embeddings_b = model_b.embeddings(sentences).unwrap();
577
578 assert_relative_eq!(embeddings_a, embeddings_b);
579 }
580}