1pub mod chat;
2pub mod client_registry;
3pub mod usage;
4
5pub use chat::*;
6pub use client_registry::*;
7pub use usage::*;
8
9use anyhow::Result;
10use rig::{completion::AssistantContent, message::ToolCall, message::ToolChoice, tool::ToolDyn};
11
12use bon::Builder;
13use std::{collections::HashMap, sync::Arc};
14use tokio::sync::Mutex;
15
16use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
17
18#[derive(Clone, Debug)]
19pub struct LMResponse {
20 pub output: Message,
22 pub usage: LmUsage,
24 pub chat: Chat,
26 pub tool_calls: Vec<ToolCall>,
28 pub tool_executions: Vec<String>,
30}
31
32#[derive(Builder)]
33#[builder(finish_fn(vis = "", name = __internal_build))]
34pub struct LM {
35 pub base_url: Option<String>,
36 pub api_key: Option<String>,
37 #[builder(default = "openai:gpt-4o-mini".to_string())]
38 pub model: String,
39 #[builder(default = 0.7)]
40 pub temperature: f32,
41 #[builder(default = 512)]
42 pub max_tokens: u32,
43 #[builder(default = 10)]
44 pub max_tool_iterations: u32,
45 #[builder(default = false)]
46 pub cache: bool,
47 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
48 #[builder(skip)]
49 client: Option<Arc<LMClient>>,
50}
51
52impl Default for LM {
53 fn default() -> Self {
54 tokio::runtime::Handle::current().block_on(async { Self::builder().build().await.unwrap() })
55 }
56}
57
58impl Clone for LM {
59 fn clone(&self) -> Self {
60 Self {
61 base_url: self.base_url.clone(),
62 api_key: self.api_key.clone(),
63 model: self.model.clone(),
64 temperature: self.temperature,
65 max_tokens: self.max_tokens,
66 max_tool_iterations: self.max_tool_iterations,
67 cache: self.cache,
68 cache_handler: self.cache_handler.clone(),
69 client: self.client.clone(),
70 }
71 }
72}
73
74impl LM {
75 async fn initialize_client(mut self) -> Result<Self> {
86 let client = match (&self.base_url, &self.api_key, &self.model) {
88 (Some(base_url), Some(api_key), _) => Arc::new(LMClient::from_openai_compatible(
91 base_url,
92 api_key,
93 &self.model,
94 )?),
95 (Some(base_url), None, _) => Arc::new(LMClient::from_local(base_url, &self.model)?),
98 (None, api_key, model) if model.contains(':') => {
101 Arc::new(LMClient::from_model_string(model, api_key.as_deref())?)
102 }
103 (None, api_key, model) => {
105 let model_str = if model.contains(':') {
106 model.to_string()
107 } else {
108 format!("openai:{}", model)
109 };
110 Arc::new(LMClient::from_model_string(&model_str, api_key.as_deref())?)
111 }
112 };
113
114 self.client = Some(client);
115
116 if self.cache && self.cache_handler.is_none() {
118 self.cache_handler = Some(Arc::new(Mutex::new(ResponseCache::new().await)));
119 }
120
121 Ok(self)
122 }
123
124 pub async fn with_client(self, client: LMClient) -> Result<Self> {
125 Ok(LM {
126 client: Some(Arc::new(client)),
127 ..self
128 })
129 }
130}
131
132impl<S: l_m_builder::State> LMBuilder<S> {
134 pub async fn build(self) -> Result<LM> {
141 let lm = self.__internal_build();
142 lm.initialize_client().await
143 }
144}
145
146struct ToolLoopResult {
147 message: Message,
148 #[allow(unused)]
149 chat_history: Vec<rig::message::Message>,
150 tool_calls: Vec<ToolCall>,
151 tool_executions: Vec<String>,
152}
153
154impl LM {
155 async fn execute_tool_loop(
156 &self,
157 initial_tool_call: &rig::message::ToolCall,
158 mut tools: Vec<Arc<dyn ToolDyn>>,
159 tool_definitions: Vec<rig::completion::ToolDefinition>,
160 mut chat_history: Vec<rig::message::Message>,
161 system_prompt: String,
162 accumulated_usage: &mut LmUsage,
163 ) -> Result<ToolLoopResult> {
164 use rig::OneOrMany;
165 use rig::completion::CompletionRequest;
166 use rig::message::UserContent;
167
168 let max_iterations = self.max_tool_iterations as usize;
169
170 let mut tool_calls = Vec::new();
171 let mut tool_executions = Vec::new();
172
173 let tool_name = &initial_tool_call.function.name;
175 let args_str = initial_tool_call.function.arguments.to_string();
176
177 let mut tool_result = format!("Tool '{}' not found", tool_name);
178 for tool in &mut tools {
179 let def = tool.definition("".to_string()).await;
180 if def.name == *tool_name {
181 let args_json: serde_json::Value =
183 serde_json::from_str(&args_str).unwrap_or_default();
184 tool_result = format!("Called tool {} with args: {}", tool_name, args_json);
185 tool_calls.push(initial_tool_call.clone());
186 tool_executions.push(tool_result.clone());
187 break;
188 }
189 }
190
191 chat_history.push(rig::message::Message::Assistant {
193 id: None,
194 content: OneOrMany::one(rig::message::AssistantContent::ToolCall(
195 initial_tool_call.clone(),
196 )),
197 });
198
199 let tool_result_content = if let Some(call_id) = &initial_tool_call.call_id {
200 UserContent::tool_result_with_call_id(
201 initial_tool_call.id.clone(),
202 call_id.clone(),
203 OneOrMany::one(tool_result.into()),
204 )
205 } else {
206 UserContent::tool_result(
207 initial_tool_call.id.clone(),
208 OneOrMany::one(tool_result.into()),
209 )
210 };
211
212 chat_history.push(rig::message::Message::User {
213 content: OneOrMany::one(tool_result_content),
214 });
215
216 for _iteration in 1..max_iterations {
218 let request = CompletionRequest {
219 preamble: Some(system_prompt.clone()),
220 chat_history: if chat_history.len() == 1 {
221 OneOrMany::one(chat_history.clone().into_iter().next().unwrap())
222 } else {
223 OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty")
224 },
225 documents: Vec::new(),
226 tools: tool_definitions.clone(),
227 temperature: Some(self.temperature as f64),
228 max_tokens: Some(self.max_tokens as u64),
229 tool_choice: Some(ToolChoice::Auto),
230 additional_params: None,
231 };
232
233 let response = self
234 .client
235 .as_ref()
236 .ok_or_else(|| anyhow::anyhow!("LM client not initialized"))?
237 .completion(request)
238 .await?;
239
240 accumulated_usage.prompt_tokens += response.usage.input_tokens;
241 accumulated_usage.completion_tokens += response.usage.output_tokens;
242 accumulated_usage.total_tokens += response.usage.total_tokens;
243
244 match response.choice.first() {
245 AssistantContent::Text(text) => {
246 return Ok(ToolLoopResult {
247 message: Message::assistant(&text.text),
248 chat_history,
249 tool_calls,
250 tool_executions,
251 });
252 }
253 AssistantContent::Reasoning(reasoning) => {
254 return Ok(ToolLoopResult {
255 message: Message::assistant(reasoning.reasoning.join("\n")),
256 chat_history,
257 tool_calls,
258 tool_executions,
259 });
260 }
261 AssistantContent::ToolCall(tool_call) => {
262 let tool_name = &tool_call.function.name;
264 let args_str = tool_call.function.arguments.to_string();
265
266 let mut tool_result = format!("Tool '{}' not found", tool_name);
267 for tool in &mut tools {
268 let def = tool.definition("".to_string()).await;
269 if def.name == *tool_name {
270 let args_json: serde_json::Value =
273 serde_json::from_str(&args_str).unwrap_or_default();
274 tool_result =
275 format!("Called tool {} with args: {}", tool_name, args_json);
276 tool_calls.push(tool_call.clone());
277 tool_executions.push(tool_result.clone());
278 break;
279 }
280 }
281
282 chat_history.push(rig::message::Message::Assistant {
283 id: None,
284 content: OneOrMany::one(rig::message::AssistantContent::ToolCall(
285 tool_call.clone(),
286 )),
287 });
288
289 let tool_result_content = if let Some(call_id) = &tool_call.call_id {
290 UserContent::tool_result_with_call_id(
291 tool_call.id.clone(),
292 call_id.clone(),
293 OneOrMany::one(tool_result.into()),
294 )
295 } else {
296 UserContent::tool_result(
297 tool_call.id.clone(),
298 OneOrMany::one(tool_result.into()),
299 )
300 };
301
302 chat_history.push(rig::message::Message::User {
303 content: OneOrMany::one(tool_result_content),
304 });
305 }
306 }
307 }
308
309 Err(anyhow::anyhow!("Max tool iterations reached"))
310 }
311
312 pub async fn call(&self, messages: Chat, tools: Vec<Arc<dyn ToolDyn>>) -> Result<LMResponse> {
313 use rig::OneOrMany;
314 use rig::completion::CompletionRequest;
315
316 let request_messages = messages.get_rig_messages();
317
318 let mut tool_definitions = Vec::new();
319 for tool in &tools {
320 tool_definitions.push(tool.definition("".to_string()).await);
321 }
322
323 let mut chat_history = request_messages.conversation;
325 chat_history.push(request_messages.prompt);
326
327 let request = CompletionRequest {
328 preamble: Some(request_messages.system.clone()),
329 chat_history: if chat_history.len() == 1 {
330 OneOrMany::one(chat_history.clone().into_iter().next().unwrap())
331 } else {
332 OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty")
333 },
334 documents: Vec::new(),
335 tools: tool_definitions.clone(),
336 temperature: Some(self.temperature as f64),
337 max_tokens: Some(self.max_tokens as u64),
338 tool_choice: if !tool_definitions.is_empty() {
339 Some(ToolChoice::Auto)
340 } else {
341 None
342 },
343 additional_params: None,
344 };
345
346 let response = self
348 .client
349 .as_ref()
350 .ok_or_else(|| {
351 anyhow::anyhow!("LM client not initialized. Call build() on LMBuilder.")
352 })?
353 .completion(request)
354 .await?;
355
356 let mut accumulated_usage = LmUsage::from(response.usage);
357
358 let mut tool_loop_result = None;
360 let first_choice = match response.choice.first() {
361 AssistantContent::Text(text) => Message::assistant(&text.text),
362 AssistantContent::Reasoning(reasoning) => {
363 Message::assistant(reasoning.reasoning.join("\n"))
364 }
365 AssistantContent::ToolCall(tool_call) if !tools.is_empty() => {
366 let result = self
368 .execute_tool_loop(
369 &tool_call,
370 tools,
371 tool_definitions,
372 chat_history,
373 request_messages.system,
374 &mut accumulated_usage,
375 )
376 .await
377 .unwrap();
378 let message = result.message.clone();
379 tool_loop_result = Some(result);
380 message
381 }
382 AssistantContent::ToolCall(tool_call) => {
383 let msg = format!(
385 "Tool call requested: {} with args: {}, but no tools available",
386 tool_call.function.name, tool_call.function.arguments
387 );
388 Message::assistant(&msg)
389 }
390 };
391
392 let mut full_chat = messages.clone();
393 full_chat.push_message(first_choice.clone());
394
395 Ok(LMResponse {
396 output: first_choice,
397 usage: accumulated_usage,
398 chat: full_chat,
399 tool_calls: tool_loop_result
400 .as_ref()
401 .map(|result| result.tool_calls.clone())
402 .unwrap_or_default(),
403 tool_executions: tool_loop_result
404 .map(|result| result.tool_executions)
405 .unwrap_or_default(),
406 })
407 }
408
409 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
413 self.cache_handler
414 .as_ref()
415 .unwrap()
416 .lock()
417 .await
418 .get_history(n)
419 .await
420 .unwrap()
421 }
422}
423
424#[derive(Clone, Builder, Default)]
426pub struct DummyLM {
427 pub api_key: String,
428 #[builder(default = "https://api.openai.com/v1".to_string())]
429 pub base_url: String,
430 #[builder(default = 0.7)]
431 pub temperature: f32,
432 #[builder(default = 512)]
433 pub max_tokens: u32,
434 #[builder(default = true)]
435 pub cache: bool,
436 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
438}
439
440impl DummyLM {
441 pub async fn new() -> Self {
443 let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
444 Self {
445 api_key: "".into(),
446 base_url: "https://api.openai.com/v1".to_string(),
447 temperature: 0.7,
448 max_tokens: 512,
449 cache: true,
450 cache_handler: Some(cache_handler),
451 }
452 }
453
454 pub async fn call(
459 &self,
460 example: Example,
461 messages: Chat,
462 prediction: String,
463 ) -> Result<LMResponse> {
464 let mut full_chat = messages.clone();
465 full_chat.push_message(Message::Assistant {
466 content: prediction.clone(),
467 });
468
469 if self.cache
470 && let Some(cache) = self.cache_handler.as_ref()
471 {
472 let (tx, rx) = tokio::sync::mpsc::channel(1);
473 let cache_clone = cache.clone();
474 let example_clone = example.clone();
475
476 tokio::spawn(async move {
478 let _ = cache_clone.lock().await.insert(example_clone, rx).await;
479 });
480
481 tx.send(CallResult {
483 prompt: messages.to_json().to_string(),
484 prediction: Prediction::new(
485 HashMap::from([("prediction".to_string(), prediction.clone().into())]),
486 LmUsage::default(),
487 ),
488 })
489 .await
490 .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
491 }
492
493 Ok(LMResponse {
494 output: Message::Assistant {
495 content: prediction.clone(),
496 },
497 usage: LmUsage::default(),
498 chat: full_chat,
499 tool_calls: Vec::new(),
500 tool_executions: Vec::new(),
501 })
502 }
503
504 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
506 self.cache_handler
507 .as_ref()
508 .unwrap()
509 .lock()
510 .await
511 .get_history(n)
512 .await
513 .unwrap()
514 }
515}