04_optimize_hotpotqa/
04-optimize-hotpotqa.rs

1/*
2Script to optimize the answerer of the QARater module for a tiny sample of the HotpotQA dataset.
3
4Run with:
5```
6cargo run --example 04-optimize-hotpotqa
7```
8*/
9
10use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13    COPRO, ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Optimizer,
14    Predict, Prediction, Predictor, Signature, configure,
15};
16use secrecy::SecretString;
17
18#[Signature(cot)]
19struct QASignature {
20    /// Concisely answer the question but be accurate. If it's a yes no question, answer with yes or no.
21
22    #[input]
23    pub question: String,
24
25    #[output(desc = "Answer in less than 5 words.")]
26    pub answer: String,
27}
28
29#[derive(Builder, Optimizable)]
30pub struct QARater {
31    #[parameter]
32    #[builder(default = Predict::new(QASignature::new()))]
33    pub answerer: Predict,
34}
35
36impl Module for QARater {
37    async fn forward(&self, inputs: Example) -> Result<Prediction> {
38        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
39
40        Ok(answerer_prediction)
41    }
42}
43
44impl Evaluator for QARater {
45    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
46        let answer = example.data.get("answer").unwrap().clone();
47        let prediction = prediction.data.get("answer").unwrap().clone();
48        println!("Answer: {answer}");
49        println!("Prediction: {prediction}");
50        if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
51            1.0
52        } else {
53            0.0
54        }
55    }
56}
57
58#[tokio::main]
59async fn main() -> anyhow::Result<()> {
60    configure(
61        LM::builder()
62            .api_key(SecretString::from(std::env::var("OPENAI_API_KEY")?))
63            .build(),
64        ChatAdapter {},
65    );
66
67    let examples = DataLoader::load_hf(
68        "hotpotqa/hotpot_qa",
69        vec!["question".to_string()],
70        vec!["answer".to_string()],
71        "fullwiki",
72        "validation",
73        true,
74    )?[..10]
75        .to_vec();
76
77    let mut rater = QARater::builder().build();
78    let optimizer = COPRO::builder().breadth(10).depth(1).build();
79
80    println!("Rater: {:?}", rater.answerer.get_signature().instruction());
81
82    optimizer.compile(&mut rater, examples.clone()).await?;
83
84    println!("Rater: {:?}", rater.answerer.get_signature().instruction());
85
86    Ok(())
87}