dspy_rs/predictors/
mod.rs

1pub mod predict;
2
3pub use predict::*;
4
5use crate::{Example, LM, LmUsage, Prediction};
6use anyhow::Result;
7use futures::future::join_all;
8use std::sync::Arc;
9
10#[allow(async_fn_in_trait)]
11pub trait Predictor: Send + Sync {
12    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction>;
13    async fn forward_with_config(&self, inputs: Example, lm: Arc<LM>)
14    -> anyhow::Result<Prediction>;
15
16    async fn batch(&self, inputs: Vec<Example>) -> Result<Vec<Prediction>> {
17        let futures: Vec<_> = inputs
18            .iter()
19            .map(|input| self.forward(input.clone()))
20            .collect();
21        let predictions = join_all(futures)
22            .await
23            .into_iter()
24            .collect::<Result<Vec<Prediction>>>()?;
25        Ok(predictions)
26    }
27
28    async fn batch_with_config(
29        &self,
30        inputs: Vec<Example>,
31        lm: Arc<LM>,
32    ) -> Result<Vec<Prediction>> {
33        let futures: Vec<_> = inputs
34            .iter()
35            .map(|input| self.forward_with_config(input.clone(), lm.clone()))
36            .collect();
37        let predictions = join_all(futures)
38            .await
39            .into_iter()
40            .collect::<Result<Vec<Prediction>>>()?;
41        Ok(predictions)
42    }
43}
44
45pub struct DummyPredict;
46
47impl Predictor for DummyPredict {
48    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
49        Ok(Prediction::new(inputs.data, LmUsage::default()))
50    }
51
52    #[allow(unused_variables)]
53    async fn forward_with_config(
54        &self,
55        inputs: Example,
56        lm: Arc<LM>,
57    ) -> anyhow::Result<Prediction> {
58        Ok(Prediction::new(inputs.data, LmUsage::default()))
59    }
60}