1use crate::{DetectedEntity, EntityCategory, PseudoToken};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionConfig {
17 #[serde(default)]
19 pub enabled: bool,
20 #[serde(default = "default_id_from")]
22 pub id_from: String,
23 #[serde(default = "default_ttl")]
25 pub ttl_seconds: u64,
26 #[serde(default = "default_true")]
28 pub coreference: bool,
29 #[serde(default = "default_true")]
31 pub sensitivity_escalation: bool,
32 #[serde(default = "default_session_threshold")]
34 pub session_threshold: f64,
35}
36
37fn default_id_from() -> String {
38 "header:x-session-id".into()
39}
40fn default_ttl() -> u64 {
41 1800
42}
43fn default_true() -> bool {
44 true
45}
46fn default_session_threshold() -> f64 {
47 0.80
48}
49
50impl Default for SessionConfig {
51 fn default() -> Self {
52 Self {
53 enabled: false,
54 id_from: default_id_from(),
55 ttl_seconds: default_ttl(),
56 coreference: true,
57 sensitivity_escalation: true,
58 session_threshold: default_session_threshold(),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum SensitivityLevel {
67 Normal,
69 Elevated,
71}
72
73impl Default for SensitivityLevel {
74 fn default() -> Self {
75 Self::Normal
76 }
77}
78
79#[derive(Debug, Clone, Serialize)]
81pub struct SessionEntity {
82 pub token: PseudoToken,
84 pub category: EntityCategory,
86 pub original: String,
88 pub first_seen: u32,
90 pub last_seen: u32,
92 pub related_to: Vec<String>,
94}
95
96#[derive(Debug, Clone, Serialize)]
98pub struct Coreference {
99 pub surface: String,
101 pub target_token: String,
103 pub confidence: f64,
105 pub method: CorefMethod,
107}
108
109#[derive(Debug, Clone, Serialize)]
111#[serde(rename_all = "snake_case")]
112pub enum CorefMethod {
113 Pronoun,
115 Abbreviation,
117 DefiniteArticle,
119 Possessive,
121}
122
123#[derive(Debug)]
125pub struct SessionContext {
126 pub session_id: String,
128 pub created_at: DateTime<Utc>,
130 pub last_activity: DateTime<Utc>,
132 pub message_count: u32,
134 pub sensitivity: SensitivityLevel,
136 pub escalation_keywords: Vec<String>,
138 pub entities: HashMap<String, SessionEntity>,
140 pub coreferences: HashMap<String, Coreference>,
142 config: SessionConfig,
144}
145
146const DECISION_KEYWORDS: &[&str] = &[
148 "approved",
149 "rejected",
150 "exception",
151 "policy",
152 "override",
153 "escalated",
154 "waiver",
155 "authorized",
156 "sanctioned",
157 "compliance",
158 "violation",
159 "audit",
160 "decision",
161 "ruling",
162 "verdict",
163 "settlement",
164 "terminated",
165 "suspended",
166];
167
168const PERSON_PRONOUNS: &[&str] = &[
170 "he", "him", "his", "she", "her", "hers", "they", "them", "their",
171];
172
173const DEFINITE_ARTICLES: &[(&str, EntityCategory)] = &[
175 ("the company", EntityCategory::Organization),
176 ("the firm", EntityCategory::Organization),
177 ("the client", EntityCategory::Organization),
178 ("the organization", EntityCategory::Organization),
179 ("the bank", EntityCategory::Organization),
180 ("the hospital", EntityCategory::Organization),
181 ("the deal", EntityCategory::Amount),
182 ("the amount", EntityCategory::Amount),
183 ("the transaction", EntityCategory::Amount),
184 ("the payment", EntityCategory::Amount),
185 ("the city", EntityCategory::Location),
186 ("the office", EntityCategory::Location),
187 ("the employee", EntityCategory::Person),
188 ("the manager", EntityCategory::Person),
189 ("the patient", EntityCategory::Person),
190 ("the customer", EntityCategory::Person),
191 ("the applicant", EntityCategory::Person),
192];
193
194impl SessionContext {
195 pub fn new(session_id: String, config: SessionConfig) -> Self {
197 let now = Utc::now();
198 Self {
199 session_id,
200 created_at: now,
201 last_activity: now,
202 message_count: 0,
203 sensitivity: SensitivityLevel::Normal,
204 escalation_keywords: Vec::new(),
205 entities: HashMap::new(),
206 coreferences: HashMap::new(),
207 config,
208 }
209 }
210
211 pub fn is_expired(&self) -> bool {
213 let elapsed = Utc::now()
214 .signed_duration_since(self.last_activity)
215 .num_seconds();
216 elapsed > self.config.ttl_seconds as i64
217 }
218
219 pub fn record_entities(
222 &mut self,
223 entities: &[DetectedEntity],
224 tokens: &[PseudoToken],
225 ) {
226 self.message_count += 1;
227 self.last_activity = Utc::now();
228
229 let msg_tokens: Vec<String> = tokens.iter().map(|t| t.token.clone()).collect();
231
232 for (entity, token) in entities.iter().zip(tokens.iter()) {
233 let entry = self
234 .entities
235 .entry(entity.original.clone())
236 .or_insert_with(|| SessionEntity {
237 token: token.clone(),
238 category: entity.category.clone(),
239 original: entity.original.clone(),
240 first_seen: self.message_count,
241 last_seen: self.message_count,
242 related_to: Vec::new(),
243 });
244
245 entry.last_seen = self.message_count;
246
247 for t in &msg_tokens {
249 if t != &token.token && !entry.related_to.contains(t) {
250 entry.related_to.push(t.clone());
251 }
252 }
253 }
254
255 if self.config.coreference {
257 self.update_coreferences();
258 }
259 }
260
261 pub fn check_sensitivity(&mut self, text: &str) -> bool {
264 if !self.config.sensitivity_escalation {
265 return false;
266 }
267 if self.sensitivity == SensitivityLevel::Elevated {
268 return false; }
270
271 let text_lower = text.to_lowercase();
272 let mut found = Vec::new();
273 for &keyword in DECISION_KEYWORDS {
274 if text_lower.contains(keyword) {
275 found.push(keyword.to_string());
276 }
277 }
278
279 if found.len() >= 2 {
280 self.sensitivity = SensitivityLevel::Elevated;
282 self.escalation_keywords = found;
283 true
284 } else {
285 false
286 }
287 }
288
289 pub fn resolve_coreferences(&self, text: &str) -> Vec<(DetectedEntity, PseudoToken)> {
295 if !self.config.coreference || self.entities.is_empty() {
296 return Vec::new();
297 }
298
299 let mut results = Vec::new();
300 let text_lower = text.to_lowercase();
301
302 for &pronoun in PERSON_PRONOUNS {
304 if let Some(pos) = find_word_boundary(&text_lower, pronoun) {
305 if let Some(person) = self.most_recent_entity(&EntityCategory::Person) {
307 let recent_persons = self.recent_entities_of_category(&EntityCategory::Person, 3);
309 if recent_persons.len() == 1 {
310 results.push((
311 DetectedEntity {
312 original: text[pos..pos + pronoun.len()].to_string(),
313 start: pos,
314 end: pos + pronoun.len(),
315 category: EntityCategory::Person,
316 confidence: 0.7,
317 source: crate::DetectionSource::Pattern,
318 },
319 person.token.clone(),
320 ));
321 }
322 }
323 }
324 }
325
326 for &(article, ref category) in DEFINITE_ARTICLES {
328 if let Some(pos) = find_word_boundary(&text_lower, article) {
329 if let Some(entity) = self.most_recent_entity(category) {
330 results.push((
331 DetectedEntity {
332 original: text[pos..pos + article.len()].to_string(),
333 start: pos,
334 end: pos + article.len(),
335 category: category.clone(),
336 confidence: 0.6,
337 source: crate::DetectionSource::Pattern,
338 },
339 entity.token.clone(),
340 ));
341 }
342 }
343 }
344
345 results.extend(self.resolve_abbreviations(text));
347
348 results.extend(self.resolve_possessives(text));
350
351 results.sort_by(|a, b| a.0.start.cmp(&b.0.start));
353 results.dedup_by(|a, b| {
354 if a.0.start == b.0.start {
355 if b.0.confidence > a.0.confidence {
356 std::mem::swap(a, b);
357 }
358 true
359 } else {
360 false
361 }
362 });
363
364 results
365 }
366
367 fn most_recent_entity(&self, category: &EntityCategory) -> Option<&SessionEntity> {
369 self.entities
370 .values()
371 .filter(|e| &e.category == category)
372 .max_by_key(|e| e.last_seen)
373 }
374
375 fn recent_entities_of_category(
377 &self,
378 category: &EntityCategory,
379 within_messages: u32,
380 ) -> Vec<&SessionEntity> {
381 let cutoff = self.message_count.saturating_sub(within_messages);
382 self.entities
383 .values()
384 .filter(|e| &e.category == category && e.last_seen >= cutoff)
385 .collect()
386 }
387
388 fn resolve_abbreviations(&self, text: &str) -> Vec<(DetectedEntity, PseudoToken)> {
390 let mut results = Vec::new();
391
392 for entity in self.entities.values() {
393 if !matches!(
395 entity.category,
396 EntityCategory::Organization | EntityCategory::Person
397 ) {
398 continue;
399 }
400
401 let words: Vec<&str> = entity.original.split_whitespace().collect();
402 if words.len() < 2 {
403 continue;
404 }
405
406 let abbrev: String = words.iter().map(|w| {
408 w.chars().next().unwrap_or_default().to_uppercase().to_string()
409 }).collect();
410
411 if abbrev.len() < 2 {
412 continue;
413 }
414
415 if let Some(pos) = find_word_boundary(text, &abbrev) {
417 results.push((
418 DetectedEntity {
419 original: text[pos..pos + abbrev.len()].to_string(),
420 start: pos,
421 end: pos + abbrev.len(),
422 category: entity.category.clone(),
423 confidence: 0.65,
424 source: crate::DetectionSource::Pattern,
425 },
426 entity.token.clone(),
427 ));
428 }
429 }
430
431 results
432 }
433
434 fn resolve_possessives(&self, text: &str) -> Vec<(DetectedEntity, PseudoToken)> {
436 let mut results = Vec::new();
437
438 for entity in self.entities.values() {
439 if entity.category != EntityCategory::Person {
440 continue;
441 }
442
443 let first_name = entity
445 .original
446 .split_whitespace()
447 .next()
448 .unwrap_or(&entity.original);
449
450 let possessive = format!("{}'s", first_name);
451 if let Some(pos) = find_word_boundary(text, &possessive) {
452 results.push((
453 DetectedEntity {
454 original: text[pos..pos + possessive.len()].to_string(),
455 start: pos,
456 end: pos + possessive.len(),
457 category: EntityCategory::Person,
458 confidence: 0.75,
459 source: crate::DetectionSource::Pattern,
460 },
461 entity.token.clone(),
462 ));
463 }
464 }
465
466 results
467 }
468
469 fn update_coreferences(&mut self) {
471 let mut new_corefs: Vec<(String, Coreference)> = Vec::new();
473
474 {
476 let recent_persons = self.recent_entities_of_category(&EntityCategory::Person, 3);
477 if recent_persons.len() == 1 {
478 let token_str = recent_persons[0].token.token.clone();
479 for &pronoun in PERSON_PRONOUNS {
480 new_corefs.push((
481 pronoun.to_string(),
482 Coreference {
483 surface: pronoun.to_string(),
484 target_token: token_str.clone(),
485 confidence: 0.7,
486 method: CorefMethod::Pronoun,
487 },
488 ));
489 }
490 }
491 }
492
493 for &(article, ref category) in DEFINITE_ARTICLES {
495 if let Some(entity) = self.most_recent_entity(category) {
496 let token_str = entity.token.token.clone();
497 new_corefs.push((
498 article.to_string(),
499 Coreference {
500 surface: article.to_string(),
501 target_token: token_str,
502 confidence: 0.6,
503 method: CorefMethod::DefiniteArticle,
504 },
505 ));
506 }
507 }
508
509 for entity in self.entities.values() {
511 if !matches!(
512 entity.category,
513 EntityCategory::Organization | EntityCategory::Person
514 ) {
515 continue;
516 }
517 let words: Vec<&str> = entity.original.split_whitespace().collect();
518 if words.len() < 2 {
519 continue;
520 }
521 let abbrev: String = words
522 .iter()
523 .map(|w| w.chars().next().unwrap_or_default().to_uppercase().to_string())
524 .collect();
525 if abbrev.len() >= 2 {
526 new_corefs.push((
527 abbrev.clone(),
528 Coreference {
529 surface: abbrev,
530 target_token: entity.token.token.clone(),
531 confidence: 0.65,
532 method: CorefMethod::Abbreviation,
533 },
534 ));
535 }
536 }
537
538 self.coreferences.clear();
540 for (key, coref) in new_corefs {
541 self.coreferences.insert(key, coref);
542 }
543 }
544
545 pub fn stats(&self) -> SessionStats {
547 let mut categories = HashMap::new();
548 for entity in self.entities.values() {
549 *categories.entry(format!("{:?}", entity.category)).or_insert(0u32) += 1;
550 }
551
552 SessionStats {
553 session_id: self.session_id.clone(),
554 message_count: self.message_count,
555 entity_count: self.entities.len(),
556 coreference_count: self.coreferences.len(),
557 sensitivity: self.sensitivity,
558 escalation_keywords: self.escalation_keywords.clone(),
559 categories,
560 created_at: self.created_at.to_rfc3339(),
561 last_activity: self.last_activity.to_rfc3339(),
562 }
563 }
564
565 pub fn coreference_map(&self) -> &HashMap<String, Coreference> {
567 &self.coreferences
568 }
569
570 pub fn resolver_threshold(&self) -> f64 {
572 self.config.session_threshold
573 }
574}
575
576#[derive(Debug, Clone, Serialize)]
578pub struct SessionStats {
579 pub session_id: String,
580 pub message_count: u32,
581 pub entity_count: usize,
582 pub coreference_count: usize,
583 pub sensitivity: SensitivityLevel,
584 pub escalation_keywords: Vec<String>,
585 pub categories: HashMap<String, u32>,
586 pub created_at: String,
587 pub last_activity: String,
588}
589
590pub struct SessionManager {
592 sessions: std::sync::RwLock<HashMap<String, SessionContext>>,
593 config: SessionConfig,
594}
595
596impl SessionManager {
597 pub fn new(config: SessionConfig) -> Self {
599 Self {
600 sessions: std::sync::RwLock::new(HashMap::new()),
601 config,
602 }
603 }
604
605 pub fn get_or_create(&self, session_id: &str) -> String {
607 let mut sessions = self.sessions.write().unwrap();
608
609 if let Some(session) = sessions.get(session_id) {
611 if !session.is_expired() {
612 return session_id.to_string();
613 }
614 sessions.remove(session_id);
616 }
617
618 sessions.insert(
620 session_id.to_string(),
621 SessionContext::new(session_id.to_string(), self.config.clone()),
622 );
623 session_id.to_string()
624 }
625
626 pub fn with_session<F, R>(&self, session_id: &str, f: F) -> Option<R>
628 where
629 F: FnOnce(&mut SessionContext) -> R,
630 {
631 let mut sessions = self.sessions.write().unwrap();
632 sessions.get_mut(session_id).map(f)
633 }
634
635 pub fn with_session_ref<F, R>(&self, session_id: &str, f: F) -> Option<R>
637 where
638 F: FnOnce(&SessionContext) -> R,
639 {
640 let sessions = self.sessions.read().unwrap();
641 sessions.get(session_id).map(f)
642 }
643
644 pub fn list_sessions(&self) -> Vec<SessionStats> {
646 let sessions = self.sessions.read().unwrap();
647 sessions.values().map(|s| s.stats()).collect()
648 }
649
650 pub fn inspect(&self, session_id: &str) -> Option<SessionStats> {
652 let sessions = self.sessions.read().unwrap();
653 sessions.get(session_id).map(|s| s.stats())
654 }
655
656 pub fn flush_session(&self, session_id: &str) -> bool {
658 let mut sessions = self.sessions.write().unwrap();
659 sessions.remove(session_id).is_some()
660 }
661
662 pub fn flush_all(&self) -> usize {
664 let mut sessions = self.sessions.write().unwrap();
665 let count = sessions.len();
666 sessions.clear();
667 count
668 }
669
670 pub fn evict_expired(&self) -> usize {
672 let mut sessions = self.sessions.write().unwrap();
673 let before = sessions.len();
674 sessions.retain(|_, s| !s.is_expired());
675 before - sessions.len()
676 }
677
678 pub fn is_enabled(&self) -> bool {
680 self.config.enabled
681 }
682}
683
684fn find_word_boundary(text: &str, word: &str) -> Option<usize> {
686 let text_lower = text.to_lowercase();
687 let word_lower = word.to_lowercase();
688 let mut start = 0;
689 while let Some(pos) = text_lower[start..].find(&word_lower) {
690 let abs_pos = start + pos;
691 let before_ok = abs_pos == 0
692 || !text.as_bytes()[abs_pos - 1].is_ascii_alphanumeric();
693 let after_pos = abs_pos + word_lower.len();
694 let after_ok = after_pos >= text.len()
695 || !text.as_bytes()[after_pos].is_ascii_alphanumeric();
696
697 if before_ok && after_ok {
698 return Some(abs_pos);
699 }
700 start = abs_pos + 1;
701 }
702 None
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use crate::{DetectedEntity, DetectionSource, EntityCategory, PseudoToken};
709
710 fn test_config() -> SessionConfig {
711 SessionConfig {
712 enabled: true,
713 id_from: "header:x-session-id".into(),
714 ttl_seconds: 1800,
715 coreference: true,
716 sensitivity_escalation: true,
717 session_threshold: 0.80,
718 }
719 }
720
721 fn make_entity(original: &str, category: EntityCategory) -> DetectedEntity {
722 DetectedEntity {
723 original: original.to_string(),
724 start: 0,
725 end: original.len(),
726 category,
727 confidence: 1.0,
728 source: DetectionSource::Pattern,
729 }
730 }
731
732 fn make_token(token: &str, category: EntityCategory, id: u32) -> PseudoToken {
733 PseudoToken {
734 token: token.to_string(),
735 category,
736 id,
737 }
738 }
739
740 #[test]
741 fn test_session_creation() {
742 let ctx = SessionContext::new("sess-1".into(), test_config());
743 assert_eq!(ctx.session_id, "sess-1");
744 assert_eq!(ctx.message_count, 0);
745 assert_eq!(ctx.sensitivity, SensitivityLevel::Normal);
746 assert!(ctx.entities.is_empty());
747 }
748
749 #[test]
750 fn test_record_entities() {
751 let mut ctx = SessionContext::new("sess-1".into(), test_config());
752
753 let entities = vec![
754 make_entity("Rahul Sharma", EntityCategory::Person),
755 make_entity("Tata Motors", EntityCategory::Organization),
756 ];
757 let tokens = vec![
758 make_token("PERSON_1", EntityCategory::Person, 1),
759 make_token("ORG_1", EntityCategory::Organization, 1),
760 ];
761
762 ctx.record_entities(&entities, &tokens);
763
764 assert_eq!(ctx.message_count, 1);
765 assert_eq!(ctx.entities.len(), 2);
766
767 let rahul = ctx.entities.get("Rahul Sharma").unwrap();
768 assert_eq!(rahul.token.token, "PERSON_1");
769 assert_eq!(rahul.first_seen, 1);
770 assert_eq!(rahul.related_to, vec!["ORG_1"]);
771 }
772
773 #[test]
774 fn test_pronoun_resolution() {
775 let mut ctx = SessionContext::new("sess-1".into(), test_config());
776
777 ctx.record_entities(
779 &[make_entity("Rahul Sharma", EntityCategory::Person)],
780 &[make_token("PERSON_1", EntityCategory::Person, 1)],
781 );
782
783 let corefs = ctx.resolve_coreferences("He approved the deal");
785 assert!(!corefs.is_empty());
786
787 let he_coref = corefs.iter().find(|(e, _)| e.original.to_lowercase() == "he");
788 assert!(he_coref.is_some());
789 assert_eq!(he_coref.unwrap().1.token, "PERSON_1");
790 }
791
792 #[test]
793 fn test_pronoun_ambiguity_blocks_resolution() {
794 let mut ctx = SessionContext::new("sess-1".into(), test_config());
795
796 ctx.record_entities(
798 &[
799 make_entity("Rahul Sharma", EntityCategory::Person),
800 make_entity("Priya Singh", EntityCategory::Person),
801 ],
802 &[
803 make_token("PERSON_1", EntityCategory::Person, 1),
804 make_token("PERSON_2", EntityCategory::Person, 2),
805 ],
806 );
807
808 let corefs = ctx.resolve_coreferences("He approved the deal");
809 let he_coref = corefs.iter().find(|(e, _)| e.original.to_lowercase() == "he");
810 assert!(he_coref.is_none());
811 }
812
813 #[test]
814 fn test_abbreviation_resolution() {
815 let mut ctx = SessionContext::new("sess-1".into(), test_config());
816
817 ctx.record_entities(
818 &[make_entity("Tata Motors", EntityCategory::Organization)],
819 &[make_token("ORG_1", EntityCategory::Organization, 1)],
820 );
821
822 let corefs = ctx.resolve_coreferences("TM reported strong quarterly earnings");
823 let tm_coref = corefs.iter().find(|(e, _)| e.original == "TM");
824 assert!(tm_coref.is_some());
825 assert_eq!(tm_coref.unwrap().1.token, "ORG_1");
826 }
827
828 #[test]
829 fn test_definite_article_resolution() {
830 let mut ctx = SessionContext::new("sess-1".into(), test_config());
831
832 ctx.record_entities(
833 &[make_entity("Infosys Ltd", EntityCategory::Organization)],
834 &[make_token("ORG_1", EntityCategory::Organization, 1)],
835 );
836
837 let corefs = ctx.resolve_coreferences("the company posted record revenue");
838 let co_coref = corefs
839 .iter()
840 .find(|(e, _)| e.original.to_lowercase() == "the company");
841 assert!(co_coref.is_some());
842 assert_eq!(co_coref.unwrap().1.token, "ORG_1");
843 }
844
845 #[test]
846 fn test_possessive_resolution() {
847 let mut ctx = SessionContext::new("sess-1".into(), test_config());
848
849 ctx.record_entities(
850 &[make_entity("Rahul Sharma", EntityCategory::Person)],
851 &[make_token("PERSON_1", EntityCategory::Person, 1)],
852 );
853
854 let corefs = ctx.resolve_coreferences("Rahul's decision was final");
855 let poss = corefs
856 .iter()
857 .find(|(e, _)| e.original.contains("Rahul's"));
858 assert!(poss.is_some());
859 assert_eq!(poss.unwrap().1.token, "PERSON_1");
860 }
861
862 #[test]
863 fn test_sensitivity_escalation() {
864 let mut ctx = SessionContext::new("sess-1".into(), test_config());
865
866 assert!(!ctx.check_sensitivity("The request was approved"));
868 assert_eq!(ctx.sensitivity, SensitivityLevel::Normal);
869
870 assert!(ctx.check_sensitivity("The exception was approved per policy override"));
872 assert_eq!(ctx.sensitivity, SensitivityLevel::Elevated);
873 assert!(ctx.escalation_keywords.len() >= 2);
874 }
875
876 #[test]
877 fn test_sensitivity_no_double_escalation() {
878 let mut ctx = SessionContext::new("sess-1".into(), test_config());
879
880 ctx.check_sensitivity("approved policy override exception");
881 assert_eq!(ctx.sensitivity, SensitivityLevel::Elevated);
882
883 assert!(!ctx.check_sensitivity("another decision violation"));
885 }
886
887 #[test]
888 fn test_session_ttl_expiry() {
889 let mut ctx = SessionContext::new("sess-1".into(), SessionConfig {
890 ttl_seconds: 0, ..test_config()
892 });
893 ctx.last_activity = Utc::now() - chrono::Duration::seconds(1);
894 assert!(ctx.is_expired());
895 }
896
897 #[test]
898 fn test_session_manager_create_and_list() {
899 let mgr = SessionManager::new(test_config());
900 mgr.get_or_create("sess-1");
901 mgr.get_or_create("sess-2");
902
903 let sessions = mgr.list_sessions();
904 assert_eq!(sessions.len(), 2);
905 }
906
907 #[test]
908 fn test_session_manager_flush() {
909 let mgr = SessionManager::new(test_config());
910 mgr.get_or_create("sess-1");
911 mgr.get_or_create("sess-2");
912
913 assert!(mgr.flush_session("sess-1"));
914 assert!(!mgr.flush_session("nonexistent"));
915
916 let sessions = mgr.list_sessions();
917 assert_eq!(sessions.len(), 1);
918 }
919
920 #[test]
921 fn test_session_manager_flush_all() {
922 let mgr = SessionManager::new(test_config());
923 mgr.get_or_create("sess-1");
924 mgr.get_or_create("sess-2");
925
926 assert_eq!(mgr.flush_all(), 2);
927 assert!(mgr.list_sessions().is_empty());
928 }
929
930 #[test]
931 fn test_session_stats() {
932 let mut ctx = SessionContext::new("sess-1".into(), test_config());
933 ctx.record_entities(
934 &[
935 make_entity("Alice", EntityCategory::Person),
936 make_entity("Acme Corp", EntityCategory::Organization),
937 ],
938 &[
939 make_token("PERSON_1", EntityCategory::Person, 1),
940 make_token("ORG_1", EntityCategory::Organization, 1),
941 ],
942 );
943
944 let stats = ctx.stats();
945 assert_eq!(stats.session_id, "sess-1");
946 assert_eq!(stats.message_count, 1);
947 assert_eq!(stats.entity_count, 2);
948 assert_eq!(stats.sensitivity, SensitivityLevel::Normal);
949 }
950
951 #[test]
952 fn test_find_word_boundary() {
953 assert_eq!(find_word_boundary("He went home", "he"), Some(0));
954 assert_eq!(find_word_boundary("Then he went", "he"), Some(5));
955 assert!(find_word_boundary("The cat", "he").is_none()); assert_eq!(find_word_boundary("Say TM earnings", "TM"), Some(4));
957 assert!(find_word_boundary("ATMS are here", "TM").is_none()); }
959
960 #[test]
961 fn test_cross_message_entity_tracking() {
962 let mut ctx = SessionContext::new("sess-1".into(), test_config());
963
964 ctx.record_entities(
966 &[make_entity("Rahul", EntityCategory::Person)],
967 &[make_token("PERSON_1", EntityCategory::Person, 1)],
968 );
969
970 ctx.record_entities(
972 &[make_entity("Rahul", EntityCategory::Person)],
973 &[make_token("PERSON_1", EntityCategory::Person, 1)],
974 );
975
976 let rahul = ctx.entities.get("Rahul").unwrap();
977 assert_eq!(rahul.first_seen, 1);
978 assert_eq!(rahul.last_seen, 2);
979 }
980
981 #[test]
982 fn test_coreference_map_updates() {
983 let mut ctx = SessionContext::new("sess-1".into(), test_config());
984
985 ctx.record_entities(
986 &[make_entity("Tata Motors", EntityCategory::Organization)],
987 &[make_token("ORG_1", EntityCategory::Organization, 1)],
988 );
989
990 let corefs = ctx.coreference_map();
991 assert!(corefs.contains_key("TM")); assert!(corefs.contains_key("the company")); }
994
995 #[test]
998 fn test_multi_message_conversation_flow() {
999 let mut ctx = SessionContext::new("conv-1".into(), test_config());
1001
1002 ctx.record_entities(
1004 &[
1005 make_entity("Rahul Sharma", EntityCategory::Person),
1006 make_entity("Tata Motors", EntityCategory::Organization),
1007 ],
1008 &[
1009 make_token("PERSON_1", EntityCategory::Person, 1),
1010 make_token("ORG_1", EntityCategory::Organization, 1),
1011 ],
1012 );
1013 assert_eq!(ctx.message_count, 1);
1014
1015 ctx.record_entities(
1017 &[make_entity("Mumbai", EntityCategory::Location)],
1018 &[make_token("LOC_1", EntityCategory::Location, 1)],
1019 );
1020 assert_eq!(ctx.message_count, 2);
1021 assert_eq!(ctx.entities.len(), 3);
1022
1023 let corefs = ctx.resolve_coreferences("He was transferred to the office");
1025 let he = corefs.iter().find(|(e, _)| e.original.to_lowercase() == "he");
1026 assert!(he.is_some());
1027 assert_eq!(he.unwrap().1.token, "PERSON_1");
1028
1029 let office = corefs.iter().find(|(e, _)| e.original.to_lowercase() == "the office");
1031 assert!(office.is_some());
1032 assert_eq!(office.unwrap().1.token, "LOC_1");
1033
1034 ctx.record_entities(
1036 &[make_entity("Priya Patel", EntityCategory::Person)],
1037 &[make_token("PERSON_2", EntityCategory::Person, 2)],
1038 );
1039
1040 let corefs2 = ctx.resolve_coreferences("She approved the transfer");
1041 let she = corefs2.iter().find(|(e, _)| e.original.to_lowercase() == "she");
1042 assert!(she.is_none()); }
1044
1045 #[test]
1046 fn test_sensitivity_requires_two_keywords() {
1047 let mut ctx = SessionContext::new("s-1".into(), test_config());
1048
1049 assert!(!ctx.check_sensitivity("This was approved."));
1051 assert!(!ctx.check_sensitivity("Check compliance"));
1052 assert!(!ctx.check_sensitivity("The verdict is in"));
1053 assert_eq!(ctx.sensitivity, SensitivityLevel::Normal);
1054 }
1055
1056 #[test]
1057 fn test_sensitivity_various_keyword_pairs() {
1058 let pairs = [
1060 "The verdict was approved",
1061 "compliance violation detected",
1062 "authorized the exception",
1063 "settlement was sanctioned",
1064 "decision to terminate suspended",
1065 ];
1066 for text in pairs {
1067 let mut ctx = SessionContext::new("s".into(), test_config());
1068 assert!(ctx.check_sensitivity(text), "Should escalate for: {}", text);
1069 }
1070 }
1071
1072 #[test]
1073 fn test_sensitivity_disabled() {
1074 let mut ctx = SessionContext::new("s-1".into(), SessionConfig {
1075 sensitivity_escalation: false,
1076 ..test_config()
1077 });
1078 assert!(!ctx.check_sensitivity("approved the policy exception override waiver"));
1079 assert_eq!(ctx.sensitivity, SensitivityLevel::Normal);
1080 }
1081
1082 #[test]
1083 fn test_coreference_disabled() {
1084 let mut ctx = SessionContext::new("s-1".into(), SessionConfig {
1085 coreference: false,
1086 ..test_config()
1087 });
1088 ctx.record_entities(
1089 &[make_entity("Rahul Sharma", EntityCategory::Person)],
1090 &[make_token("PERSON_1", EntityCategory::Person, 1)],
1091 );
1092
1093 let corefs = ctx.resolve_coreferences("He approved it");
1094 assert!(corefs.is_empty());
1095 assert!(ctx.coreference_map().is_empty());
1096 }
1097
1098 #[test]
1099 fn test_abbreviation_needs_multi_word() {
1100 let mut ctx = SessionContext::new("s-1".into(), test_config());
1101
1102 ctx.record_entities(
1104 &[make_entity("Google", EntityCategory::Organization)],
1105 &[make_token("ORG_1", EntityCategory::Organization, 1)],
1106 );
1107 let corefs = ctx.resolve_coreferences("G is great");
1108 let g = corefs.iter().find(|(e, _)| e.original == "G");
1109 assert!(g.is_none()); }
1111
1112 #[test]
1113 fn test_abbreviation_case_insensitive_source() {
1114 let mut ctx = SessionContext::new("s-1".into(), test_config());
1115
1116 ctx.record_entities(
1117 &[make_entity("New York City", EntityCategory::Location)],
1118 &[make_token("LOC_1", EntityCategory::Location, 1)],
1119 );
1120 let corefs = ctx.resolve_coreferences("NYC is busy");
1122 let nyc = corefs.iter().find(|(e, _)| e.original == "NYC");
1123 assert!(nyc.is_none()); }
1125
1126 #[test]
1127 fn test_possessive_first_name_only() {
1128 let mut ctx = SessionContext::new("s-1".into(), test_config());
1129
1130 ctx.record_entities(
1131 &[make_entity("Alice Johnson", EntityCategory::Person)],
1132 &[make_token("PERSON_1", EntityCategory::Person, 1)],
1133 );
1134
1135 let corefs = ctx.resolve_coreferences("Alice's report was thorough");
1137 let poss = corefs.iter().find(|(e, _)| e.original.contains("Alice's"));
1138 assert!(poss.is_some());
1139 assert_eq!(poss.unwrap().1.token, "PERSON_1");
1140
1141 let corefs2 = ctx.resolve_coreferences("Johnson's report was thorough");
1143 let poss2 = corefs2.iter().find(|(e, _)| e.original.contains("Johnson's"));
1144 assert!(poss2.is_none());
1145 }
1146
1147 #[test]
1148 fn test_co_occurrence_tracking() {
1149 let mut ctx = SessionContext::new("s-1".into(), test_config());
1150
1151 ctx.record_entities(
1152 &[
1153 make_entity("Alice", EntityCategory::Person),
1154 make_entity("$50,000", EntityCategory::Amount),
1155 make_entity("Acme Corp", EntityCategory::Organization),
1156 ],
1157 &[
1158 make_token("PERSON_1", EntityCategory::Person, 1),
1159 make_token("AMOUNT_1", EntityCategory::Amount, 1),
1160 make_token("ORG_1", EntityCategory::Organization, 1),
1161 ],
1162 );
1163
1164 let alice = ctx.entities.get("Alice").unwrap();
1165 assert!(alice.related_to.contains(&"AMOUNT_1".to_string()));
1166 assert!(alice.related_to.contains(&"ORG_1".to_string()));
1167 assert!(!alice.related_to.contains(&"PERSON_1".to_string())); let amount = ctx.entities.get("$50,000").unwrap();
1170 assert!(amount.related_to.contains(&"PERSON_1".to_string()));
1171 assert!(amount.related_to.contains(&"ORG_1".to_string()));
1172 }
1173
1174 #[test]
1175 fn test_definite_article_all_categories() {
1176 let mut ctx = SessionContext::new("s-1".into(), test_config());
1177
1178 ctx.record_entities(
1179 &[
1180 make_entity("Acme Corp", EntityCategory::Organization),
1181 make_entity("$1M", EntityCategory::Amount),
1182 make_entity("Chicago", EntityCategory::Location),
1183 make_entity("Jane Doe", EntityCategory::Person),
1184 ],
1185 &[
1186 make_token("ORG_1", EntityCategory::Organization, 1),
1187 make_token("AMOUNT_1", EntityCategory::Amount, 1),
1188 make_token("LOC_1", EntityCategory::Location, 1),
1189 make_token("PERSON_1", EntityCategory::Person, 1),
1190 ],
1191 );
1192
1193 let org = ctx.resolve_coreferences("the firm announced profits");
1195 assert!(org.iter().any(|(e, t)| e.original.to_lowercase() == "the firm" && t.token == "ORG_1"));
1196
1197 let amt = ctx.resolve_coreferences("the deal was finalized");
1198 assert!(amt.iter().any(|(e, t)| e.original.to_lowercase() == "the deal" && t.token == "AMOUNT_1"));
1199
1200 let loc = ctx.resolve_coreferences("the city experienced growth");
1201 assert!(loc.iter().any(|(e, t)| e.original.to_lowercase() == "the city" && t.token == "LOC_1"));
1202
1203 let person = ctx.resolve_coreferences("the employee was promoted");
1205 assert!(person.iter().any(|(e, t)| e.original.to_lowercase() == "the employee" && t.token == "PERSON_1"));
1206 }
1207
1208 #[test]
1209 fn test_session_manager_with_session() {
1210 let mgr = SessionManager::new(test_config());
1211 mgr.get_or_create("sess-1");
1212
1213 let result = mgr.with_session("sess-1", |ctx| {
1215 ctx.record_entities(
1216 &[make_entity("Bob", EntityCategory::Person)],
1217 &[make_token("PERSON_1", EntityCategory::Person, 1)],
1218 );
1219 ctx.message_count
1220 });
1221 assert_eq!(result, Some(1));
1222
1223 let stats = mgr.inspect("sess-1").unwrap();
1225 assert_eq!(stats.entity_count, 1);
1226 assert_eq!(stats.message_count, 1);
1227 }
1228
1229 #[test]
1230 fn test_session_manager_with_session_ref() {
1231 let mgr = SessionManager::new(test_config());
1232 mgr.get_or_create("sess-1");
1233
1234 let count = mgr.with_session_ref("sess-1", |ctx| ctx.message_count);
1236 assert_eq!(count, Some(0));
1237
1238 let none = mgr.with_session_ref("nonexistent", |ctx| ctx.message_count);
1240 assert!(none.is_none());
1241 }
1242
1243 #[test]
1244 fn test_session_manager_evict_expired() {
1245 let mgr = SessionManager::new(SessionConfig {
1246 ttl_seconds: 0,
1247 ..test_config()
1248 });
1249 mgr.get_or_create("sess-1");
1250 mgr.get_or_create("sess-2");
1251
1252 mgr.with_session("sess-1", |ctx| {
1254 ctx.last_activity = Utc::now() - chrono::Duration::seconds(2);
1255 });
1256 mgr.with_session("sess-2", |ctx| {
1257 ctx.last_activity = Utc::now() - chrono::Duration::seconds(2);
1258 });
1259
1260 let evicted = mgr.evict_expired();
1261 assert_eq!(evicted, 2);
1262 assert!(mgr.list_sessions().is_empty());
1263 }
1264
1265 #[test]
1266 fn test_session_manager_get_or_create_replaces_expired() {
1267 let mgr = SessionManager::new(SessionConfig {
1268 ttl_seconds: 0,
1269 ..test_config()
1270 });
1271 mgr.get_or_create("sess-1");
1272
1273 mgr.with_session("sess-1", |ctx| {
1275 ctx.record_entities(
1276 &[make_entity("Alice", EntityCategory::Person)],
1277 &[make_token("PERSON_1", EntityCategory::Person, 1)],
1278 );
1279 ctx.last_activity = Utc::now() - chrono::Duration::seconds(2);
1280 });
1281
1282 mgr.get_or_create("sess-1");
1284 let stats = mgr.inspect("sess-1").unwrap();
1285 assert_eq!(stats.entity_count, 0); assert_eq!(stats.message_count, 0);
1287 }
1288
1289 #[test]
1290 fn test_session_manager_is_enabled() {
1291 let mgr = SessionManager::new(test_config());
1292 assert!(mgr.is_enabled());
1293
1294 let mgr2 = SessionManager::new(SessionConfig {
1295 enabled: false,
1296 ..test_config()
1297 });
1298 assert!(!mgr2.is_enabled());
1299 }
1300
1301 #[test]
1302 fn test_stats_categories_counted() {
1303 let mut ctx = SessionContext::new("s-1".into(), test_config());
1304 ctx.record_entities(
1305 &[
1306 make_entity("Alice", EntityCategory::Person),
1307 make_entity("Bob", EntityCategory::Person),
1308 make_entity("Acme", EntityCategory::Organization),
1309 make_entity("secret-key-123", EntityCategory::Secret),
1310 ],
1311 &[
1312 make_token("PERSON_1", EntityCategory::Person, 1),
1313 make_token("PERSON_2", EntityCategory::Person, 2),
1314 make_token("ORG_1", EntityCategory::Organization, 1),
1315 make_token("SECRET_1", EntityCategory::Secret, 1),
1316 ],
1317 );
1318
1319 let stats = ctx.stats();
1320 assert_eq!(*stats.categories.get("Person").unwrap(), 2);
1321 assert_eq!(*stats.categories.get("Organization").unwrap(), 1);
1322 assert_eq!(*stats.categories.get("Secret").unwrap(), 1);
1323 }
1324
1325 #[test]
1326 fn test_resolve_coreferences_empty_session() {
1327 let ctx = SessionContext::new("s-1".into(), test_config());
1328 let corefs = ctx.resolve_coreferences("He went to the company");
1329 assert!(corefs.is_empty()); }
1331
1332 #[test]
1333 fn test_word_boundary_edge_cases() {
1334 assert_eq!(find_word_boundary("He is here", "he"), Some(0));
1336 assert_eq!(find_word_boundary("it was he", "he"), Some(7));
1338 assert_eq!(find_word_boundary("he, she, they", "she"), Some(4));
1340 assert_eq!(find_word_boundary("(he) was there", "he"), Some(1));
1341 assert!(find_word_boundary("sheet", "he").is_none());
1343 assert!(find_word_boundary("ether", "he").is_none());
1344 }
1345
1346 #[test]
1347 fn test_resolver_threshold() {
1348 let ctx = SessionContext::new("s-1".into(), test_config());
1349 assert!((ctx.resolver_threshold() - 0.80).abs() < f64::EPSILON);
1350
1351 let ctx2 = SessionContext::new("s-2".into(), SessionConfig {
1352 session_threshold: 0.70,
1353 ..test_config()
1354 });
1355 assert!((ctx2.resolver_threshold() - 0.70).abs() < f64::EPSILON);
1356 }
1357
1358 #[test]
1359 fn test_default_session_config() {
1360 let config = SessionConfig::default();
1361 assert!(!config.enabled);
1362 assert_eq!(config.id_from, "header:x-session-id");
1363 assert_eq!(config.ttl_seconds, 1800);
1364 assert!(config.coreference);
1365 assert!(config.sensitivity_escalation);
1366 assert!((config.session_threshold - 0.80).abs() < f64::EPSILON);
1367 }
1368
1369 #[test]
1370 fn test_dedup_coreferences_by_position() {
1371 let mut ctx = SessionContext::new("s-1".into(), test_config());
1372
1373 ctx.record_entities(
1375 &[make_entity("The Company LLC", EntityCategory::Organization)],
1376 &[make_token("ORG_1", EntityCategory::Organization, 1)],
1377 );
1378
1379 let corefs = ctx.resolve_coreferences("the company is doing well");
1382 let at_zero: Vec<_> = corefs.iter().filter(|(e, _)| e.start == 0).collect();
1383 assert!(at_zero.len() <= 1); }
1385}