dspy_rs/predictors/
mod.rs

1pub mod predict;
2
3pub use predict::*;
4
5use crate::{Example, LM, LmUsage, Prediction};
6use anyhow::Result;
7use futures::stream::{self, StreamExt};
8use std::sync::Arc;
9
10#[allow(async_fn_in_trait)]
11pub trait Predictor: Send + Sync {
12    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction>;
13    async fn forward_with_config(&self, inputs: Example, lm: Arc<LM>)
14    -> anyhow::Result<Prediction>;
15
16    async fn batch(&self, inputs: Vec<Example>) -> Result<Vec<Prediction>> {
17        let indexed_results: Vec<(usize, Result<Prediction>)> =
18            stream::iter(inputs.into_iter().enumerate())
19                .map(|(idx, input)| async move {
20                    let result = self.forward(input).await;
21                    (idx, result)
22                })
23                .buffer_unordered(32) // Match MAX_CONCURRENCY from Evaluator
24                .collect()
25                .await;
26
27        // Sort results back to original order
28        let mut indexed_results = indexed_results;
29        indexed_results.sort_by_key(|(idx, _)| *idx);
30
31        // Collect predictions and handle errors
32        let mut predictions = Vec::with_capacity(indexed_results.len());
33        for (_, result) in indexed_results {
34            predictions.push(result?);
35        }
36        Ok(predictions)
37    }
38
39    async fn batch_with_config(
40        &self,
41        inputs: Vec<Example>,
42        lm: Arc<LM>,
43    ) -> Result<Vec<Prediction>> {
44        let lm_ref = lm.clone();
45        let indexed_results: Vec<(usize, Result<Prediction>)> =
46            stream::iter(inputs.into_iter().enumerate())
47                .map(|(idx, input)| {
48                    let lm_clone = lm_ref.clone();
49                    async move {
50                        let result = self.forward_with_config(input, lm_clone).await;
51                        (idx, result)
52                    }
53                })
54                .buffer_unordered(32) // Match MAX_CONCURRENCY from Evaluator
55                .collect()
56                .await;
57
58        // Sort results back to original order
59        let mut indexed_results = indexed_results;
60        indexed_results.sort_by_key(|(idx, _)| *idx);
61
62        // Collect predictions and handle errors
63        let mut predictions = Vec::with_capacity(indexed_results.len());
64        for (_, result) in indexed_results {
65            predictions.push(result?);
66        }
67        Ok(predictions)
68    }
69}
70
71pub struct DummyPredict;
72
73impl Predictor for DummyPredict {
74    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
75        Ok(Prediction::new(inputs.data, LmUsage::default()))
76    }
77
78    #[allow(unused_variables)]
79    async fn forward_with_config(
80        &self,
81        inputs: Example,
82        lm: Arc<LM>,
83    ) -> anyhow::Result<Prediction> {
84        Ok(Prediction::new(inputs.data, LmUsage::default()))
85    }
86}