dspy_rs/predictors/
mod.rs1pub mod predict;
2
3pub use predict::*;
4
5use crate::{Example, LM, LmUsage, Prediction};
6use anyhow::Result;
7use futures::future::join_all;
8use std::sync::Arc;
9
10#[allow(async_fn_in_trait)]
11pub trait Predictor: Send + Sync {
12 async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction>;
13 async fn forward_with_config(&self, inputs: Example, lm: Arc<LM>)
14 -> anyhow::Result<Prediction>;
15
16 async fn batch(&self, inputs: Vec<Example>) -> Result<Vec<Prediction>> {
17 let futures: Vec<_> = inputs
18 .iter()
19 .map(|input| self.forward(input.clone()))
20 .collect();
21 let predictions = join_all(futures)
22 .await
23 .into_iter()
24 .collect::<Result<Vec<Prediction>>>()?;
25 Ok(predictions)
26 }
27
28 async fn batch_with_config(
29 &self,
30 inputs: Vec<Example>,
31 lm: Arc<LM>,
32 ) -> Result<Vec<Prediction>> {
33 let futures: Vec<_> = inputs
34 .iter()
35 .map(|input| self.forward_with_config(input.clone(), lm.clone()))
36 .collect();
37 let predictions = join_all(futures)
38 .await
39 .into_iter()
40 .collect::<Result<Vec<Prediction>>>()?;
41 Ok(predictions)
42 }
43}
44
45pub struct DummyPredict;
46
47impl Predictor for DummyPredict {
48 async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
49 Ok(Prediction::new(inputs.data, LmUsage::default()))
50 }
51
52 #[allow(unused_variables)]
53 async fn forward_with_config(
54 &self,
55 inputs: Example,
56 lm: Arc<LM>,
57 ) -> anyhow::Result<Prediction> {
58 Ok(Prediction::new(inputs.data, LmUsage::default()))
59 }
60}