04_optimize_hotpotqa/
04-optimize-hotpotqa.rs1use anyhow::Result;
13use bon::Builder;
14use dspy_rs::{
15 COPRO, ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Optimizer,
16 Predict, Prediction, Predictor, Signature, configure,
17};
18
19#[Signature(cot)]
20struct QASignature {
21 #[input]
24 pub question: String,
25
26 #[output(desc = "Answer in less than 5 words.")]
27 pub answer: String,
28}
29
30#[derive(Builder, Optimizable)]
31pub struct QARater {
32 #[parameter]
33 #[builder(default = Predict::new(QASignature::new()))]
34 pub answerer: Predict,
35}
36
37impl Module for QARater {
38 async fn forward(&self, inputs: Example) -> Result<Prediction> {
39 let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
40
41 Ok(answerer_prediction)
42 }
43}
44
45impl Evaluator for QARater {
46 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
47 let answer = example.data.get("answer").unwrap().clone();
48 let prediction = prediction.data.get("answer").unwrap().clone();
49 println!("Answer: {answer}");
50 println!("Prediction: {prediction}");
51 if answer.to_string().to_lowercase() == prediction.to_string().to_lowercase() {
52 1.0
53 } else {
54 0.0
55 }
56 }
57}
58
59#[tokio::main]
60async fn main() -> anyhow::Result<()> {
61 configure(
62 LM::builder()
63 .model("openai:gpt-4o-mini".to_string())
64 .build()
65 .await
66 .unwrap(),
67 ChatAdapter {},
68 );
69
70 let examples = DataLoader::load_hf(
71 "hotpotqa/hotpot_qa",
72 vec!["question".to_string()],
73 vec!["answer".to_string()],
74 "fullwiki",
75 "validation",
76 true,
77 )?[..10]
78 .to_vec();
79
80 let mut rater = QARater::builder().build();
81 let optimizer = COPRO::builder().breadth(10).depth(1).build();
82
83 println!("Rater: {:?}", rater.answerer.get_signature().instruction());
84
85 optimizer.compile(&mut rater, examples.clone()).await?;
86
87 println!("Rater: {:?}", rater.answerer.get_signature().instruction());
88
89 Ok(())
90}