dspy_rs/core/lm/
mod.rs

1pub mod chat;
2pub mod client_registry;
3pub mod config;
4pub mod usage;
5
6pub use chat::*;
7pub use client_registry::*;
8pub use config::*;
9pub use usage::*;
10
11use anyhow::Result;
12use rig::completion::AssistantContent;
13
14use bon::Builder;
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::Mutex;
17
18use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
19
20#[derive(Clone, Debug)]
21pub struct LMResponse {
22    /// Assistant message chosen by the provider.
23    pub output: Message,
24    /// Token usage reported by the provider for this call.
25    pub usage: LmUsage,
26    /// Chat history including the freshly appended assistant response.
27    pub chat: Chat,
28}
29
30pub struct LM {
31    pub config: LMConfig,
32    client: Arc<LMClient>,
33    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
34}
35
36impl Default for LM {
37    fn default() -> Self {
38        // Use a blocking tokio runtime to call the async new function
39        tokio::runtime::Runtime::new()
40            .expect("Failed to create tokio runtime")
41            .block_on(Self::new(LMConfig::default()))
42    }
43}
44
45impl Clone for LM {
46    fn clone(&self) -> Self {
47        Self {
48            config: self.config.clone(),
49            client: self.client.clone(),
50            cache_handler: self.cache_handler.clone(),
51        }
52    }
53}
54
55impl LM {
56    /// Creates a new LM with the given configuration.
57    /// Uses enum dispatch for optimal runtime performance.
58    ///
59    /// This is an async function because it initializes the cache handler when
60    /// `config.cache` is `true`. For synchronous contexts where cache initialization
61    /// is not needed, use `new_sync` instead.
62    pub async fn new(config: LMConfig) -> Self {
63        let client = LMClient::from_model_string(&config.model)
64            .expect("Failed to create client from model string");
65
66        let cache_handler = if config.cache {
67            Some(Arc::new(Mutex::new(ResponseCache::new().await)))
68        } else {
69            None
70        };
71
72        Self {
73            config,
74            client: Arc::new(client),
75            cache_handler,
76        }
77    }
78
79    /// Executes a chat completion against the configured provider.
80    ///
81    /// `messages` must already be formatted as OpenAI-compatible chat turns.
82    /// The call returns an [`LMResponse`] containing the assistant output,
83    /// token usage, and chat history including the new response.
84    pub async fn call(&self, messages: Chat) -> Result<LMResponse> {
85        use rig::OneOrMany;
86        use rig::completion::CompletionRequest;
87
88        let request_messages = messages.get_rig_messages();
89
90        // Build the completion request manually
91        let mut chat_history = request_messages.conversation;
92        chat_history.push(request_messages.prompt);
93
94        let request = CompletionRequest {
95            preamble: Some(request_messages.system),
96            chat_history: if chat_history.len() == 1 {
97                OneOrMany::one(chat_history.into_iter().next().unwrap())
98            } else {
99                OneOrMany::many(chat_history).expect("chat_history should not be empty")
100            },
101            documents: Vec::new(),
102            tools: Vec::new(),
103            temperature: Some(self.config.temperature as f64),
104            max_tokens: Some(self.config.max_tokens as u64),
105            tool_choice: None,
106            additional_params: None,
107        };
108
109        // Execute the completion using enum dispatch (zero-cost abstraction)
110        let response = self.client.completion(request).await?;
111
112        let first_choice = match response.choice.first() {
113            AssistantContent::Text(text) => Message::assistant(&text.text),
114            AssistantContent::Reasoning(reasoning) => {
115                Message::assistant(reasoning.reasoning.join("\n"))
116            }
117            AssistantContent::ToolCall(_tool_call) => {
118                todo!()
119            }
120        };
121
122        let usage = LmUsage::from(response.usage);
123
124        let mut full_chat = messages.clone();
125        full_chat.push_message(first_choice.clone());
126
127        Ok(LMResponse {
128            output: first_choice,
129            usage,
130            chat: full_chat,
131        })
132    }
133
134    /// Returns the `n` most recent cached calls.
135    ///
136    /// Panics if caching is disabled for this `LM`.
137    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
138        self.cache_handler
139            .as_ref()
140            .unwrap()
141            .lock()
142            .await
143            .get_history(n)
144            .await
145            .unwrap()
146    }
147}
148
149/// In-memory LM used for deterministic tests and examples.
150#[derive(Clone, Builder, Default)]
151pub struct DummyLM {
152    pub api_key: String,
153    #[builder(default = "https://api.openai.com/v1".to_string())]
154    pub base_url: String,
155    /// Static configuration applied to stubbed responses.
156    #[builder(default = LMConfig::default())]
157    pub config: LMConfig,
158    /// Cache backing storage shared with the real implementation.
159    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
160}
161
162impl DummyLM {
163    /// Creates a new [`DummyLM`] with an enabled in-memory cache.
164    pub async fn new() -> Self {
165        let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
166        Self {
167            api_key: "".into(),
168            base_url: "https://api.openai.com/v1".to_string(),
169            config: LMConfig::default(),
170            cache_handler: Some(cache_handler),
171        }
172    }
173
174    /// Mimics [`LM::call`] without hitting a remote provider.
175    ///
176    /// The provided `prediction` becomes the assistant output and is inserted
177    /// into the shared cache when caching is enabled.
178    pub async fn call(
179        &self,
180        example: Example,
181        messages: Chat,
182        prediction: String,
183    ) -> Result<LMResponse> {
184        let mut full_chat = messages.clone();
185        full_chat.push_message(Message::Assistant {
186            content: prediction.clone(),
187        });
188
189        if self.config.cache
190            && let Some(cache) = self.cache_handler.as_ref()
191        {
192            let (tx, rx) = tokio::sync::mpsc::channel(1);
193            let cache_clone = cache.clone();
194            let example_clone = example.clone();
195
196            // Spawn the cache insert operation to avoid deadlock
197            tokio::spawn(async move {
198                let _ = cache_clone.lock().await.insert(example_clone, rx).await;
199            });
200
201            // Send the result to the cache
202            tx.send(CallResult {
203                prompt: messages.to_json().to_string(),
204                prediction: Prediction::new(
205                    HashMap::from([("prediction".to_string(), prediction.clone().into())]),
206                    LmUsage::default(),
207                ),
208            })
209            .await
210            .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
211        }
212
213        Ok(LMResponse {
214            output: Message::Assistant {
215                content: prediction.clone(),
216            },
217            usage: LmUsage::default(),
218            chat: full_chat,
219        })
220    }
221
222    /// Returns cached entries just like [`LM::inspect_history`].
223    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
224        self.cache_handler
225            .as_ref()
226            .unwrap()
227            .lock()
228            .await
229            .get_history(n)
230            .await
231            .unwrap()
232    }
233}