04_optimize_hotpotqa/
04-optimize-hotpotqa.rs

1/*
2Script to optimize the answerer of the QARater module for a tiny sample of the HotpotQA dataset.
3
4Run with:
5```
6cargo run --example 04-optimize-hotpotqa --features dataloaders
7```
8
9Note: The `dataloaders` feature is required for loading datasets.
10*/
11
12use anyhow::Result;
13use bon::Builder;
14use dspy_rs::{
15    COPRO, ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Optimizer,
16    Predict, Prediction, Predictor, Signature, configure,
17};
18
19#[Signature(cot)]
20struct QASignature {
21    /// Concisely answer the question but be accurate. If it's a yes no question, answer with yes or no.
22
23    #[input]
24    pub question: String,
25
26    #[output(desc = "Answer in less than 5 words.")]
27    pub answer: String,
28}
29
30#[derive(Builder, Optimizable)]
31pub struct QARater {
32    #[parameter]
33    #[builder(default = Predict::new(QASignature::new()))]
34    pub answerer: Predict,
35}
36
37impl Module for QARater {
38    async fn forward(&self, inputs: Example) -> Result<Prediction> {
39        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
40
41        Ok(answerer_prediction)
42    }
43}
44
45impl Evaluator for QARater {
46    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
47        let answer = example.data.get("answer").unwrap().clone();
48        let prediction = prediction.data.get("answer").unwrap().clone();
49        println!("Answer: {answer}");
50        println!("Prediction: {prediction}");
51        if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
52            1.0
53        } else {
54            0.0
55        }
56    }
57}
58
59#[tokio::main]
60async fn main() -> anyhow::Result<()> {
61    configure(
62        LM::builder()
63            .model("openai:gpt-4o-mini".to_string())
64            .build()
65            .await
66            .unwrap(),
67        ChatAdapter {},
68    );
69
70    let examples = DataLoader::load_hf(
71        "hotpotqa/hotpot_qa",
72        vec!["question".to_string()],
73        vec!["answer".to_string()],
74        "fullwiki",
75        "validation",
76        true,
77    )?[..10]
78        .to_vec();
79
80    let mut rater = QARater::builder().build();
81    let optimizer = COPRO::builder().breadth(10).depth(1).build();
82
83    println!("Rater: {:?}", rater.answerer.get_signature().instruction());
84
85    optimizer.compile(&mut rater, examples.clone()).await?;
86
87    println!("Rater: {:?}", rater.answerer.get_signature().instruction());
88
89    Ok(())
90}