mockforge_data/rag/
utils.rs

1//! Utility functions and helpers for RAG operations
2//!
3//! This module provides utility functions for text processing,
4//! similarity calculations, data validation, and other common RAG operations.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// Text chunking utilities
12pub struct TextChunker;
13
14impl TextChunker {
15    /// Split text into chunks of specified size with overlap
16    pub fn split_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
17        if text.is_empty() || chunk_size == 0 {
18            return Vec::new();
19        }
20
21        let words: Vec<&str> = text.split_whitespace().collect();
22        if words.is_empty() {
23            return Vec::new();
24        }
25
26        let mut chunks = Vec::new();
27        let step = chunk_size.saturating_sub(overlap);
28
29        for start in (0..words.len()).step_by(step) {
30            let end = (start + chunk_size).min(words.len());
31            let chunk: Vec<&str> = words[start..end].to_vec();
32            if !chunk.is_empty() {
33                chunks.push(chunk.join(" "));
34            }
35        }
36
37        chunks
38    }
39
40    /// Split text by sentences
41    pub fn split_by_sentences(text: &str) -> Vec<String> {
42        // Simple sentence splitting - in practice, you might want to use a proper NLP library
43        let mut sentences = Vec::new();
44        let mut current_sentence = String::new();
45
46        for ch in text.chars() {
47            current_sentence.push(ch);
48            if ch == '.' || ch == '!' || ch == '?' {
49                if !current_sentence.trim().is_empty() {
50                    sentences.push(current_sentence.trim().to_string());
51                }
52                current_sentence.clear();
53            }
54        }
55
56        if !current_sentence.trim().is_empty() {
57            sentences.push(current_sentence.trim().to_string());
58        }
59
60        sentences
61    }
62
63    /// Split text by paragraphs
64    pub fn split_by_paragraphs(text: &str) -> Vec<String> {
65        text.split("\n\n")
66            .map(|s| s.trim().to_string())
67            .filter(|s| !s.is_empty())
68            .collect()
69    }
70
71    /// Create overlapping chunks for better context preservation
72    pub fn create_overlapping_chunks(
73        text: &str,
74        chunk_size: usize,
75        overlap_ratio: f32,
76    ) -> Vec<String> {
77        let overlap = ((chunk_size as f32) * overlap_ratio).round() as usize;
78        Self::split_text(text, chunk_size, overlap)
79    }
80
81    /// Chunk text with metadata preservation
82    pub fn chunk_with_metadata(
83        text: &str,
84        chunk_size: usize,
85        overlap: usize,
86        metadata: HashMap<String, String>,
87    ) -> Vec<(String, HashMap<String, String>)> {
88        let chunks = Self::split_text(text, chunk_size, overlap);
89        chunks.into_iter().map(|chunk| (chunk, metadata.clone())).collect()
90    }
91}
92
93/// Similarity calculation utilities
94pub struct SimilarityCalculator;
95
96impl SimilarityCalculator {
97    /// Calculate cosine similarity between two vectors
98    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
99        if a.len() != b.len() || a.is_empty() {
100            return 0.0;
101        }
102
103        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
104        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
105        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
106
107        if norm_a == 0.0 || norm_b == 0.0 {
108            0.0
109        } else {
110            dot_product / (norm_a * norm_b)
111        }
112    }
113
114    /// Calculate Euclidean distance between two vectors
115    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
116        if a.len() != b.len() || a.is_empty() {
117            return f32::INFINITY;
118        }
119
120        let sum_squares: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
121
122        sum_squares.sqrt()
123    }
124
125    /// Calculate Manhattan distance between two vectors
126    pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
127        if a.len() != b.len() || a.is_empty() {
128            return f32::INFINITY;
129        }
130
131        a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
132    }
133
134    /// Calculate dot product of two vectors
135    pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
136        if a.len() != b.len() {
137            return 0.0;
138        }
139
140        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
141    }
142
143    /// Normalize vector to unit length
144    pub fn normalize_vector(vector: &[f32]) -> Vec<f32> {
145        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
146        if norm == 0.0 {
147            return vec![0.0; vector.len()];
148        }
149
150        vector.iter().map(|x| x / norm).collect()
151    }
152
153    /// Calculate similarity matrix for multiple vectors
154    pub fn similarity_matrix(vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
155        let n = vectors.len();
156        let mut matrix = vec![vec![0.0; n]; n];
157
158        for i in 0..n {
159            for j in i..n {
160                let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
161                matrix[i][j] = similarity;
162                matrix[j][i] = similarity;
163            }
164        }
165
166        matrix
167    }
168
169    /// Find most similar vectors to a query vector
170    pub fn find_most_similar(
171        query: &[f32],
172        candidates: &[Vec<f32>],
173        top_k: usize,
174    ) -> Vec<(usize, f32)> {
175        let mut similarities: Vec<(usize, f32)> = candidates
176            .iter()
177            .enumerate()
178            .map(|(i, vec)| (i, Self::cosine_similarity(query, vec)))
179            .collect();
180
181        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
182        similarities.truncate(top_k);
183
184        similarities
185    }
186}
187
188/// Text preprocessing utilities
189pub struct TextPreprocessor;
190
191impl TextPreprocessor {
192    /// Clean text by removing extra whitespace and normalizing
193    pub fn clean_text(text: &str) -> String {
194        text.split_whitespace().collect::<Vec<&str>>().join(" ")
195    }
196
197    /// Remove HTML tags from text
198    pub fn remove_html_tags(text: &str) -> String {
199        // Simple HTML tag removal - in practice, you might want to use a proper HTML parser
200        let mut result = String::new();
201        let mut in_tag = false;
202
203        for ch in text.chars() {
204            match ch {
205                '<' => in_tag = true,
206                '>' => in_tag = false,
207                _ if !in_tag => result.push(ch),
208                _ => {}
209            }
210        }
211
212        result
213    }
214
215    /// Normalize whitespace
216    pub fn normalize_whitespace(text: &str) -> String {
217        text.chars()
218            .fold((String::new(), false), |(mut acc, mut was_space), ch| {
219                if ch.is_whitespace() {
220                    if !was_space {
221                        acc.push(' ');
222                        was_space = true;
223                    }
224                } else {
225                    acc.push(ch);
226                    was_space = false;
227                }
228                (acc, was_space)
229            })
230            .0
231    }
232
233    /// Extract keywords from text
234    pub fn extract_keywords(text: &str, max_keywords: usize) -> Vec<String> {
235        let words: Vec<String> = text
236            .to_lowercase()
237            .split_whitespace()
238            .map(|word| word.trim_matches(|c: char| !c.is_alphabetic()).to_string())
239            .filter(|trimmed_word| {
240                // Filter out common stop words and short words
241                trimmed_word.len() > 2 && !is_stop_word(trimmed_word)
242            })
243            .collect();
244
245        // Count word frequencies
246        let mut word_counts: HashMap<String, usize> = HashMap::new();
247        for word in words {
248            *word_counts.entry(word).or_insert(0) += 1;
249        }
250
251        // Sort by frequency and take top keywords
252        let mut sorted_words: Vec<(String, usize)> = word_counts.into_iter().collect();
253        sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
254
255        sorted_words.into_iter().take(max_keywords).map(|(word, _)| word).collect()
256    }
257
258    /// Truncate text to maximum length while preserving word boundaries
259    pub fn truncate_text(text: &str, max_length: usize) -> String {
260        if text.len() <= max_length {
261            return text.to_string();
262        }
263
264        let truncated = &text[..max_length];
265        let last_space = truncated.rfind(' ').unwrap_or(max_length);
266        truncated[..last_space].trim().to_string()
267    }
268
269    /// Expand contractions in text
270    pub fn expand_contractions(text: &str) -> String {
271        text.replace("don't", "do not")
272            .replace("can't", "cannot")
273            .replace("won't", "will not")
274            .replace("i'm", "i am")
275            .replace("you're", "you are")
276            .replace("it's", "it is")
277            .replace("that's", "that is")
278            .replace("there's", "there is")
279            .replace("here's", "here is")
280            .replace("what's", "what is")
281            .replace("where's", "where is")
282            .replace("when's", "when is")
283            .replace("why's", "why is")
284            .replace("how's", "how is")
285    }
286}
287
288/// Common stop words (simplified list)
289fn is_stop_word(word: &str) -> bool {
290    matches!(
291        word,
292        "the"
293            | "a"
294            | "an"
295            | "and"
296            | "or"
297            | "but"
298            | "in"
299            | "on"
300            | "at"
301            | "to"
302            | "for"
303            | "of"
304            | "with"
305            | "by"
306            | "from"
307            | "up"
308            | "about"
309            | "into"
310            | "through"
311            | "during"
312            | "before"
313            | "after"
314            | "above"
315            | "below"
316            | "between"
317            | "among"
318            | "is"
319            | "are"
320            | "was"
321            | "were"
322            | "be"
323            | "been"
324            | "being"
325            | "have"
326            | "has"
327            | "had"
328            | "do"
329            | "does"
330            | "did"
331            | "will"
332            | "would"
333            | "could"
334            | "should"
335            | "may"
336            | "might"
337            | "must"
338            | "can"
339            | "this"
340            | "that"
341            | "these"
342            | "those"
343            | "i"
344            | "you"
345            | "he"
346            | "she"
347            | "it"
348            | "we"
349            | "they"
350            | "me"
351            | "him"
352            | "her"
353            | "us"
354            | "them"
355            | "my"
356            | "your"
357            | "his"
358            | "its"
359            | "our"
360            | "their"
361            | "mine"
362            | "yours"
363            | "hers"
364            | "ours"
365            | "theirs"
366            | "am"
367            | "not"
368            | "no"
369            | "yes"
370            | "here"
371            | "there"
372            | "now"
373            | "then"
374            | "so"
375            | "very"
376            | "too"
377            | "also"
378            | "only"
379            | "just"
380            | "even"
381            | "still"
382            | "yet"
383            | "again"
384            | "once"
385            | "never"
386            | "always"
387            | "often"
388            | "sometimes"
389            | "usually"
390    )
391}
392
393/// Rate limiting utilities
394pub struct RateLimiter {
395    requests_per_minute: u32,
396    burst_size: u32,
397    request_times: Vec<std::time::Instant>,
398}
399
400impl RateLimiter {
401    /// Create a new rate limiter
402    pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
403        Self {
404            requests_per_minute,
405            burst_size,
406            request_times: Vec::new(),
407        }
408    }
409
410    /// Check if request is allowed
411    pub fn is_allowed(&mut self) -> bool {
412        let now = std::time::Instant::now();
413        let window_start = now - std::time::Duration::from_secs(60);
414
415        // Remove old requests
416        self.request_times.retain(|&time| time > window_start);
417
418        // Check if within burst limit
419        if self.request_times.len() < self.burst_size as usize {
420            self.request_times.push(now);
421            return true;
422        }
423
424        // Check if within rate limit
425        let requests_in_window = self.request_times.len();
426        requests_in_window < self.requests_per_minute as usize
427    }
428
429    /// Get time until next allowed request
430    pub fn time_until_next(&self) -> std::time::Duration {
431        if self.request_times.is_empty() {
432            return std::time::Duration::from_secs(0);
433        }
434
435        let now = std::time::Instant::now();
436        let window_start = now - std::time::Duration::from_secs(60);
437
438        if let Some(&oldest_request) = self.request_times.first() {
439            if oldest_request > window_start {
440                oldest_request - window_start
441            } else {
442                std::time::Duration::from_secs(0)
443            }
444        } else {
445            std::time::Duration::from_secs(0)
446        }
447    }
448}
449
450/// Caching utilities
451pub struct Cache<K, V> {
452    data: HashMap<K, (V, std::time::Instant)>,
453    ttl: std::time::Duration,
454    max_size: usize,
455    hits: u64,
456    misses: u64,
457}
458
459impl<K, V> Cache<K, V>
460where
461    K: std::cmp::Eq + std::hash::Hash + Clone,
462    V: Clone,
463{
464    /// Create a new cache
465    pub fn new(ttl_secs: u64, max_size: usize) -> Self {
466        Self {
467            data: HashMap::new(),
468            ttl: std::time::Duration::from_secs(ttl_secs),
469            max_size,
470            hits: 0,
471            misses: 0,
472        }
473    }
474
475    /// Get value from cache
476    pub fn get(&mut self, key: &K) -> Option<V> {
477        if let Some((value, timestamp)) = self.data.get(key) {
478            let now = std::time::Instant::now();
479            if now.duration_since(*timestamp) < self.ttl {
480                self.hits += 1;
481                return Some(value.clone());
482            } else {
483                // Expired, remove it
484                self.data.remove(key);
485            }
486        }
487        self.misses += 1;
488        None
489    }
490
491    /// Put value in cache
492    pub fn put(&mut self, key: K, value: V) {
493        let now = std::time::Instant::now();
494
495        // Remove expired entries
496        self.data.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.ttl);
497
498        // Check if we need to evict old entries
499        if self.data.len() >= self.max_size {
500            // Simple LRU eviction - remove oldest entry
501            if let Some(oldest_key) = self.data.keys().next().cloned() {
502                self.data.remove(&oldest_key);
503            }
504        }
505
506        self.data.insert(key, (value, now));
507    }
508
509    /// Clear cache
510    pub fn clear(&mut self) {
511        self.data.clear();
512    }
513
514    /// Get cache size
515    pub fn size(&self) -> usize {
516        self.data.len()
517    }
518
519    /// Get cache hit rate
520    pub fn hit_rate(&self) -> f32 {
521        let total = self.hits + self.misses;
522        if total == 0 {
523            0.0
524        } else {
525            self.hits as f32 / total as f32
526        }
527    }
528}
529
530/// Performance monitoring utilities
531pub struct PerformanceMonitor {
532    start_time: std::time::Instant,
533    metrics: HashMap<String, MetricValue>,
534}
535
536/// Performance metric value types for monitoring RAG operations
537#[derive(Debug, Clone, Serialize, Deserialize)]
538pub enum MetricValue {
539    /// Counter metric (monotonically increasing value)
540    Counter(u64),
541    /// Gauge metric (can increase or decrease)
542    Gauge(f64),
543    /// Histogram metric (distribution of values)
544    Histogram(Vec<f64>),
545    /// Duration metric (time measurement)
546    Duration(std::time::Duration),
547}
548
549impl Default for PerformanceMonitor {
550    fn default() -> Self {
551        Self::new()
552    }
553}
554
555impl PerformanceMonitor {
556    /// Create a new performance monitor
557    pub fn new() -> Self {
558        Self {
559            start_time: std::time::Instant::now(),
560            metrics: HashMap::new(),
561        }
562    }
563
564    /// Start timing an operation
565    pub fn start_timer(&mut self, operation: &str) -> TimerGuard<'_> {
566        let start = std::time::Instant::now();
567        TimerGuard {
568            monitor: self,
569            operation: operation.to_string(),
570            start,
571        }
572    }
573
574    /// Record a metric
575    pub fn record_metric(&mut self, name: String, value: MetricValue) {
576        self.metrics.insert(name, value);
577    }
578
579    /// Increment counter
580    pub fn increment_counter(&mut self, name: &str) {
581        let counter = match self.metrics.get(name) {
582            Some(MetricValue::Counter(count)) => *count + 1,
583            _ => 1,
584        };
585        self.metrics.insert(name.to_string(), MetricValue::Counter(counter));
586    }
587
588    /// Record gauge value
589    pub fn record_gauge(&mut self, name: &str, value: f64) {
590        self.metrics.insert(name.to_string(), MetricValue::Gauge(value));
591    }
592
593    /// Get elapsed time since start
594    pub fn elapsed(&self) -> std::time::Duration {
595        self.start_time.elapsed()
596    }
597
598    /// Get all metrics
599    pub fn metrics(&self) -> &HashMap<String, MetricValue> {
600        &self.metrics
601    }
602
603    /// Get uptime
604    pub fn uptime(&self) -> std::time::Duration {
605        self.start_time.elapsed()
606    }
607}
608
609/// Timer guard for automatic timing
610pub struct TimerGuard<'a> {
611    monitor: &'a mut PerformanceMonitor,
612    operation: String,
613    start: std::time::Instant,
614}
615
616impl<'a> Drop for TimerGuard<'a> {
617    fn drop(&mut self) {
618        let duration = self.start.elapsed();
619        let operation_duration = format!("{}_duration", self.operation);
620        self.monitor.metrics.insert(operation_duration, MetricValue::Duration(duration));
621    }
622}
623
624/// File utilities for RAG operations
625pub struct FileUtils;
626
627impl FileUtils {
628    /// Read text file
629    pub async fn read_text_file<P: AsRef<Path>>(path: P) -> Result<String> {
630        let content = tokio::fs::read_to_string(path).await?;
631        Ok(content)
632    }
633
634    /// Write text file
635    pub async fn write_text_file<P: AsRef<Path>>(path: P, content: &str) -> Result<()> {
636        tokio::fs::write(path, content).await?;
637        Ok(())
638    }
639
640    /// Check if file exists
641    pub async fn file_exists<P: AsRef<Path>>(path: P) -> bool {
642        tokio::fs::try_exists(path).await.unwrap_or(false)
643    }
644
645    /// Get file size
646    pub async fn file_size<P: AsRef<Path>>(path: P) -> Result<u64> {
647        let metadata = tokio::fs::metadata(path).await?;
648        Ok(metadata.len())
649    }
650
651    /// List files in directory
652    pub async fn list_files<P: AsRef<Path>>(dir: P) -> Result<Vec<std::path::PathBuf>> {
653        let mut files = Vec::new();
654        let mut entries = tokio::fs::read_dir(dir).await?;
655
656        while let Some(entry) = entries.next_entry().await? {
657            files.push(entry.path());
658        }
659
660        Ok(files)
661    }
662
663    /// Read JSON file
664    pub async fn read_json_file<T: for<'de> Deserialize<'de>, P: AsRef<Path>>(
665        path: P,
666    ) -> Result<T> {
667        let content = Self::read_text_file(path).await?;
668        let data: T = serde_json::from_str(&content)?;
669        Ok(data)
670    }
671
672    /// Write JSON file
673    pub async fn write_json_file<T: Serialize, P: AsRef<Path>>(path: P, data: &T) -> Result<()> {
674        let content = serde_json::to_string_pretty(data)?;
675        Self::write_text_file(path, &content).await
676    }
677}
678
679/// Error handling utilities
680pub struct ErrorUtils;
681
682impl ErrorUtils {
683    /// Create a generic error
684    pub fn generic_error(message: &str) -> crate::Error {
685        crate::Error::generic(message.to_string())
686    }
687
688    /// Create an error with context
689    pub fn context_error(message: &str, context: &str) -> crate::Error {
690        crate::Error::generic(format!("{}: {}", message, context))
691    }
692
693    /// Wrap an error with additional context
694    pub fn wrap_error<E: std::fmt::Display>(error: E, context: &str) -> crate::Error {
695        crate::Error::generic(format!("{}: {}", context, error))
696    }
697
698    /// Check if error is retryable
699    pub fn is_retryable_error(error: &crate::Error) -> bool {
700        // Simple heuristic - in practice, you might want to categorize errors
701        error.to_string().contains("timeout")
702            || error.to_string().contains("rate limit")
703            || error.to_string().contains("503")
704            || error.to_string().contains("502")
705            || error.to_string().contains("504")
706    }
707}
708
709#[cfg(test)]
710mod tests {
711
712    #[test]
713    fn test_module_compiles() {
714        // Basic compilation test
715    }
716}