1use crate as dspy_rs;
2use crate::{
16 Evaluator, Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
17 example, get_lm,
18};
19use anyhow::{Context, Result};
20use bon::Builder;
21use dsrs_macros::Signature;
22use std::sync::Arc;
23
24#[Signature]
29struct GenerateProgramDescription {
30 #[input(desc = "The task signature showing input and output fields")]
33 pub signature_fields: String,
34
35 #[input(desc = "Example input-output traces from the program")]
36 pub example_traces: String,
37
38 #[output(desc = "A clear description of what the program does")]
39 pub program_description: String,
40}
41
42#[Signature]
43struct GenerateInstructionFromTips {
44 #[input(desc = "Description of what the program should do")]
49 pub program_description: String,
50
51 #[input(desc = "Example input-output traces showing desired behavior")]
52 pub example_traces: String,
53
54 #[input(desc = "Best practices and tips for writing effective prompts")]
55 pub prompting_tips: String,
56
57 #[output(desc = "An optimized instruction for the language model")]
58 pub instruction: String,
59}
60
61#[derive(Clone, Debug)]
67pub struct Trace {
68 pub inputs: Example,
70 pub outputs: Prediction,
72 pub score: Option<f32>,
74}
75
76impl Trace {
77 pub fn new(inputs: Example, outputs: Prediction, score: Option<f32>) -> Self {
79 Self {
80 inputs,
81 outputs,
82 score,
83 }
84 }
85
86 pub fn format_for_prompt(&self) -> String {
88 let mut result = String::new();
89 result.push_str("Input:\n");
90
91 for (key, value) in &self.inputs.data {
92 result.push_str(&format!(" {}: {}\n", key, value));
93 }
94
95 result.push_str("Output:\n");
96 for (key, value) in &self.outputs.data {
97 result.push_str(&format!(" {}: {}\n", key, value));
98 }
99
100 if let Some(score) = self.score {
101 result.push_str(&format!("Score: {:.3}\n", score));
102 }
103
104 result
105 }
106}
107
108#[derive(Clone, Debug)]
110pub struct PromptCandidate {
111 pub instruction: String,
113 #[allow(dead_code)]
115 pub demos: Vec<Example>,
116 pub score: f32,
118}
119
120impl PromptCandidate {
121 pub fn new(instruction: String, demos: Vec<Example>) -> Self {
123 Self {
124 instruction,
125 demos,
126 score: 0.0,
127 }
128 }
129
130 pub fn with_score(mut self, score: f32) -> Self {
132 self.score = score;
133 self
134 }
135}
136
137pub struct PromptingTips {
139 pub tips: Vec<String>,
140}
141
142impl PromptingTips {
143 pub fn default_tips() -> Self {
145 Self {
146 tips: vec![
147 "Use clear and specific language".to_string(),
148 "Provide context about the task domain".to_string(),
149 "Specify the desired output format".to_string(),
150 "Use chain-of-thought reasoning for complex tasks".to_string(),
151 "Include few-shot examples when helpful".to_string(),
152 "Break down complex instructions into steps".to_string(),
153 "Use role-playing (e.g., 'You are an expert...') when appropriate".to_string(),
154 "Specify constraints and edge cases".to_string(),
155 "Request explanations or reasoning when needed".to_string(),
156 "Use structured output formats (JSON, lists, etc.) when applicable".to_string(),
157 "Consider the model's strengths and limitations".to_string(),
158 "Be explicit about what to avoid or exclude".to_string(),
159 "Use positive framing (what to do vs. what not to do)".to_string(),
160 "Provide examples of both correct and incorrect outputs when useful".to_string(),
161 "Use delimiters or markers to separate different sections".to_string(),
162 ],
163 }
164 }
165
166 pub fn format_for_prompt(&self) -> String {
168 self.tips
169 .iter()
170 .enumerate()
171 .map(|(i, tip)| format!("{}. {}", i + 1, tip))
172 .collect::<Vec<_>>()
173 .join("\n")
174 }
175}
176
177#[derive(Builder)]
186pub struct MIPROv2 {
187 #[builder(default = 10)]
189 pub num_candidates: usize,
190
191 #[builder(default = 3)]
193 pub max_bootstrapped_demos: usize,
194
195 #[builder(default = 3)]
197 pub max_labeled_demos: usize,
198
199 #[builder(default = 20)]
201 pub num_trials: usize,
202
203 #[builder(default = 25)]
205 pub minibatch_size: usize,
206
207 #[builder(default = 1.0)]
209 pub temperature: f32,
210
211 pub prompt_model: Option<LM>,
213
214 #[builder(default = true)]
216 pub track_stats: bool,
217
218 pub seed: Option<u64>,
220}
221
222impl MIPROv2 {
223 async fn generate_traces<M>(&self, module: &M, examples: &[Example]) -> Result<Vec<Trace>>
229 where
230 M: Module + Evaluator,
231 {
232 let mut traces = Vec::with_capacity(examples.len());
233
234 println!(
235 "Stage 1: Generating traces from {} examples",
236 examples.len()
237 );
238
239 for (idx, example) in examples.iter().enumerate() {
240 if idx % 10 == 0 {
241 println!(" Processing example {}/{}", idx + 1, examples.len());
242 }
243
244 let prediction = module
246 .forward(example.clone())
247 .await
248 .context("Failed to generate prediction for trace")?;
249
250 let score = module.metric(example, &prediction).await;
252
253 traces.push(Trace::new(example.clone(), prediction, Some(score)));
254 }
255
256 println!("Generated {} traces", traces.len());
257 Ok(traces)
258 }
259
260 pub fn select_best_traces(&self, traces: &[Trace], num_select: usize) -> Vec<Trace> {
262 let mut scored_traces: Vec<_> = traces
263 .iter()
264 .filter(|t| t.score.is_some())
265 .cloned()
266 .collect();
267
268 scored_traces.sort_by(|a, b| {
270 b.score
271 .partial_cmp(&a.score)
272 .unwrap_or(std::cmp::Ordering::Equal)
273 });
274
275 scored_traces.into_iter().take(num_select).collect()
276 }
277
278 async fn generate_program_description(
284 &self,
285 signature_desc: &str,
286 traces: &[Trace],
287 ) -> Result<String> {
288 let description_generator = Predict::new(GenerateProgramDescription::new());
289
290 let traces_str = traces
292 .iter()
293 .take(5) .map(|t| t.format_for_prompt())
295 .collect::<Vec<_>>()
296 .join("\n---\n");
297
298 let input = example! {
299 "signature_fields": "input" => signature_desc.to_string(),
300 "example_traces": "input" => traces_str,
301 };
302
303 let prediction = if let Some(mut pm) = self.prompt_model.clone() {
304 pm.temperature = 0.7;
305 description_generator
306 .forward_with_config(input, Arc::new(pm))
307 .await?
308 } else {
309 let lm = get_lm();
310 description_generator.forward_with_config(input, lm).await?
311 };
312
313 Ok(prediction
314 .data
315 .get("program_description")
316 .and_then(|v| v.as_str())
317 .unwrap_or("Generate accurate outputs for the given inputs.")
318 .to_string())
319 }
320
321 async fn generate_candidate_instructions(
323 &self,
324 program_description: &str,
325 traces: &[Trace],
326 num_candidates: usize,
327 ) -> Result<Vec<String>> {
328 let instruction_generator = Predict::new(GenerateInstructionFromTips::new());
329 let tips = PromptingTips::default_tips();
330
331 let traces_str = traces
333 .iter()
334 .take(8)
335 .map(|t| t.format_for_prompt())
336 .collect::<Vec<_>>()
337 .join("\n---\n");
338
339 println!(
340 "Stage 2: Generating {} candidate instructions",
341 num_candidates
342 );
343
344 let mut candidates = Vec::new();
345
346 for i in 0..num_candidates {
348 let input = example! {
349 "program_description": "input" => program_description.to_string(),
350 "example_traces": "input" => traces_str.clone(),
351 "prompting_tips": "input" => tips.format_for_prompt(),
352 };
353
354 let result = if let Some(mut pm) = self.prompt_model.clone() {
355 pm.temperature = self.temperature;
356 instruction_generator
357 .forward_with_config(input, Arc::new(pm))
358 .await
359 } else {
360 let lm = get_lm();
361 instruction_generator.forward_with_config(input, lm).await
362 };
363
364 if let Ok(pred) = result
365 && let Some(instruction) = pred.data.get("instruction").and_then(|v| v.as_str())
366 {
367 candidates.push(instruction.to_string());
368 }
369
370 if (i + 1) % 3 == 0 || i == num_candidates - 1 {
371 println!(
372 " Generated {}/{} candidates",
373 candidates.len(),
374 num_candidates
375 );
376 }
377 }
378
379 println!(
380 "Generated {} total candidate instructions",
381 candidates.len()
382 );
383 Ok(candidates)
384 }
385
386 pub fn create_prompt_candidates(
388 &self,
389 instructions: Vec<String>,
390 traces: &[Trace],
391 ) -> Vec<PromptCandidate> {
392 let best_traces = self.select_best_traces(traces, self.max_labeled_demos);
393 let demo_examples: Vec<Example> = best_traces.into_iter().map(|t| t.inputs).collect();
394
395 instructions
396 .into_iter()
397 .map(|inst| PromptCandidate::new(inst, demo_examples.clone()))
398 .collect()
399 }
400
401 async fn evaluate_candidate<M>(
407 &self,
408 module: &mut M,
409 candidate: &PromptCandidate,
410 eval_examples: &[Example],
411 predictor_name: &str,
412 ) -> Result<f32>
413 where
414 M: Module + Optimizable + Evaluator,
415 {
416 {
418 let mut params = module.parameters();
419 if let Some(predictor) = params.get_mut(predictor_name) {
420 predictor.update_signature_instruction(candidate.instruction.clone())?;
421
422 }
425 }
426
427 let minibatch: Vec<Example> = eval_examples
429 .iter()
430 .take(self.minibatch_size)
431 .cloned()
432 .collect();
433
434 let score = module.evaluate(minibatch).await;
435 Ok(score)
436 }
437
438 async fn evaluate_and_select_best<M>(
440 &self,
441 module: &mut M,
442 candidates: Vec<PromptCandidate>,
443 eval_examples: &[Example],
444 predictor_name: &str,
445 ) -> Result<PromptCandidate>
446 where
447 M: Module + Optimizable + Evaluator,
448 {
449 println!(
450 "Stage 3: Evaluating {} candidates on minibatch of {} examples",
451 candidates.len(),
452 self.minibatch_size.min(eval_examples.len())
453 );
454
455 let mut evaluated_candidates = Vec::new();
456
457 for (idx, candidate) in candidates.into_iter().enumerate() {
458 println!(" Evaluating candidate {}/{}", idx + 1, self.num_candidates);
459
460 let score = self
461 .evaluate_candidate(module, &candidate, eval_examples, predictor_name)
462 .await?;
463
464 evaluated_candidates.push(candidate.with_score(score));
465
466 if self.track_stats {
467 println!(" Score: {:.3}", score);
468 }
469 }
470
471 let best = evaluated_candidates
473 .into_iter()
474 .max_by(|a, b| {
475 a.score
476 .partial_cmp(&b.score)
477 .unwrap_or(std::cmp::Ordering::Equal)
478 })
479 .context("No candidates to evaluate")?;
480
481 println!("Best candidate score: {:.3}", best.score);
482 Ok(best)
483 }
484
485 pub fn format_signature_fields(&self, signature: &dyn crate::core::MetaSignature) -> String {
491 let mut result = String::new();
492
493 result.push_str("Input Fields:\n");
494 if let Some(obj) = signature.input_fields().as_object() {
495 for (name, field) in obj {
496 let desc = field
497 .get("desc")
498 .and_then(|v| v.as_str())
499 .unwrap_or("No description");
500 result.push_str(&format!(" - {}: {}\n", name, desc));
501 }
502 }
503
504 result.push_str("\nOutput Fields:\n");
505 if let Some(obj) = signature.output_fields().as_object() {
506 for (name, field) in obj {
507 let desc = field
508 .get("desc")
509 .and_then(|v| v.as_str())
510 .unwrap_or("No description");
511 result.push_str(&format!(" - {}: {}\n", name, desc));
512 }
513 }
514
515 result
516 }
517}
518
519impl Optimizer for MIPROv2 {
524 async fn compile<M>(&self, module: &mut M, trainset: Vec<Example>) -> Result<()>
525 where
526 M: Module + Optimizable + Evaluator,
527 {
528 println!("\n=== MIPROv2 Optimization Started ===");
529 println!("Configuration:");
530 println!(" Candidates: {}", self.num_candidates);
531 println!(" Trials: {}", self.num_trials);
532 println!(" Minibatch size: {}", self.minibatch_size);
533 println!(" Training examples: {}", trainset.len());
534
535 let predictor_names: Vec<String> = module.parameters().keys().cloned().collect();
537
538 if predictor_names.is_empty() {
539 return Err(anyhow::anyhow!("No optimizable parameters found in module"));
540 }
541
542 println!(
543 " Optimizing {} predictor(s): {:?}\n",
544 predictor_names.len(),
545 predictor_names
546 );
547
548 for predictor_name in predictor_names {
550 println!("--- Optimizing predictor: {} ---", predictor_name);
551
552 let signature_desc = {
554 let params = module.parameters();
555 if let Some(predictor) = params.get(&predictor_name) {
556 self.format_signature_fields(predictor.get_signature())
557 } else {
558 continue;
559 }
560 };
561
562 let traces = self.generate_traces(module, &trainset).await?;
564
565 let program_description = self
567 .generate_program_description(&signature_desc, &traces)
568 .await?;
569
570 println!("Generated program description: {}", program_description);
571
572 let instructions = self
573 .generate_candidate_instructions(&program_description, &traces, self.num_candidates)
574 .await?;
575
576 let candidates = self.create_prompt_candidates(instructions, &traces);
577
578 let best_candidate = self
580 .evaluate_and_select_best(module, candidates, &trainset, &predictor_name)
581 .await?;
582
583 {
585 let mut params = module.parameters();
586 if let Some(predictor) = params.get_mut(&predictor_name) {
587 predictor.update_signature_instruction(best_candidate.instruction.clone())?;
588 }
591 }
592
593 println!(
594 "✓ Optimized {} with score {:.3}",
595 predictor_name, best_candidate.score
596 );
597 println!(" Instruction: {}\n", best_candidate.instruction);
598 }
599
600 println!("=== MIPROv2 Optimization Complete ===\n");
601 Ok(())
602 }
603}