Skip to main content

forgetless/ai/
llm.rs

1//! Local LLM integration for intelligent context processing
2//!
3//! Uses mistral.rs for fast, pure Rust LLM inference.
4//!
5//! # Supported Models
6//!
7//! - **SmolLM2-135M** - Smallest, fastest
8//! - **SmolLM2-360M** - Better quality, still fast
9//! - **Qwen2.5-0.5B** - Good balance (default)
10//! - **Phi-3-mini** - Best quality, slower
11//! - Any HuggingFace model compatible with mistral.rs
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use forgetless::llm::{LLM, LLMConfig, generate, summarize};
17//!
18//! #[tokio::main]
19//! async fn main() {
20//!     // Default: Qwen2.5-0.5B with Q4 quantization
21//!     LLM::init().await.unwrap();
22//!
23//!     // Or with custom config
24//!     LLM::init_with_config(LLMConfig::phi3_mini()).await.unwrap();
25//!
26//!     // Generate text
27//!     let response = generate("What is Rust?", None).await.unwrap();
28//!
29//!     // Summarize content
30//!     let summary = summarize("Long text...", 50).await.unwrap();
31//! }
32//! ```
33
34use serde::{Deserialize, Serialize};
35use mistralrs::{
36    IsqType, Model, RequestBuilder, TextMessageRole, TextMessages, TextModelBuilder,
37};
38use std::sync::{Arc, OnceLock};
39use tokio::sync::Mutex;
40
41use crate::core::error::{Error, Result};
42
43// ============================================================================
44// Configuration
45// ============================================================================
46
47/// LLM Configuration
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LLMConfig {
50    /// Model ID on HuggingFace (e.g., "HuggingFaceTB/SmolLM2-135M-Instruct")
51    pub model_id: String,
52
53    /// Quantization level
54    pub quantization: Quantization,
55
56    /// Temperature for sampling (0.0 = deterministic, 1.0 = creative)
57    pub temperature: f64,
58
59    /// Top-p (nucleus) sampling
60    pub top_p: f64,
61
62    /// Maximum tokens to generate
63    pub max_tokens: usize,
64
65    /// Repetition penalty (1.0 = no penalty)
66    pub repetition_penalty: f32,
67}
68
69/// Quantization options for model compression
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum Quantization {
72    /// No quantization (full precision, largest)
73    None,
74    /// 4-bit quantization (smallest, fastest)
75    Q4,
76    /// 8-bit quantization (balanced)
77    Q8,
78}
79
80impl Default for Quantization {
81    fn default() -> Self {
82        Self::Q4
83    }
84}
85
86impl Default for LLMConfig {
87    fn default() -> Self {
88        // Qwen2.5-0.5B is the smallest model that reliably follows instructions
89        Self::qwen_0_5b()
90    }
91}
92
93impl LLMConfig {
94    /// SmolLM2-135M - Smallest and fastest (135M params)
95    pub fn smollm2() -> Self {
96        Self {
97            model_id: "HuggingFaceTB/SmolLM2-135M-Instruct".to_string(),
98            quantization: Quantization::Q4,
99            temperature: 0.3,
100            top_p: 0.9,
101            max_tokens: 256,
102            repetition_penalty: 1.1,
103        }
104    }
105
106    /// SmolLM2-360M - Small but smarter (360M params)
107    pub fn smollm2_360m() -> Self {
108        Self {
109            model_id: "HuggingFaceTB/SmolLM2-360M-Instruct".to_string(),
110            quantization: Quantization::Q4,
111            temperature: 0.3,
112            top_p: 0.9,
113            max_tokens: 256,
114            repetition_penalty: 1.1,
115        }
116    }
117
118    /// Qwen2.5-0.5B - Good quality, still small (500M params)
119    pub fn qwen_0_5b() -> Self {
120        Self {
121            model_id: "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
122            quantization: Quantization::Q4,
123            temperature: 0.3,
124            top_p: 0.9,
125            max_tokens: 256,
126            repetition_penalty: 1.1,
127        }
128    }
129
130    /// Phi-3-mini - Microsoft's efficient model (3.8B params)
131    pub fn phi3_mini() -> Self {
132        Self {
133            model_id: "microsoft/Phi-3-mini-4k-instruct".to_string(),
134            quantization: Quantization::Q4,
135            temperature: 0.3,
136            top_p: 0.9,
137            max_tokens: 256,
138            repetition_penalty: 1.1,
139        }
140    }
141
142    /// Custom model configuration
143    pub fn custom(model_id: impl Into<String>) -> Self {
144        Self {
145            model_id: model_id.into(),
146            quantization: Quantization::Q4,
147            temperature: 0.3,
148            top_p: 0.9,
149            max_tokens: 256,
150            repetition_penalty: 1.1,
151        }
152    }
153
154    /// Set quantization level
155    pub fn with_quantization(mut self, q: Quantization) -> Self {
156        self.quantization = q;
157        self
158    }
159
160    /// Set temperature
161    pub fn with_temperature(mut self, t: f64) -> Self {
162        self.temperature = t;
163        self
164    }
165
166    /// Set max tokens
167    pub fn with_max_tokens(mut self, max: usize) -> Self {
168        self.max_tokens = max;
169        self
170    }
171}
172
173// ============================================================================
174// LLM Provider
175// ============================================================================
176
177/// Global model instance
178static GLOBAL_MODEL: OnceLock<Arc<Mutex<ModelState>>> = OnceLock::new();
179
180struct ModelState {
181    model: Model,
182    config: LLMConfig,
183}
184
185/// LLM provider for text generation
186pub struct LLM;
187
188impl LLM {
189    /// Initialize with default config (Qwen2.5-0.5B)
190    pub async fn init() -> Result<()> {
191        Self::init_with_config(LLMConfig::default()).await
192    }
193
194    /// Initialize with custom config
195    pub async fn init_with_config(config: LLMConfig) -> Result<()> {
196        if GLOBAL_MODEL.get().is_some() {
197            tracing::warn!("LLM already initialized, skipping");
198            return Ok(());
199        }
200
201        tracing::info!("Loading LLM: {} ({:?})", config.model_id, config.quantization);
202
203        let mut builder = TextModelBuilder::new(config.model_id.clone());
204
205        // Apply quantization
206        builder = match config.quantization {
207            Quantization::Q4 => builder.with_isq(IsqType::Q4_0),
208            Quantization::Q8 => builder.with_isq(IsqType::Q8_0),
209            Quantization::None => builder,
210        };
211
212        let model = builder
213            .with_logging()
214            .build()
215            .await
216            .map_err(|e| Error::Model(format!("Failed to load model: {e}")))?;
217
218        let state = ModelState {
219            model,
220            config: config.clone(),
221        };
222
223        let _ = GLOBAL_MODEL.set(Arc::new(Mutex::new(state)));
224        tracing::info!("LLM loaded successfully: {}", config.model_id);
225
226        Ok(())
227    }
228
229    /// Check if model is loaded
230    pub fn is_loaded() -> bool {
231        GLOBAL_MODEL.get().is_some()
232    }
233
234    /// Get current model ID
235    pub async fn model_id() -> Option<String> {
236        let state = GLOBAL_MODEL.get()?;
237        let guard = state.lock().await;
238        Some(guard.config.model_id.clone())
239    }
240
241    fn get_state() -> Result<Arc<Mutex<ModelState>>> {
242        GLOBAL_MODEL
243            .get()
244            .cloned()
245            .ok_or_else(|| Error::Model("LLM not initialized. Call LLM::init() first.".into()))
246    }
247}
248
249// ============================================================================
250// Generation Functions
251// ============================================================================
252
253/// Default max tokens if not specified (prevent runaway generation)
254const DEFAULT_MAX_TOKENS: usize = 256;
255
256/// Generate text from a prompt
257pub async fn generate(prompt: &str, max_tokens: Option<usize>) -> Result<String> {
258    let state = LLM::get_state()?;
259    let guard = state.lock().await;
260
261    let limit = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
262    let messages = TextMessages::new()
263        .add_message(TextMessageRole::User, prompt);
264
265    let request: RequestBuilder = messages.into();
266    let request = request
267        .set_sampler_max_len(limit)
268        .set_sampler_frequency_penalty(1.2)
269        .set_sampler_presence_penalty(0.6)
270        .set_sampler_temperature(0.7);
271
272    let response = guard
273        .model
274        .send_chat_request(request)
275        .await
276        .map_err(|e| Error::Model(format!("Generation error: {e}")))?;
277
278    let content = response
279        .choices
280        .first()
281        .and_then(|c| c.message.content.as_ref())
282        .map(|s| s.to_string())
283        .unwrap_or_default();
284
285    Ok(content)
286}
287
288/// Generate with system prompt
289pub async fn generate_with_system(system: &str, user: &str, max_tokens: Option<usize>) -> Result<String> {
290    let state = LLM::get_state()?;
291    let guard = state.lock().await;
292
293    let limit = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
294    let messages = TextMessages::new()
295        .add_message(TextMessageRole::System, system)
296        .add_message(TextMessageRole::User, user);
297
298    let request: RequestBuilder = messages.into();
299    let request = request
300        .set_sampler_max_len(limit)
301        .set_sampler_frequency_penalty(1.2)
302        .set_sampler_presence_penalty(0.6)
303        .set_sampler_temperature(0.7);
304
305    let response = guard
306        .model
307        .send_chat_request(request)
308        .await
309        .map_err(|e| Error::Model(format!("Generation error: {e}")))?;
310
311    let content = response
312        .choices
313        .first()
314        .and_then(|c| c.message.content.as_ref())
315        .map(|s| s.to_string())
316        .unwrap_or_default();
317
318    Ok(content)
319}
320
321/// Summarize content to target length
322pub async fn summarize(content: &str, target_words: usize) -> Result<String> {
323    let system = "You are a concise summarizer. Output only the summary, nothing else.";
324    let user = format!(
325        "Summarize in about {target_words} words:\n\n{content}"
326    );
327
328    // Words to tokens ratio is ~1.3, add buffer
329    let max_tokens = (target_words as f32 * 1.5) as usize + 20;
330    generate_with_system(system, &user, Some(max_tokens)).await
331}
332
333/// Polish and organize content chunks
334pub async fn polish(chunks: &[&str], query: Option<&str>) -> Result<String> {
335    let content = chunks.join("\n\n---\n\n");
336
337    let system = "You organize and clean up text. Remove redundancy, improve flow. Output only the cleaned text.";
338
339    let user = match query {
340        Some(q) => format!(
341            "Clean up this content for someone asking: \"{q}\"\n\nContent:\n{content}"
342        ),
343        None => format!("Clean up this content:\n\n{content}"),
344    };
345
346    // Use reasonable limit based on input size, max 512
347    let max_tokens = (content.split_whitespace().count() * 2).min(512);
348    generate_with_system(system, &user, Some(max_tokens)).await
349}
350
351/// Polish content - reorganize and clean up text WITHOUT adding new information
352/// This is used after optimization to make the output more readable.
353/// CRITICAL: The LLM only reorganizes - it does NOT add any new facts or information.
354pub async fn polish_content(content: &str) -> Result<String> {
355    let system = r#"You are a text organizer. Your ONLY job is to clean up and reorganize the given text.
356
357CRITICAL RULES:
3581. Output ONLY content from the input - do NOT add any new information
3592. Remove redundancy and duplicates
3603. Improve flow and organization
3614. Keep all facts exactly as stated in the input
3625. Do NOT generate answers, explanations, or new content
3636. If the text mentions a question, do NOT answer it - just include the question
364
365Output the cleaned, organized version of the text only."#;
366
367    let user = format!("Organize this text:\n\n{content}");
368
369    // Use ~80% of input size as output limit
370    let max_tokens = (content.split_whitespace().count() as f32 * 1.3 * 0.8) as usize;
371    let max_tokens = max_tokens.clamp(100, 1024);
372
373    generate_with_system(system, &user, Some(max_tokens)).await
374}
375
376/// Score content relevance to a query (returns 0.0-1.0)
377pub async fn score_relevance(content: &str, query: &str) -> Result<f32> {
378    let system = "Output ONLY a number. No explanation.";
379    let user = format!(
380        "Relevance score (0.0=unrelated, 1.0=perfect match):\nQ: {}\nC: {}",
381        query,
382        &content[..content.len().min(200)]
383    );
384
385    let response = generate_with_system(system, &user, Some(5)).await?;
386
387    // Extract first number from response
388    let score = response
389        .split_whitespace()
390        .find_map(|word| {
391            word.trim_matches(|c: char| !c.is_numeric() && c != '.')
392                .parse::<f32>()
393                .ok()
394        })
395        .unwrap_or(0.5)
396        .clamp(0.0, 1.0);
397
398    Ok(score)
399}
400
401// ============================================================================
402// Tests
403// ============================================================================
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[tokio::test]
410    #[ignore] // Requires model download
411    async fn test_llm_init() {
412        let result = LLM::init_with_config(LLMConfig::smollm2()).await;
413        assert!(result.is_ok());
414        assert!(LLM::is_loaded());
415    }
416
417    #[tokio::test]
418    #[ignore] // Requires model download
419    async fn test_generate() {
420        LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
421
422        let response = generate("What is 2+2?", Some(20)).await;
423        assert!(response.is_ok());
424        let text = response.unwrap();
425        assert!(!text.is_empty());
426    }
427
428    #[tokio::test]
429    #[ignore] // Requires model download
430    async fn test_summarize() {
431        LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
432
433        let long_text = "Machine learning is a subset of artificial intelligence. \
434                        It involves training algorithms on data to make predictions. \
435                        Deep learning uses neural networks with many layers. \
436                        These techniques power modern AI applications. \
437                        Machine learning models learn patterns from data and use them \
438                        to make decisions without being explicitly programmed.";
439
440        let summary = summarize(long_text, 50).await;
441        assert!(summary.is_ok());
442        let text = summary.unwrap();
443        assert!(!text.is_empty(), "Summary should not be empty");
444    }
445
446    #[tokio::test]
447    #[ignore] // Requires model download
448    async fn test_score_relevance() {
449        LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
450
451        let content = "Rust is a systems programming language focused on safety.";
452        let query = "What is Rust?";
453
454        let score = score_relevance(content, query).await;
455        assert!(score.is_ok());
456        let s = score.unwrap();
457        assert!(s >= 0.0 && s <= 1.0);
458    }
459
460    #[tokio::test]
461    #[ignore] // Requires model download
462    async fn test_polish_content() {
463        LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
464
465        let messy = "The cat sat. The cat sat on mat. Cat was sitting. Mat was soft.";
466
467        let polished = polish_content(messy).await;
468        assert!(polished.is_ok());
469        let text = polished.unwrap();
470        assert!(!text.is_empty());
471    }
472
473    #[test]
474    fn test_config_presets() {
475        let smol = LLMConfig::smollm2();
476        assert!(smol.model_id.contains("SmolLM2"));
477
478        let qwen = LLMConfig::default();
479        assert!(qwen.model_id.contains("Qwen"));
480    }
481}