1use crate::error::{Result, TextError};
19use std::collections::{HashMap, HashSet};
20
21#[derive(Debug, Clone, PartialEq)]
27pub enum QAMethod {
28 TfIdf,
31 BigramOverlap,
33 WordEmbeddingMatch,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum QuestionType {
41 Who,
43 What,
45 When,
47 Where,
49 Why,
51 How,
53 Unknown,
55}
56
57#[derive(Debug, Clone)]
59pub struct AnswerSpan {
60 pub start: usize,
62 pub end: usize,
64 pub text: String,
66 pub confidence: f64,
68 pub sentence_index: usize,
70}
71
72pub struct QAContext {
74 pub text: String,
76 sentences: Vec<SentenceRecord>,
78 embeddings: Option<HashMap<String, Vec<f64>>>,
80}
81
82#[derive(Debug, Clone)]
88struct SentenceRecord {
89 text: String,
90 start: usize,
91 tokens: Vec<String>,
92}
93
94fn simple_tokenize(text: &str) -> Vec<String> {
96 text.split(|c: char| !c.is_alphanumeric())
97 .filter(|s| !s.is_empty())
98 .map(|s| s.to_lowercase())
99 .collect()
100}
101
102fn stop_words() -> HashSet<&'static str> {
104 [
105 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
106 "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
107 "could", "to", "of", "in", "on", "at", "by", "for", "with", "about", "against", "between",
108 "into", "through", "during", "before", "after", "above", "below", "from", "up", "down",
109 "out", "off", "over", "under", "again", "further", "then", "once", "and", "but", "or",
110 "nor", "so", "yet", "both", "either", "neither", "not", "only", "own", "same", "than",
111 "too", "very", "just", "i", "you", "he", "she", "it", "we", "they", "me", "him", "her",
112 "us", "them", "my", "your", "his", "its", "our", "their", "what", "which", "who", "whom",
113 "this", "that", "these", "those", "am", "s", "t",
114 ]
115 .iter()
116 .cloned()
117 .collect()
118}
119
120pub fn classify_question(question: &str) -> QuestionType {
122 let lower = question.to_lowercase();
123 for word in lower.split_whitespace() {
125 let w = word.trim_matches(|c: char| !c.is_alphabetic());
126 match w {
127 "who" | "whose" | "whom" => return QuestionType::Who,
128 "when" => return QuestionType::When,
129 "where" => return QuestionType::Where,
130 "why" => return QuestionType::Why,
131 "how" => return QuestionType::How,
132 "what" | "which" => return QuestionType::What,
133 _ => {}
134 }
135 }
136 QuestionType::Unknown
137}
138
139impl QAContext {
144 pub fn new(text: &str) -> Self {
146 let sentences = Self::split_sentences(text);
147 Self {
148 text: text.to_string(),
149 sentences,
150 embeddings: None,
151 }
152 }
153
154 pub fn with_embeddings(mut self, embeddings: HashMap<String, Vec<f64>>) -> Self {
156 self.embeddings = Some(embeddings);
157 self
158 }
159
160 fn split_sentences(text: &str) -> Vec<SentenceRecord> {
165 let mut records = Vec::new();
166 let mut start = 0usize;
167 let bytes = text.as_bytes();
168 let len = bytes.len();
169
170 while start < len {
171 let mut end = start;
174 while end < len {
175 let b = bytes[end];
176 if b == b'.' || b == b'?' || b == b'!' {
177 end += 1;
179 while end < len && bytes[end] == b' ' {
182 end += 1;
183 }
184 break;
185 }
186 end += 1;
187 }
188
189 let raw = text[start..end].trim();
190 if !raw.is_empty() {
191 records.push(SentenceRecord {
192 text: raw.to_string(),
193 start,
194 tokens: simple_tokenize(raw),
195 });
196 }
197 start = end;
198 }
199
200 records
201 }
202
203 fn build_idf(&self) -> HashMap<String, f64> {
209 let n = self.sentences.len() as f64;
210 let mut df: HashMap<String, usize> = HashMap::new();
211 for sent in &self.sentences {
212 let unique: HashSet<&String> = sent.tokens.iter().collect();
213 for tok in unique {
214 *df.entry(tok.clone()).or_insert(0) += 1;
215 }
216 }
217 df.into_iter()
218 .map(|(t, d)| (t, (1.0 + n / (1.0 + d as f64)).ln()))
219 .collect()
220 }
221
222 fn tfidf_score(
224 query_tokens: &[String],
225 sentence: &SentenceRecord,
226 idf: &HashMap<String, f64>,
227 stops: &HashSet<&'static str>,
228 ) -> f64 {
229 let mut sent_tf: HashMap<String, f64> = HashMap::new();
231 for tok in &sentence.tokens {
232 *sent_tf.entry(tok.clone()).or_insert(0.0) += 1.0;
233 }
234 let sent_len = sentence.tokens.len().max(1) as f64;
235
236 let mut dot = 0.0f64;
237 let mut q_norm = 0.0f64;
238 let mut s_norm = 0.0f64;
239
240 let query_freq: HashMap<&String, f64> = {
242 let mut m = HashMap::new();
243 for t in query_tokens {
244 if !stops.contains(t.as_str()) {
245 *m.entry(t).or_insert(0.0) += 1.0;
246 }
247 }
248 m
249 };
250
251 for (tok, &qf) in &query_freq {
252 let idf_val = idf.get(*tok).copied().unwrap_or(0.0);
253 let q_tfidf = (qf / query_tokens.len().max(1) as f64) * idf_val;
254 let s_tfidf = sent_tf.get(*tok).copied().unwrap_or(0.0) / sent_len * idf_val;
255 dot += q_tfidf * s_tfidf;
256 q_norm += q_tfidf * q_tfidf;
257 s_norm += s_tfidf * s_tfidf;
258 }
259
260 if q_norm > 0.0 && s_norm > 0.0 {
261 dot / (q_norm.sqrt() * s_norm.sqrt())
262 } else {
263 0.0
264 }
265 }
266
267 fn bigram_overlap_score(query_tokens: &[String], sentence: &SentenceRecord) -> f64 {
272 if query_tokens.len() < 2 || sentence.tokens.len() < 2 {
273 let q_set: HashSet<&String> = query_tokens.iter().collect();
275 let s_set: HashSet<&String> = sentence.tokens.iter().collect();
276 let inter = q_set.intersection(&s_set).count();
277 return inter as f64 / q_set.len().max(1) as f64;
278 }
279
280 let q_bigrams: HashSet<(&String, &String)> =
281 query_tokens.windows(2).map(|w| (&w[0], &w[1])).collect();
282 let s_bigrams: HashSet<(&String, &String)> =
283 sentence.tokens.windows(2).map(|w| (&w[0], &w[1])).collect();
284
285 let inter = q_bigrams.intersection(&s_bigrams).count();
286 let union = q_bigrams.union(&s_bigrams).count();
287 if union == 0 {
288 0.0
289 } else {
290 inter as f64 / union as f64
291 }
292 }
293
294 fn embedding_score(
299 query_tokens: &[String],
300 sentence: &SentenceRecord,
301 embeddings: &HashMap<String, Vec<f64>>,
302 ) -> f64 {
303 let q_vec = Self::average_embedding(query_tokens, embeddings);
304 let s_vec = Self::average_embedding(&sentence.tokens, embeddings);
305 match (q_vec, s_vec) {
306 (Some(q), Some(s)) => cosine_sim(&q, &s),
307 _ => 0.0,
308 }
309 }
310
311 fn average_embedding(
312 tokens: &[String],
313 embeddings: &HashMap<String, Vec<f64>>,
314 ) -> Option<Vec<f64>> {
315 let vecs: Vec<&Vec<f64>> = tokens.iter().filter_map(|t| embeddings.get(t)).collect();
316 if vecs.is_empty() {
317 return None;
318 }
319 let dim = vecs[0].len();
320 let mut sum = vec![0.0f64; dim];
321 for v in &vecs {
322 for (s, &x) in sum.iter_mut().zip(v.iter()) {
323 *s += x;
324 }
325 }
326 let n = vecs.len() as f64;
327 Some(sum.into_iter().map(|x| x / n).collect())
328 }
329
330 fn extract_best_span(
339 sentence: &SentenceRecord,
340 q_type: &QuestionType,
341 doc_text: &str,
342 ) -> Option<(usize, usize, f64)> {
343 let patterns: &[(&str, &[QuestionType], f64)] = &[
345 (
347 r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b",
348 &[QuestionType::When],
349 0.3,
350 ),
351 (
352 r"\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b",
353 &[QuestionType::When],
354 0.3,
355 ),
356 (r"\b\d{4}\b", &[QuestionType::When], 0.15),
357 (
359 r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b",
360 &[QuestionType::Who, QuestionType::Where, QuestionType::What],
361 0.2,
362 ),
363 (
365 r"\bin\s+[A-Z][a-zA-Z]+(?:,\s*[A-Z][a-zA-Z]+)*\b",
366 &[QuestionType::Where],
367 0.25,
368 ),
369 ];
370
371 let sent_start_in_doc = sentence.start;
373 let sent_end_in_doc = sent_start_in_doc + sentence.text.len();
374 let sent_end_in_doc = sent_end_in_doc.min(doc_text.len());
376
377 let mut best: Option<(usize, usize, f64)> = None;
378
379 for (pattern_str, qtypes, bonus) in patterns {
380 let applies = qtypes.iter().any(|qt| qt == q_type);
381 if !applies && *q_type != QuestionType::Unknown && *q_type != QuestionType::What {
382 continue;
383 }
384
385 if let Ok(re) = regex::Regex::new(pattern_str) {
389 for m in re.find_iter(&sentence.text) {
390 let abs_start = sent_start_in_doc + m.start();
391 let abs_end = sent_start_in_doc + m.end();
392 if abs_end > sent_end_in_doc {
394 continue;
395 }
396 let score = 0.5 + bonus;
397 if best.is_none_or(|(_, _, s)| score > s) {
398 best = Some((abs_start, abs_end, score));
399 }
400 }
401 }
402 }
403
404 best
405 }
406
407 pub fn find_answer_span(&self, question: &str, method: QAMethod) -> Result<Option<AnswerSpan>> {
413 if self.sentences.is_empty() {
414 return Ok(None);
415 }
416
417 let q_tokens = simple_tokenize(question);
418 if q_tokens.is_empty() {
419 return Err(TextError::InvalidInput(
420 "Question must not be empty".to_string(),
421 ));
422 }
423
424 let q_type = classify_question(question);
425 let stops = stop_words();
426
427 let mut scored: Vec<(usize, f64)> = self
429 .sentences
430 .iter()
431 .enumerate()
432 .map(|(i, sent)| {
433 let base = match &method {
434 QAMethod::TfIdf => {
435 let idf = self.build_idf();
436 Self::tfidf_score(&q_tokens, sent, &idf, &stops)
437 }
438 QAMethod::BigramOverlap => Self::bigram_overlap_score(&q_tokens, sent),
439 QAMethod::WordEmbeddingMatch => {
440 if let Some(emb) = &self.embeddings {
441 Self::embedding_score(&q_tokens, sent, emb)
442 } else {
443 Self::bigram_overlap_score(&q_tokens, sent)
445 }
446 }
447 };
448 (i, base)
449 })
450 .collect();
451
452 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
453
454 let (best_idx, base_score) = scored[0];
455 if base_score <= 0.0 {
456 return Ok(None);
457 }
458
459 let best_sent = &self.sentences[best_idx];
460
461 let span = Self::extract_best_span(best_sent, &q_type, &self.text);
463
464 let answer = if let Some((start, end, ne_bonus)) = span {
465 if start < end && end <= self.text.len() {
466 AnswerSpan {
467 start,
468 end,
469 text: self.text[start..end].to_string(),
470 confidence: (base_score + ne_bonus).min(1.0),
471 sentence_index: best_idx,
472 }
473 } else {
474 let start = best_sent.start;
476 let end = (best_sent.start + best_sent.text.len()).min(self.text.len());
477 AnswerSpan {
478 start,
479 end,
480 text: best_sent.text.clone(),
481 confidence: base_score,
482 sentence_index: best_idx,
483 }
484 }
485 } else {
486 let start = best_sent.start;
488 let end = (best_sent.start + best_sent.text.len()).min(self.text.len());
489 AnswerSpan {
490 start,
491 end,
492 text: best_sent.text.clone(),
493 confidence: base_score,
494 sentence_index: best_idx,
495 }
496 };
497
498 Ok(Some(answer))
499 }
500
501 pub fn find_top_k(
507 &self,
508 question: &str,
509 method: QAMethod,
510 k: usize,
511 ) -> Result<Vec<AnswerSpan>> {
512 if self.sentences.is_empty() || k == 0 {
513 return Ok(Vec::new());
514 }
515
516 let q_tokens = simple_tokenize(question);
517 if q_tokens.is_empty() {
518 return Err(TextError::InvalidInput(
519 "Question must not be empty".to_string(),
520 ));
521 }
522
523 let q_type = classify_question(question);
524 let stops = stop_words();
525 let idf = self.build_idf();
526
527 let mut scored: Vec<(usize, f64)> = self
528 .sentences
529 .iter()
530 .enumerate()
531 .map(|(i, sent)| {
532 let base = match &method {
533 QAMethod::TfIdf => Self::tfidf_score(&q_tokens, sent, &idf, &stops),
534 QAMethod::BigramOverlap => Self::bigram_overlap_score(&q_tokens, sent),
535 QAMethod::WordEmbeddingMatch => {
536 if let Some(emb) = &self.embeddings {
537 Self::embedding_score(&q_tokens, sent, emb)
538 } else {
539 Self::bigram_overlap_score(&q_tokens, sent)
540 }
541 }
542 };
543 (i, base)
544 })
545 .collect();
546
547 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
548
549 let mut answers = Vec::new();
550 for (idx, base_score) in scored.into_iter().take(k) {
551 if base_score <= 0.0 {
552 break;
553 }
554 let sent = &self.sentences[idx];
555 let span = Self::extract_best_span(sent, &q_type, &self.text);
556
557 let answer = if let Some((start, end, bonus)) = span {
558 if start < end && end <= self.text.len() {
559 AnswerSpan {
560 start,
561 end,
562 text: self.text[start..end].to_string(),
563 confidence: (base_score + bonus).min(1.0),
564 sentence_index: idx,
565 }
566 } else {
567 let s = sent.start;
568 let e = (sent.start + sent.text.len()).min(self.text.len());
569 AnswerSpan {
570 start: s,
571 end: e,
572 text: sent.text.clone(),
573 confidence: base_score,
574 sentence_index: idx,
575 }
576 }
577 } else {
578 let s = sent.start;
579 let e = (sent.start + sent.text.len()).min(self.text.len());
580 AnswerSpan {
581 start: s,
582 end: e,
583 text: sent.text.clone(),
584 confidence: base_score,
585 sentence_index: idx,
586 }
587 };
588
589 answers.push(answer);
590 }
591
592 Ok(answers)
593 }
594}
595
596pub fn tf_idf_similarity(query_tokens: &[String], context_sentences: &[Vec<String>]) -> Vec<f64> {
604 if context_sentences.is_empty() || query_tokens.is_empty() {
605 return vec![0.0; context_sentences.len()];
606 }
607
608 let n = context_sentences.len() as f64;
609 let stops = stop_words();
610
611 let mut df: HashMap<String, usize> = HashMap::new();
613 for sent in context_sentences {
614 let unique: HashSet<&String> = sent.iter().collect();
615 for tok in unique {
616 *df.entry(tok.clone()).or_insert(0) += 1;
617 }
618 }
619 let idf: HashMap<String, f64> = df
620 .into_iter()
621 .map(|(t, d)| (t, (1.0 + n / (1.0 + d as f64)).ln()))
622 .collect();
623
624 context_sentences
625 .iter()
626 .map(|sent| {
627 let record = SentenceRecord {
628 text: sent.join(" "),
629 start: 0,
630 tokens: sent.clone(),
631 };
632 QAContext::tfidf_score(query_tokens, &record, &idf, &stops)
633 })
634 .collect()
635}
636
637pub fn extract_answer(question: &str, document: &str, top_k: usize) -> Vec<AnswerSpan> {
640 let ctx = QAContext::new(document);
641 ctx.find_top_k(question, QAMethod::TfIdf, top_k)
642 .unwrap_or_default()
643}
644
645fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
647 if a.len() != b.len() || a.is_empty() {
648 return 0.0;
649 }
650 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
651 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
652 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
653 if na > 0.0 && nb > 0.0 {
654 dot / (na * nb)
655 } else {
656 0.0
657 }
658}
659
660#[cfg(test)]
665mod tests {
666 use super::*;
667
668 const DOC: &str = "Marie Curie was born in Warsaw on 7 November 1867. \
669 She conducted pioneering research on radioactivity. \
670 In 1903 she won the Nobel Prize in Physics. \
671 She also won the Nobel Prize in Chemistry in 1911. \
672 Paris became her home after she moved from Poland.";
673
674 #[test]
675 fn test_classify_question() {
676 assert_eq!(
677 classify_question("Who discovered radium?"),
678 QuestionType::Who
679 );
680 assert_eq!(classify_question("When was she born?"), QuestionType::When);
681 assert_eq!(
682 classify_question("Where did she live?"),
683 QuestionType::Where
684 );
685 assert_eq!(classify_question("How did she win?"), QuestionType::How);
686 assert_eq!(
687 classify_question("What is radioactivity?"),
688 QuestionType::What
689 );
690 assert_eq!(
691 classify_question("Why is science important?"),
692 QuestionType::Why
693 );
694 }
695
696 #[test]
697 fn test_extract_answer_tfidf() {
698 let answers = extract_answer("When was Marie Curie born?", DOC, 3);
699 assert!(!answers.is_empty());
700 assert!(
702 answers[0].text.to_lowercase().contains("born")
703 || answers[0].text.contains("1867")
704 || answers[0].text.to_lowercase().contains("november")
705 );
706 assert!(answers[0].confidence > 0.0);
707 }
708
709 #[test]
710 fn test_find_answer_span_bigram() {
711 let ctx = QAContext::new(DOC);
712 let ans = ctx
713 .find_answer_span(
714 "What prize did she win in Physics?",
715 QAMethod::BigramOverlap,
716 )
717 .expect("QA failed");
718 assert!(ans.is_some());
719 let span = ans.expect("should have a span");
720 assert!(
721 span.text.to_lowercase().contains("physics")
722 || span.text.contains("1903")
723 || span.text.to_lowercase().contains("prize")
724 );
725 }
726
727 #[test]
728 fn test_find_top_k() {
729 let ctx = QAContext::new(DOC);
730 let answers = ctx
731 .find_top_k("Nobel Prize", QAMethod::TfIdf, 2)
732 .expect("top-k failed");
733 assert!(answers.len() <= 2);
734 }
735
736 #[test]
737 fn test_embedding_fallback_without_embeddings() {
738 let ctx = QAContext::new(DOC);
739 let ans = ctx
741 .find_answer_span("Where did she live?", QAMethod::WordEmbeddingMatch)
742 .expect("QA failed");
743 let _ = ans;
745 }
746
747 #[test]
748 fn test_tf_idf_similarity_standalone() {
749 let query = simple_tokenize("Nobel Prize winner");
750 let sentences: Vec<Vec<String>> = vec![
751 simple_tokenize("She won the Nobel Prize in Physics"),
752 simple_tokenize("Marie Curie was born in Warsaw"),
753 simple_tokenize("Nobel Prize in Chemistry was awarded"),
754 ];
755 let scores = tf_idf_similarity(&query, &sentences);
756 assert_eq!(scores.len(), 3);
757 assert!(scores[0] > scores[1] || scores[2] > scores[1]);
759 }
760
761 #[test]
762 fn test_answer_span_bounds() {
763 let ctx = QAContext::new(DOC);
764 let answers = ctx
765 .find_top_k("Marie Curie radioactivity", QAMethod::TfIdf, 5)
766 .expect("failed");
767 for ans in answers {
768 assert!(ans.start <= ans.end);
770 assert!(ans.end <= DOC.len());
771 assert_eq!(ans.text, DOC[ans.start..ans.end]);
772 }
773 }
774
775 #[test]
776 fn test_empty_document() {
777 let ctx = QAContext::new("");
778 let ans = ctx
779 .find_answer_span("Who is here?", QAMethod::TfIdf)
780 .expect("QA failed");
781 assert!(ans.is_none());
782 }
783}