1pub mod chat;
2pub mod config;
3pub mod usage;
4
5pub use chat::*;
6pub use config::*;
7pub use usage::*;
8
9use anyhow::Result;
10use async_openai::types::CreateChatCompletionRequestArgs;
11use async_openai::{Client, config::OpenAIConfig};
12
13use bon::Builder;
14use secrecy::{ExposeSecretMut, SecretString};
15
16#[derive(Clone, Debug)]
17pub struct LMResponse {
18 pub chat: Chat,
19 pub config: LMConfig,
20 pub output: Message,
21 pub signature: String,
22}
23
24fn get_base_url(provider: &str) -> String {
25 match provider {
26 "openai" => "https://api.openai.com/v1".to_string(),
27 "anthropic" => "https://api.anthropic.com/v1".to_string(),
28 "google" => "https://generativelanguage.googleapis.com/v1beta/openai".to_string(),
29 "cohere" => "https://api.cohere.ai/compatibility/v1".to_string(),
30 "groq" => "https://api.groq.com/openai/v1".to_string(),
31 "openrouter" => "https://openrouter.ai/api/v1".to_string(),
32 "qwen" => "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".to_string(),
33 "together" => "https://api.together.xyz/v1".to_string(),
34 "xai" => "https://api.x.ai/v1".to_string(),
35 _ => "https://openrouter.ai/api/v1".to_string(),
36 }
37}
38
39#[derive(Clone, Builder)]
40pub struct LM {
41 pub api_key: SecretString,
42 #[builder(default = "https://api.openai.com/v1".to_string())]
43 pub base_url: String,
44 #[builder(default = LMConfig::default())]
45 pub config: LMConfig,
46 #[builder(default = Vec::new())]
47 pub history: Vec<LMResponse>,
48 client: Option<Client<OpenAIConfig>>,
49}
50
51impl LM {
52 fn setup_client(&mut self) {
53 let config = OpenAIConfig::new()
54 .with_api_key(self.api_key.expose_secret_mut().to_string())
55 .with_api_base(self.base_url.clone());
56
57 self.client = Some(Client::with_config(config));
58 }
59
60 pub async fn call(&mut self, messages: Chat, signature: &str) -> Result<(Message, LmUsage)> {
61 if self.client.is_none() {
62 if self.config.model.contains("/") {
63 let model_str = self.config.model.clone();
64 let (provider, model_id) = model_str.split_once("/").unwrap();
65 self.config.model = model_id.to_string();
66 self.base_url = get_base_url(provider);
67 }
68 self.setup_client();
69 }
70
71 let request_messages = messages.get_async_openai_messages();
72
73 let is_gemini = self.config.model.starts_with("gemini-");
75
76 let mut builder = CreateChatCompletionRequestArgs::default();
78
79 builder
80 .model(self.config.model.clone())
81 .messages(request_messages)
82 .temperature(self.config.temperature)
83 .top_p(self.config.top_p)
84 .n(self.config.n)
85 .max_tokens(self.config.max_tokens)
86 .presence_penalty(self.config.presence_penalty);
87
88 if !is_gemini {
90 builder
91 .frequency_penalty(self.config.frequency_penalty)
92 .seed(self.config.seed)
93 .logit_bias(self.config.logit_bias.clone().unwrap_or_default());
94 }
95
96 let request = builder.build()?;
97
98 let response = self.client.as_ref().unwrap().chat().create(request).await?;
99 let first_choice = Message::from(response.choices.first().unwrap().message.clone());
100 let usage = LmUsage::from(response.usage.unwrap());
101
102 self.history.push(LMResponse {
103 chat: messages.clone(),
104 output: first_choice.clone(),
105 config: self.config.clone(),
106 signature: signature.to_string(),
107 });
108
109 Ok((first_choice, usage))
110 }
111
112 pub fn inspect_history(&self, n: usize) -> Vec<LMResponse> {
113 self.history.iter().rev().take(n).cloned().collect()
114 }
115}
116
117#[derive(Clone, Builder, Default)]
118pub struct DummyLM {
119 pub api_key: SecretString,
120 #[builder(default = "https://api.openai.com/v1".to_string())]
121 pub base_url: String,
122 #[builder(default = LMConfig::default())]
123 pub config: LMConfig,
124 #[builder(default = Vec::new())]
125 pub history: Vec<LMResponse>,
126}
127
128impl DummyLM {
129 pub async fn call(
130 &mut self,
131 messages: Chat,
132 signature: &str,
133 prediction: String,
134 ) -> Result<(Message, LmUsage)> {
135 self.history.push(LMResponse {
136 chat: messages.clone(),
137 output: Message::Assistant {
138 content: prediction.clone(),
139 },
140 config: self.config.clone(),
141 signature: signature.to_string(),
142 });
143
144 Ok((
145 Message::Assistant {
146 content: prediction.clone(),
147 },
148 LmUsage::default(),
149 ))
150 }
151
152 pub fn inspect_history(&self, n: usize) -> Vec<LMResponse> {
153 self.history.iter().rev().take(n).cloned().collect()
154 }
155}