06_oai_compatible_models_batch/
06-oai-compatible-models-batch.rs

1/*
2Script to run a simple pipeline.
3
4Run with:
5```
6cargo run --example 01-simple
7```
8*/
9
10use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13    ChatAdapter, Example, LM, LMConfig, Module, Predict, Prediction, Predictor, Signature,
14    configure, example, hashmap, prediction,
15};
16use secrecy::SecretString;
17
18#[Signature(cot)]
19struct QASignature {
20    #[input]
21    pub question: String,
22
23    #[output]
24    pub answer: String,
25}
26
27#[Signature]
28struct RateSignature {
29    /// Rate the answer on a scale of 1(very bad) to 10(very good)
30
31    #[input]
32    pub question: String,
33
34    #[input]
35    pub answer: String,
36
37    #[output]
38    pub rating: i8,
39}
40
41#[derive(Builder)]
42pub struct QARater {
43    #[builder(default = Predict::new(QASignature::new()))]
44    pub answerer: Predict,
45    #[builder(default = Predict::new(RateSignature::new()))]
46    pub rater: Predict,
47}
48
49impl Module for QARater {
50    async fn forward(&self, inputs: Example) -> Result<Prediction> {
51        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
52
53        let question = inputs.data.get("question").unwrap().clone();
54        let answer = answerer_prediction.data.get("answer").unwrap().clone();
55        let answer_lm_usage = answerer_prediction.lm_usage;
56
57        let inputs = Example::new(
58            hashmap! {
59                "answer".to_string() => answer.clone(),
60                "question".to_string() => question.clone()
61            },
62            vec!["answer".to_string(), "question".to_string()],
63            vec![],
64        );
65        let rating_prediction = self.rater.forward(inputs).await?;
66        let rating_lm_usage = rating_prediction.lm_usage;
67
68        Ok(prediction! {
69            "answer"=> answer,
70            "question"=> question,
71            "rating"=> rating_prediction.data.get("rating").unwrap().clone(),
72        }
73        .set_lm_usage(answer_lm_usage + rating_lm_usage))
74    }
75}
76
77#[tokio::main]
78async fn main() {
79    // Anthropic
80    configure(
81        LM::builder()
82            .api_key(SecretString::from(
83                std::env::var("ANTHROPIC_API_KEY").unwrap(),
84            ))
85            .config(LMConfig {
86                model: "anthropic/claude-sonnet-4-20250514".to_string(),
87                ..LMConfig::default()
88            })
89            .build(),
90        ChatAdapter,
91    );
92
93    let example = vec![
94        example! {
95            "question": "input" => "What is the capital of France?",
96        },
97        example! {
98            "question": "input" => "What is the capital of Germany?",
99        },
100        example! {
101            "question": "input" => "What is the capital of Italy?",
102        },
103    ];
104
105    let qa_rater = QARater::builder().build();
106    let prediction = qa_rater.batch(example.clone(), 2, true).await.unwrap();
107    println!("Anthropic: {prediction:?}");
108
109    // Gemini
110    configure(
111        LM::builder()
112            .api_key(SecretString::from(std::env::var("GEMINI_API_KEY").unwrap()))
113            .config(LMConfig {
114                model: "google/gemini-2.0-flash".to_string(),
115                ..LMConfig::default()
116            })
117            .build(),
118        ChatAdapter,
119    );
120
121    let prediction = qa_rater.batch(example, 2, true).await.unwrap();
122    println!("Gemini: {prediction:?}");
123}