1use serde::{Deserialize, Serialize};
35use mistralrs::{
36 IsqType, Model, RequestBuilder, TextMessageRole, TextMessages, TextModelBuilder,
37};
38use std::sync::{Arc, OnceLock};
39use tokio::sync::Mutex;
40
41use crate::core::error::{Error, Result};
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LLMConfig {
50 pub model_id: String,
52
53 pub quantization: Quantization,
55
56 pub temperature: f64,
58
59 pub top_p: f64,
61
62 pub max_tokens: usize,
64
65 pub repetition_penalty: f32,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum Quantization {
72 None,
74 Q4,
76 Q8,
78}
79
80impl Default for Quantization {
81 fn default() -> Self {
82 Self::Q4
83 }
84}
85
86impl Default for LLMConfig {
87 fn default() -> Self {
88 Self::qwen_0_5b()
90 }
91}
92
93impl LLMConfig {
94 pub fn smollm2() -> Self {
96 Self {
97 model_id: "HuggingFaceTB/SmolLM2-135M-Instruct".to_string(),
98 quantization: Quantization::Q4,
99 temperature: 0.3,
100 top_p: 0.9,
101 max_tokens: 256,
102 repetition_penalty: 1.1,
103 }
104 }
105
106 pub fn smollm2_360m() -> Self {
108 Self {
109 model_id: "HuggingFaceTB/SmolLM2-360M-Instruct".to_string(),
110 quantization: Quantization::Q4,
111 temperature: 0.3,
112 top_p: 0.9,
113 max_tokens: 256,
114 repetition_penalty: 1.1,
115 }
116 }
117
118 pub fn qwen_0_5b() -> Self {
120 Self {
121 model_id: "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
122 quantization: Quantization::Q4,
123 temperature: 0.3,
124 top_p: 0.9,
125 max_tokens: 256,
126 repetition_penalty: 1.1,
127 }
128 }
129
130 pub fn phi3_mini() -> Self {
132 Self {
133 model_id: "microsoft/Phi-3-mini-4k-instruct".to_string(),
134 quantization: Quantization::Q4,
135 temperature: 0.3,
136 top_p: 0.9,
137 max_tokens: 256,
138 repetition_penalty: 1.1,
139 }
140 }
141
142 pub fn custom(model_id: impl Into<String>) -> Self {
144 Self {
145 model_id: model_id.into(),
146 quantization: Quantization::Q4,
147 temperature: 0.3,
148 top_p: 0.9,
149 max_tokens: 256,
150 repetition_penalty: 1.1,
151 }
152 }
153
154 pub fn with_quantization(mut self, q: Quantization) -> Self {
156 self.quantization = q;
157 self
158 }
159
160 pub fn with_temperature(mut self, t: f64) -> Self {
162 self.temperature = t;
163 self
164 }
165
166 pub fn with_max_tokens(mut self, max: usize) -> Self {
168 self.max_tokens = max;
169 self
170 }
171}
172
173static GLOBAL_MODEL: OnceLock<Arc<Mutex<ModelState>>> = OnceLock::new();
179
180struct ModelState {
181 model: Model,
182 config: LLMConfig,
183}
184
185pub struct LLM;
187
188impl LLM {
189 pub async fn init() -> Result<()> {
191 Self::init_with_config(LLMConfig::default()).await
192 }
193
194 pub async fn init_with_config(config: LLMConfig) -> Result<()> {
196 if GLOBAL_MODEL.get().is_some() {
197 tracing::warn!("LLM already initialized, skipping");
198 return Ok(());
199 }
200
201 tracing::info!("Loading LLM: {} ({:?})", config.model_id, config.quantization);
202
203 let mut builder = TextModelBuilder::new(config.model_id.clone());
204
205 builder = match config.quantization {
207 Quantization::Q4 => builder.with_isq(IsqType::Q4_0),
208 Quantization::Q8 => builder.with_isq(IsqType::Q8_0),
209 Quantization::None => builder,
210 };
211
212 let model = builder
213 .with_logging()
214 .build()
215 .await
216 .map_err(|e| Error::Model(format!("Failed to load model: {e}")))?;
217
218 let state = ModelState {
219 model,
220 config: config.clone(),
221 };
222
223 let _ = GLOBAL_MODEL.set(Arc::new(Mutex::new(state)));
224 tracing::info!("LLM loaded successfully: {}", config.model_id);
225
226 Ok(())
227 }
228
229 pub fn is_loaded() -> bool {
231 GLOBAL_MODEL.get().is_some()
232 }
233
234 pub async fn model_id() -> Option<String> {
236 let state = GLOBAL_MODEL.get()?;
237 let guard = state.lock().await;
238 Some(guard.config.model_id.clone())
239 }
240
241 fn get_state() -> Result<Arc<Mutex<ModelState>>> {
242 GLOBAL_MODEL
243 .get()
244 .cloned()
245 .ok_or_else(|| Error::Model("LLM not initialized. Call LLM::init() first.".into()))
246 }
247}
248
249const DEFAULT_MAX_TOKENS: usize = 256;
255
256pub async fn generate(prompt: &str, max_tokens: Option<usize>) -> Result<String> {
258 let state = LLM::get_state()?;
259 let guard = state.lock().await;
260
261 let limit = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
262 let messages = TextMessages::new()
263 .add_message(TextMessageRole::User, prompt);
264
265 let request: RequestBuilder = messages.into();
266 let request = request
267 .set_sampler_max_len(limit)
268 .set_sampler_frequency_penalty(1.2)
269 .set_sampler_presence_penalty(0.6)
270 .set_sampler_temperature(0.7);
271
272 let response = guard
273 .model
274 .send_chat_request(request)
275 .await
276 .map_err(|e| Error::Model(format!("Generation error: {e}")))?;
277
278 let content = response
279 .choices
280 .first()
281 .and_then(|c| c.message.content.as_ref())
282 .map(|s| s.to_string())
283 .unwrap_or_default();
284
285 Ok(content)
286}
287
288pub async fn generate_with_system(system: &str, user: &str, max_tokens: Option<usize>) -> Result<String> {
290 let state = LLM::get_state()?;
291 let guard = state.lock().await;
292
293 let limit = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
294 let messages = TextMessages::new()
295 .add_message(TextMessageRole::System, system)
296 .add_message(TextMessageRole::User, user);
297
298 let request: RequestBuilder = messages.into();
299 let request = request
300 .set_sampler_max_len(limit)
301 .set_sampler_frequency_penalty(1.2)
302 .set_sampler_presence_penalty(0.6)
303 .set_sampler_temperature(0.7);
304
305 let response = guard
306 .model
307 .send_chat_request(request)
308 .await
309 .map_err(|e| Error::Model(format!("Generation error: {e}")))?;
310
311 let content = response
312 .choices
313 .first()
314 .and_then(|c| c.message.content.as_ref())
315 .map(|s| s.to_string())
316 .unwrap_or_default();
317
318 Ok(content)
319}
320
321pub async fn summarize(content: &str, target_words: usize) -> Result<String> {
323 let system = "You are a concise summarizer. Output only the summary, nothing else.";
324 let user = format!(
325 "Summarize in about {target_words} words:\n\n{content}"
326 );
327
328 let max_tokens = (target_words as f32 * 1.5) as usize + 20;
330 generate_with_system(system, &user, Some(max_tokens)).await
331}
332
333pub async fn polish(chunks: &[&str], query: Option<&str>) -> Result<String> {
335 let content = chunks.join("\n\n---\n\n");
336
337 let system = "You organize and clean up text. Remove redundancy, improve flow. Output only the cleaned text.";
338
339 let user = match query {
340 Some(q) => format!(
341 "Clean up this content for someone asking: \"{q}\"\n\nContent:\n{content}"
342 ),
343 None => format!("Clean up this content:\n\n{content}"),
344 };
345
346 let max_tokens = (content.split_whitespace().count() * 2).min(512);
348 generate_with_system(system, &user, Some(max_tokens)).await
349}
350
351pub async fn polish_content(content: &str) -> Result<String> {
355 let system = r#"You are a text organizer. Your ONLY job is to clean up and reorganize the given text.
356
357CRITICAL RULES:
3581. Output ONLY content from the input - do NOT add any new information
3592. Remove redundancy and duplicates
3603. Improve flow and organization
3614. Keep all facts exactly as stated in the input
3625. Do NOT generate answers, explanations, or new content
3636. If the text mentions a question, do NOT answer it - just include the question
364
365Output the cleaned, organized version of the text only."#;
366
367 let user = format!("Organize this text:\n\n{content}");
368
369 let max_tokens = (content.split_whitespace().count() as f32 * 1.3 * 0.8) as usize;
371 let max_tokens = max_tokens.clamp(100, 1024);
372
373 generate_with_system(system, &user, Some(max_tokens)).await
374}
375
376pub async fn score_relevance(content: &str, query: &str) -> Result<f32> {
378 let system = "Output ONLY a number. No explanation.";
379 let user = format!(
380 "Relevance score (0.0=unrelated, 1.0=perfect match):\nQ: {}\nC: {}",
381 query,
382 &content[..content.len().min(200)]
383 );
384
385 let response = generate_with_system(system, &user, Some(5)).await?;
386
387 let score = response
389 .split_whitespace()
390 .find_map(|word| {
391 word.trim_matches(|c: char| !c.is_numeric() && c != '.')
392 .parse::<f32>()
393 .ok()
394 })
395 .unwrap_or(0.5)
396 .clamp(0.0, 1.0);
397
398 Ok(score)
399}
400
401#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[tokio::test]
410 #[ignore] async fn test_llm_init() {
412 let result = LLM::init_with_config(LLMConfig::smollm2()).await;
413 assert!(result.is_ok());
414 assert!(LLM::is_loaded());
415 }
416
417 #[tokio::test]
418 #[ignore] async fn test_generate() {
420 LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
421
422 let response = generate("What is 2+2?", Some(20)).await;
423 assert!(response.is_ok());
424 let text = response.unwrap();
425 assert!(!text.is_empty());
426 }
427
428 #[tokio::test]
429 #[ignore] async fn test_summarize() {
431 LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
432
433 let long_text = "Machine learning is a subset of artificial intelligence. \
434 It involves training algorithms on data to make predictions. \
435 Deep learning uses neural networks with many layers. \
436 These techniques power modern AI applications. \
437 Machine learning models learn patterns from data and use them \
438 to make decisions without being explicitly programmed.";
439
440 let summary = summarize(long_text, 50).await;
441 assert!(summary.is_ok());
442 let text = summary.unwrap();
443 assert!(!text.is_empty(), "Summary should not be empty");
444 }
445
446 #[tokio::test]
447 #[ignore] async fn test_score_relevance() {
449 LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
450
451 let content = "Rust is a systems programming language focused on safety.";
452 let query = "What is Rust?";
453
454 let score = score_relevance(content, query).await;
455 assert!(score.is_ok());
456 let s = score.unwrap();
457 assert!(s >= 0.0 && s <= 1.0);
458 }
459
460 #[tokio::test]
461 #[ignore] async fn test_polish_content() {
463 LLM::init_with_config(LLMConfig::smollm2()).await.unwrap();
464
465 let messy = "The cat sat. The cat sat on mat. Cat was sitting. Mat was soft.";
466
467 let polished = polish_content(messy).await;
468 assert!(polished.is_ok());
469 let text = polished.unwrap();
470 assert!(!text.is_empty());
471 }
472
473 #[test]
474 fn test_config_presets() {
475 let smol = LLMConfig::smollm2();
476 assert!(smol.model_id.contains("SmolLM2"));
477
478 let qwen = LLMConfig::default();
479 assert!(qwen.model_id.contains("Qwen"));
480 }
481}