dspy_rs/predictors/
predict.rs

1use indexmap::IndexMap;
2use rig::tool::ToolDyn;
3use std::sync::Arc;
4
5use crate::core::{MetaSignature, Optimizable};
6use crate::{ChatAdapter, Example, GLOBAL_SETTINGS, LM, Prediction, adapter::Adapter};
7
8pub struct Predict {
9    pub signature: Box<dyn MetaSignature>,
10    pub tools: Vec<Arc<dyn ToolDyn>>,
11}
12
13impl Predict {
14    pub fn new(signature: impl MetaSignature + 'static) -> Self {
15        Self {
16            signature: Box::new(signature),
17            tools: vec![],
18        }
19    }
20
21    pub fn new_with_tools(
22        signature: impl MetaSignature + 'static,
23        tools: Vec<Box<dyn ToolDyn>>,
24    ) -> Self {
25        Self {
26            signature: Box::new(signature),
27            tools: tools.into_iter().map(Arc::from).collect(),
28        }
29    }
30
31    pub fn with_tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
32        self.tools = tools.into_iter().map(Arc::from).collect();
33        self
34    }
35
36    pub fn add_tool(mut self, tool: Box<dyn ToolDyn>) -> Self {
37        self.tools.push(Arc::from(tool));
38        self
39    }
40}
41
42impl super::Predictor for Predict {
43    async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
44        let (adapter, lm) = {
45            let guard = GLOBAL_SETTINGS.read().unwrap();
46            let settings = guard.as_ref().unwrap();
47            (settings.adapter.clone(), Arc::clone(&settings.lm))
48        }; // guard is dropped here
49        adapter
50            .call(lm, self.signature.as_ref(), inputs, self.tools.clone())
51            .await
52    }
53
54    async fn forward_with_config(
55        &self,
56        inputs: Example,
57        lm: Arc<LM>,
58    ) -> anyhow::Result<Prediction> {
59        ChatAdapter
60            .call(lm, self.signature.as_ref(), inputs, self.tools.clone())
61            .await
62    }
63}
64
65impl Optimizable for Predict {
66    fn get_signature(&self) -> &dyn MetaSignature {
67        self.signature.as_ref()
68    }
69
70    fn parameters(&mut self) -> IndexMap<String, &mut dyn Optimizable> {
71        IndexMap::new()
72    }
73
74    fn update_signature_instruction(&mut self, instruction: String) -> anyhow::Result<()> {
75        let _ = self.signature.update_instruction(instruction);
76        Ok(())
77    }
78}