dspy_rs/core/lm/
mod.rs

1pub mod chat;
2pub mod client_registry;
3pub mod usage;
4
5pub use chat::*;
6pub use client_registry::*;
7pub use usage::*;
8
9use anyhow::Result;
10use rig::{completion::AssistantContent, message::ToolCall, message::ToolChoice, tool::ToolDyn};
11
12use bon::Builder;
13use std::{collections::HashMap, sync::Arc};
14use tokio::sync::Mutex;
15
16use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
17
18#[derive(Clone, Debug)]
19pub struct LMResponse {
20    /// Assistant message chosen by the provider.
21    pub output: Message,
22    /// Token usage reported by the provider for this call.
23    pub usage: LmUsage,
24    /// Chat history including the freshly appended assistant response.
25    pub chat: Chat,
26    /// Tool calls made by the provider.
27    pub tool_calls: Vec<ToolCall>,
28    /// Tool executions made by the provider.
29    pub tool_executions: Vec<String>,
30}
31
32#[derive(Builder)]
33#[builder(finish_fn(vis = "", name = __internal_build))]
34pub struct LM {
35    pub base_url: Option<String>,
36    pub api_key: Option<String>,
37    #[builder(default = "openai:gpt-4o-mini".to_string())]
38    pub model: String,
39    #[builder(default = 0.7)]
40    pub temperature: f32,
41    #[builder(default = 512)]
42    pub max_tokens: u32,
43    #[builder(default = 10)]
44    pub max_tool_iterations: u32,
45    #[builder(default = false)]
46    pub cache: bool,
47    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
48    #[builder(skip)]
49    client: Option<Arc<LMClient>>,
50}
51
52impl Default for LM {
53    fn default() -> Self {
54        tokio::runtime::Handle::current().block_on(async { Self::builder().build().await.unwrap() })
55    }
56}
57
58impl Clone for LM {
59    fn clone(&self) -> Self {
60        Self {
61            base_url: self.base_url.clone(),
62            api_key: self.api_key.clone(),
63            model: self.model.clone(),
64            temperature: self.temperature,
65            max_tokens: self.max_tokens,
66            max_tool_iterations: self.max_tool_iterations,
67            cache: self.cache,
68            cache_handler: self.cache_handler.clone(),
69            client: self.client.clone(),
70        }
71    }
72}
73
74impl LM {
75    /// Finalizes construction of an [`LM`], initializing the HTTP client and
76    /// optional response cache based on provided parameters.
77    ///
78    /// Supports 3 build cases:
79    /// 1. OpenAI-compatible with auth: `base_url` + `api_key` provided
80    ///    → Uses OpenAI client with custom base URL
81    /// 2. Local OpenAI-compatible: `base_url` only (no `api_key`)
82    ///    → Uses OpenAI client for vLLM/local servers (dummy key)
83    /// 3. Provider via model string: no `base_url`, model in "provider:model" format
84    ///    → Uses provider-specific client (openai, anthropic, gemini, etc.)
85    async fn initialize_client(mut self) -> Result<Self> {
86        // Determine which build case based on what's provided
87        let client = match (&self.base_url, &self.api_key, &self.model) {
88            // Case 1: OpenAI-compatible with authentication (base_url + api_key)
89            // For custom OpenAI-compatible APIs that require API keys
90            (Some(base_url), Some(api_key), _) => Arc::new(LMClient::from_openai_compatible(
91                base_url,
92                api_key,
93                &self.model,
94            )?),
95            // Case 2: Local OpenAI-compatible server (base_url only, no api_key)
96            // For vLLM, text-generation-inference, and other local OpenAI-compatible servers
97            (Some(base_url), None, _) => Arc::new(LMClient::from_local(base_url, &self.model)?),
98            // Case 3: Provider via model string (no base_url, model in "provider:model" format)
99            // Uses provider-specific clients
100            (None, api_key, model) if model.contains(':') => {
101                Arc::new(LMClient::from_model_string(model, api_key.as_deref())?)
102            }
103            // Default case: assume OpenAI provider if no colon in model name
104            (None, api_key, model) => {
105                let model_str = if model.contains(':') {
106                    model.to_string()
107                } else {
108                    format!("openai:{}", model)
109                };
110                Arc::new(LMClient::from_model_string(&model_str, api_key.as_deref())?)
111            }
112        };
113
114        self.client = Some(client);
115
116        // Initialize cache if enabled
117        if self.cache && self.cache_handler.is_none() {
118            self.cache_handler = Some(Arc::new(Mutex::new(ResponseCache::new().await)));
119        }
120
121        Ok(self)
122    }
123
124    pub async fn with_client(self, client: LMClient) -> Result<Self> {
125        Ok(LM {
126            client: Some(Arc::new(client)),
127            ..self
128        })
129    }
130}
131
132// Implement build() for all builder states since optional fields don't require setting
133impl<S: l_m_builder::State> LMBuilder<S> {
134    /// Builds the LM instance with proper client initialization
135    ///
136    /// Supports 3 build cases:
137    /// 1. OpenAI-compatible with auth: `base_url` + `api_key` provided
138    /// 2. Local OpenAI-compatible: `base_url` only (for vLLM, etc.)
139    /// 3. Provider via model string: model in "provider:model" format
140    pub async fn build(self) -> Result<LM> {
141        let lm = self.__internal_build();
142        lm.initialize_client().await
143    }
144}
145
146struct ToolLoopResult {
147    message: Message,
148    #[allow(unused)]
149    chat_history: Vec<rig::message::Message>,
150    tool_calls: Vec<ToolCall>,
151    tool_executions: Vec<String>,
152}
153
154impl LM {
155    async fn execute_tool_loop(
156        &self,
157        initial_tool_call: &rig::message::ToolCall,
158        mut tools: Vec<Arc<dyn ToolDyn>>,
159        tool_definitions: Vec<rig::completion::ToolDefinition>,
160        mut chat_history: Vec<rig::message::Message>,
161        system_prompt: String,
162        accumulated_usage: &mut LmUsage,
163    ) -> Result<ToolLoopResult> {
164        use rig::OneOrMany;
165        use rig::completion::CompletionRequest;
166        use rig::message::UserContent;
167
168        let max_iterations = self.max_tool_iterations as usize;
169
170        let mut tool_calls = Vec::new();
171        let mut tool_executions = Vec::new();
172
173        // Execute the first tool call
174        let tool_name = &initial_tool_call.function.name;
175        let args_str = initial_tool_call.function.arguments.to_string();
176
177        let mut tool_result = format!("Tool '{}' not found", tool_name);
178        for tool in &mut tools {
179            let def = tool.definition("".to_string()).await;
180            if def.name == *tool_name {
181                // Parse args and call the tool
182                let args_json: serde_json::Value =
183                    serde_json::from_str(&args_str).unwrap_or_default();
184                tool_result = format!("Called tool {} with args: {}", tool_name, args_json);
185                tool_calls.push(initial_tool_call.clone());
186                tool_executions.push(tool_result.clone());
187                break;
188            }
189        }
190
191        // Add initial tool call and result to history
192        chat_history.push(rig::message::Message::Assistant {
193            id: None,
194            content: OneOrMany::one(rig::message::AssistantContent::ToolCall(
195                initial_tool_call.clone(),
196            )),
197        });
198
199        let tool_result_content = if let Some(call_id) = &initial_tool_call.call_id {
200            UserContent::tool_result_with_call_id(
201                initial_tool_call.id.clone(),
202                call_id.clone(),
203                OneOrMany::one(tool_result.into()),
204            )
205        } else {
206            UserContent::tool_result(
207                initial_tool_call.id.clone(),
208                OneOrMany::one(tool_result.into()),
209            )
210        };
211
212        chat_history.push(rig::message::Message::User {
213            content: OneOrMany::one(tool_result_content),
214        });
215
216        // Now loop until we get a text response
217        for _iteration in 1..max_iterations {
218            let request = CompletionRequest {
219                preamble: Some(system_prompt.clone()),
220                chat_history: if chat_history.len() == 1 {
221                    OneOrMany::one(chat_history.clone().into_iter().next().unwrap())
222                } else {
223                    OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty")
224                },
225                documents: Vec::new(),
226                tools: tool_definitions.clone(),
227                temperature: Some(self.temperature as f64),
228                max_tokens: Some(self.max_tokens as u64),
229                tool_choice: Some(ToolChoice::Auto),
230                additional_params: None,
231            };
232
233            let response = self
234                .client
235                .as_ref()
236                .ok_or_else(|| anyhow::anyhow!("LM client not initialized"))?
237                .completion(request)
238                .await?;
239
240            accumulated_usage.prompt_tokens += response.usage.input_tokens;
241            accumulated_usage.completion_tokens += response.usage.output_tokens;
242            accumulated_usage.total_tokens += response.usage.total_tokens;
243
244            match response.choice.first() {
245                AssistantContent::Text(text) => {
246                    return Ok(ToolLoopResult {
247                        message: Message::assistant(&text.text),
248                        chat_history,
249                        tool_calls,
250                        tool_executions,
251                    });
252                }
253                AssistantContent::Reasoning(reasoning) => {
254                    return Ok(ToolLoopResult {
255                        message: Message::assistant(reasoning.reasoning.join("\n")),
256                        chat_history,
257                        tool_calls,
258                        tool_executions,
259                    });
260                }
261                AssistantContent::ToolCall(tool_call) => {
262                    // Execute tool and continue
263                    let tool_name = &tool_call.function.name;
264                    let args_str = tool_call.function.arguments.to_string();
265
266                    let mut tool_result = format!("Tool '{}' not found", tool_name);
267                    for tool in &mut tools {
268                        let def = tool.definition("".to_string()).await;
269                        if def.name == *tool_name {
270                            // For now, just indicate the tool was called
271                            // Actual tool execution would require knowing the concrete Args type
272                            let args_json: serde_json::Value =
273                                serde_json::from_str(&args_str).unwrap_or_default();
274                            tool_result =
275                                format!("Called tool {} with args: {}", tool_name, args_json);
276                            tool_calls.push(tool_call.clone());
277                            tool_executions.push(tool_result.clone());
278                            break;
279                        }
280                    }
281
282                    chat_history.push(rig::message::Message::Assistant {
283                        id: None,
284                        content: OneOrMany::one(rig::message::AssistantContent::ToolCall(
285                            tool_call.clone(),
286                        )),
287                    });
288
289                    let tool_result_content = if let Some(call_id) = &tool_call.call_id {
290                        UserContent::tool_result_with_call_id(
291                            tool_call.id.clone(),
292                            call_id.clone(),
293                            OneOrMany::one(tool_result.into()),
294                        )
295                    } else {
296                        UserContent::tool_result(
297                            tool_call.id.clone(),
298                            OneOrMany::one(tool_result.into()),
299                        )
300                    };
301
302                    chat_history.push(rig::message::Message::User {
303                        content: OneOrMany::one(tool_result_content),
304                    });
305                }
306            }
307        }
308
309        Err(anyhow::anyhow!("Max tool iterations reached"))
310    }
311
312    pub async fn call(&self, messages: Chat, tools: Vec<Arc<dyn ToolDyn>>) -> Result<LMResponse> {
313        use rig::OneOrMany;
314        use rig::completion::CompletionRequest;
315
316        let request_messages = messages.get_rig_messages();
317
318        let mut tool_definitions = Vec::new();
319        for tool in &tools {
320            tool_definitions.push(tool.definition("".to_string()).await);
321        }
322
323        // Build the completion request manually
324        let mut chat_history = request_messages.conversation;
325        chat_history.push(request_messages.prompt);
326
327        let request = CompletionRequest {
328            preamble: Some(request_messages.system.clone()),
329            chat_history: if chat_history.len() == 1 {
330                OneOrMany::one(chat_history.clone().into_iter().next().unwrap())
331            } else {
332                OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty")
333            },
334            documents: Vec::new(),
335            tools: tool_definitions.clone(),
336            temperature: Some(self.temperature as f64),
337            max_tokens: Some(self.max_tokens as u64),
338            tool_choice: if !tool_definitions.is_empty() {
339                Some(ToolChoice::Auto)
340            } else {
341                None
342            },
343            additional_params: None,
344        };
345
346        // Execute the completion using enum dispatch (zero-cost abstraction)
347        let response = self
348            .client
349            .as_ref()
350            .ok_or_else(|| {
351                anyhow::anyhow!("LM client not initialized. Call build() on LMBuilder.")
352            })?
353            .completion(request)
354            .await?;
355
356        let mut accumulated_usage = LmUsage::from(response.usage);
357
358        // Handle the response
359        let mut tool_loop_result = None;
360        let first_choice = match response.choice.first() {
361            AssistantContent::Text(text) => Message::assistant(&text.text),
362            AssistantContent::Reasoning(reasoning) => {
363                Message::assistant(reasoning.reasoning.join("\n"))
364            }
365            AssistantContent::ToolCall(tool_call) if !tools.is_empty() => {
366                // Only execute tool loop if we have tools available
367                let result = self
368                    .execute_tool_loop(
369                        &tool_call,
370                        tools,
371                        tool_definitions,
372                        chat_history,
373                        request_messages.system,
374                        &mut accumulated_usage,
375                    )
376                    .await
377                    .unwrap();
378                let message = result.message.clone();
379                tool_loop_result = Some(result);
380                message
381            }
382            AssistantContent::ToolCall(tool_call) => {
383                // No tools available, just return a message indicating this
384                let msg = format!(
385                    "Tool call requested: {} with args: {}, but no tools available",
386                    tool_call.function.name, tool_call.function.arguments
387                );
388                Message::assistant(&msg)
389            }
390        };
391
392        let mut full_chat = messages.clone();
393        full_chat.push_message(first_choice.clone());
394
395        Ok(LMResponse {
396            output: first_choice,
397            usage: accumulated_usage,
398            chat: full_chat,
399            tool_calls: tool_loop_result
400                .as_ref()
401                .map(|result| result.tool_calls.clone())
402                .unwrap_or_default(),
403            tool_executions: tool_loop_result
404                .map(|result| result.tool_executions)
405                .unwrap_or_default(),
406        })
407    }
408
409    /// Returns the `n` most recent cached calls.
410    ///
411    /// Panics if caching is disabled for this `LM`.
412    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
413        self.cache_handler
414            .as_ref()
415            .unwrap()
416            .lock()
417            .await
418            .get_history(n)
419            .await
420            .unwrap()
421    }
422}
423
424/// In-memory LM used for deterministic tests and examples.
425#[derive(Clone, Builder, Default)]
426pub struct DummyLM {
427    pub api_key: String,
428    #[builder(default = "https://api.openai.com/v1".to_string())]
429    pub base_url: String,
430    #[builder(default = 0.7)]
431    pub temperature: f32,
432    #[builder(default = 512)]
433    pub max_tokens: u32,
434    #[builder(default = true)]
435    pub cache: bool,
436    /// Cache backing storage shared with the real implementation.
437    pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
438}
439
440impl DummyLM {
441    /// Creates a new [`DummyLM`] with an enabled in-memory cache.
442    pub async fn new() -> Self {
443        let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
444        Self {
445            api_key: "".into(),
446            base_url: "https://api.openai.com/v1".to_string(),
447            temperature: 0.7,
448            max_tokens: 512,
449            cache: true,
450            cache_handler: Some(cache_handler),
451        }
452    }
453
454    /// Mimics [`LM::call`] without hitting a remote provider.
455    ///
456    /// The provided `prediction` becomes the assistant output and is inserted
457    /// into the shared cache when caching is enabled.
458    pub async fn call(
459        &self,
460        example: Example,
461        messages: Chat,
462        prediction: String,
463    ) -> Result<LMResponse> {
464        let mut full_chat = messages.clone();
465        full_chat.push_message(Message::Assistant {
466            content: prediction.clone(),
467        });
468
469        if self.cache
470            && let Some(cache) = self.cache_handler.as_ref()
471        {
472            let (tx, rx) = tokio::sync::mpsc::channel(1);
473            let cache_clone = cache.clone();
474            let example_clone = example.clone();
475
476            // Spawn the cache insert operation to avoid deadlock
477            tokio::spawn(async move {
478                let _ = cache_clone.lock().await.insert(example_clone, rx).await;
479            });
480
481            // Send the result to the cache
482            tx.send(CallResult {
483                prompt: messages.to_json().to_string(),
484                prediction: Prediction::new(
485                    HashMap::from([("prediction".to_string(), prediction.clone().into())]),
486                    LmUsage::default(),
487                ),
488            })
489            .await
490            .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
491        }
492
493        Ok(LMResponse {
494            output: Message::Assistant {
495                content: prediction.clone(),
496            },
497            usage: LmUsage::default(),
498            chat: full_chat,
499            tool_calls: Vec::new(),
500            tool_executions: Vec::new(),
501        })
502    }
503
504    /// Returns cached entries just like [`LM::inspect_history`].
505    pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
506        self.cache_handler
507            .as_ref()
508            .unwrap()
509            .lock()
510            .await
511            .get_history(n)
512            .await
513            .unwrap()
514    }
515}