03_evaluate_hotpotqa/
03-evaluate-hotpotqa.rs

1/*
2Script to evaluate the answerer of the QARater module for a tiny sample of the HotpotQA dataset.
3
4Run with:
5```
6cargo run --example 03-evaluate-hotpotqa --features dataloaders
7```
8
9Note: The `dataloaders` feature is required for loading datasets.
10*/
11
12use anyhow::Result;
13use bon::Builder;
14use dspy_rs::{
15    ChatAdapter, Evaluator, Example, LM, Module, Optimizable, Predict, Prediction, Predictor,
16    Signature, configure,
17};
18
19use dspy_rs::DataLoader;
20
21#[Signature(cot)]
22struct QASignature {
23    /// Concisely answer the question but be accurate. If it's a yes no question, answer with yes or no.
24
25    #[input]
26    pub question: String,
27
28    #[output(desc = "Answer in less than 5 words.")]
29    pub answer: String,
30}
31
32#[derive(Builder, Optimizable)]
33pub struct QARater {
34    #[parameter]
35    #[builder(default = Predict::new(QASignature::new()))]
36    pub answerer: Predict,
37}
38
39impl Module for QARater {
40    async fn forward(&self, inputs: Example) -> Result<Prediction> {
41        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
42
43        Ok(answerer_prediction)
44    }
45}
46
47impl Evaluator for QARater {
48    const MAX_CONCURRENCY: usize = 16;
49    const DISPLAY_PROGRESS: bool = true;
50
51    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
52        let answer = example.data.get("answer").unwrap().clone();
53        let prediction = prediction.data.get("answer").unwrap().clone();
54
55        if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
56            1.0
57        } else {
58            0.0
59        }
60    }
61}
62
63#[tokio::main]
64async fn main() -> anyhow::Result<()> {
65    configure(
66        LM::builder()
67            .model("openai:gpt-4o-mini".to_string())
68            .build()
69            .await?,
70        ChatAdapter {},
71    );
72
73    let examples = DataLoader::load_hf(
74        "hotpotqa/hotpot_qa",
75        vec!["question".to_string()],
76        vec!["answer".to_string()],
77        "fullwiki",
78        "validation",
79        true,
80    )?[..128]
81        .to_vec();
82
83    let evaluator = QARater::builder().build();
84    let metric = evaluator.evaluate(examples).await;
85
86    println!("Metric: {metric}");
87    Ok(())
88}