02_module_iteration_and_updation/
02-module-iteration-and-updation.rs1use anyhow::Result;
11use bon::Builder;
12use dspy_rs::{
13 Example, Module, Optimizable, Predict, Prediction, Predictor, Signature, hashmap, prediction,
14};
15
16#[Signature(cot)]
17struct QASignature {
18 #[input]
19 pub question: String,
20
21 #[output]
22 pub answer: String,
23}
24
25#[Signature]
26struct RateSignature {
27 #[input]
30 pub question: String,
31
32 #[input]
33 pub answer: String,
34
35 #[output]
36 pub rating: i8,
37}
38
39#[derive(Builder, Optimizable)]
40pub struct QARater {
41 #[parameter]
42 #[builder(default = Predict::new(QASignature::new()))]
43 pub answerer: Predict,
44
45 #[parameter]
46 #[builder(default = Predict::new(RateSignature::new()))]
47 pub rater: Predict,
48}
49
50#[derive(Builder, Optimizable)]
51pub struct NestedModule {
52 #[parameter]
53 #[builder(default = QARater::builder().build())]
54 pub qa_outer: QARater,
55
56 #[parameter]
57 #[builder(default = QARater::builder().build())]
58 pub qa_inner: QARater,
59
60 #[parameter]
61 #[builder(default = Predict::new(QASignature::new()))]
62 pub extra: Predict,
63}
64
65impl Module for QARater {
66 async fn forward(&self, inputs: Example) -> Result<Prediction> {
67 let answerer_prediction = self.answerer.forward(inputs.clone()).await?;
68
69 let question = inputs.data.get("question").unwrap().clone();
70 let answer = answerer_prediction.data.get("answer").unwrap().clone();
71
72 let inputs = Example::new(
73 hashmap! {
74 "answer".to_string() => answer.clone(),
75 "question".to_string() => question.clone()
76 },
77 vec!["answer".to_string(), "question".to_string()],
78 vec![],
79 );
80 let rating_prediction = self.rater.forward(inputs).await?;
81 Ok(prediction! {
82 "answer"=> answer,
83 "question"=> question,
84 "rating"=> rating_prediction.data.get("rating").unwrap().clone(),
85 }
86 .set_lm_usage(rating_prediction.lm_usage))
87 }
88}
89
90#[tokio::main]
91async fn main() {
92 let mut qa_rater = QARater::builder().build();
94 for (name, param) in qa_rater.parameters() {
95 param
96 .update_signature_instruction("Updated instruction for ".to_string() + &name)
97 .unwrap();
98 }
99 println!(
100 "single.answerer -> {}",
101 qa_rater.answerer.signature.instruction()
102 );
103 println!(
104 "single.rater -> {}",
105 qa_rater.rater.signature.instruction()
106 );
107
108 let mut nested = NestedModule::builder().build();
110 for (name, param) in nested.parameters() {
111 param
112 .update_signature_instruction("Deep updated: ".to_string() + &name)
113 .unwrap();
114 }
115
116 println!(
118 "nested.qa_outer.answerer -> {}",
119 nested.qa_outer.answerer.signature.instruction()
120 );
121 println!(
122 "nested.qa_outer.rater -> {}",
123 nested.qa_outer.rater.signature.instruction()
124 );
125 println!(
126 "nested.qa_inner.answerer -> {}",
127 nested.qa_inner.answerer.signature.instruction()
128 );
129 println!(
130 "nested.qa_inner.rater -> {}",
131 nested.qa_inner.rater.signature.instruction()
132 );
133 println!(
134 "nested.extra -> {}",
135 nested.extra.signature.instruction()
136 );
137}