01_simple/
01-simple.rs

1/*
2Script to run a simple pipeline.
3
4Run with:
5```
6cargo run --example 01-simple
7```
8*/
9
10use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13    ChatAdapter, Example, LM, Module, Predict, Prediction, Predictor, Signature, configure,
14    example, prediction,
15};
16
17#[Signature(cot)]
18struct QASignature {
19    #[input]
20    pub question: String,
21
22    #[output]
23    pub answer: String,
24}
25
26#[Signature]
27struct RateSignature {
28    /// Rate the answer on a scale of 1(very bad) to 10(very good)
29
30    #[input]
31    pub question: String,
32
33    #[input]
34    pub answer: String,
35
36    #[output]
37    pub rating: i8,
38}
39
40#[derive(Builder)]
41pub struct QARater {
42    #[builder(default = Predict::new(QASignature::new()))]
43    pub answerer: Predict,
44    #[builder(default = Predict::new(RateSignature::new()))]
45    pub rater: Predict,
46}
47
48impl Module for QARater {
49    async fn forward(&self, inputs: Example) -> Result<Prediction> {
50        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
51
52        let question = inputs.data.get("question").unwrap().clone();
53        let answer = answerer_prediction.data.get("answer").unwrap().clone();
54
55        let inputs = example! {
56            "question": "input" => question.clone(),
57            "answer": "output" => answer.clone()
58        };
59
60        let rating_prediction = self.rater.forward(inputs).await?;
61        Ok(prediction! {
62            "answer"=> answer,
63            "question"=> question,
64            "rating"=> rating_prediction.data.get("rating").unwrap().clone(),
65        }
66        .set_lm_usage(rating_prediction.lm_usage))
67    }
68}
69
70#[tokio::main]
71async fn main() -> Result<()> {
72    configure(
73        LM::builder()
74            .model("openai:gpt-4o-mini".to_string())
75            .build()
76            .await
77            .unwrap(),
78        ChatAdapter,
79    );
80
81    let example = example! {
82        "question": "input" => "What is the capital of France?",
83    };
84
85    let qa_rater = QARater::builder().build();
86    let prediction = qa_rater.forward(example).await.unwrap();
87    println!("{prediction:?}");
88
89    Ok(())
90}