dspy_rs/core/lm/
mod.rs

1pub mod chat;
2pub mod config;
3pub mod usage;
4
5pub use chat::*;
6pub use config::*;
7pub use usage::*;
8
9use anyhow::Result;
10use async_openai::types::CreateChatCompletionRequestArgs;
11use async_openai::{Client, config::OpenAIConfig};
12
13use bon::Builder;
14use secrecy::{ExposeSecretMut, SecretString};
15
16#[derive(Clone, Debug)]
17pub struct LMResponse {
18    pub chat: Chat,
19    pub config: LMConfig,
20    pub output: Message,
21    pub signature: String,
22}
23
24fn get_base_url(provider: &str) -> String {
25    match provider {
26        "openai" => "https://api.openai.com/v1".to_string(),
27        "anthropic" => "https://api.anthropic.com/v1".to_string(),
28        "google" => "https://generativelanguage.googleapis.com/v1beta/openai".to_string(),
29        "cohere" => "https://api.cohere.ai/compatibility/v1".to_string(),
30        "groq" => "https://api.groq.com/openai/v1".to_string(),
31        "openrouter" => "https://openrouter.ai/api/v1".to_string(),
32        "qwen" => "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".to_string(),
33        "together" => "https://api.together.xyz/v1".to_string(),
34        "xai" => "https://api.x.ai/v1".to_string(),
35        _ => "https://openrouter.ai/api/v1".to_string(),
36    }
37}
38
39#[derive(Clone, Builder)]
40pub struct LM {
41    pub api_key: SecretString,
42    #[builder(default = "https://api.openai.com/v1".to_string())]
43    pub base_url: String,
44    #[builder(default = LMConfig::default())]
45    pub config: LMConfig,
46    #[builder(default = Vec::new())]
47    pub history: Vec<LMResponse>,
48    client: Option<Client<OpenAIConfig>>,
49}
50
51impl LM {
52    fn setup_client(&mut self) {
53        let config = OpenAIConfig::new()
54            .with_api_key(self.api_key.expose_secret_mut().to_string())
55            .with_api_base(self.base_url.clone());
56
57        self.client = Some(Client::with_config(config));
58    }
59
60    pub async fn call(&mut self, messages: Chat, signature: &str) -> Result<(Message, LmUsage)> {
61        if self.client.is_none() {
62            if self.config.model.contains("/") {
63                let model_str = self.config.model.clone();
64                let (provider, model_id) = model_str.split_once("/").unwrap();
65                self.config.model = model_id.to_string();
66                self.base_url = get_base_url(provider);
67            }
68            self.setup_client();
69        }
70
71        let request_messages = messages.get_async_openai_messages();
72
73        // Check if we're using a Gemini model
74        let is_gemini = self.config.model.starts_with("gemini-");
75
76        // Build the base request
77        let mut builder = CreateChatCompletionRequestArgs::default();
78
79        builder
80            .model(self.config.model.clone())
81            .messages(request_messages)
82            .temperature(self.config.temperature)
83            .top_p(self.config.top_p)
84            .n(self.config.n)
85            .max_tokens(self.config.max_tokens)
86            .presence_penalty(self.config.presence_penalty);
87
88        // Only add frequency_penalty, seed, and logit_bias for non-Gemini models
89        if !is_gemini {
90            builder
91                .frequency_penalty(self.config.frequency_penalty)
92                .seed(self.config.seed)
93                .logit_bias(self.config.logit_bias.clone().unwrap_or_default());
94        }
95
96        let request = builder.build()?;
97
98        let response = self.client.as_ref().unwrap().chat().create(request).await?;
99        let first_choice = Message::from(response.choices.first().unwrap().message.clone());
100        let usage = LmUsage::from(response.usage.unwrap());
101
102        self.history.push(LMResponse {
103            chat: messages.clone(),
104            output: first_choice.clone(),
105            config: self.config.clone(),
106            signature: signature.to_string(),
107        });
108
109        Ok((first_choice, usage))
110    }
111
112    pub fn inspect_history(&self, n: usize) -> Vec<LMResponse> {
113        self.history.iter().rev().take(n).cloned().collect()
114    }
115}
116
117#[derive(Clone, Builder, Default)]
118pub struct DummyLM {
119    pub api_key: SecretString,
120    #[builder(default = "https://api.openai.com/v1".to_string())]
121    pub base_url: String,
122    #[builder(default = LMConfig::default())]
123    pub config: LMConfig,
124    #[builder(default = Vec::new())]
125    pub history: Vec<LMResponse>,
126}
127
128impl DummyLM {
129    pub async fn call(
130        &mut self,
131        messages: Chat,
132        signature: &str,
133        prediction: String,
134    ) -> Result<(Message, LmUsage)> {
135        self.history.push(LMResponse {
136            chat: messages.clone(),
137            output: Message::Assistant {
138                content: prediction.clone(),
139            },
140            config: self.config.clone(),
141            signature: signature.to_string(),
142        });
143
144        Ok((
145            Message::Assistant {
146                content: prediction.clone(),
147            },
148            LmUsage::default(),
149        ))
150    }
151
152    pub fn inspect_history(&self, n: usize) -> Vec<LMResponse> {
153        self.history.iter().rev().take(n).cloned().collect()
154    }
155}