oxify_connect_llm/
helpers.rs1use crate::{ImageInput, ImageSourceType, LlmRequest, Tool};
7
8pub struct LlmRequestBuilder {
10 request: LlmRequest,
11}
12
13impl Default for LlmRequestBuilder {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl LlmRequestBuilder {
20 pub fn new() -> Self {
22 Self {
23 request: LlmRequest {
24 prompt: String::new(),
25 system_prompt: None,
26 temperature: None,
27 max_tokens: None,
28 tools: Vec::new(),
29 images: Vec::new(),
30 },
31 }
32 }
33
34 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
36 self.request.prompt = prompt.into();
37 self
38 }
39
40 pub fn system(mut self, system_prompt: impl Into<String>) -> Self {
42 self.request.system_prompt = Some(system_prompt.into());
43 self
44 }
45
46 pub fn temperature(mut self, temperature: f64) -> Self {
48 self.request.temperature = Some(temperature);
49 self
50 }
51
52 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
54 self.request.max_tokens = Some(max_tokens);
55 self
56 }
57
58 pub fn tool(mut self, tool: Tool) -> Self {
60 self.request.tools.push(tool);
61 self
62 }
63
64 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
66 self.request.tools.extend(tools);
67 self
68 }
69
70 pub fn image_url(mut self, url: impl Into<String>) -> Self {
72 self.request.images.push(ImageInput {
73 data: url.into(),
74 source_type: ImageSourceType::Url,
75 media_type: None,
76 });
77 self
78 }
79
80 pub fn image_base64(mut self, data: impl Into<String>, media_type: impl Into<String>) -> Self {
82 self.request.images.push(ImageInput {
83 data: data.into(),
84 source_type: ImageSourceType::Base64,
85 media_type: Some(media_type.into()),
86 });
87 self
88 }
89
90 pub fn build(self) -> LlmRequest {
92 self.request
93 }
94}
95
96pub struct QuickRequest;
98
99impl QuickRequest {
100 pub fn simple(prompt: impl Into<String>) -> LlmRequest {
102 LlmRequestBuilder::new().prompt(prompt).build()
103 }
104
105 pub fn chat(
107 prompt: impl Into<String>,
108 system: impl Into<String>,
109 temperature: f64,
110 ) -> LlmRequest {
111 LlmRequestBuilder::new()
112 .prompt(prompt)
113 .system(system)
114 .temperature(temperature)
115 .build()
116 }
117
118 pub fn code(prompt: impl Into<String>) -> LlmRequest {
120 LlmRequestBuilder::new()
121 .prompt(prompt)
122 .system("You are an expert programmer. Generate clean, efficient, and well-documented code.")
123 .temperature(0.2)
124 .build()
125 }
126
127 pub fn creative(prompt: impl Into<String>) -> LlmRequest {
129 LlmRequestBuilder::new()
130 .prompt(prompt)
131 .system("You are a creative writer. Generate engaging and imaginative content.")
132 .temperature(0.9)
133 .build()
134 }
135
136 pub fn summarize(text: impl Into<String>) -> LlmRequest {
138 LlmRequestBuilder::new()
139 .prompt(format!("Summarize the following text:\n\n{}", text.into()))
140 .system("You are a helpful assistant that creates concise summaries.")
141 .temperature(0.3)
142 .max_tokens(500)
143 .build()
144 }
145
146 pub fn translate(text: impl Into<String>, target_lang: impl Into<String>) -> LlmRequest {
148 LlmRequestBuilder::new()
149 .prompt(format!(
150 "Translate the following text to {}:\n\n{}",
151 target_lang.into(),
152 text.into()
153 ))
154 .temperature(0.3)
155 .build()
156 }
157
158 pub fn analyze_image(image_url: impl Into<String>, question: impl Into<String>) -> LlmRequest {
160 LlmRequestBuilder::new()
161 .prompt(question)
162 .image_url(image_url)
163 .build()
164 }
165}
166
167pub struct TokenUtils;
169
170impl TokenUtils {
171 pub fn estimate_tokens(text: &str) -> u32 {
173 ((text.len() as f64) / 4.0).ceil() as u32
174 }
175
176 pub fn estimate_cost(
178 prompt: &str,
179 estimated_completion_tokens: u32,
180 cost_per_1k_input: f64,
181 cost_per_1k_output: f64,
182 ) -> f64 {
183 let prompt_tokens = Self::estimate_tokens(prompt);
184 let input_cost = (prompt_tokens as f64 / 1000.0) * cost_per_1k_input;
185 let output_cost = (estimated_completion_tokens as f64 / 1000.0) * cost_per_1k_output;
186 input_cost + output_cost
187 }
188
189 pub fn exceeds_limit(text: &str, limit: u32) -> bool {
191 Self::estimate_tokens(text) > limit
192 }
193
194 pub fn truncate_to_limit(text: &str, limit: u32) -> String {
196 let chars_limit = (limit as usize) * 4; if text.len() <= chars_limit {
198 text.to_string()
199 } else {
200 format!("{}...", &text[..chars_limit - 3])
201 }
202 }
203}
204
205pub struct ModelUtils;
207
208impl ModelUtils {
209 pub fn is_gpt(model: &str) -> bool {
211 model.starts_with("gpt-") || model.starts_with("o1-")
212 }
213
214 pub fn is_claude(model: &str) -> bool {
216 model.starts_with("claude-")
217 }
218
219 pub fn is_gemini(model: &str) -> bool {
221 model.starts_with("gemini-")
222 }
223
224 pub fn is_local(model: &str) -> bool {
226 model.contains("llama")
228 || model.contains("mistral")
229 || model.contains("mixtral")
230 || model.contains("vicuna")
231 || model.contains("alpaca")
232 }
233
234 pub fn infer_provider(model: &str) -> Option<&str> {
236 if Self::is_gpt(model) {
237 Some("openai")
238 } else if Self::is_claude(model) {
239 Some("anthropic")
240 } else if Self::is_gemini(model) {
241 Some("google")
242 } else if model.starts_with("command") {
243 Some("cohere")
244 } else if Self::is_local(model) {
245 Some("ollama")
246 } else {
247 None
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_request_builder() {
258 let request = LlmRequestBuilder::new()
259 .prompt("Hello")
260 .system("You are helpful")
261 .temperature(0.7)
262 .max_tokens(100)
263 .build();
264
265 assert_eq!(request.prompt, "Hello");
266 assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
267 assert_eq!(request.temperature, Some(0.7));
268 assert_eq!(request.max_tokens, Some(100));
269 }
270
271 #[test]
272 fn test_quick_request_simple() {
273 let request = QuickRequest::simple("Test prompt");
274 assert_eq!(request.prompt, "Test prompt");
275 assert!(request.system_prompt.is_none());
276 }
277
278 #[test]
279 fn test_quick_request_chat() {
280 let request = QuickRequest::chat("Hello", "You are helpful", 0.8);
281 assert_eq!(request.prompt, "Hello");
282 assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
283 assert_eq!(request.temperature, Some(0.8));
284 }
285
286 #[test]
287 fn test_quick_request_code() {
288 let request = QuickRequest::code("Write a function");
289 assert!(request.system_prompt.is_some());
290 assert_eq!(request.temperature, Some(0.2));
291 }
292
293 #[test]
294 fn test_quick_request_creative() {
295 let request = QuickRequest::creative("Write a story");
296 assert_eq!(request.temperature, Some(0.9));
297 }
298
299 #[test]
300 fn test_quick_request_summarize() {
301 let request = QuickRequest::summarize("Long text here");
302 assert!(request.prompt.contains("Summarize"));
303 assert_eq!(request.max_tokens, Some(500));
304 }
305
306 #[test]
307 fn test_token_utils_estimate() {
308 let tokens = TokenUtils::estimate_tokens("Hello, world!");
309 assert!(tokens > 0);
310 }
311
312 #[test]
313 fn test_token_utils_estimate_cost() {
314 let cost = TokenUtils::estimate_cost("Hello", 100, 0.5, 1.5);
315 assert!(cost > 0.0);
316 }
317
318 #[test]
319 fn test_token_utils_exceeds_limit() {
320 let text = "a".repeat(10000);
321 assert!(TokenUtils::exceeds_limit(&text, 100));
322 assert!(!TokenUtils::exceeds_limit("short", 1000));
323 }
324
325 #[test]
326 fn test_token_utils_truncate() {
327 let text = "a".repeat(1000);
328 let truncated = TokenUtils::truncate_to_limit(&text, 10);
329 assert!(truncated.len() < text.len());
330 assert!(truncated.ends_with("..."));
331 }
332
333 #[test]
334 fn test_model_utils_is_gpt() {
335 assert!(ModelUtils::is_gpt("gpt-4"));
336 assert!(ModelUtils::is_gpt("gpt-3.5-turbo"));
337 assert!(ModelUtils::is_gpt("o1-preview"));
338 assert!(!ModelUtils::is_gpt("claude-3"));
339 }
340
341 #[test]
342 fn test_model_utils_is_claude() {
343 assert!(ModelUtils::is_claude("claude-3-opus"));
344 assert!(!ModelUtils::is_claude("gpt-4"));
345 }
346
347 #[test]
348 fn test_model_utils_is_gemini() {
349 assert!(ModelUtils::is_gemini("gemini-pro"));
350 assert!(!ModelUtils::is_gemini("gpt-4"));
351 }
352
353 #[test]
354 fn test_model_utils_infer_provider() {
355 assert_eq!(ModelUtils::infer_provider("gpt-4"), Some("openai"));
356 assert_eq!(
357 ModelUtils::infer_provider("claude-3-opus"),
358 Some("anthropic")
359 );
360 assert_eq!(ModelUtils::infer_provider("gemini-pro"), Some("google"));
361 assert_eq!(ModelUtils::infer_provider("command-r"), Some("cohere"));
362 }
363
364 #[test]
365 fn test_request_builder_with_images() {
366 let request = LlmRequestBuilder::new()
367 .prompt("Analyze this image")
368 .image_url("https://example.com/image.jpg")
369 .build();
370
371 assert_eq!(request.images.len(), 1);
372 assert_eq!(request.images[0].source_type, ImageSourceType::Url);
373 }
374
375 #[test]
376 fn test_request_builder_with_tools() {
377 let tool = Tool {
378 name: "test".to_string(),
379 description: "test tool".to_string(),
380 parameters: serde_json::json!({}),
381 };
382
383 let request = LlmRequestBuilder::new().prompt("Test").tool(tool).build();
384
385 assert_eq!(request.tools.len(), 1);
386 }
387}