dspy_rs/predictors/
predict.rs

1use indexmap::IndexMap;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4
5use crate::core::{MetaSignature, Optimizable};
6use crate::{ChatAdapter, Example, GLOBAL_SETTINGS, LM, Prediction, adapter::Adapter};
7
8pub struct Predict {
9    pub signature: Box<dyn MetaSignature>,
10}
11
12impl Predict {
13    pub fn new(signature: impl MetaSignature + 'static) -> Self {
14        Self {
15            signature: Box::new(signature),
16        }
17    }
18}
19
20impl super::Predictor for Predict {
21    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
22        let (adapter, lm) = {
23            let guard = GLOBAL_SETTINGS.read().unwrap();
24            let settings = guard.as_ref().unwrap();
25            (settings.adapter.clone(), Arc::clone(&settings.lm))
26        }; // guard is dropped here
27        adapter.call(lm, self.signature.as_ref(), inputs).await
28    }
29
30    async fn forward_with_config(
31        &self,
32        inputs: Example,
33        lm: Arc<Mutex<LM>>,
34    ) -> anyhow::Result<Prediction> {
35        ChatAdapter.call(lm, self.signature.as_ref(), inputs).await
36    }
37}
38
39impl Optimizable for Predict {
40    fn get_signature(&self) -> &dyn MetaSignature {
41        self.signature.as_ref()
42    }
43
44    fn parameters(&mut self) -> IndexMap<String, &mut (dyn Optimizable)> {
45        IndexMap::new()
46    }
47
48    fn update_signature_instruction(&mut self, instruction: String) -> anyhow::Result<()> {
49        let _ = self.signature.update_instruction(instruction);
50        Ok(())
51    }
52}