02_module_iteration_and_updation/
02-module-iteration-and-updation.rs

1/*
2Script to iterate and update the parameters of a module.
3
4Run with:
5```
6cargo run --example 02-module-iteration-and-updation
7```
8*/
9
10use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13    Example, Module, Optimizable, Predict, Prediction, Predictor, Signature, hashmap, prediction,
14};
15
16#[Signature(cot)]
17struct QASignature {
18    #[input]
19    pub question: String,
20
21    #[output]
22    pub answer: String,
23}
24
25#[Signature]
26struct RateSignature {
27    /// Rate the answer on a scale of 1(very bad) to 10(very good)
28
29    #[input]
30    pub question: String,
31
32    #[input]
33    pub answer: String,
34
35    #[output]
36    pub rating: i8,
37}
38
39#[derive(Builder, Optimizable)]
40pub struct QARater {
41    #[parameter]
42    #[builder(default = Predict::new(QASignature::new()))]
43    pub answerer: Predict,
44
45    #[parameter]
46    #[builder(default = Predict::new(RateSignature::new()))]
47    pub rater: Predict,
48}
49
50#[derive(Builder, Optimizable)]
51pub struct NestedModule {
52    #[parameter]
53    #[builder(default = QARater::builder().build())]
54    pub qa_outer: QARater,
55
56    #[parameter]
57    #[builder(default = QARater::builder().build())]
58    pub qa_inner: QARater,
59
60    #[parameter]
61    #[builder(default = Predict::new(QASignature::new()))]
62    pub extra: Predict,
63}
64
65impl Module for QARater {
66    async fn forward(&self, inputs: Example) -> Result<Prediction> {
67        let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
68
69        let question = inputs.data.get("question").unwrap().clone();
70        let answer = answerer_prediction.data.get("answer").unwrap().clone();
71
72        let inputs = Example::new(
73            hashmap! {
74                "answer".to_string() => answer.clone(),
75                "question".to_string() => question.clone()
76            },
77            vec!["answer".to_string(), "question".to_string()],
78            vec![],
79        );
80        let rating_prediction = self.rater.forward(inputs).await?;
81        Ok(prediction! {
82            "answer"=> answer,
83            "question"=> question,
84            "rating"=> rating_prediction.data.get("rating").unwrap().clone(),
85        }
86        .set_lm_usage(rating_prediction.lm_usage))
87    }
88}
89
90#[tokio::main]
91async fn main() {
92    // Single module test
93    let mut qa_rater = QARater::builder().build();
94    for (name, param) in qa_rater.parameters() {
95        param
96            .update_signature_instruction("Updated instruction for ".to_string() + &name)
97            .unwrap();
98    }
99    println!(
100        "single.answerer -> {}",
101        qa_rater.answerer.signature.instruction()
102    );
103    println!(
104        "single.rater    -> {}",
105        qa_rater.rater.signature.instruction()
106    );
107
108    // Nested module test
109    let mut nested = NestedModule::builder().build();
110    for (name, param) in nested.parameters() {
111        param
112            .update_signature_instruction("Deep updated: ".to_string() + &name)
113            .unwrap();
114    }
115
116    // Show nested updates (module-in-module)
117    println!(
118        "nested.qa_outer.answerer -> {}",
119        nested.qa_outer.answerer.signature.instruction()
120    );
121    println!(
122        "nested.qa_outer.rater    -> {}",
123        nested.qa_outer.rater.signature.instruction()
124    );
125    println!(
126        "nested.qa_inner.answerer -> {}",
127        nested.qa_inner.answerer.signature.instruction()
128    );
129    println!(
130        "nested.qa_inner.rater    -> {}",
131        nested.qa_inner.rater.signature.instruction()
132    );
133    println!(
134        "nested.extra    -> {}",
135        nested.extra.signature.instruction()
136    );
137}