1pub mod chat;
2pub mod client_registry;
3pub mod usage;
4
5pub use chat::*;
6pub use client_registry::*;
7pub use usage::*;
8
9use anyhow::Result;
10use rig::completion::AssistantContent;
11
12use bon::Builder;
13use std::{collections::HashMap, sync::Arc};
14use tokio::sync::Mutex;
15
16use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
17
18#[derive(Clone, Debug)]
19pub struct LMResponse {
20 pub output: Message,
22 pub usage: LmUsage,
24 pub chat: Chat,
26}
27
28#[derive(Builder)]
29#[builder(finish_fn(vis = "", name = __internal_build))]
30pub struct LM {
31 pub base_url: Option<String>,
32 pub api_key: Option<String>,
33 #[builder(default = "openai:gpt-4o-mini".to_string())]
34 pub model: String,
35 #[builder(default = 0.7)]
36 pub temperature: f32,
37 #[builder(default = 512)]
38 pub max_tokens: u32,
39 #[builder(default = true)]
40 pub cache: bool,
41 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
42 #[builder(skip)]
43 client: Option<Arc<LMClient>>,
44}
45
46impl Default for LM {
47 fn default() -> Self {
48 tokio::runtime::Handle::current().block_on(async { Self::builder().build().await.unwrap() })
49 }
50}
51
52impl Clone for LM {
53 fn clone(&self) -> Self {
54 Self {
55 base_url: self.base_url.clone(),
56 api_key: self.api_key.clone(),
57 model: self.model.clone(),
58 temperature: self.temperature,
59 max_tokens: self.max_tokens,
60 cache: self.cache,
61 cache_handler: self.cache_handler.clone(),
62 client: self.client.clone(),
63 }
64 }
65}
66
67impl LM {
68 async fn initialize_client(mut self) -> Result<Self> {
79 let client = match (&self.base_url, &self.api_key, &self.model) {
81 (Some(base_url), Some(api_key), _) => Arc::new(LMClient::from_openai_compatible(
84 base_url,
85 api_key,
86 &self.model,
87 )?),
88 (Some(base_url), None, _) => Arc::new(LMClient::from_local(base_url, &self.model)?),
91 (None, api_key, model) if model.contains(':') => {
94 Arc::new(LMClient::from_model_string(model, api_key.as_deref())?)
95 }
96 (None, api_key, model) => {
98 let model_str = if model.contains(':') {
99 model.to_string()
100 } else {
101 format!("openai:{}", model)
102 };
103 Arc::new(LMClient::from_model_string(&model_str, api_key.as_deref())?)
104 }
105 };
106
107 self.client = Some(client);
108
109 if self.cache && self.cache_handler.is_none() {
111 self.cache_handler = Some(Arc::new(Mutex::new(ResponseCache::new().await)));
112 }
113
114 Ok(self)
115 }
116}
117
118impl<S: l_m_builder::State> LMBuilder<S> {
120 pub async fn build(self) -> Result<LM> {
127 let lm = self.__internal_build();
128 lm.initialize_client().await
129 }
130}
131
132impl LM {
133 pub async fn call(&self, messages: Chat) -> Result<LMResponse> {
139 use rig::OneOrMany;
140 use rig::completion::CompletionRequest;
141
142 let request_messages = messages.get_rig_messages();
143
144 let mut chat_history = request_messages.conversation;
146 chat_history.push(request_messages.prompt);
147
148 let request = CompletionRequest {
149 preamble: Some(request_messages.system),
150 chat_history: if chat_history.len() == 1 {
151 OneOrMany::one(chat_history.into_iter().next().unwrap())
152 } else {
153 OneOrMany::many(chat_history).expect("chat_history should not be empty")
154 },
155 documents: Vec::new(),
156 tools: Vec::new(),
157 temperature: Some(self.temperature as f64),
158 max_tokens: Some(self.max_tokens as u64),
159 tool_choice: None,
160 additional_params: None,
161 };
162
163 let response = self
165 .client
166 .as_ref()
167 .ok_or_else(|| {
168 anyhow::anyhow!("LM client not initialized. Call build() on LMBuilder.")
169 })?
170 .completion(request)
171 .await?;
172
173 let first_choice = match response.choice.first() {
174 AssistantContent::Text(text) => Message::assistant(&text.text),
175 AssistantContent::Reasoning(reasoning) => {
176 Message::assistant(reasoning.reasoning.join("\n"))
177 }
178 AssistantContent::ToolCall(_tool_call) => {
179 todo!()
180 }
181 };
182
183 let usage = LmUsage::from(response.usage);
184
185 let mut full_chat = messages.clone();
186 full_chat.push_message(first_choice.clone());
187
188 Ok(LMResponse {
189 output: first_choice,
190 usage,
191 chat: full_chat,
192 })
193 }
194
195 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
199 self.cache_handler
200 .as_ref()
201 .unwrap()
202 .lock()
203 .await
204 .get_history(n)
205 .await
206 .unwrap()
207 }
208}
209
210#[derive(Clone, Builder, Default)]
212pub struct DummyLM {
213 pub api_key: String,
214 #[builder(default = "https://api.openai.com/v1".to_string())]
215 pub base_url: String,
216 #[builder(default = 0.7)]
217 pub temperature: f32,
218 #[builder(default = 512)]
219 pub max_tokens: u32,
220 #[builder(default = true)]
221 pub cache: bool,
222 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
224}
225
226impl DummyLM {
227 pub async fn new() -> Self {
229 let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
230 Self {
231 api_key: "".into(),
232 base_url: "https://api.openai.com/v1".to_string(),
233 temperature: 0.7,
234 max_tokens: 512,
235 cache: true,
236 cache_handler: Some(cache_handler),
237 }
238 }
239
240 pub async fn call(
245 &self,
246 example: Example,
247 messages: Chat,
248 prediction: String,
249 ) -> Result<LMResponse> {
250 let mut full_chat = messages.clone();
251 full_chat.push_message(Message::Assistant {
252 content: prediction.clone(),
253 });
254
255 if self.cache
256 && let Some(cache) = self.cache_handler.as_ref()
257 {
258 let (tx, rx) = tokio::sync::mpsc::channel(1);
259 let cache_clone = cache.clone();
260 let example_clone = example.clone();
261
262 tokio::spawn(async move {
264 let _ = cache_clone.lock().await.insert(example_clone, rx).await;
265 });
266
267 tx.send(CallResult {
269 prompt: messages.to_json().to_string(),
270 prediction: Prediction::new(
271 HashMap::from([("prediction".to_string(), prediction.clone().into())]),
272 LmUsage::default(),
273 ),
274 })
275 .await
276 .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
277 }
278
279 Ok(LMResponse {
280 output: Message::Assistant {
281 content: prediction.clone(),
282 },
283 usage: LmUsage::default(),
284 chat: full_chat,
285 })
286 }
287
288 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
290 self.cache_handler
291 .as_ref()
292 .unwrap()
293 .lock()
294 .await
295 .get_history(n)
296 .await
297 .unwrap()
298 }
299}