03_evaluate_hotpotqa/
03-evaluate-hotpotqa.rs

1/*
2Script to evaluate the answerer of the QARater module for a tiny sample of the HotpotQA dataset.
3
4Run with:
5```
6cargo run --example 03-evaluate-hotpotqa --features dataloaders
7```
8
9Note: The `dataloaders` feature is required for loading datasets.
10*/
11
12use anyhow::Result;
13use bon::Builder;
14use dspy_rs::{
15    ChatAdapter, Evaluator, Example, LM, Module, Optimizable, Predict, Prediction, Predictor,
16    Signature, configure,
17};
18
19use dspy_rs::DataLoader;
20
21#[Signature(cot)]
22struct QASignature {
23    /// Concisely answer the question but be accurate. If it's a yes no question, answer with yes or no.
24
25    #[input]
26    pub question: String,
27
28    #[output(desc = "Answer in less than 5 words.")]
29    pub answer: String,
30}
31
32#[derive(Builder, Optimizable)]
33pub struct QARater {
34    #[parameter]
35    #[builder(default = Predict::new(QASignature::new()))]
36    pub answerer: Predict,
37}
38
39impl Module for QARater {
40    async fn forward(&self, inputs: Example) -> Result<Prediction> {
41        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
42
43        Ok(answerer_prediction)
44    }
45}
46
47impl Evaluator for QARater {
48    const MAX_CONCURRENCY: usize = 16;
49    const DISPLAY_PROGRESS: bool = true;
50
51    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
52        let answer = example.data.get("answer").unwrap().clone();
53        let prediction = prediction.data.get("answer").unwrap().clone();
54
55        if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
56            1.0
57        } else {
58            0.0
59        }
60    }
61}
62
63#[tokio::main]
64async fn main() -> anyhow::Result<()> {
65    configure(
66        LM::builder()
67            .model("openai:gpt-4o-mini".to_string())
68            .build()
69            .await
70            .unwrap(),
71        ChatAdapter {},
72    );
73
74    let examples = DataLoader::load_hf(
75        "hotpotqa/hotpot_qa",
76        vec!["question".to_string()],
77        vec!["answer".to_string()],
78        "fullwiki",
79        "validation",
80        true,
81    )?[..128]
82        .to_vec();
83
84    let evaluator = QARater::builder().build();
85    let metric = evaluator.evaluate(examples).await;
86
87    println!("Metric: {metric}");
88    Ok(())
89}