dspy_rs/core/lm/
mod.rs

1pub mod chat;
2pub mod client_registry;
3pub mod usage;
4
5pub use chat::*;
6pub use client_registry::*;
7pub use usage::*;
8
9use anyhow::Result;
10use rig::completion::AssistantContent;
11
12use bon::Builder;
13use std::{collections::HashMap, sync::Arc};
14use tokio::sync::Mutex;
15
16use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
17
18#[derive(Clone, Debug)]
19pub struct LMResponse {
20    /// Assistant message chosen by the provider.
21    pub output: Message,
22    /// Token usage reported by the provider for this call.
23    pub usage: LmUsage,
24    /// Chat history including the freshly appended assistant response.
25    pub chat: Chat,
26}
27
28#[derive(Builder)]
29#[builder(finish_fn(vis = "", name = __internal_build))]
30pub struct LM {
31    pub base_url: Option<String>,
32    pub api_key: Option<String>,
33    #[builder(default = "openai:gpt-4o-mini".to_string())]
34    pub model: String,
35    #[builder(default = 0.7)]
36    pub temperature: f32,
37    #[builder(default = 512)]
38    pub max_tokens: u32,
39    #[builder(default = true)]
40    pub cache: bool,
41    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
42    #[builder(skip)]
43    client: Option<Arc<LMClient>>,
44}
45
46impl Default for LM {
47    fn default() -> Self {
48        tokio::runtime::Handle::current().block_on(async { Self::builder().build().await.unwrap() })
49    }
50}
51
52impl Clone for LM {
53    fn clone(&self) -> Self {
54        Self {
55            base_url: self.base_url.clone(),
56            api_key: self.api_key.clone(),
57            model: self.model.clone(),
58            temperature: self.temperature,
59            max_tokens: self.max_tokens,
60            cache: self.cache,
61            cache_handler: self.cache_handler.clone(),
62            client: self.client.clone(),
63        }
64    }
65}
66
67impl LM {
68    /// Finalizes construction of an [`LM`], initializing the HTTP client and
69    /// optional response cache based on provided parameters.
70    ///
71    /// Supports 3 build cases:
72    /// 1. OpenAI-compatible with auth: `base_url` + `api_key` provided
73    ///    → Uses OpenAI client with custom base URL
74    /// 2. Local OpenAI-compatible: `base_url` only (no `api_key`)
75    ///    → Uses OpenAI client for vLLM/local servers (dummy key)
76    /// 3. Provider via model string: no `base_url`, model in "provider:model" format
77    ///    → Uses provider-specific client (openai, anthropic, gemini, etc.)
78    async fn initialize_client(mut self) -> Result<Self> {
79        // Determine which build case based on what's provided
80        let client = match (&self.base_url, &self.api_key, &self.model) {
81            // Case 1: OpenAI-compatible with authentication (base_url + api_key)
82            // For custom OpenAI-compatible APIs that require API keys
83            (Some(base_url), Some(api_key), _) => Arc::new(LMClient::from_openai_compatible(
84                base_url,
85                api_key,
86                &self.model,
87            )?),
88            // Case 2: Local OpenAI-compatible server (base_url only, no api_key)
89            // For vLLM, text-generation-inference, and other local OpenAI-compatible servers
90            (Some(base_url), None, _) => Arc::new(LMClient::from_local(base_url, &self.model)?),
91            // Case 3: Provider via model string (no base_url, model in "provider:model" format)
92            // Uses provider-specific clients
93            (None, api_key, model) if model.contains(':') => {
94                Arc::new(LMClient::from_model_string(model, api_key.as_deref())?)
95            }
96            // Default case: assume OpenAI provider if no colon in model name
97            (None, api_key, model) => {
98                let model_str = if model.contains(':') {
99                    model.to_string()
100                } else {
101                    format!("openai:{}", model)
102                };
103                Arc::new(LMClient::from_model_string(&model_str, api_key.as_deref())?)
104            }
105        };
106
107        self.client = Some(client);
108
109        // Initialize cache if enabled
110        if self.cache && self.cache_handler.is_none() {
111            self.cache_handler = Some(Arc::new(Mutex::new(ResponseCache::new().await)));
112        }
113
114        Ok(self)
115    }
116}
117
118// Implement build() for all builder states since optional fields don't require setting
119impl<S: l_m_builder::State> LMBuilder<S> {
120    /// Builds the LM instance with proper client initialization
121    ///
122    /// Supports 3 build cases:
123    /// 1. OpenAI-compatible with auth: `base_url` + `api_key` provided
124    /// 2. Local OpenAI-compatible: `base_url` only (for vLLM, etc.)
125    /// 3. Provider via model string: model in "provider:model" format
126    pub async fn build(self) -> Result<LM> {
127        let lm = self.__internal_build();
128        lm.initialize_client().await
129    }
130}
131
132impl LM {
133    /// Executes a chat completion against the configured provider.
134    ///
135    /// `messages` must already be formatted as OpenAI-compatible chat turns.
136    /// The call returns an [`LMResponse`] containing the assistant output,
137    /// token usage, and chat history including the new response.
138    pub async fn call(&self, messages: Chat) -> Result<LMResponse> {
139        use rig::OneOrMany;
140        use rig::completion::CompletionRequest;
141
142        let request_messages = messages.get_rig_messages();
143
144        // Build the completion request manually
145        let mut chat_history = request_messages.conversation;
146        chat_history.push(request_messages.prompt);
147
148        let request = CompletionRequest {
149            preamble: Some(request_messages.system),
150            chat_history: if chat_history.len() == 1 {
151                OneOrMany::one(chat_history.into_iter().next().unwrap())
152            } else {
153                OneOrMany::many(chat_history).expect("chat_history should not be empty")
154            },
155            documents: Vec::new(),
156            tools: Vec::new(),
157            temperature: Some(self.temperature as f64),
158            max_tokens: Some(self.max_tokens as u64),
159            tool_choice: None,
160            additional_params: None,
161        };
162
163        // Execute the completion using enum dispatch (zero-cost abstraction)
164        let response = self
165            .client
166            .as_ref()
167            .ok_or_else(|| {
168                anyhow::anyhow!("LM client not initialized. Call build() on LMBuilder.")
169            })?
170            .completion(request)
171            .await?;
172
173        let first_choice = match response.choice.first() {
174            AssistantContent::Text(text) => Message::assistant(&text.text),
175            AssistantContent::Reasoning(reasoning) => {
176                Message::assistant(reasoning.reasoning.join("\n"))
177            }
178            AssistantContent::ToolCall(_tool_call) => {
179                todo!()
180            }
181        };
182
183        let usage = LmUsage::from(response.usage);
184
185        let mut full_chat = messages.clone();
186        full_chat.push_message(first_choice.clone());
187
188        Ok(LMResponse {
189            output: first_choice,
190            usage,
191            chat: full_chat,
192        })
193    }
194
195    /// Returns the `n` most recent cached calls.
196    ///
197    /// Panics if caching is disabled for this `LM`.
198    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
199        self.cache_handler
200            .as_ref()
201            .unwrap()
202            .lock()
203            .await
204            .get_history(n)
205            .await
206            .unwrap()
207    }
208}
209
210/// In-memory LM used for deterministic tests and examples.
211#[derive(Clone, Builder, Default)]
212pub struct DummyLM {
213    pub api_key: String,
214    #[builder(default = "https://api.openai.com/v1".to_string())]
215    pub base_url: String,
216    #[builder(default = 0.7)]
217    pub temperature: f32,
218    #[builder(default = 512)]
219    pub max_tokens: u32,
220    #[builder(default = true)]
221    pub cache: bool,
222    /// Cache backing storage shared with the real implementation.
223    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
224}
225
226impl DummyLM {
227    /// Creates a new [`DummyLM`] with an enabled in-memory cache.
228    pub async fn new() -> Self {
229        let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
230        Self {
231            api_key: "".into(),
232            base_url: "https://api.openai.com/v1".to_string(),
233            temperature: 0.7,
234            max_tokens: 512,
235            cache: true,
236            cache_handler: Some(cache_handler),
237        }
238    }
239
240    /// Mimics [`LM::call`] without hitting a remote provider.
241    ///
242    /// The provided `prediction` becomes the assistant output and is inserted
243    /// into the shared cache when caching is enabled.
244    pub async fn call(
245        &self,
246        example: Example,
247        messages: Chat,
248        prediction: String,
249    ) -> Result<LMResponse> {
250        let mut full_chat = messages.clone();
251        full_chat.push_message(Message::Assistant {
252            content: prediction.clone(),
253        });
254
255        if self.cache
256            && let Some(cache) = self.cache_handler.as_ref()
257        {
258            let (tx, rx) = tokio::sync::mpsc::channel(1);
259            let cache_clone = cache.clone();
260            let example_clone = example.clone();
261
262            // Spawn the cache insert operation to avoid deadlock
263            tokio::spawn(async move {
264                let _ = cache_clone.lock().await.insert(example_clone, rx).await;
265            });
266
267            // Send the result to the cache
268            tx.send(CallResult {
269                prompt: messages.to_json().to_string(),
270                prediction: Prediction::new(
271                    HashMap::from([("prediction".to_string(), prediction.clone().into())]),
272                    LmUsage::default(),
273                ),
274            })
275            .await
276            .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
277        }
278
279        Ok(LMResponse {
280            output: Message::Assistant {
281                content: prediction.clone(),
282            },
283            usage: LmUsage::default(),
284            chat: full_chat,
285        })
286    }
287
288    /// Returns cached entries just like [`LM::inspect_history`].
289    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
290        self.cache_handler
291            .as_ref()
292            .unwrap()
293            .lock()
294            .await
295            .get_history(n)
296            .await
297            .unwrap()
298    }
299}