mockforge_data/rag/
utils.rs

1//! Utility functions and helpers for RAG operations
2//!
3//! This module provides utility functions for text processing,
4//! similarity calculations, data validation, and other common RAG operations.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// Text chunking utilities
12pub struct TextChunker;
13
14impl TextChunker {
15    /// Split text into chunks of specified size with overlap
16    pub fn split_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
17        if text.is_empty() || chunk_size == 0 {
18            return Vec::new();
19        }
20
21        let words: Vec<&str> = text.split_whitespace().collect();
22        if words.is_empty() {
23            return Vec::new();
24        }
25
26        let mut chunks = Vec::new();
27        let step = chunk_size.saturating_sub(overlap);
28
29        for start in (0..words.len()).step_by(step) {
30            let end = (start + chunk_size).min(words.len());
31            let chunk: Vec<&str> = words[start..end].to_vec();
32            if !chunk.is_empty() {
33                chunks.push(chunk.join(" "));
34            }
35        }
36
37        chunks
38    }
39
40    /// Split text by sentences
41    pub fn split_by_sentences(text: &str) -> Vec<String> {
42        // Simple sentence splitting - in practice, you might want to use a proper NLP library
43        let mut sentences = Vec::new();
44        let mut current_sentence = String::new();
45
46        for ch in text.chars() {
47            current_sentence.push(ch);
48            if ch == '.' || ch == '!' || ch == '?' {
49                if !current_sentence.trim().is_empty() {
50                    sentences.push(current_sentence.trim().to_string());
51                }
52                current_sentence.clear();
53            }
54        }
55
56        if !current_sentence.trim().is_empty() {
57            sentences.push(current_sentence.trim().to_string());
58        }
59
60        sentences
61    }
62
63    /// Split text by paragraphs
64    pub fn split_by_paragraphs(text: &str) -> Vec<String> {
65        text.split("\n\n")
66            .map(|s| s.trim().to_string())
67            .filter(|s| !s.is_empty())
68            .collect()
69    }
70
71    /// Create overlapping chunks for better context preservation
72    pub fn create_overlapping_chunks(
73        text: &str,
74        chunk_size: usize,
75        overlap_ratio: f32,
76    ) -> Vec<String> {
77        let overlap = ((chunk_size as f32) * overlap_ratio).round() as usize;
78        Self::split_text(text, chunk_size, overlap)
79    }
80
81    /// Chunk text with metadata preservation
82    pub fn chunk_with_metadata(
83        text: &str,
84        chunk_size: usize,
85        overlap: usize,
86        metadata: HashMap<String, String>,
87    ) -> Vec<(String, HashMap<String, String>)> {
88        let chunks = Self::split_text(text, chunk_size, overlap);
89        chunks.into_iter().map(|chunk| (chunk, metadata.clone())).collect()
90    }
91}
92
93/// Similarity calculation utilities
94pub struct SimilarityCalculator;
95
96impl SimilarityCalculator {
97    /// Calculate cosine similarity between two vectors
98    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
99        if a.len() != b.len() || a.is_empty() {
100            return 0.0;
101        }
102
103        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
104        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
105        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
106
107        if norm_a == 0.0 || norm_b == 0.0 {
108            0.0
109        } else {
110            dot_product / (norm_a * norm_b)
111        }
112    }
113
114    /// Calculate Euclidean distance between two vectors
115    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
116        if a.len() != b.len() || a.is_empty() {
117            return f32::INFINITY;
118        }
119
120        let sum_squares: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
121
122        sum_squares.sqrt()
123    }
124
125    /// Calculate Manhattan distance between two vectors
126    pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
127        if a.len() != b.len() || a.is_empty() {
128            return f32::INFINITY;
129        }
130
131        a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
132    }
133
134    /// Calculate dot product of two vectors
135    pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
136        if a.len() != b.len() {
137            return 0.0;
138        }
139
140        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
141    }
142
143    /// Normalize vector to unit length
144    pub fn normalize_vector(vector: &[f32]) -> Vec<f32> {
145        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
146        if norm == 0.0 {
147            return vec![0.0; vector.len()];
148        }
149
150        vector.iter().map(|x| x / norm).collect()
151    }
152
153    /// Calculate similarity matrix for multiple vectors
154    pub fn similarity_matrix(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
155        let n = vectors.len();
156        let mut matrix = vec![vec![0.0; n]; n];
157
158        for i in 0..n {
159            for j in i..n {
160                let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
161                matrix[i][j] = similarity;
162                matrix[j][i] = similarity;
163            }
164        }
165
166        matrix
167    }
168
169    /// Find most similar vectors to a query vector
170    pub fn find_most_similar(
171        query: &[f32],
172        candidates: &[Vec<f32>],
173        top_k: usize,
174    ) -> Vec<(usize, f32)> {
175        let mut similarities: Vec<(usize, f32)> = candidates
176            .iter()
177            .enumerate()
178            .map(|(i, vec)| (i, Self::cosine_similarity(query, vec)))
179            .collect();
180
181        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
182        similarities.truncate(top_k);
183
184        similarities
185    }
186}
187
188/// Text preprocessing utilities
189pub struct TextPreprocessor;
190
191impl TextPreprocessor {
192    /// Clean text by removing extra whitespace and normalizing
193    pub fn clean_text(text: &str) -> String {
194        text.split_whitespace().collect::<Vec<&str>>().join(" ")
195    }
196
197    /// Remove HTML tags from text
198    pub fn remove_html_tags(text: &str) -> String {
199        // Simple HTML tag removal - in practice, you might want to use a proper HTML parser
200        let mut result = String::new();
201        let mut in_tag = false;
202
203        for ch in text.chars() {
204            match ch {
205                '<' => in_tag = true,
206                '>' => in_tag = false,
207                _ if !in_tag => result.push(ch),
208                _ => {}
209            }
210        }
211
212        result
213    }
214
215    /// Normalize whitespace
216    pub fn normalize_whitespace(text: &str) -> String {
217        text.chars()
218            .fold((String::new(), false), |(mut acc, mut was_space), ch| {
219                if ch.is_whitespace() {
220                    if !was_space {
221                        acc.push(' ');
222                        was_space = true;
223                    }
224                } else {
225                    acc.push(ch);
226                    was_space = false;
227                }
228                (acc, was_space)
229            })
230            .0
231    }
232
233    /// Extract keywords from text
234    pub fn extract_keywords(text: &str, max_keywords: usize) -> Vec<String> {
235        let words: Vec<String> = text
236            .to_lowercase()
237            .split_whitespace()
238            .map(|word| word.trim_matches(|c: char| !c.is_alphabetic()).to_string())
239            .filter(|trimmed_word| {
240                // Filter out common stop words and short words
241                trimmed_word.len() > 2 && !is_stop_word(trimmed_word)
242            })
243            .collect();
244
245        // Count word frequencies
246        let mut word_counts: HashMap<String, usize> = HashMap::new();
247        for word in words {
248            *word_counts.entry(word).or_insert(0) += 1;
249        }
250
251        // Sort by frequency and take top keywords
252        let mut sorted_words: Vec<(String, usize)> = word_counts.into_iter().collect();
253        sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
254
255        sorted_words.into_iter().take(max_keywords).map(|(word, _)| word).collect()
256    }
257
258    /// Truncate text to maximum length while preserving word boundaries
259    pub fn truncate_text(text: &str, max_length: usize) -> String {
260        if text.len() <= max_length {
261            return text.to_string();
262        }
263
264        let truncated = &text[..max_length];
265        let last_space = truncated.rfind(' ').unwrap_or(max_length);
266        truncated[..last_space].trim().to_string()
267    }
268
269    /// Expand contractions in text
270    pub fn expand_contractions(text: &str) -> String {
271        text.replace("don't", "do not")
272            .replace("can't", "cannot")
273            .replace("won't", "will not")
274            .replace("i'm", "i am")
275            .replace("you're", "you are")
276            .replace("it's", "it is")
277            .replace("that's", "that is")
278            .replace("there's", "there is")
279            .replace("here's", "here is")
280            .replace("what's", "what is")
281            .replace("where's", "where is")
282            .replace("when's", "when is")
283            .replace("why's", "why is")
284            .replace("how's", "how is")
285    }
286}
287
288/// Common stop words (simplified list)
289fn is_stop_word(word: &str) -> bool {
290    matches!(
291        word,
292        "the"
293            | "a"
294            | "an"
295            | "and"
296            | "or"
297            | "but"
298            | "in"
299            | "on"
300            | "at"
301            | "to"
302            | "for"
303            | "of"
304            | "with"
305            | "by"
306            | "from"
307            | "up"
308            | "about"
309            | "into"
310            | "through"
311            | "during"
312            | "before"
313            | "after"
314            | "above"
315            | "below"
316            | "between"
317            | "among"
318            | "is"
319            | "are"
320            | "was"
321            | "were"
322            | "be"
323            | "been"
324            | "being"
325            | "have"
326            | "has"
327            | "had"
328            | "do"
329            | "does"
330            | "did"
331            | "will"
332            | "would"
333            | "could"
334            | "should"
335            | "may"
336            | "might"
337            | "must"
338            | "can"
339            | "this"
340            | "that"
341            | "these"
342            | "those"
343            | "i"
344            | "you"
345            | "he"
346            | "she"
347            | "it"
348            | "we"
349            | "they"
350            | "me"
351            | "him"
352            | "her"
353            | "us"
354            | "them"
355            | "my"
356            | "your"
357            | "his"
358            | "its"
359            | "our"
360            | "their"
361            | "mine"
362            | "yours"
363            | "hers"
364            | "ours"
365            | "theirs"
366            | "am"
367            | "not"
368            | "no"
369            | "yes"
370            | "here"
371            | "there"
372            | "now"
373            | "then"
374            | "so"
375            | "very"
376            | "too"
377            | "also"
378            | "only"
379            | "just"
380            | "even"
381            | "still"
382            | "yet"
383            | "again"
384            | "once"
385            | "never"
386            | "always"
387            | "often"
388            | "sometimes"
389            | "usually"
390    )
391}
392
393/// Rate limiting utilities
394pub struct RateLimiter {
395    requests_per_minute: u32,
396    burst_size: u32,
397    request_times: Vec<std::time::Instant>,
398}
399
400impl RateLimiter {
401    /// Create a new rate limiter
402    pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
403        Self {
404            requests_per_minute,
405            burst_size,
406            request_times: Vec::new(),
407        }
408    }
409
410    /// Check if request is allowed
411    pub fn is_allowed(&mut self) -> bool {
412        let now = std::time::Instant::now();
413        let window_start = now - std::time::Duration::from_secs(60);
414
415        // Remove old requests
416        self.request_times.retain(|&time| time > window_start);
417
418        // Check if within burst limit
419        if self.request_times.len() < self.burst_size as usize {
420            self.request_times.push(now);
421            return true;
422        }
423
424        // Check if within rate limit
425        let requests_in_window = self.request_times.len();
426        requests_in_window < self.requests_per_minute as usize
427    }
428
429    /// Get time until next allowed request
430    pub fn time_until_next(&self) -> std::time::Duration {
431        if self.request_times.is_empty() {
432            return std::time::Duration::from_secs(0);
433        }
434
435        let now = std::time::Instant::now();
436        let window_start = now - std::time::Duration::from_secs(60);
437
438        if let Some(&oldest_request) = self.request_times.first() {
439            if oldest_request > window_start {
440                oldest_request - window_start
441            } else {
442                std::time::Duration::from_secs(0)
443            }
444        } else {
445            std::time::Duration::from_secs(0)
446        }
447    }
448}
449
450/// Caching utilities
451pub struct Cache<K, V> {
452    data: HashMap<K, (V, std::time::Instant)>,
453    ttl: std::time::Duration,
454    max_size: usize,
455    hits: u64,
456    misses: u64,
457}
458
459impl<K, V> Cache<K, V>
460where
461    K: std::cmp::Eq + std::hash::Hash + Clone,
462    V: Clone,
463{
464    /// Create a new cache
465    pub fn new(ttl_secs: u64, max_size: usize) -> Self {
466        Self {
467            data: HashMap::new(),
468            ttl: std::time::Duration::from_secs(ttl_secs),
469            max_size,
470            hits: 0,
471            misses: 0,
472        }
473    }
474
475    /// Get value from cache
476    pub fn get(&mut self, key: &K) -> Option<V> {
477        if let Some((value, timestamp)) = self.data.get(key) {
478            let now = std::time::Instant::now();
479            if now.duration_since(*timestamp) < self.ttl {
480                self.hits += 1;
481                return Some(value.clone());
482            } else {
483                // Expired, remove it
484                self.data.remove(key);
485            }
486        }
487        self.misses += 1;
488        None
489    }
490
491    /// Put value in cache
492    pub fn put(&mut self, key: K, value: V) {
493        let now = std::time::Instant::now();
494
495        // Remove expired entries
496        self.data.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.ttl);
497
498        // Check if we need to evict old entries
499        if self.data.len() >= self.max_size {
500            // Simple LRU eviction - remove oldest entry
501            if let Some(oldest_key) = self.data.keys().next().cloned() {
502                self.data.remove(&oldest_key);
503            }
504        }
505
506        self.data.insert(key, (value, now));
507    }
508
509    /// Clear cache
510    pub fn clear(&mut self) {
511        self.data.clear();
512    }
513
514    /// Get cache size
515    pub fn size(&self) -> usize {
516        self.data.len()
517    }
518
519    /// Get cache hit rate
520    pub fn hit_rate(&self) -> f32 {
521        let total = self.hits + self.misses;
522        if total == 0 {
523            0.0
524        } else {
525            self.hits as f32 / total as f32
526        }
527    }
528}
529
530/// Performance monitoring utilities
531pub struct PerformanceMonitor {
532    start_time: std::time::Instant,
533    metrics: HashMap<String, MetricValue>,
534}
535
536#[derive(Debug, Clone, Serialize, Deserialize)]
537pub enum MetricValue {
538    Counter(u64),
539    Gauge(f64),
540    Histogram(Vec<f64>),
541    Duration(std::time::Duration),
542}
543
544impl Default for PerformanceMonitor {
545    fn default() -> Self {
546        Self::new()
547    }
548}
549
550impl PerformanceMonitor {
551    /// Create a new performance monitor
552    pub fn new() -> Self {
553        Self {
554            start_time: std::time::Instant::now(),
555            metrics: HashMap::new(),
556        }
557    }
558
559    /// Start timing an operation
560    pub fn start_timer(&mut self, operation: &str) -> TimerGuard<'_> {
561        let start = std::time::Instant::now();
562        TimerGuard {
563            monitor: self,
564            operation: operation.to_string(),
565            start,
566        }
567    }
568
569    /// Record a metric
570    pub fn record_metric(&mut self, name: String, value: MetricValue) {
571        self.metrics.insert(name, value);
572    }
573
574    /// Increment counter
575    pub fn increment_counter(&mut self, name: &str) {
576        let counter = match self.metrics.get(name) {
577            Some(MetricValue::Counter(count)) => *count + 1,
578            _ => 1,
579        };
580        self.metrics.insert(name.to_string(), MetricValue::Counter(counter));
581    }
582
583    /// Record gauge value
584    pub fn record_gauge(&mut self, name: &str, value: f64) {
585        self.metrics.insert(name.to_string(), MetricValue::Gauge(value));
586    }
587
588    /// Get elapsed time since start
589    pub fn elapsed(&self) -> std::time::Duration {
590        self.start_time.elapsed()
591    }
592
593    /// Get all metrics
594    pub fn metrics(&self) -> &HashMap<String, MetricValue> {
595        &self.metrics
596    }
597
598    /// Get uptime
599    pub fn uptime(&self) -> std::time::Duration {
600        self.start_time.elapsed()
601    }
602}
603
604/// Timer guard for automatic timing
605pub struct TimerGuard<'a> {
606    monitor: &'a mut PerformanceMonitor,
607    operation: String,
608    start: std::time::Instant,
609}
610
611impl<'a> Drop for TimerGuard<'a> {
612    fn drop(&mut self) {
613        let duration = self.start.elapsed();
614        let operation_duration = format!("{}_duration", self.operation);
615        self.monitor.metrics.insert(operation_duration, MetricValue::Duration(duration));
616    }
617}
618
619/// File utilities for RAG operations
620pub struct FileUtils;
621
622impl FileUtils {
623    /// Read text file
624    pub async fn read_text_file<P: AsRef<Path>>(path: P) -> Result<String> {
625        let content = tokio::fs::read_to_string(path).await?;
626        Ok(content)
627    }
628
629    /// Write text file
630    pub async fn write_text_file<P: AsRef<Path>>(path: P, content: &str) -> Result<()> {
631        tokio::fs::write(path, content).await?;
632        Ok(())
633    }
634
635    /// Check if file exists
636    pub async fn file_exists<P: AsRef<Path>>(path: P) -> bool {
637        tokio::fs::try_exists(path).await.unwrap_or(false)
638    }
639
640    /// Get file size
641    pub async fn file_size<P: AsRef<Path>>(path: P) -> Result<u64> {
642        let metadata = tokio::fs::metadata(path).await?;
643        Ok(metadata.len())
644    }
645
646    /// List files in directory
647    pub async fn list_files<P: AsRef<Path>>(dir: P) -> Result<Vec<std::path::PathBuf>> {
648        let mut files = Vec::new();
649        let mut entries = tokio::fs::read_dir(dir).await?;
650
651        while let Some(entry) = entries.next_entry().await? {
652            files.push(entry.path());
653        }
654
655        Ok(files)
656    }
657
658    /// Read JSON file
659    pub async fn read_json_file<T: for<'de> Deserialize<'de>, P: AsRef<Path>>(
660        path: P,
661    ) -> Result<T> {
662        let content = Self::read_text_file(path).await?;
663        let data: T = serde_json::from_str(&content)?;
664        Ok(data)
665    }
666
667    /// Write JSON file
668    pub async fn write_json_file<T: Serialize, P: AsRef<Path>>(path: P, data: &T) -> Result<()> {
669        let content = serde_json::to_string_pretty(data)?;
670        Self::write_text_file(path, &content).await
671    }
672}
673
674/// Error handling utilities
675pub struct ErrorUtils;
676
677impl ErrorUtils {
678    /// Create a generic error
679    pub fn generic_error(message: &str) -> mockforge_core::Error {
680        mockforge_core::Error::generic(message.to_string())
681    }
682
683    /// Create an error with context
684    pub fn context_error(message: &str, context: &str) -> mockforge_core::Error {
685        mockforge_core::Error::generic(format!("{}: {}", message, context))
686    }
687
688    /// Wrap an error with additional context
689    pub fn wrap_error<E: std::fmt::Display>(error: E, context: &str) -> mockforge_core::Error {
690        mockforge_core::Error::generic(format!("{}: {}", context, error))
691    }
692
693    /// Check if error is retryable
694    pub fn is_retryable_error(error: &mockforge_core::Error) -> bool {
695        // Simple heuristic - in practice, you might want to categorize errors
696        error.to_string().contains("timeout")
697            || error.to_string().contains("rate limit")
698            || error.to_string().contains("503")
699            || error.to_string().contains("502")
700            || error.to_string().contains("504")
701    }
702}
703
704#[cfg(test)]
705mod tests {
706
707    #[test]
708    fn test_module_compiles() {
709        // Basic compilation test
710    }
711}