03_evaluate_hotpotqa/
03-evaluate-hotpotqa.rs1use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13 ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Predict, Prediction,
14 Predictor, Signature, configure,
15};
16use secrecy::SecretString;
17
18#[Signature(cot)]
19struct QASignature {
20 #[input]
23 pub question: String,
24
25 #[output(desc = "Answer in less than 5 words.")]
26 pub answer: String,
27}
28
29#[derive(Builder, Optimizable)]
30pub struct QARater {
31 #[parameter]
32 #[builder(default = Predict::new(QASignature::new()))]
33 pub answerer: Predict,
34}
35
36impl Module for QARater {
37 async fn forward(&self, inputs: Example) -> Result<Prediction> {
38 let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
39
40 Ok(answerer_prediction)
41 }
42}
43
44impl Evaluator for QARater {
45 const MAX_CONCURRENCY: usize = 16;
46 const DISPLAY_PROGRESS: bool = true;
47
48 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
49 let answer = example.data.get("answer").unwrap().clone();
50 let prediction = prediction.data.get("answer").unwrap().clone();
51
52 if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
53 1.0
54 } else {
55 0.0
56 }
57 }
58}
59
60#[tokio::main]
61async fn main() -> anyhow::Result<()> {
62 configure(
63 LM::builder()
64 .api_key(SecretString::from(std::env::var("OPENAI_API_KEY")?))
65 .build(),
66 ChatAdapter {},
67 );
68
69 let examples = DataLoader::load_hf(
70 "hotpotqa/hotpot_qa",
71 vec!["question".to_string()],
72 vec!["answer".to_string()],
73 "fullwiki",
74 "validation",
75 true,
76 )?[..128]
77 .to_vec();
78
79 let evaluator = QARater::builder().build();
80 let metric = evaluator.evaluate(examples).await;
81
82 println!("Metric: {metric}");
83 Ok(())
84}