dspy_rs/predictors/
mod.rs1pub mod predict;
2
3pub use predict::*;
4
5use crate::{Example, LM, LmUsage, Prediction};
6use anyhow::Result;
7use futures::stream::{self, StreamExt};
8use std::sync::Arc;
9
10#[allow(async_fn_in_trait)]
11pub trait Predictor: Send + Sync {
12 async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction>;
13 async fn forward_with_config(&self, inputs: Example, lm: Arc<LM>)
14 -> anyhow::Result<Prediction>;
15
16 async fn batch(&self, inputs: Vec<Example>) -> Result<Vec<Prediction>> {
17 let indexed_results: Vec<(usize, Result<Prediction>)> =
18 stream::iter(inputs.into_iter().enumerate())
19 .map(|(idx, input)| async move {
20 let result = self.forward(input).await;
21 (idx, result)
22 })
23 .buffer_unordered(32) .collect()
25 .await;
26
27 let mut indexed_results = indexed_results;
29 indexed_results.sort_by_key(|(idx, _)| *idx);
30
31 let mut predictions = Vec::with_capacity(indexed_results.len());
33 for (_, result) in indexed_results {
34 predictions.push(result?);
35 }
36 Ok(predictions)
37 }
38
39 async fn batch_with_config(
40 &self,
41 inputs: Vec<Example>,
42 lm: Arc<LM>,
43 ) -> Result<Vec<Prediction>> {
44 let lm_ref = lm.clone();
45 let indexed_results: Vec<(usize, Result<Prediction>)> =
46 stream::iter(inputs.into_iter().enumerate())
47 .map(|(idx, input)| {
48 let lm_clone = lm_ref.clone();
49 async move {
50 let result = self.forward_with_config(input, lm_clone).await;
51 (idx, result)
52 }
53 })
54 .buffer_unordered(32) .collect()
56 .await;
57
58 let mut indexed_results = indexed_results;
60 indexed_results.sort_by_key(|(idx, _)| *idx);
61
62 let mut predictions = Vec::with_capacity(indexed_results.len());
64 for (_, result) in indexed_results {
65 predictions.push(result?);
66 }
67 Ok(predictions)
68 }
69}
70
71pub struct DummyPredict;
72
73impl Predictor for DummyPredict {
74 async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
75 Ok(Prediction::new(inputs.data, LmUsage::default()))
76 }
77
78 #[allow(unused_variables)]
79 async fn forward_with_config(
80 &self,
81 inputs: Example,
82 lm: Arc<LM>,
83 ) -> anyhow::Result<Prediction> {
84 Ok(Prediction::new(inputs.data, LmUsage::default()))
85 }
86}