Skip to main content

trueno_rag/eval/
generate.rs

1//! Synthetic ground truth generation from corpus chunks
2
3use super::client::AnthropicClient;
4use super::domain::{classify_domain, extract_course_dir};
5use super::types::GroundTruthEntry;
6use rand::seq::SliceRandom;
7use rand::SeedableRng;
8use std::collections::HashMap;
9
10const SYSTEM_PROMPT: &str = "You generate evaluation questions from video transcript chunks.
11Given a transcript chunk, generate ONE specific question this text answers.
12Rules:
13(1) The question must be answerable only from the provided text.
14(2) Write a student-style query, 8-20 words long.
15(3) Do NOT reference \"the video\", \"the instructor\", \"the speaker\", or \"this lecture\".
16(4) Do NOT ask yes/no questions.
17(5) If the text is too vague or navigational to generate a good question, respond with exactly: SKIP";
18
19/// Chunk data extracted from a PersistedIndex
20#[derive(Debug, Clone)]
21pub struct IndexChunk {
22    /// Chunk text content
23    pub content: String,
24    /// Source file path
25    pub source: String,
26    /// Optional title
27    pub title: Option<String>,
28    /// Start timestamp
29    pub start_secs: Option<f64>,
30    /// End timestamp
31    pub end_secs: Option<f64>,
32}
33
34/// Generator for synthetic ground truth
35pub struct GroundTruthGenerator {
36    client: AnthropicClient,
37    model: String,
38    sample_size: usize,
39    seed: u64,
40}
41
42impl GroundTruthGenerator {
43    /// Create a new generator
44    pub fn new(client: AnthropicClient, model: &str, sample_size: usize, seed: u64) -> Self {
45        Self { client, model: model.to_string(), sample_size, seed }
46    }
47
48    /// Sample chunks using stratified sampling by course directory
49    pub fn sample_chunks(&self, chunks: &[IndexChunk]) -> Vec<SampledChunk> {
50        let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
51
52        // Group by course
53        let mut by_course: HashMap<String, Vec<&IndexChunk>> = HashMap::new();
54        for chunk in chunks {
55            let course = extract_course_dir(&chunk.source).to_string();
56            by_course.entry(course).or_default().push(chunk);
57        }
58
59        // Sort courses by chunk count descending, then by name for determinism
60        let mut courses: Vec<(String, Vec<&IndexChunk>)> = by_course.into_iter().collect();
61        courses.sort_by(|a, b| b.1.len().cmp(&a.1.len()).then_with(|| a.0.cmp(&b.0)));
62
63        let mut sampled = Vec::new();
64
65        for (course, course_chunks) in &courses {
66            // Filter eligible chunks
67            let eligible: Vec<&&IndexChunk> =
68                course_chunks.iter().filter(|c| is_eligible(c)).collect();
69
70            if eligible.len() < 2 {
71                continue;
72            }
73
74            // Sample 2-3 chunks per course
75            let n = eligible.len().min(3);
76            let mut indices: Vec<usize> = (0..eligible.len()).collect();
77            indices.shuffle(&mut rng);
78
79            for &idx in indices.iter().take(n) {
80                let chunk = eligible[idx];
81                sampled.push(SampledChunk {
82                    content: chunk.content.clone(),
83                    source: chunk.source.clone(),
84                    start_secs: chunk.start_secs,
85                    end_secs: chunk.end_secs,
86                    course: course.clone(),
87                    domain: classify_domain(course).to_string(),
88                });
89            }
90
91            if sampled.len() >= self.sample_size {
92                break;
93            }
94        }
95
96        // Trim to exact size
97        sampled.truncate(self.sample_size);
98
99        // Report distribution
100        let mut domain_counts: HashMap<&str, usize> = HashMap::new();
101        for s in &sampled {
102            *domain_counts.entry(&s.domain).or_default() += 1;
103        }
104        eprintln!(
105            "Sampled {} chunks from {} courses",
106            sampled.len(),
107            courses.len().min(sampled.len())
108        );
109        let mut sorted_domains: Vec<_> = domain_counts.into_iter().collect();
110        sorted_domains.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
111        for (domain, count) in &sorted_domains {
112            eprintln!("  {domain}: {count}");
113        }
114
115        sampled
116    }
117
118    /// Generate a question for a single chunk
119    pub async fn generate_question(&self, content: &str) -> Result<Option<String>, String> {
120        let user_msg = format!("Transcript chunk:\n---\n{content}\n---");
121
122        let result = self.client.complete(&self.model, Some(SYSTEM_PROMPT), &user_msg, 150).await?;
123
124        let text = result.text.trim().to_string();
125        if text == "SKIP" || text.starts_with("SKIP") {
126            return Ok(None);
127        }
128
129        // Clean up
130        let mut question = text.trim_matches('"').trim_matches('\'').trim().to_string();
131        if !question.ends_with('?') {
132            question.push('?');
133        }
134
135        Ok(Some(question))
136    }
137
138    /// Generate ground truth for all sampled chunks
139    pub async fn generate(&self, chunks: &[IndexChunk]) -> Result<Vec<GroundTruthEntry>, String> {
140        let sampled = self.sample_chunks(chunks);
141        let total = sampled.len();
142        let mut results = Vec::new();
143        let mut skipped = 0usize;
144        let mut errors = 0usize;
145
146        for (i, sample) in sampled.iter().enumerate() {
147            eprint!("[{}/{}] {} ({})...", i + 1, total, sample.course, sample.domain);
148
149            match self.generate_question(&sample.content).await {
150                Ok(Some(question)) => {
151                    eprintln!(" {}", &question[..question.len().min(60)]);
152                    results.push(GroundTruthEntry {
153                        query: question,
154                        chunk_content: sample.content.clone(),
155                        chunk_source: sample.source.clone(),
156                        chunk_start_secs: sample.start_secs,
157                        chunk_end_secs: sample.end_secs,
158                        domain: sample.domain.clone(),
159                        course: sample.course.clone(),
160                    });
161                }
162                Ok(None) => {
163                    eprintln!(" SKIP");
164                    skipped += 1;
165                }
166                Err(e) => {
167                    eprintln!(" ERROR: {e}");
168                    errors += 1;
169                }
170            }
171        }
172
173        eprintln!("\nGenerated {} queries, {} skipped, {} errors", results.len(), skipped, errors);
174
175        Ok(results)
176    }
177}
178
179/// A sampled chunk with its metadata
180#[derive(Debug, Clone)]
181pub struct SampledChunk {
182    /// Chunk text
183    pub content: String,
184    /// Source path
185    pub source: String,
186    /// Start time
187    pub start_secs: Option<f64>,
188    /// End time
189    pub end_secs: Option<f64>,
190    /// Course directory
191    pub course: String,
192    /// Domain
193    pub domain: String,
194}
195
196/// Check if a chunk is eligible for question generation
197fn is_eligible(chunk: &IndexChunk) -> bool {
198    let words: Vec<&str> = chunk.content.split_whitespace().collect();
199    if words.len() < 50 {
200        return false;
201    }
202    let lowered: Vec<String> = words.iter().map(|w| w.to_lowercase()).collect();
203    let unique: std::collections::HashSet<&str> = lowered.iter().map(|w| w.as_str()).collect();
204    if unique.len() < 15 {
205        return false;
206    }
207
208    // Skip navigational boilerplate
209    let lower = chunk.content.to_lowercase();
210    let nav_phrases = [
211        "welcome back",
212        "in this video",
213        "let's go ahead",
214        "see you in the next",
215        "don't forget to subscribe",
216        "click the link",
217        "table of contents",
218    ];
219    let nav_count = nav_phrases.iter().filter(|p| lower.contains(*p)).count();
220    nav_count < 3
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    fn make_chunk(content: &str, source: &str) -> IndexChunk {
228        IndexChunk {
229            content: content.to_string(),
230            source: source.to_string(),
231            title: None,
232            start_secs: Some(0.0),
233            end_secs: Some(30.0),
234        }
235    }
236
237    #[test]
238    fn test_is_eligible_short() {
239        let chunk = make_chunk("too short", "/data/courses/test/build/a.srt");
240        assert!(!is_eligible(&chunk));
241    }
242
243    #[test]
244    fn test_is_eligible_valid() {
245        let words: Vec<String> = (0..60).map(|i| format!("word{i}")).collect();
246        let content = words.join(" ");
247        let chunk = make_chunk(&content, "/data/courses/test/build/a.srt");
248        assert!(is_eligible(&chunk));
249    }
250
251    #[test]
252    fn test_sampling_deterministic() {
253        let chunks: Vec<IndexChunk> = (0..100)
254            .map(|i| {
255                let words: Vec<String> = (0..60).map(|j| format!("w{j}c{i}")).collect();
256                make_chunk(
257                    &words.join(" "),
258                    &format!("/data/courses/course-{}/build/vid.srt", i / 5),
259                )
260            })
261            .collect();
262
263        let gen1 = GroundTruthGenerator::new(AnthropicClient::new("fake"), "model", 20, 42);
264        let gen2 = GroundTruthGenerator::new(AnthropicClient::new("fake"), "model", 20, 42);
265
266        let s1 = gen1.sample_chunks(&chunks);
267        let s2 = gen2.sample_chunks(&chunks);
268
269        assert_eq!(s1.len(), s2.len());
270        for (a, b) in s1.iter().zip(s2.iter()) {
271            assert_eq!(a.source, b.source);
272            assert_eq!(a.course, b.course);
273        }
274    }
275}