rnltk/
sentiment.rs

1//! Module containing types used to get valence and arousal sentiment scores.
2
3use std::{collections::HashMap, borrow::Cow};
4use std::f64::consts::PI;
5
6use serde::{Serialize, Deserialize};
7
8use crate::stem;
9use crate::error::RnltkError;
10
11pub type CustomWords = HashMap<String, SentimentDictValue>;
12pub type CustomStems = HashMap<String, SentimentDictValue>;
13
14/// Struct for holding raw arousal and sentiment values for
15/// `average` and `standard_deviation`.
16#[derive(Debug, PartialEq)]
17pub struct RawSentiment {
18    pub average: f64,
19    pub standard_deviation: f64
20}
21
22impl RawSentiment {
23    fn new(average: f64, standard_deviation: f64) -> Self {
24        RawSentiment {
25            average,
26            standard_deviation
27        }
28    }
29}
30
31/// Struct for creating the basis of the sentiment lexicon.
32#[derive(Serialize, Deserialize, Debug)]
33pub struct SentimentDictValue {
34    /// The full, unstemmed word
35    pub word: String,
36    /// The stemmed version of the word
37    pub stem: String,
38    /// The average values of valence and arousal
39    /// Expected format of vec![5.0, 5.0]
40    pub avg: Vec<f64>,
41    /// The standard deviation values of valence and arousal
42    /// Expected format of vec![5.0, 5.0]
43    pub std: Vec<f64>
44}
45
46impl SentimentDictValue {
47    pub fn new(word: String, stem: String, avg: Vec<f64>, std: Vec<f64>) -> Self {
48        SentimentDictValue {
49            word,
50            stem,
51            avg,
52            std
53        }
54    }
55}
56
57pub struct SentimentModel {
58    custom_words: CustomWords,
59    custom_stems: CustomStems,
60}
61
62impl SentimentModel {
63    /// Creates new instance of SentimentModel from `custom_words`, a [`CustomWords`]
64    /// sentiment lexicon.
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use rnltk::sentiment::{SentimentModel, CustomWords};
70    /// 
71    /// let custom_word_dict = r#"
72    /// {
73    ///     "abduction": {
74    ///         "word": "abduction",
75    ///         "stem": "abduct",
76    ///         "avg": [2.76, 5.53],
77    ///         "std": [2.06, 2.43]
78    ///     }
79    /// }"#;
80    /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
81    /// 
82    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
83    /// if sentiment.does_term_exist("abduction") {
84    ///     println!("abduction exists");
85    /// }
86    /// ```
87    pub fn new(custom_words: CustomWords) -> Self {
88        let custom_stems_dict = SentimentDictValue::new("".to_string(), "".to_string(), vec![0.0, 0.0], vec![0.0, 0.0]);
89        let custom_stems = HashMap::from([("".to_string(), custom_stems_dict)]);
90        
91        SentimentModel {
92            custom_words,
93            custom_stems,
94        }
95    }
96
97    /// Adds new `custom_stems` lexicon of stemmed words.
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// use rnltk::sentiment::{SentimentModel, CustomWords, CustomStems};
103    /// use rnltk::sample_data;
104    /// 
105    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
106    /// 
107    /// let custom_stem_dict = r#"
108    /// {
109    ///     "abduct": {
110    ///         "word": "abduction",
111    ///         "stem": "abduct",
112    ///         "avg": [2.76, 5.53],
113    ///         "std": [2.06, 2.43]
114    ///     }
115    /// }"#;
116    /// let custom_stems_sentiment_hashmap: CustomStems = serde_json::from_str(custom_stem_dict).unwrap();
117    /// 
118    /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
119    /// sentiment.add_custom_stems(custom_stems_sentiment_hashmap);
120    /// if sentiment.does_term_exist("abduct") {
121    ///     println!("abduct exists");
122    /// }
123    /// ```
124    pub fn add_custom_stems(&mut self, custom_stems: CustomStems) {
125        self.custom_stems = custom_stems        
126    }
127
128    /// Checks if a `term` exists in the sentiment dictionaries.
129    ///
130    /// # Examples
131    ///
132    /// ```
133    /// use rnltk::sentiment::{SentimentModel, CustomWords};
134    /// use rnltk::sample_data;
135    /// 
136    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
137    /// 
138    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
139    /// if sentiment.does_term_exist("abduction") {
140    ///     println!("abduction exists");
141    /// }
142    /// ```
143    pub fn does_term_exist(&self, term: &str) -> bool {
144        self.custom_words.contains_key(term) || self.custom_stems.contains_key(term)
145    }
146
147    /// Gets the raw arousal values ([`RawSentiment`]) for a given `term` word token.
148    ///
149    /// # Examples
150    ///
151    /// ```
152    /// use rnltk::sentiment::{SentimentModel, CustomWords};
153    /// use rnltk::sample_data;
154    /// 
155    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
156    /// 
157    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
158    /// let arousal = sentiment.get_raw_arousal("abduction");
159    /// let correct_arousal = vec![5.53, 2.43];
160    /// 
161    /// assert_eq!(vec![arousal.average, arousal.standard_deviation], correct_arousal);
162    /// ```
163    pub fn get_raw_arousal(&self, term: &str) -> RawSentiment {
164        let mut average = 0.0;
165        let mut std_dev = 0.0; 
166
167        if !self.does_term_exist(term) {
168            return RawSentiment::new(average, std_dev);
169        } else if self.custom_words.contains_key(term) {
170            let sentiment_info = self.custom_words.get(term).unwrap();
171            average = sentiment_info.avg[1];
172            std_dev = sentiment_info.std[1];
173        } else if self.custom_stems.contains_key(term) {
174            let sentiment_info = self.custom_stems.get(term).unwrap();
175            average = sentiment_info.avg[1];
176            std_dev = sentiment_info.std[1];
177        }
178        RawSentiment::new(average, std_dev)
179    }
180
181    /// Gets the raw valence values ([`RawSentiment`]) for a given `term` word token.
182    ///
183    /// # Examples
184    ///
185    /// ```
186    /// use rnltk::sentiment::{SentimentModel, CustomWords};
187    /// use rnltk::sample_data;
188    /// 
189    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
190    /// 
191    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
192    /// let valence = sentiment.get_raw_valence("abduction");
193    /// let correct_valence = vec![2.76, 2.06];
194    /// 
195    /// assert_eq!(vec![valence.average, valence.standard_deviation], correct_valence);
196    /// ```
197    pub fn get_raw_valence(&self, term: &str) -> RawSentiment {
198        let mut average = 0.0;
199        let mut std_dev = 0.0; 
200
201        if !self.does_term_exist(term) {
202            return RawSentiment::new(average, std_dev);
203        } else if self.custom_words.contains_key(term) {
204            let sentiment_info = self.custom_words.get(term).unwrap();
205            average = sentiment_info.avg[0];
206            std_dev = sentiment_info.std[0];
207        } else if self.custom_stems.contains_key(term) {
208            let sentiment_info = self.custom_stems.get(term).unwrap();
209            average = sentiment_info.avg[0];
210            std_dev = sentiment_info.std[0];
211        }
212        RawSentiment::new(average, std_dev)
213    }
214
215    /// Gets the arousal value for a given `term` word token.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use rnltk::sentiment::{SentimentModel, CustomWords};
221    /// use rnltk::sample_data;
222    /// 
223    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
224    /// 
225    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
226    /// let arousal = sentiment.get_arousal_for_single_term("abduction");
227    /// let correct_arousal = 5.53;
228    /// 
229    /// assert_eq!(arousal, correct_arousal);
230    /// ```
231    pub fn get_arousal_for_single_term(&self, term: &str) -> f64 {
232        self.get_raw_arousal(term).average
233    }
234
235    /// Gets the valence value for a given `term` word token.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// use rnltk::sentiment::{SentimentModel, CustomWords};
241    /// use rnltk::sample_data;
242    /// 
243    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
244    /// 
245    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
246    /// let valence = sentiment.get_valence_for_single_term("abduction");
247    /// let correct_valence = 2.76;
248    /// 
249    /// assert_eq!(valence, correct_valence);
250    /// ```
251    pub fn get_valence_for_single_term(&self, term: &str) -> f64 {
252        self.get_raw_valence(term).average
253    }
254
255    /// Gets the arousal value for a word token vector of `terms`.
256    ///
257    /// # Examples
258    ///
259    /// ```
260    /// use rnltk::sentiment::{SentimentModel, CustomWords};
261    /// 
262    /// let custom_word_dict = r#"
263    /// {
264    ///     "betrayed": {
265    ///         "word": "betrayed",
266    ///         "stem": "betrai",
267    ///         "avg": [2.57, 7.24],
268    ///         "std": [1.83, 2.06]
269    ///     },
270    ///     "bees": {
271    ///         "word": "bees",
272    ///         "stem": "bee",
273    ///         "avg": [3.2, 6.51],
274    ///         "std": [2.07, 2.14]
275    ///     }
276    /// }"#;
277    /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
278    /// 
279    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
280    /// let arousal = sentiment.get_arousal_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
281    /// let correct_arousal = 6.881952380952381;
282    /// 
283    /// assert_eq!(arousal, correct_arousal);
284    /// ```
285    pub fn get_arousal_for_term_vector(&self, terms: &Vec<&str>) -> f64 {
286        let c = 2.0 * PI;
287        let mut prob: Vec<f64> = vec![];
288        let mut prob_sum = 0.0;
289        let mut arousal_means: Vec<f64> = vec![];
290
291        for term in terms {
292            if self.does_term_exist(term) {
293                let raw_arousal = self.get_raw_arousal(term);
294                
295                let p = 1.0 / (c * raw_arousal.standard_deviation.powi(2)).sqrt();
296                prob.push(p);
297                prob_sum += p;
298
299                arousal_means.push(raw_arousal.average);
300            }
301        }
302        let mut arousal = 0.0;
303        for index in 0..arousal_means.len() {
304            arousal += prob[index] / prob_sum * arousal_means[index];
305        }
306
307        arousal
308    }
309
310    /// Gets the valence value for a word token vector of `terms`.
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// use rnltk::sentiment::{SentimentModel, CustomWords};
316    /// 
317    /// let custom_word_dict = r#"
318    /// {
319    ///     "betrayed": {
320    ///         "word": "betrayed",
321    ///         "stem": "betrai",
322    ///         "avg": [2.57, 7.24],
323    ///         "std": [1.83, 2.06]
324    ///     },
325    ///     "bees": {
326    ///         "word": "bees",
327    ///         "stem": "bee",
328    ///         "avg": [3.2, 6.51],
329    ///         "std": [2.07, 2.14]
330    ///     }
331    /// }"#;
332    /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
333    /// 
334    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
335    /// let valence = sentiment.get_valence_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
336    /// let correct_valence = 2.865615384615385;
337    /// 
338    /// assert_eq!(valence, correct_valence);
339    /// ```
340    pub fn get_valence_for_term_vector(&self, terms: &Vec<&str>) -> f64 {
341        let c = 2.0 * PI;
342        let mut prob: Vec<f64> = vec![];
343        let mut prob_sum = 0.0;
344        let mut valence_means: Vec<f64> = vec![];
345
346        for term in terms {
347            if self.does_term_exist(term) {
348                let raw_valence = self.get_raw_valence(term);
349                
350                let p = 1.0 / (c * raw_valence.standard_deviation.powi(2)).sqrt();
351                prob.push(p);
352                prob_sum += p;
353
354                valence_means.push(raw_valence.average);
355            }
356        }
357        let mut valence = 0.0;
358        for index in 0..valence_means.len() {
359            valence += prob[index] / prob_sum * valence_means[index];
360        }
361
362        valence
363    }
364
365    /// Gets the valence, arousal sentiment for a `term` word token.
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use std::collections::HashMap;
371    /// use rnltk::sentiment::{SentimentModel, CustomWords};
372    /// use rnltk::sample_data;
373    /// 
374    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
375    /// 
376    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
377    /// let sentiment_info = sentiment.get_sentiment_for_term("abduction");
378    /// let sentiment_map = HashMap::from([("valence", 2.76), ("arousal", 5.53)]);
379    /// 
380    /// assert_eq!(sentiment_info, sentiment_map);
381    /// ```
382    pub fn get_sentiment_for_term(&self, term: &str) -> HashMap<&str, f64> {
383        let mut sentiment: HashMap<&str, f64>  = HashMap::new();
384        sentiment.insert("valence", self.get_valence_for_single_term(term));
385        sentiment.insert("arousal", self.get_arousal_for_single_term(term));
386
387        sentiment
388    }
389
390    /// Gets the valence, arousal sentiment for a word token vector of `terms`.
391    ///
392    /// # Examples
393    ///
394    /// ```
395    /// use std::collections::HashMap;
396    /// use rnltk::sentiment::{SentimentModel, CustomWords};
397    /// 
398    /// let custom_word_dict = r#"
399    /// {
400    ///     "betrayed": {
401    ///         "word": "betrayed",
402    ///         "stem": "betrai",
403    ///         "avg": [2.57, 7.24],
404    ///         "std": [1.83, 2.06]
405    ///     },
406    ///     "bees": {
407    ///         "word": "bees",
408    ///         "stem": "bee",
409    ///         "avg": [3.2, 6.51],
410    ///         "std": [2.07, 2.14]
411    ///     }
412    /// }"#;
413    /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
414    /// 
415    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
416    /// let sentiment_info = sentiment.get_sentiment_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
417    /// let sentiment_map = HashMap::from([("valence", 2.865615384615385), ("arousal", 6.881952380952381)]);
418    /// 
419    /// assert_eq!(sentiment_info, sentiment_map);
420    /// ```
421    pub fn get_sentiment_for_term_vector(&self, terms: &Vec<&str>) -> HashMap<&str, f64> {
422        let mut sentiment: HashMap<&str, f64>  = HashMap::new();
423        sentiment.insert("valence", self.get_valence_for_term_vector(terms));
424        sentiment.insert("arousal", self.get_arousal_for_term_vector(terms));
425
426        sentiment
427    }
428
429    /// Gets the Russel-like description given `valence` and `arousal` scores.
430    ///
431    /// # Examples
432    ///
433    /// ```
434    /// use rnltk::sentiment::{SentimentModel, CustomWords};
435    /// use rnltk::sample_data;
436    /// 
437    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
438    /// 
439    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
440    /// let sentiment_description = sentiment.get_sentiment_description(&2.76, &5.53);
441    /// let description = "upset";
442    /// 
443    /// assert_eq!(sentiment_description, description);
444    /// ```
445    pub fn get_sentiment_description(&self, valence: &f64, arousal: &f64) -> Cow<'static, str> {
446        if !(1.0..=9.0).contains(valence) || !(1.0..=9.0).contains(arousal) {
447            println!("Valence and arousal must be bound between 1 and 9 (inclusive)");
448            return Cow::from("unknown");
449        } 
450
451        // Center of circumplex (5,5) will give an r=0, div by zero error, so handle explicitly
452        if *valence == 5.0 && *arousal == 5.0 {
453            return Cow::from("average");
454        }
455
456        // Angular cutoffs for different emotional states (same on top and bottom)
457        let angular_cutoffs = vec![0.0, 18.43, 45.0, 71.57, 90.0, 108.43, 135.0, 161.57, 180.0];
458
459        // Terms to return for bottom, top half of circumplex
460        let lower_term = vec![
461            "contented", "serene", "relaxed", "calm",
462            "bored", "lethargic", "depressed", "sad"
463        ];
464        let upper_term = vec![
465            "happy", "elated", "excited", "alert",
466            "tense", "nervous", "stressed", "upset"
467        ];
468
469        // Normalize valence and arousal, using polar coordinates to get angle
470        // clockwise along bottom, counterclockwise along top
471        let normalized_valence = ((valence - 1.0) - 4.0) / 4.0;
472        let normalized_arousal = ((arousal - 1.0) - 4.0) / 4.0;
473        let mut radius = (normalized_valence.powi(2).abs() + normalized_arousal.powi(2).abs()).sqrt();
474        let direction = (normalized_valence / radius).acos().to_degrees();
475
476        //  Normalize radius for "strength" of emotion
477        if direction <= 45.0 || direction >= 135.0 {
478            radius /= (normalized_arousal.powi(2).abs() + 1.0).sqrt();
479        } else {
480            radius /= (normalized_valence.powi(2).abs() + 1.0).sqrt();
481        }
482
483        let mut modify = "";
484        
485        if radius <= 0.25 {
486            modify = "slightly ";
487        } else if radius <= 0.5 {
488            modify = "moderately ";
489        } else if radius > 0.75 {
490            modify = "very ";
491        }
492
493        // Use normalized arousal to determine if we're on bottom of top of circumplex
494        let mut term = lower_term;
495        if normalized_arousal > 0.0 {
496            term = upper_term;
497        }
498
499        let description;
500
501        // Walk along angular boundaries until we determine which "slice"
502        // our valence and arousal point lies in, return corresponding term
503        for index in 0..term.len() {
504            if direction >= angular_cutoffs[index] && direction <= angular_cutoffs[index + 1] {
505                description = format!("{}{}", modify, term[index]);
506                return Cow::from(description);
507            }
508        }
509
510        println!("unexpected angle {} did not match any term", normalized_arousal);
511        Cow::from("unknown")
512    }
513
514    /// Gets the Russel-like description given a `term` word token.
515    ///
516    /// # Examples
517    ///
518    /// ```
519    /// use std::collections::HashMap;
520    /// use rnltk::sentiment::{SentimentModel, CustomWords};
521    /// use rnltk::sample_data;
522    /// 
523    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
524    /// 
525    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
526    /// let sentiment_description = sentiment.get_term_description("abduction");
527    /// let description = "upset";
528    /// 
529    /// assert_eq!(sentiment_description, description);
530    /// ```
531    pub fn get_term_description(&self, term: &str) -> Cow<'static, str> {
532        let sentiment = self.get_sentiment_for_term(term);
533        if sentiment.get("arousal").unwrap() == &0.0 {
534            return Cow::from("unknown");
535        }
536        self.get_sentiment_description(sentiment.get("valence").unwrap(), sentiment.get("arousal").unwrap())
537    }
538
539    /// Gets the Russel-like description given a word token vector of `terms`.
540    ///
541    /// # Examples
542    ///
543    /// ```
544    /// use rnltk::sentiment::{SentimentModel, CustomWords};
545    /// 
546    /// let custom_word_dict = r#"
547    /// {
548    ///     "betrayed": {
549    ///         "word": "betrayed",
550    ///         "stem": "betrai",
551    ///         "avg": [2.57, 7.24],
552    ///         "std": [1.83, 2.06]
553    ///     },
554    ///     "bees": {
555    ///         "word": "bees",
556    ///         "stem": "bee",
557    ///         "avg": [3.2, 6.51],
558    ///         "std": [2.07, 2.14]
559    ///     }
560    /// }"#;
561    /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
562    /// 
563    /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
564    /// let sentiment_description = sentiment.get_term_vector_description(&vec!["I", "betrayed", "the", "bees"]);
565    /// let description = "stressed";
566    /// 
567    /// assert_eq!(sentiment_description, description);
568    /// ```
569    pub fn get_term_vector_description(&self, terms: &Vec<&str>) -> Cow<'static, str> {
570        let sentiment = self.get_sentiment_for_term_vector(terms);
571        if sentiment.get("arousal").unwrap() == &0.0 {
572            return Cow::from("unknown");
573        }
574        self.get_sentiment_description(sentiment.get("valence").unwrap(), sentiment.get("arousal").unwrap())
575    }
576
577    /// Adds a new `term` word token with its corresponding `valence` and `arousal`
578    /// values to the sentiment lexicons. If the `term` does not already exist, it 
579    /// will be added to the custom sentiment lexicon.
580    /// 
581    /// # Errors
582    /// 
583    /// Returns [`RnltkError::SentimentTermExists`] if the term already exists
584    ///
585    /// # Examples
586    ///
587    /// ```
588    /// use std::collections::HashMap;
589    /// use rnltk::sentiment::{SentimentModel, CustomWords};
590    /// use rnltk::error::RnltkError;
591    /// use rnltk::sample_data;
592    /// 
593    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
594    /// 
595    /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
596    /// let sentiment_return_value = sentiment.add_term_without_replacement("squanch", &2.0, &8.5);
597    /// match sentiment_return_value {
598    ///     Ok(_) => {
599    ///         let sentiment_info = sentiment.get_sentiment_for_term("squanch");
600    ///         let sentiment_map = HashMap::from([("valence", 2.0), ("arousal", 8.5)]);
601    /// 
602    ///         assert_eq!(sentiment_info, sentiment_map);
603    ///     },
604    ///     Err(error_msg) => assert_eq!(error_msg, RnltkError::SentimentTermExists),
605    /// }
606    /// ```
607    pub fn add_term_without_replacement(&mut self, term: &'static str, valence: &f64, arousal: &f64) -> Result<(), RnltkError>{
608        if self.does_term_exist(term) {
609            return Err(RnltkError::SentimentTermExists);
610        } else {
611            let stemmed_word = stem::get(term)?;
612            let word = term.to_string();
613            let avg = vec![*valence, *arousal];
614            let std = vec![1.0, 1.0];
615            let word_dict_value = SentimentDictValue {
616                word: word.clone(),
617                stem: stemmed_word.clone(),
618                avg,
619                std
620            };
621            let avg = vec![*valence, *arousal];
622            let std = vec![1.0, 1.0];
623            let stem_dict_value = SentimentDictValue {
624                word,
625                stem: stemmed_word,
626                avg,
627                std
628            };
629            self.custom_words.insert(term.to_string(), word_dict_value);
630            self.custom_stems.insert(term.to_string(), stem_dict_value);
631        }
632        Ok(())
633    }
634    
635    /// Adds a new `term` word token and its corresponding `valence` and `arousal`
636    /// values to the sentiment lexicons. If this `term` already exists, the `term` will be updated
637    /// with the new `valence` and `arousal` values. If the `term` does not already exist, the `term` will be
638    /// stemmed and added to the custom sentiment lexicon.
639    ///
640    /// # Errors
641    /// 
642    /// Returns [`RnltkError::StemNonAscii`] in the event that the `term` being stemmed contains non-ASCII characters (like hopè).
643    /// 
644    /// # Examples
645    ///
646    /// ```
647    /// use std::collections::HashMap;
648    /// use rnltk::sentiment::{SentimentModel, CustomWords};
649    /// use rnltk::error::RnltkError;
650    /// use rnltk::sample_data;
651    /// 
652    /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
653    /// 
654    /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
655    /// let sentiment_return_value = sentiment.add_term_with_replacement("abduction", &8.0, &8.5);
656    /// match sentiment_return_value {
657    ///     Ok(_) => {
658    ///         let sentiment_info = sentiment.get_sentiment_for_term("abduction");
659    ///         let sentiment_map = HashMap::from([("valence", 8.0), ("arousal", 8.5)]);
660    /// 
661    ///         assert_eq!(sentiment_info, sentiment_map);
662    ///     },
663    ///     Err(error_msg) => assert_eq!(error_msg, RnltkError::StemNonAscii),
664    /// }
665    /// ```
666    pub fn add_term_with_replacement(&mut self, term: &'static str, valence: &f64, arousal: &f64) -> Result<(), RnltkError>{
667        if self.custom_words.contains_key(term) {
668            let dict_value = self.custom_words.get_mut(term).unwrap();
669            dict_value.avg[0] = *valence;
670            dict_value.avg[1] = *arousal;
671        } else if self.custom_stems.contains_key(term) {
672            let dict_value = self.custom_stems.get_mut(term).unwrap();
673            dict_value.avg[0] = *valence;
674            dict_value.avg[1] = *arousal;
675        } else {
676            let stemmed_word = stem::get(term)?;
677            let word = term.to_string();
678            let avg = vec![*valence, *arousal];
679            let std = vec![1.0, 1.0];
680            let word_dict_value = SentimentDictValue {
681                word: word.clone(),
682                stem: stemmed_word.clone(),
683                avg,
684                std
685            };
686            let avg = vec![*valence, *arousal];
687            let std = vec![1.0, 1.0];
688            let stem_dict_value = SentimentDictValue {
689                word,
690                stem: stemmed_word,
691                avg,
692                std
693            };
694            self.custom_words.insert(term.to_string(), word_dict_value);
695            self.custom_stems.insert(term.to_string(), stem_dict_value);
696        }
697        Ok(())
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    struct Setup {
706        custom_words: CustomWords
707    }
708    
709    impl Setup {
710        fn new() -> Self {
711            let custom_word_dict: &str = include_str!("../test_data/test.json");
712            let custom_words: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
713            Setup {
714                custom_words
715            }
716        }
717    }
718
719    #[test]
720    fn raw_arousal() {
721        let setup = Setup::new();
722        let sentiment = SentimentModel::new(setup.custom_words);
723        let arousal = sentiment.get_raw_arousal("abduction");
724        let raw_sentiment = RawSentiment::new(5.53, 2.43);
725
726        assert_eq!(arousal, raw_sentiment);
727    }
728
729    #[test]
730    fn raw_valence() {
731        let setup = Setup::new();
732        let sentiment = SentimentModel::new(setup.custom_words);
733        let valence = sentiment.get_raw_valence("abduction");
734        let raw_sentiment = RawSentiment::new(2.76, 2.06);
735
736        assert_eq!(valence, raw_sentiment);
737    }
738
739    #[test]
740    fn valence() {
741        let setup = Setup::new();
742        let sentiment = SentimentModel::new(setup.custom_words);
743        let valence = sentiment.get_valence_for_single_term("abduction");
744        let correct_valence = 2.76;
745
746        assert_eq!(valence, correct_valence);
747    }
748
749    #[test]
750    fn arousal() {
751        let setup = Setup::new();
752        let sentiment = SentimentModel::new(setup.custom_words);
753        let arousal = sentiment.get_arousal_for_single_term("abduction");
754        let correct_arousal = 5.53;
755
756        assert_eq!(arousal, correct_arousal);
757    }
758    
759    #[test]
760    fn arousal_vector() {
761        let setup = Setup::new();
762        let sentiment = SentimentModel::new(setup.custom_words);
763        let arousal = sentiment.get_arousal_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
764        let correct_arousal = 6.881952380952381;
765
766        assert_eq!(arousal, correct_arousal);
767    }
768
769    #[test]
770    fn valence_vector() {
771        let setup = Setup::new();
772        let sentiment = SentimentModel::new(setup.custom_words);
773        let valence = sentiment.get_valence_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
774        let correct_valence = 2.865615384615385;
775
776        assert_eq!(valence, correct_valence);
777    }
778
779    #[test]
780    fn term_sentiment() {
781        let setup = Setup::new();
782        let sentiment = SentimentModel::new(setup.custom_words);
783        let sentiment_info = sentiment.get_sentiment_for_term("abduction");
784        let sentiment_map = HashMap::from([("valence", 2.76), ("arousal", 5.53)]);
785
786        assert_eq!(sentiment_info, sentiment_map);
787    }
788
789    #[test]
790    fn term_vector_sentiment() {
791        let setup = Setup::new();
792        let sentiment = SentimentModel::new(setup.custom_words);
793        let sentiment_info = sentiment.get_sentiment_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
794        let sentiment_map = HashMap::from([("valence", 2.865615384615385), ("arousal", 6.881952380952381)]);
795
796        assert_eq!(sentiment_info, sentiment_map);
797    }
798
799    #[test]
800    fn sentiment_description() {
801        let setup = Setup::new();
802        let sentiment = SentimentModel::new(setup.custom_words);
803        let sentiment_description = sentiment.get_sentiment_description(&2.76, &5.53);
804        let description = "upset";
805
806        assert_eq!(sentiment_description, description);
807    }
808
809    #[test]
810    fn term_description() {
811        let setup = Setup::new();
812        let sentiment = SentimentModel::new(setup.custom_words);
813        let sentiment_description = sentiment.get_term_description("abduction");
814        let description = "upset";
815
816        assert_eq!(sentiment_description, description);
817    }
818
819    #[test]
820    fn term_vector_description() {
821        let setup = Setup::new();
822        let sentiment = SentimentModel::new(setup.custom_words);
823        let sentiment_description = sentiment.get_term_vector_description(&vec!["I", "betrayed", "the", "bees"]);
824        let description = "stressed";
825
826        assert_eq!(sentiment_description, description);
827    }
828
829    #[test]
830    fn replace_term() {
831        let setup = Setup::new();
832        let mut sentiment = SentimentModel::new(setup.custom_words);
833        sentiment.add_term_with_replacement("abduction", &8.0, &8.5).unwrap();
834        let sentiment_info = sentiment.get_sentiment_for_term("abduction");
835        let sentiment_map = HashMap::from([("valence", 8.0), ("arousal", 8.5)]);
836
837        assert_eq!(sentiment_info, sentiment_map);
838    }
839
840    #[test]
841    fn non_ascii_error_replace_term() {
842        let setup = Setup::new();
843        let mut sentiment = SentimentModel::new(setup.custom_words);
844        let add_sentiment_error = sentiment.add_term_with_replacement("hopè", &8.0, &8.5).unwrap_err();
845        assert_eq!(add_sentiment_error, RnltkError::StemNonAscii);
846    }
847
848    #[test]
849    fn term_exists_error() {
850        let setup = Setup::new();
851        let mut sentiment = SentimentModel::new(setup.custom_words);
852        let add_sentiment_error = sentiment.add_term_without_replacement("abduction", &8.0, &8.5).unwrap_err();
853        assert_eq!(add_sentiment_error, RnltkError::SentimentTermExists);
854    }
855
856    #[test]
857    fn add_term() {
858        let setup = Setup::new();
859        let mut sentiment = SentimentModel::new(setup.custom_words);
860        sentiment.add_term_without_replacement("squanch", &2.0, &8.5).unwrap();
861        let sentiment_info = sentiment.get_sentiment_for_term("squanch");
862        let sentiment_map = HashMap::from([("valence", 2.0), ("arousal", 8.5)]);
863
864        assert_eq!(sentiment_info, sentiment_map);
865    }
866
867}