oxify_connect_llm/
helpers.rs

1//! Helper utilities for common LLM operations
2//!
3//! This module provides convenient helper functions and utilities to simplify
4//! common LLM-related tasks.
5
6use crate::{ImageInput, ImageSourceType, LlmRequest, Tool};
7
8/// Builder for constructing LLM requests easily
9pub struct LlmRequestBuilder {
10    request: LlmRequest,
11}
12
13impl Default for LlmRequestBuilder {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl LlmRequestBuilder {
20    /// Create a new request builder
21    pub fn new() -> Self {
22        Self {
23            request: LlmRequest {
24                prompt: String::new(),
25                system_prompt: None,
26                temperature: None,
27                max_tokens: None,
28                tools: Vec::new(),
29                images: Vec::new(),
30            },
31        }
32    }
33
34    /// Set the prompt
35    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
36        self.request.prompt = prompt.into();
37        self
38    }
39
40    /// Set the system prompt
41    pub fn system(mut self, system_prompt: impl Into<String>) -> Self {
42        self.request.system_prompt = Some(system_prompt.into());
43        self
44    }
45
46    /// Set the temperature
47    pub fn temperature(mut self, temperature: f64) -> Self {
48        self.request.temperature = Some(temperature);
49        self
50    }
51
52    /// Set max tokens
53    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
54        self.request.max_tokens = Some(max_tokens);
55        self
56    }
57
58    /// Add a tool/function
59    pub fn tool(mut self, tool: Tool) -> Self {
60        self.request.tools.push(tool);
61        self
62    }
63
64    /// Add multiple tools
65    pub fn tools(mut self, tools: Vec<Tool>) -> Self {
66        self.request.tools.extend(tools);
67        self
68    }
69
70    /// Add an image from URL
71    pub fn image_url(mut self, url: impl Into<String>) -> Self {
72        self.request.images.push(ImageInput {
73            data: url.into(),
74            source_type: ImageSourceType::Url,
75            media_type: None,
76        });
77        self
78    }
79
80    /// Add an image from base64 data
81    pub fn image_base64(mut self, data: impl Into<String>, media_type: impl Into<String>) -> Self {
82        self.request.images.push(ImageInput {
83            data: data.into(),
84            source_type: ImageSourceType::Base64,
85            media_type: Some(media_type.into()),
86        });
87        self
88    }
89
90    /// Build the request
91    pub fn build(self) -> LlmRequest {
92        self.request
93    }
94}
95
96/// Quick helpers for common request patterns
97pub struct QuickRequest;
98
99impl QuickRequest {
100    /// Create a simple text generation request
101    pub fn simple(prompt: impl Into<String>) -> LlmRequest {
102        LlmRequestBuilder::new().prompt(prompt).build()
103    }
104
105    /// Create a request with system prompt and temperature
106    pub fn chat(
107        prompt: impl Into<String>,
108        system: impl Into<String>,
109        temperature: f64,
110    ) -> LlmRequest {
111        LlmRequestBuilder::new()
112            .prompt(prompt)
113            .system(system)
114            .temperature(temperature)
115            .build()
116    }
117
118    /// Create a code generation request (low temperature)
119    pub fn code(prompt: impl Into<String>) -> LlmRequest {
120        LlmRequestBuilder::new()
121            .prompt(prompt)
122            .system("You are an expert programmer. Generate clean, efficient, and well-documented code.")
123            .temperature(0.2)
124            .build()
125    }
126
127    /// Create a creative writing request (high temperature)
128    pub fn creative(prompt: impl Into<String>) -> LlmRequest {
129        LlmRequestBuilder::new()
130            .prompt(prompt)
131            .system("You are a creative writer. Generate engaging and imaginative content.")
132            .temperature(0.9)
133            .build()
134    }
135
136    /// Create a summarization request
137    pub fn summarize(text: impl Into<String>) -> LlmRequest {
138        LlmRequestBuilder::new()
139            .prompt(format!("Summarize the following text:\n\n{}", text.into()))
140            .system("You are a helpful assistant that creates concise summaries.")
141            .temperature(0.3)
142            .max_tokens(500)
143            .build()
144    }
145
146    /// Create a translation request
147    pub fn translate(text: impl Into<String>, target_lang: impl Into<String>) -> LlmRequest {
148        LlmRequestBuilder::new()
149            .prompt(format!(
150                "Translate the following text to {}:\n\n{}",
151                target_lang.into(),
152                text.into()
153            ))
154            .temperature(0.3)
155            .build()
156    }
157
158    /// Create a vision request (image analysis)
159    pub fn analyze_image(image_url: impl Into<String>, question: impl Into<String>) -> LlmRequest {
160        LlmRequestBuilder::new()
161            .prompt(question)
162            .image_url(image_url)
163            .build()
164    }
165}
166
167/// Token estimation utilities
168pub struct TokenUtils;
169
170impl TokenUtils {
171    /// Estimate tokens for text (simple heuristic: ~4 chars/token)
172    pub fn estimate_tokens(text: &str) -> u32 {
173        ((text.len() as f64) / 4.0).ceil() as u32
174    }
175
176    /// Estimate cost for a request given pricing
177    pub fn estimate_cost(
178        prompt: &str,
179        estimated_completion_tokens: u32,
180        cost_per_1k_input: f64,
181        cost_per_1k_output: f64,
182    ) -> f64 {
183        let prompt_tokens = Self::estimate_tokens(prompt);
184        let input_cost = (prompt_tokens as f64 / 1000.0) * cost_per_1k_input;
185        let output_cost = (estimated_completion_tokens as f64 / 1000.0) * cost_per_1k_output;
186        input_cost + output_cost
187    }
188
189    /// Check if text is likely to exceed token limit
190    pub fn exceeds_limit(text: &str, limit: u32) -> bool {
191        Self::estimate_tokens(text) > limit
192    }
193
194    /// Truncate text to fit within token limit (rough approximation)
195    pub fn truncate_to_limit(text: &str, limit: u32) -> String {
196        let chars_limit = (limit as usize) * 4; // Rough estimate
197        if text.len() <= chars_limit {
198            text.to_string()
199        } else {
200            format!("{}...", &text[..chars_limit - 3])
201        }
202    }
203}
204
205/// Model name utilities
206pub struct ModelUtils;
207
208impl ModelUtils {
209    /// Check if a model name is a GPT model
210    pub fn is_gpt(model: &str) -> bool {
211        model.starts_with("gpt-") || model.starts_with("o1-")
212    }
213
214    /// Check if a model name is a Claude model
215    pub fn is_claude(model: &str) -> bool {
216        model.starts_with("claude-")
217    }
218
219    /// Check if a model name is a Gemini model
220    pub fn is_gemini(model: &str) -> bool {
221        model.starts_with("gemini-")
222    }
223
224    /// Check if a model name is a local model
225    pub fn is_local(model: &str) -> bool {
226        // Common local model patterns
227        model.contains("llama")
228            || model.contains("mistral")
229            || model.contains("mixtral")
230            || model.contains("vicuna")
231            || model.contains("alpaca")
232    }
233
234    /// Get the provider from model name
235    pub fn infer_provider(model: &str) -> Option<&str> {
236        if Self::is_gpt(model) {
237            Some("openai")
238        } else if Self::is_claude(model) {
239            Some("anthropic")
240        } else if Self::is_gemini(model) {
241            Some("google")
242        } else if model.starts_with("command") {
243            Some("cohere")
244        } else if Self::is_local(model) {
245            Some("ollama")
246        } else {
247            None
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_request_builder() {
258        let request = LlmRequestBuilder::new()
259            .prompt("Hello")
260            .system("You are helpful")
261            .temperature(0.7)
262            .max_tokens(100)
263            .build();
264
265        assert_eq!(request.prompt, "Hello");
266        assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
267        assert_eq!(request.temperature, Some(0.7));
268        assert_eq!(request.max_tokens, Some(100));
269    }
270
271    #[test]
272    fn test_quick_request_simple() {
273        let request = QuickRequest::simple("Test prompt");
274        assert_eq!(request.prompt, "Test prompt");
275        assert!(request.system_prompt.is_none());
276    }
277
278    #[test]
279    fn test_quick_request_chat() {
280        let request = QuickRequest::chat("Hello", "You are helpful", 0.8);
281        assert_eq!(request.prompt, "Hello");
282        assert_eq!(request.system_prompt, Some("You are helpful".to_string()));
283        assert_eq!(request.temperature, Some(0.8));
284    }
285
286    #[test]
287    fn test_quick_request_code() {
288        let request = QuickRequest::code("Write a function");
289        assert!(request.system_prompt.is_some());
290        assert_eq!(request.temperature, Some(0.2));
291    }
292
293    #[test]
294    fn test_quick_request_creative() {
295        let request = QuickRequest::creative("Write a story");
296        assert_eq!(request.temperature, Some(0.9));
297    }
298
299    #[test]
300    fn test_quick_request_summarize() {
301        let request = QuickRequest::summarize("Long text here");
302        assert!(request.prompt.contains("Summarize"));
303        assert_eq!(request.max_tokens, Some(500));
304    }
305
306    #[test]
307    fn test_token_utils_estimate() {
308        let tokens = TokenUtils::estimate_tokens("Hello, world!");
309        assert!(tokens > 0);
310    }
311
312    #[test]
313    fn test_token_utils_estimate_cost() {
314        let cost = TokenUtils::estimate_cost("Hello", 100, 0.5, 1.5);
315        assert!(cost > 0.0);
316    }
317
318    #[test]
319    fn test_token_utils_exceeds_limit() {
320        let text = "a".repeat(10000);
321        assert!(TokenUtils::exceeds_limit(&text, 100));
322        assert!(!TokenUtils::exceeds_limit("short", 1000));
323    }
324
325    #[test]
326    fn test_token_utils_truncate() {
327        let text = "a".repeat(1000);
328        let truncated = TokenUtils::truncate_to_limit(&text, 10);
329        assert!(truncated.len() < text.len());
330        assert!(truncated.ends_with("..."));
331    }
332
333    #[test]
334    fn test_model_utils_is_gpt() {
335        assert!(ModelUtils::is_gpt("gpt-4"));
336        assert!(ModelUtils::is_gpt("gpt-3.5-turbo"));
337        assert!(ModelUtils::is_gpt("o1-preview"));
338        assert!(!ModelUtils::is_gpt("claude-3"));
339    }
340
341    #[test]
342    fn test_model_utils_is_claude() {
343        assert!(ModelUtils::is_claude("claude-3-opus"));
344        assert!(!ModelUtils::is_claude("gpt-4"));
345    }
346
347    #[test]
348    fn test_model_utils_is_gemini() {
349        assert!(ModelUtils::is_gemini("gemini-pro"));
350        assert!(!ModelUtils::is_gemini("gpt-4"));
351    }
352
353    #[test]
354    fn test_model_utils_infer_provider() {
355        assert_eq!(ModelUtils::infer_provider("gpt-4"), Some("openai"));
356        assert_eq!(
357            ModelUtils::infer_provider("claude-3-opus"),
358            Some("anthropic")
359        );
360        assert_eq!(ModelUtils::infer_provider("gemini-pro"), Some("google"));
361        assert_eq!(ModelUtils::infer_provider("command-r"), Some("cohere"));
362    }
363
364    #[test]
365    fn test_request_builder_with_images() {
366        let request = LlmRequestBuilder::new()
367            .prompt("Analyze this image")
368            .image_url("https://example.com/image.jpg")
369            .build();
370
371        assert_eq!(request.images.len(), 1);
372        assert_eq!(request.images[0].source_type, ImageSourceType::Url);
373    }
374
375    #[test]
376    fn test_request_builder_with_tools() {
377        let tool = Tool {
378            name: "test".to_string(),
379            description: "test tool".to_string(),
380            parameters: serde_json::json!({}),
381        };
382
383        let request = LlmRequestBuilder::new().prompt("Test").tool(tool).build();
384
385        assert_eq!(request.tools.len(), 1);
386    }
387}