rnltk/sentiment.rs
1//! Module containing types used to get valence and arousal sentiment scores.
2
3use std::{collections::HashMap, borrow::Cow};
4use std::f64::consts::PI;
5
6use serde::{Serialize, Deserialize};
7
8use crate::stem;
9use crate::error::RnltkError;
10
11pub type CustomWords = HashMap<String, SentimentDictValue>;
12pub type CustomStems = HashMap<String, SentimentDictValue>;
13
14/// Struct for holding raw arousal and sentiment values for
15/// `average` and `standard_deviation`.
16#[derive(Debug, PartialEq)]
17pub struct RawSentiment {
18 pub average: f64,
19 pub standard_deviation: f64
20}
21
22impl RawSentiment {
23 fn new(average: f64, standard_deviation: f64) -> Self {
24 RawSentiment {
25 average,
26 standard_deviation
27 }
28 }
29}
30
31/// Struct for creating the basis of the sentiment lexicon.
32#[derive(Serialize, Deserialize, Debug)]
33pub struct SentimentDictValue {
34 /// The full, unstemmed word
35 pub word: String,
36 /// The stemmed version of the word
37 pub stem: String,
38 /// The average values of valence and arousal
39 /// Expected format of vec![5.0, 5.0]
40 pub avg: Vec<f64>,
41 /// The standard deviation values of valence and arousal
42 /// Expected format of vec![5.0, 5.0]
43 pub std: Vec<f64>
44}
45
46impl SentimentDictValue {
47 pub fn new(word: String, stem: String, avg: Vec<f64>, std: Vec<f64>) -> Self {
48 SentimentDictValue {
49 word,
50 stem,
51 avg,
52 std
53 }
54 }
55}
56
57pub struct SentimentModel {
58 custom_words: CustomWords,
59 custom_stems: CustomStems,
60}
61
62impl SentimentModel {
63 /// Creates new instance of SentimentModel from `custom_words`, a [`CustomWords`]
64 /// sentiment lexicon.
65 ///
66 /// # Examples
67 ///
68 /// ```
69 /// use rnltk::sentiment::{SentimentModel, CustomWords};
70 ///
71 /// let custom_word_dict = r#"
72 /// {
73 /// "abduction": {
74 /// "word": "abduction",
75 /// "stem": "abduct",
76 /// "avg": [2.76, 5.53],
77 /// "std": [2.06, 2.43]
78 /// }
79 /// }"#;
80 /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
81 ///
82 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
83 /// if sentiment.does_term_exist("abduction") {
84 /// println!("abduction exists");
85 /// }
86 /// ```
87 pub fn new(custom_words: CustomWords) -> Self {
88 let custom_stems_dict = SentimentDictValue::new("".to_string(), "".to_string(), vec![0.0, 0.0], vec![0.0, 0.0]);
89 let custom_stems = HashMap::from([("".to_string(), custom_stems_dict)]);
90
91 SentimentModel {
92 custom_words,
93 custom_stems,
94 }
95 }
96
97 /// Adds new `custom_stems` lexicon of stemmed words.
98 ///
99 /// # Examples
100 ///
101 /// ```
102 /// use rnltk::sentiment::{SentimentModel, CustomWords, CustomStems};
103 /// use rnltk::sample_data;
104 ///
105 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
106 ///
107 /// let custom_stem_dict = r#"
108 /// {
109 /// "abduct": {
110 /// "word": "abduction",
111 /// "stem": "abduct",
112 /// "avg": [2.76, 5.53],
113 /// "std": [2.06, 2.43]
114 /// }
115 /// }"#;
116 /// let custom_stems_sentiment_hashmap: CustomStems = serde_json::from_str(custom_stem_dict).unwrap();
117 ///
118 /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
119 /// sentiment.add_custom_stems(custom_stems_sentiment_hashmap);
120 /// if sentiment.does_term_exist("abduct") {
121 /// println!("abduct exists");
122 /// }
123 /// ```
124 pub fn add_custom_stems(&mut self, custom_stems: CustomStems) {
125 self.custom_stems = custom_stems
126 }
127
128 /// Checks if a `term` exists in the sentiment dictionaries.
129 ///
130 /// # Examples
131 ///
132 /// ```
133 /// use rnltk::sentiment::{SentimentModel, CustomWords};
134 /// use rnltk::sample_data;
135 ///
136 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
137 ///
138 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
139 /// if sentiment.does_term_exist("abduction") {
140 /// println!("abduction exists");
141 /// }
142 /// ```
143 pub fn does_term_exist(&self, term: &str) -> bool {
144 self.custom_words.contains_key(term) || self.custom_stems.contains_key(term)
145 }
146
147 /// Gets the raw arousal values ([`RawSentiment`]) for a given `term` word token.
148 ///
149 /// # Examples
150 ///
151 /// ```
152 /// use rnltk::sentiment::{SentimentModel, CustomWords};
153 /// use rnltk::sample_data;
154 ///
155 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
156 ///
157 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
158 /// let arousal = sentiment.get_raw_arousal("abduction");
159 /// let correct_arousal = vec![5.53, 2.43];
160 ///
161 /// assert_eq!(vec![arousal.average, arousal.standard_deviation], correct_arousal);
162 /// ```
163 pub fn get_raw_arousal(&self, term: &str) -> RawSentiment {
164 let mut average = 0.0;
165 let mut std_dev = 0.0;
166
167 if !self.does_term_exist(term) {
168 return RawSentiment::new(average, std_dev);
169 } else if self.custom_words.contains_key(term) {
170 let sentiment_info = self.custom_words.get(term).unwrap();
171 average = sentiment_info.avg[1];
172 std_dev = sentiment_info.std[1];
173 } else if self.custom_stems.contains_key(term) {
174 let sentiment_info = self.custom_stems.get(term).unwrap();
175 average = sentiment_info.avg[1];
176 std_dev = sentiment_info.std[1];
177 }
178 RawSentiment::new(average, std_dev)
179 }
180
181 /// Gets the raw valence values ([`RawSentiment`]) for a given `term` word token.
182 ///
183 /// # Examples
184 ///
185 /// ```
186 /// use rnltk::sentiment::{SentimentModel, CustomWords};
187 /// use rnltk::sample_data;
188 ///
189 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
190 ///
191 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
192 /// let valence = sentiment.get_raw_valence("abduction");
193 /// let correct_valence = vec![2.76, 2.06];
194 ///
195 /// assert_eq!(vec![valence.average, valence.standard_deviation], correct_valence);
196 /// ```
197 pub fn get_raw_valence(&self, term: &str) -> RawSentiment {
198 let mut average = 0.0;
199 let mut std_dev = 0.0;
200
201 if !self.does_term_exist(term) {
202 return RawSentiment::new(average, std_dev);
203 } else if self.custom_words.contains_key(term) {
204 let sentiment_info = self.custom_words.get(term).unwrap();
205 average = sentiment_info.avg[0];
206 std_dev = sentiment_info.std[0];
207 } else if self.custom_stems.contains_key(term) {
208 let sentiment_info = self.custom_stems.get(term).unwrap();
209 average = sentiment_info.avg[0];
210 std_dev = sentiment_info.std[0];
211 }
212 RawSentiment::new(average, std_dev)
213 }
214
215 /// Gets the arousal value for a given `term` word token.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use rnltk::sentiment::{SentimentModel, CustomWords};
221 /// use rnltk::sample_data;
222 ///
223 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
224 ///
225 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
226 /// let arousal = sentiment.get_arousal_for_single_term("abduction");
227 /// let correct_arousal = 5.53;
228 ///
229 /// assert_eq!(arousal, correct_arousal);
230 /// ```
231 pub fn get_arousal_for_single_term(&self, term: &str) -> f64 {
232 self.get_raw_arousal(term).average
233 }
234
235 /// Gets the valence value for a given `term` word token.
236 ///
237 /// # Examples
238 ///
239 /// ```
240 /// use rnltk::sentiment::{SentimentModel, CustomWords};
241 /// use rnltk::sample_data;
242 ///
243 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
244 ///
245 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
246 /// let valence = sentiment.get_valence_for_single_term("abduction");
247 /// let correct_valence = 2.76;
248 ///
249 /// assert_eq!(valence, correct_valence);
250 /// ```
251 pub fn get_valence_for_single_term(&self, term: &str) -> f64 {
252 self.get_raw_valence(term).average
253 }
254
255 /// Gets the arousal value for a word token vector of `terms`.
256 ///
257 /// # Examples
258 ///
259 /// ```
260 /// use rnltk::sentiment::{SentimentModel, CustomWords};
261 ///
262 /// let custom_word_dict = r#"
263 /// {
264 /// "betrayed": {
265 /// "word": "betrayed",
266 /// "stem": "betrai",
267 /// "avg": [2.57, 7.24],
268 /// "std": [1.83, 2.06]
269 /// },
270 /// "bees": {
271 /// "word": "bees",
272 /// "stem": "bee",
273 /// "avg": [3.2, 6.51],
274 /// "std": [2.07, 2.14]
275 /// }
276 /// }"#;
277 /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
278 ///
279 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
280 /// let arousal = sentiment.get_arousal_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
281 /// let correct_arousal = 6.881952380952381;
282 ///
283 /// assert_eq!(arousal, correct_arousal);
284 /// ```
285 pub fn get_arousal_for_term_vector(&self, terms: &Vec<&str>) -> f64 {
286 let c = 2.0 * PI;
287 let mut prob: Vec<f64> = vec![];
288 let mut prob_sum = 0.0;
289 let mut arousal_means: Vec<f64> = vec![];
290
291 for term in terms {
292 if self.does_term_exist(term) {
293 let raw_arousal = self.get_raw_arousal(term);
294
295 let p = 1.0 / (c * raw_arousal.standard_deviation.powi(2)).sqrt();
296 prob.push(p);
297 prob_sum += p;
298
299 arousal_means.push(raw_arousal.average);
300 }
301 }
302 let mut arousal = 0.0;
303 for index in 0..arousal_means.len() {
304 arousal += prob[index] / prob_sum * arousal_means[index];
305 }
306
307 arousal
308 }
309
310 /// Gets the valence value for a word token vector of `terms`.
311 ///
312 /// # Examples
313 ///
314 /// ```
315 /// use rnltk::sentiment::{SentimentModel, CustomWords};
316 ///
317 /// let custom_word_dict = r#"
318 /// {
319 /// "betrayed": {
320 /// "word": "betrayed",
321 /// "stem": "betrai",
322 /// "avg": [2.57, 7.24],
323 /// "std": [1.83, 2.06]
324 /// },
325 /// "bees": {
326 /// "word": "bees",
327 /// "stem": "bee",
328 /// "avg": [3.2, 6.51],
329 /// "std": [2.07, 2.14]
330 /// }
331 /// }"#;
332 /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
333 ///
334 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
335 /// let valence = sentiment.get_valence_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
336 /// let correct_valence = 2.865615384615385;
337 ///
338 /// assert_eq!(valence, correct_valence);
339 /// ```
340 pub fn get_valence_for_term_vector(&self, terms: &Vec<&str>) -> f64 {
341 let c = 2.0 * PI;
342 let mut prob: Vec<f64> = vec![];
343 let mut prob_sum = 0.0;
344 let mut valence_means: Vec<f64> = vec![];
345
346 for term in terms {
347 if self.does_term_exist(term) {
348 let raw_valence = self.get_raw_valence(term);
349
350 let p = 1.0 / (c * raw_valence.standard_deviation.powi(2)).sqrt();
351 prob.push(p);
352 prob_sum += p;
353
354 valence_means.push(raw_valence.average);
355 }
356 }
357 let mut valence = 0.0;
358 for index in 0..valence_means.len() {
359 valence += prob[index] / prob_sum * valence_means[index];
360 }
361
362 valence
363 }
364
365 /// Gets the valence, arousal sentiment for a `term` word token.
366 ///
367 /// # Examples
368 ///
369 /// ```
370 /// use std::collections::HashMap;
371 /// use rnltk::sentiment::{SentimentModel, CustomWords};
372 /// use rnltk::sample_data;
373 ///
374 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
375 ///
376 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
377 /// let sentiment_info = sentiment.get_sentiment_for_term("abduction");
378 /// let sentiment_map = HashMap::from([("valence", 2.76), ("arousal", 5.53)]);
379 ///
380 /// assert_eq!(sentiment_info, sentiment_map);
381 /// ```
382 pub fn get_sentiment_for_term(&self, term: &str) -> HashMap<&str, f64> {
383 let mut sentiment: HashMap<&str, f64> = HashMap::new();
384 sentiment.insert("valence", self.get_valence_for_single_term(term));
385 sentiment.insert("arousal", self.get_arousal_for_single_term(term));
386
387 sentiment
388 }
389
390 /// Gets the valence, arousal sentiment for a word token vector of `terms`.
391 ///
392 /// # Examples
393 ///
394 /// ```
395 /// use std::collections::HashMap;
396 /// use rnltk::sentiment::{SentimentModel, CustomWords};
397 ///
398 /// let custom_word_dict = r#"
399 /// {
400 /// "betrayed": {
401 /// "word": "betrayed",
402 /// "stem": "betrai",
403 /// "avg": [2.57, 7.24],
404 /// "std": [1.83, 2.06]
405 /// },
406 /// "bees": {
407 /// "word": "bees",
408 /// "stem": "bee",
409 /// "avg": [3.2, 6.51],
410 /// "std": [2.07, 2.14]
411 /// }
412 /// }"#;
413 /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
414 ///
415 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
416 /// let sentiment_info = sentiment.get_sentiment_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
417 /// let sentiment_map = HashMap::from([("valence", 2.865615384615385), ("arousal", 6.881952380952381)]);
418 ///
419 /// assert_eq!(sentiment_info, sentiment_map);
420 /// ```
421 pub fn get_sentiment_for_term_vector(&self, terms: &Vec<&str>) -> HashMap<&str, f64> {
422 let mut sentiment: HashMap<&str, f64> = HashMap::new();
423 sentiment.insert("valence", self.get_valence_for_term_vector(terms));
424 sentiment.insert("arousal", self.get_arousal_for_term_vector(terms));
425
426 sentiment
427 }
428
429 /// Gets the Russel-like description given `valence` and `arousal` scores.
430 ///
431 /// # Examples
432 ///
433 /// ```
434 /// use rnltk::sentiment::{SentimentModel, CustomWords};
435 /// use rnltk::sample_data;
436 ///
437 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
438 ///
439 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
440 /// let sentiment_description = sentiment.get_sentiment_description(&2.76, &5.53);
441 /// let description = "upset";
442 ///
443 /// assert_eq!(sentiment_description, description);
444 /// ```
445 pub fn get_sentiment_description(&self, valence: &f64, arousal: &f64) -> Cow<'static, str> {
446 if !(1.0..=9.0).contains(valence) || !(1.0..=9.0).contains(arousal) {
447 println!("Valence and arousal must be bound between 1 and 9 (inclusive)");
448 return Cow::from("unknown");
449 }
450
451 // Center of circumplex (5,5) will give an r=0, div by zero error, so handle explicitly
452 if *valence == 5.0 && *arousal == 5.0 {
453 return Cow::from("average");
454 }
455
456 // Angular cutoffs for different emotional states (same on top and bottom)
457 let angular_cutoffs = vec![0.0, 18.43, 45.0, 71.57, 90.0, 108.43, 135.0, 161.57, 180.0];
458
459 // Terms to return for bottom, top half of circumplex
460 let lower_term = vec![
461 "contented", "serene", "relaxed", "calm",
462 "bored", "lethargic", "depressed", "sad"
463 ];
464 let upper_term = vec![
465 "happy", "elated", "excited", "alert",
466 "tense", "nervous", "stressed", "upset"
467 ];
468
469 // Normalize valence and arousal, using polar coordinates to get angle
470 // clockwise along bottom, counterclockwise along top
471 let normalized_valence = ((valence - 1.0) - 4.0) / 4.0;
472 let normalized_arousal = ((arousal - 1.0) - 4.0) / 4.0;
473 let mut radius = (normalized_valence.powi(2).abs() + normalized_arousal.powi(2).abs()).sqrt();
474 let direction = (normalized_valence / radius).acos().to_degrees();
475
476 // Normalize radius for "strength" of emotion
477 if direction <= 45.0 || direction >= 135.0 {
478 radius /= (normalized_arousal.powi(2).abs() + 1.0).sqrt();
479 } else {
480 radius /= (normalized_valence.powi(2).abs() + 1.0).sqrt();
481 }
482
483 let mut modify = "";
484
485 if radius <= 0.25 {
486 modify = "slightly ";
487 } else if radius <= 0.5 {
488 modify = "moderately ";
489 } else if radius > 0.75 {
490 modify = "very ";
491 }
492
493 // Use normalized arousal to determine if we're on bottom of top of circumplex
494 let mut term = lower_term;
495 if normalized_arousal > 0.0 {
496 term = upper_term;
497 }
498
499 let description;
500
501 // Walk along angular boundaries until we determine which "slice"
502 // our valence and arousal point lies in, return corresponding term
503 for index in 0..term.len() {
504 if direction >= angular_cutoffs[index] && direction <= angular_cutoffs[index + 1] {
505 description = format!("{}{}", modify, term[index]);
506 return Cow::from(description);
507 }
508 }
509
510 println!("unexpected angle {} did not match any term", normalized_arousal);
511 Cow::from("unknown")
512 }
513
514 /// Gets the Russel-like description given a `term` word token.
515 ///
516 /// # Examples
517 ///
518 /// ```
519 /// use std::collections::HashMap;
520 /// use rnltk::sentiment::{SentimentModel, CustomWords};
521 /// use rnltk::sample_data;
522 ///
523 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
524 ///
525 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
526 /// let sentiment_description = sentiment.get_term_description("abduction");
527 /// let description = "upset";
528 ///
529 /// assert_eq!(sentiment_description, description);
530 /// ```
531 pub fn get_term_description(&self, term: &str) -> Cow<'static, str> {
532 let sentiment = self.get_sentiment_for_term(term);
533 if sentiment.get("arousal").unwrap() == &0.0 {
534 return Cow::from("unknown");
535 }
536 self.get_sentiment_description(sentiment.get("valence").unwrap(), sentiment.get("arousal").unwrap())
537 }
538
539 /// Gets the Russel-like description given a word token vector of `terms`.
540 ///
541 /// # Examples
542 ///
543 /// ```
544 /// use rnltk::sentiment::{SentimentModel, CustomWords};
545 ///
546 /// let custom_word_dict = r#"
547 /// {
548 /// "betrayed": {
549 /// "word": "betrayed",
550 /// "stem": "betrai",
551 /// "avg": [2.57, 7.24],
552 /// "std": [1.83, 2.06]
553 /// },
554 /// "bees": {
555 /// "word": "bees",
556 /// "stem": "bee",
557 /// "avg": [3.2, 6.51],
558 /// "std": [2.07, 2.14]
559 /// }
560 /// }"#;
561 /// let custom_words_sentiment_hashmap: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
562 ///
563 /// let sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
564 /// let sentiment_description = sentiment.get_term_vector_description(&vec!["I", "betrayed", "the", "bees"]);
565 /// let description = "stressed";
566 ///
567 /// assert_eq!(sentiment_description, description);
568 /// ```
569 pub fn get_term_vector_description(&self, terms: &Vec<&str>) -> Cow<'static, str> {
570 let sentiment = self.get_sentiment_for_term_vector(terms);
571 if sentiment.get("arousal").unwrap() == &0.0 {
572 return Cow::from("unknown");
573 }
574 self.get_sentiment_description(sentiment.get("valence").unwrap(), sentiment.get("arousal").unwrap())
575 }
576
577 /// Adds a new `term` word token with its corresponding `valence` and `arousal`
578 /// values to the sentiment lexicons. If the `term` does not already exist, it
579 /// will be added to the custom sentiment lexicon.
580 ///
581 /// # Errors
582 ///
583 /// Returns [`RnltkError::SentimentTermExists`] if the term already exists
584 ///
585 /// # Examples
586 ///
587 /// ```
588 /// use std::collections::HashMap;
589 /// use rnltk::sentiment::{SentimentModel, CustomWords};
590 /// use rnltk::error::RnltkError;
591 /// use rnltk::sample_data;
592 ///
593 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
594 ///
595 /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
596 /// let sentiment_return_value = sentiment.add_term_without_replacement("squanch", &2.0, &8.5);
597 /// match sentiment_return_value {
598 /// Ok(_) => {
599 /// let sentiment_info = sentiment.get_sentiment_for_term("squanch");
600 /// let sentiment_map = HashMap::from([("valence", 2.0), ("arousal", 8.5)]);
601 ///
602 /// assert_eq!(sentiment_info, sentiment_map);
603 /// },
604 /// Err(error_msg) => assert_eq!(error_msg, RnltkError::SentimentTermExists),
605 /// }
606 /// ```
607 pub fn add_term_without_replacement(&mut self, term: &'static str, valence: &f64, arousal: &f64) -> Result<(), RnltkError>{
608 if self.does_term_exist(term) {
609 return Err(RnltkError::SentimentTermExists);
610 } else {
611 let stemmed_word = stem::get(term)?;
612 let word = term.to_string();
613 let avg = vec![*valence, *arousal];
614 let std = vec![1.0, 1.0];
615 let word_dict_value = SentimentDictValue {
616 word: word.clone(),
617 stem: stemmed_word.clone(),
618 avg,
619 std
620 };
621 let avg = vec![*valence, *arousal];
622 let std = vec![1.0, 1.0];
623 let stem_dict_value = SentimentDictValue {
624 word,
625 stem: stemmed_word,
626 avg,
627 std
628 };
629 self.custom_words.insert(term.to_string(), word_dict_value);
630 self.custom_stems.insert(term.to_string(), stem_dict_value);
631 }
632 Ok(())
633 }
634
635 /// Adds a new `term` word token and its corresponding `valence` and `arousal`
636 /// values to the sentiment lexicons. If this `term` already exists, the `term` will be updated
637 /// with the new `valence` and `arousal` values. If the `term` does not already exist, the `term` will be
638 /// stemmed and added to the custom sentiment lexicon.
639 ///
640 /// # Errors
641 ///
642 /// Returns [`RnltkError::StemNonAscii`] in the event that the `term` being stemmed contains non-ASCII characters (like hopè).
643 ///
644 /// # Examples
645 ///
646 /// ```
647 /// use std::collections::HashMap;
648 /// use rnltk::sentiment::{SentimentModel, CustomWords};
649 /// use rnltk::error::RnltkError;
650 /// use rnltk::sample_data;
651 ///
652 /// let custom_words_sentiment_hashmap: CustomWords = sample_data::get_sample_custom_word_dict();
653 ///
654 /// let mut sentiment = SentimentModel::new(custom_words_sentiment_hashmap);
655 /// let sentiment_return_value = sentiment.add_term_with_replacement("abduction", &8.0, &8.5);
656 /// match sentiment_return_value {
657 /// Ok(_) => {
658 /// let sentiment_info = sentiment.get_sentiment_for_term("abduction");
659 /// let sentiment_map = HashMap::from([("valence", 8.0), ("arousal", 8.5)]);
660 ///
661 /// assert_eq!(sentiment_info, sentiment_map);
662 /// },
663 /// Err(error_msg) => assert_eq!(error_msg, RnltkError::StemNonAscii),
664 /// }
665 /// ```
666 pub fn add_term_with_replacement(&mut self, term: &'static str, valence: &f64, arousal: &f64) -> Result<(), RnltkError>{
667 if self.custom_words.contains_key(term) {
668 let dict_value = self.custom_words.get_mut(term).unwrap();
669 dict_value.avg[0] = *valence;
670 dict_value.avg[1] = *arousal;
671 } else if self.custom_stems.contains_key(term) {
672 let dict_value = self.custom_stems.get_mut(term).unwrap();
673 dict_value.avg[0] = *valence;
674 dict_value.avg[1] = *arousal;
675 } else {
676 let stemmed_word = stem::get(term)?;
677 let word = term.to_string();
678 let avg = vec![*valence, *arousal];
679 let std = vec![1.0, 1.0];
680 let word_dict_value = SentimentDictValue {
681 word: word.clone(),
682 stem: stemmed_word.clone(),
683 avg,
684 std
685 };
686 let avg = vec![*valence, *arousal];
687 let std = vec![1.0, 1.0];
688 let stem_dict_value = SentimentDictValue {
689 word,
690 stem: stemmed_word,
691 avg,
692 std
693 };
694 self.custom_words.insert(term.to_string(), word_dict_value);
695 self.custom_stems.insert(term.to_string(), stem_dict_value);
696 }
697 Ok(())
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704
705 struct Setup {
706 custom_words: CustomWords
707 }
708
709 impl Setup {
710 fn new() -> Self {
711 let custom_word_dict: &str = include_str!("../test_data/test.json");
712 let custom_words: CustomWords = serde_json::from_str(custom_word_dict).unwrap();
713 Setup {
714 custom_words
715 }
716 }
717 }
718
719 #[test]
720 fn raw_arousal() {
721 let setup = Setup::new();
722 let sentiment = SentimentModel::new(setup.custom_words);
723 let arousal = sentiment.get_raw_arousal("abduction");
724 let raw_sentiment = RawSentiment::new(5.53, 2.43);
725
726 assert_eq!(arousal, raw_sentiment);
727 }
728
729 #[test]
730 fn raw_valence() {
731 let setup = Setup::new();
732 let sentiment = SentimentModel::new(setup.custom_words);
733 let valence = sentiment.get_raw_valence("abduction");
734 let raw_sentiment = RawSentiment::new(2.76, 2.06);
735
736 assert_eq!(valence, raw_sentiment);
737 }
738
739 #[test]
740 fn valence() {
741 let setup = Setup::new();
742 let sentiment = SentimentModel::new(setup.custom_words);
743 let valence = sentiment.get_valence_for_single_term("abduction");
744 let correct_valence = 2.76;
745
746 assert_eq!(valence, correct_valence);
747 }
748
749 #[test]
750 fn arousal() {
751 let setup = Setup::new();
752 let sentiment = SentimentModel::new(setup.custom_words);
753 let arousal = sentiment.get_arousal_for_single_term("abduction");
754 let correct_arousal = 5.53;
755
756 assert_eq!(arousal, correct_arousal);
757 }
758
759 #[test]
760 fn arousal_vector() {
761 let setup = Setup::new();
762 let sentiment = SentimentModel::new(setup.custom_words);
763 let arousal = sentiment.get_arousal_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
764 let correct_arousal = 6.881952380952381;
765
766 assert_eq!(arousal, correct_arousal);
767 }
768
769 #[test]
770 fn valence_vector() {
771 let setup = Setup::new();
772 let sentiment = SentimentModel::new(setup.custom_words);
773 let valence = sentiment.get_valence_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
774 let correct_valence = 2.865615384615385;
775
776 assert_eq!(valence, correct_valence);
777 }
778
779 #[test]
780 fn term_sentiment() {
781 let setup = Setup::new();
782 let sentiment = SentimentModel::new(setup.custom_words);
783 let sentiment_info = sentiment.get_sentiment_for_term("abduction");
784 let sentiment_map = HashMap::from([("valence", 2.76), ("arousal", 5.53)]);
785
786 assert_eq!(sentiment_info, sentiment_map);
787 }
788
789 #[test]
790 fn term_vector_sentiment() {
791 let setup = Setup::new();
792 let sentiment = SentimentModel::new(setup.custom_words);
793 let sentiment_info = sentiment.get_sentiment_for_term_vector(&vec!["I", "betrayed", "the", "bees"]);
794 let sentiment_map = HashMap::from([("valence", 2.865615384615385), ("arousal", 6.881952380952381)]);
795
796 assert_eq!(sentiment_info, sentiment_map);
797 }
798
799 #[test]
800 fn sentiment_description() {
801 let setup = Setup::new();
802 let sentiment = SentimentModel::new(setup.custom_words);
803 let sentiment_description = sentiment.get_sentiment_description(&2.76, &5.53);
804 let description = "upset";
805
806 assert_eq!(sentiment_description, description);
807 }
808
809 #[test]
810 fn term_description() {
811 let setup = Setup::new();
812 let sentiment = SentimentModel::new(setup.custom_words);
813 let sentiment_description = sentiment.get_term_description("abduction");
814 let description = "upset";
815
816 assert_eq!(sentiment_description, description);
817 }
818
819 #[test]
820 fn term_vector_description() {
821 let setup = Setup::new();
822 let sentiment = SentimentModel::new(setup.custom_words);
823 let sentiment_description = sentiment.get_term_vector_description(&vec!["I", "betrayed", "the", "bees"]);
824 let description = "stressed";
825
826 assert_eq!(sentiment_description, description);
827 }
828
829 #[test]
830 fn replace_term() {
831 let setup = Setup::new();
832 let mut sentiment = SentimentModel::new(setup.custom_words);
833 sentiment.add_term_with_replacement("abduction", &8.0, &8.5).unwrap();
834 let sentiment_info = sentiment.get_sentiment_for_term("abduction");
835 let sentiment_map = HashMap::from([("valence", 8.0), ("arousal", 8.5)]);
836
837 assert_eq!(sentiment_info, sentiment_map);
838 }
839
840 #[test]
841 fn non_ascii_error_replace_term() {
842 let setup = Setup::new();
843 let mut sentiment = SentimentModel::new(setup.custom_words);
844 let add_sentiment_error = sentiment.add_term_with_replacement("hopè", &8.0, &8.5).unwrap_err();
845 assert_eq!(add_sentiment_error, RnltkError::StemNonAscii);
846 }
847
848 #[test]
849 fn term_exists_error() {
850 let setup = Setup::new();
851 let mut sentiment = SentimentModel::new(setup.custom_words);
852 let add_sentiment_error = sentiment.add_term_without_replacement("abduction", &8.0, &8.5).unwrap_err();
853 assert_eq!(add_sentiment_error, RnltkError::SentimentTermExists);
854 }
855
856 #[test]
857 fn add_term() {
858 let setup = Setup::new();
859 let mut sentiment = SentimentModel::new(setup.custom_words);
860 sentiment.add_term_without_replacement("squanch", &2.0, &8.5).unwrap();
861 let sentiment_info = sentiment.get_sentiment_for_term("squanch");
862 let sentiment_map = HashMap::from([("valence", 2.0), ("arousal", 8.5)]);
863
864 assert_eq!(sentiment_info, sentiment_map);
865 }
866
867}