06_oai_compatible_models_batch/
06-oai-compatible-models-batch.rs1use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13 ChatAdapter, Example, LM, LMConfig, Module, Predict, Prediction, Predictor, Signature,
14 configure, example, hashmap, prediction,
15};
16use secrecy::SecretString;
17
18#[Signature(cot)]
19struct QASignature {
20 #[input]
21 pub question: String,
22
23 #[output]
24 pub answer: String,
25}
26
27#[Signature]
28struct RateSignature {
29 #[input]
32 pub question: String,
33
34 #[input]
35 pub answer: String,
36
37 #[output]
38 pub rating: i8,
39}
40
41#[derive(Builder)]
42pub struct QARater {
43 #[builder(default = Predict::new(QASignature::new()))]
44 pub answerer: Predict,
45 #[builder(default = Predict::new(RateSignature::new()))]
46 pub rater: Predict,
47}
48
49impl Module for QARater {
50 async fn forward(&self, inputs: Example) -> Result<Prediction> {
51 let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
52
53 let question = inputs.data.get("question").unwrap().clone();
54 let answer = answerer_prediction.data.get("answer").unwrap().clone();
55 let answer_lm_usage = answerer_prediction.lm_usage;
56
57 let inputs = Example::new(
58 hashmap! {
59 "answer".to_string() => answer.clone(),
60 "question".to_string() => question.clone()
61 },
62 vec!["answer".to_string(), "question".to_string()],
63 vec![],
64 );
65 let rating_prediction = self.rater.forward(inputs).await?;
66 let rating_lm_usage = rating_prediction.lm_usage;
67
68 Ok(prediction! {
69 "answer"=> answer,
70 "question"=> question,
71 "rating"=> rating_prediction.data.get("rating").unwrap().clone(),
72 }
73 .set_lm_usage(answer_lm_usage + rating_lm_usage))
74 }
75}
76
77#[tokio::main]
78async fn main() {
79 configure(
81 LM::builder()
82 .api_key(SecretString::from(
83 std::env::var("ANTHROPIC_API_KEY").unwrap(),
84 ))
85 .config(LMConfig {
86 model: "anthropic/claude-sonnet-4-20250514".to_string(),
87 ..LMConfig::default()
88 })
89 .build(),
90 ChatAdapter,
91 );
92
93 let example = vec![
94 example! {
95 "question": "input" => "What is the capital of France?",
96 },
97 example! {
98 "question": "input" => "What is the capital of Germany?",
99 },
100 example! {
101 "question": "input" => "What is the capital of Italy?",
102 },
103 ];
104
105 let qa_rater = QARater::builder().build();
106 let prediction = qa_rater.batch(example.clone(), 2, true).await.unwrap();
107 println!("Anthropic: {prediction:?}");
108
109 configure(
111 LM::builder()
112 .api_key(SecretString::from(std::env::var("GEMINI_API_KEY").unwrap()))
113 .config(LMConfig {
114 model: "google/gemini-2.0-flash".to_string(),
115 ..LMConfig::default()
116 })
117 .build(),
118 ChatAdapter,
119 );
120
121 let prediction = qa_rater.batch(example, 2, true).await.unwrap();
122 println!("Gemini: {prediction:?}");
123}