1pub mod chat;
2pub mod client_registry;
3pub mod config;
4pub mod usage;
5
6pub use chat::*;
7pub use client_registry::*;
8pub use config::*;
9pub use usage::*;
10
11use anyhow::Result;
12use rig::completion::AssistantContent;
13
14use bon::Builder;
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::Mutex;
17
18use crate::{Cache, CallResult, Example, Prediction, ResponseCache};
19
20#[derive(Clone, Debug)]
21pub struct LMResponse {
22 pub output: Message,
24 pub usage: LmUsage,
26 pub chat: Chat,
28}
29
30pub struct LM {
31 pub config: LMConfig,
32 client: Arc<LMClient>,
33 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
34}
35
36impl Default for LM {
37 fn default() -> Self {
38 tokio::runtime::Runtime::new()
40 .expect("Failed to create tokio runtime")
41 .block_on(Self::new(LMConfig::default()))
42 }
43}
44
45impl Clone for LM {
46 fn clone(&self) -> Self {
47 Self {
48 config: self.config.clone(),
49 client: self.client.clone(),
50 cache_handler: self.cache_handler.clone(),
51 }
52 }
53}
54
55impl LM {
56 pub async fn new(config: LMConfig) -> Self {
63 let client = LMClient::from_model_string(&config.model)
64 .expect("Failed to create client from model string");
65
66 let cache_handler = if config.cache {
67 Some(Arc::new(Mutex::new(ResponseCache::new().await)))
68 } else {
69 None
70 };
71
72 Self {
73 config,
74 client: Arc::new(client),
75 cache_handler,
76 }
77 }
78
79 pub async fn call(&self, messages: Chat) -> Result<LMResponse> {
85 use rig::OneOrMany;
86 use rig::completion::CompletionRequest;
87
88 let request_messages = messages.get_rig_messages();
89
90 let mut chat_history = request_messages.conversation;
92 chat_history.push(request_messages.prompt);
93
94 let request = CompletionRequest {
95 preamble: Some(request_messages.system),
96 chat_history: if chat_history.len() == 1 {
97 OneOrMany::one(chat_history.into_iter().next().unwrap())
98 } else {
99 OneOrMany::many(chat_history).expect("chat_history should not be empty")
100 },
101 documents: Vec::new(),
102 tools: Vec::new(),
103 temperature: Some(self.config.temperature as f64),
104 max_tokens: Some(self.config.max_tokens as u64),
105 tool_choice: None,
106 additional_params: None,
107 };
108
109 let response = self.client.completion(request).await?;
111
112 let first_choice = match response.choice.first() {
113 AssistantContent::Text(text) => Message::assistant(&text.text),
114 AssistantContent::Reasoning(reasoning) => {
115 Message::assistant(reasoning.reasoning.join("\n"))
116 }
117 AssistantContent::ToolCall(_tool_call) => {
118 todo!()
119 }
120 };
121
122 let usage = LmUsage::from(response.usage);
123
124 let mut full_chat = messages.clone();
125 full_chat.push_message(first_choice.clone());
126
127 Ok(LMResponse {
128 output: first_choice,
129 usage,
130 chat: full_chat,
131 })
132 }
133
134 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
138 self.cache_handler
139 .as_ref()
140 .unwrap()
141 .lock()
142 .await
143 .get_history(n)
144 .await
145 .unwrap()
146 }
147}
148
149#[derive(Clone, Builder, Default)]
151pub struct DummyLM {
152 pub api_key: String,
153 #[builder(default = "https://api.openai.com/v1".to_string())]
154 pub base_url: String,
155 #[builder(default = LMConfig::default())]
157 pub config: LMConfig,
158 pub cache_handler: Option<Arc<Mutex<ResponseCache>>>,
160}
161
162impl DummyLM {
163 pub async fn new() -> Self {
165 let cache_handler = Arc::new(Mutex::new(ResponseCache::new().await));
166 Self {
167 api_key: "".into(),
168 base_url: "https://api.openai.com/v1".to_string(),
169 config: LMConfig::default(),
170 cache_handler: Some(cache_handler),
171 }
172 }
173
174 pub async fn call(
179 &self,
180 example: Example,
181 messages: Chat,
182 prediction: String,
183 ) -> Result<LMResponse> {
184 let mut full_chat = messages.clone();
185 full_chat.push_message(Message::Assistant {
186 content: prediction.clone(),
187 });
188
189 if self.config.cache
190 && let Some(cache) = self.cache_handler.as_ref()
191 {
192 let (tx, rx) = tokio::sync::mpsc::channel(1);
193 let cache_clone = cache.clone();
194 let example_clone = example.clone();
195
196 tokio::spawn(async move {
198 let _ = cache_clone.lock().await.insert(example_clone, rx).await;
199 });
200
201 tx.send(CallResult {
203 prompt: messages.to_json().to_string(),
204 prediction: Prediction::new(
205 HashMap::from([("prediction".to_string(), prediction.clone().into())]),
206 LmUsage::default(),
207 ),
208 })
209 .await
210 .map_err(|_| anyhow::anyhow!("Failed to send to cache"))?;
211 }
212
213 Ok(LMResponse {
214 output: Message::Assistant {
215 content: prediction.clone(),
216 },
217 usage: LmUsage::default(),
218 chat: full_chat,
219 })
220 }
221
222 pub async fn inspect_history(&self, n: usize) -> Vec<CallResult> {
224 self.cache_handler
225 .as_ref()
226 .unwrap()
227 .lock()
228 .await
229 .get_history(n)
230 .await
231 .unwrap()
232 }
233}