dspy_rs/predictors/
predict.rs1use indexmap::IndexMap;
2use rig::tool::ToolDyn;
3use std::sync::Arc;
4
5use crate::core::{MetaSignature, Optimizable};
6use crate::{ChatAdapter, Example, GLOBAL_SETTINGS, LM, Prediction, adapter::Adapter};
7
8pub struct Predict {
9 pub signature: Box<dyn MetaSignature>,
10 pub tools: Vec<Arc<dyn ToolDyn>>,
11}
12
13impl Predict {
14 pub fn new(signature: impl MetaSignature + 'static) -> Self {
15 Self {
16 signature: Box::new(signature),
17 tools: vec![],
18 }
19 }
20
21 pub fn new_with_tools(
22 signature: impl MetaSignature + 'static,
23 tools: Vec<Box<dyn ToolDyn>>,
24 ) -> Self {
25 Self {
26 signature: Box::new(signature),
27 tools: tools.into_iter().map(Arc::from).collect(),
28 }
29 }
30
31 pub fn with_tools(mut self, tools: Vec<Box<dyn ToolDyn>>) -> Self {
32 self.tools = tools.into_iter().map(Arc::from).collect();
33 self
34 }
35
36 pub fn add_tool(mut self, tool: Box<dyn ToolDyn>) -> Self {
37 self.tools.push(Arc::from(tool));
38 self
39 }
40}
41
42impl super::Predictor for Predict {
43 async fn forward(&self, inputs: Example) -> anyhow::Result<Prediction> {
44 let (adapter, lm) = {
45 let guard = GLOBAL_SETTINGS.read().unwrap();
46 let settings = guard.as_ref().unwrap();
47 (settings.adapter.clone(), Arc::clone(&settings.lm))
48 }; adapter
50 .call(lm, self.signature.as_ref(), inputs, self.tools.clone())
51 .await
52 }
53
54 async fn forward_with_config(
55 &self,
56 inputs: Example,
57 lm: Arc<LM>,
58 ) -> anyhow::Result<Prediction> {
59 ChatAdapter
60 .call(lm, self.signature.as_ref(), inputs, self.tools.clone())
61 .await
62 }
63}
64
65impl Optimizable for Predict {
66 fn get_signature(&self) -> &dyn MetaSignature {
67 self.signature.as_ref()
68 }
69
70 fn parameters(&mut self) -> IndexMap<String, &mut dyn Optimizable> {
71 IndexMap::new()
72 }
73
74 fn update_signature_instruction(&mut self, instruction: String) -> anyhow::Result<()> {
75 let _ = self.signature.update_instruction(instruction);
76 Ok(())
77 }
78}