1use crate::inference::generation::{
6 GenerationCache, GenerationConfig, GenerationOutput, TextGenerator,
7};
8use crate::inference::reasoning::ReasoningOutput;
9use crate::inference::sampling::SamplingConfig;
10use crate::model::DeepSeekR1Model;
11use crate::utils::error::Result;
12use crate::utils::tokenizer::{Tokenizer, TokenizerConfig};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum ProblemType {
18 Mathematical,
19 Logical,
20 CodeAnalysis,
21 General,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MathSolutionOutput {
27 pub problem: String,
28 pub reasoning_steps: Vec<String>,
29 pub final_answer: Option<String>,
30 pub confidence: f32,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CodeExplanationOutput {
36 pub original_code: String,
37 pub language: Option<String>,
38 pub reasoning_steps: Vec<String>,
39 pub summary: String,
40 pub confidence: f32,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct LogicalSolutionOutput {
46 pub problem: String,
47 pub reasoning_steps: Vec<String>,
48 pub conclusion: String,
49 pub confidence: f32,
50}
51
52pub struct InferenceEngine {
54 model: DeepSeekR1Model,
55 tokenizer: Tokenizer,
56 text_generator: TextGenerator,
57 generation_cache: GenerationCache,
58 default_config: GenerationConfig,
59}
60
61impl InferenceEngine {
62 pub fn new(model: DeepSeekR1Model) -> Result<Self> {
64 let tokenizer_config = TokenizerConfig {
65 vocab_size: model.config().vocab_size,
66 ..TokenizerConfig::default()
67 };
68 let tokenizer = Tokenizer::new(tokenizer_config)?;
69
70 let sampling_config = SamplingConfig::default();
71 let text_generator = TextGenerator::new(sampling_config);
72
73 let generation_cache = GenerationCache::new();
74 let default_config = GenerationConfig::default();
75
76 Ok(Self {
77 model,
78 tokenizer,
79 text_generator,
80 generation_cache,
81 default_config,
82 })
83 }
84
85 pub fn with_configs(
87 model: DeepSeekR1Model,
88 tokenizer_config: TokenizerConfig,
89 sampling_config: SamplingConfig,
90 generation_config: GenerationConfig,
91 ) -> Result<Self> {
92 let tokenizer_config = TokenizerConfig {
94 vocab_size: model.config().vocab_size,
95 ..tokenizer_config
96 };
97 let tokenizer = Tokenizer::new(tokenizer_config)?;
98 let text_generator = TextGenerator::new(sampling_config);
99 let generation_cache = GenerationCache::new();
100
101 Ok(Self {
102 model,
103 tokenizer,
104 text_generator,
105 generation_cache,
106 default_config: generation_config,
107 })
108 }
109
110 pub fn generate_text(&mut self, prompt: &str) -> Result<String> {
112 let output = self.generate_text_with_config(prompt, &self.default_config.clone())?;
113 Ok(output.text)
114 }
115
116 pub fn generate_text_with_config(
118 &mut self,
119 prompt: &str,
120 config: &GenerationConfig,
121 ) -> Result<GenerationOutput> {
122 self.text_generator.generate_with_cache(
123 &mut self.model,
124 &self.tokenizer,
125 prompt,
126 config,
127 &mut self.generation_cache,
128 )
129 }
130
131 pub fn generate_text_streaming<F>(
133 &mut self,
134 prompt: &str,
135 config: &GenerationConfig,
136 mut callback: F,
137 ) -> Result<GenerationOutput>
138 where
139 F: FnMut(&str) -> Result<bool>, {
141 let output = self.generate_text_with_config(prompt, config)?;
144
145 let should_continue = callback(&output.text)?;
147 if !should_continue {
148 return Ok(GenerationOutput::new(
149 output.text,
150 output.tokens_generated,
151 crate::inference::generation::StopReason::Error("Stopped by callback".to_string()),
152 ));
153 }
154
155 Ok(output)
156 }
157
158 pub fn set_generation_config(&mut self, config: GenerationConfig) {
160 self.default_config = config;
161 }
162
163 pub fn generation_config(&self) -> &GenerationConfig {
165 &self.default_config
166 }
167
168 pub fn clear_cache(&mut self) {
170 self.generation_cache.clear();
171 }
172
173 pub fn tokenizer(&self) -> &Tokenizer {
175 &self.tokenizer
176 }
177
178 pub fn generate_with_reasoning(&mut self, prompt: &str) -> Result<ReasoningOutput> {
180 self.text_generator.generate_with_reasoning(
181 &mut self.model,
182 &self.tokenizer,
183 prompt,
184 &self.default_config,
185 )
186 }
187
188 pub fn generate_with_reasoning_config(
190 &mut self,
191 prompt: &str,
192 config: &GenerationConfig,
193 ) -> Result<ReasoningOutput> {
194 self.text_generator.generate_with_reasoning(
195 &mut self.model,
196 &self.tokenizer,
197 prompt,
198 config,
199 )
200 }
201
202 pub fn generate_with_reasoning_detection(
204 &mut self,
205 prompt: &str,
206 ) -> Result<(GenerationOutput, Option<ReasoningOutput>)> {
207 self.text_generator.generate_with_reasoning_detection(
208 &mut self.model,
209 &self.tokenizer,
210 prompt,
211 &self.default_config,
212 )
213 }
214
215 pub fn generate_structured_reasoning(&mut self, prompt: &str) -> Result<ReasoningOutput> {
217 self.text_generator.generate_structured_reasoning(
218 &mut self.model,
219 &self.tokenizer,
220 prompt,
221 &self.default_config,
222 )
223 }
224
225 pub fn solve_math_problem(&mut self, problem: &str) -> Result<ReasoningOutput> {
227 let math_prompt = format!(
228 "Solve this mathematical problem step by step: {}\n\n<think>Let me break this down step by step and show my reasoning.</think>",
229 problem
230 );
231
232 let mut config = self.default_config.clone();
233 #[cfg(test)]
234 {
235 config.max_tokens = config.max_tokens.min(8);
236 }
237 #[cfg(not(test))]
238 {
239 config.max_tokens = 512;
240 }
241
242 self.generate_with_reasoning_config(&math_prompt, &config)
243 }
244
245 pub fn solve_math_problem_detailed(&mut self, problem: &str) -> Result<MathSolutionOutput> {
247 let reasoning_output = self.solve_math_problem(problem)?;
248
249 let final_answer = self.extract_final_answer(&reasoning_output.final_answer);
251
252 Ok(MathSolutionOutput {
253 problem: problem.to_string(),
254 reasoning_steps: reasoning_output.thinking_chain,
255 final_answer,
256 confidence: reasoning_output.confidence,
257 })
258 }
259
260 pub fn explain_code(&mut self, code: &str) -> Result<ReasoningOutput> {
262 let code_prompt = format!(
263 "Explain this code step by step:\n\n```\n{}\n```\n\n<think>Let me analyze this code line by line and explain what it does.</think>",
264 code
265 );
266
267 let mut config = self.default_config.clone();
268 #[cfg(test)]
269 {
270 config.max_tokens = config.max_tokens.min(8);
271 }
272 #[cfg(not(test))]
273 {
274 config.max_tokens = 512; }
276
277 self.generate_with_reasoning_config(&code_prompt, &config)
278 }
279
280 pub fn explain_code_detailed(
282 &mut self,
283 code: &str,
284 language: Option<&str>,
285 ) -> Result<CodeExplanationOutput> {
286 let language_hint = language
287 .map(|lang| format!(" ({})", lang))
288 .unwrap_or_default();
289 let code_prompt = format!(
290 "Analyze and explain this{} code in detail:\n\n```\n{}\n```\n\n<think>Let me break down this code step by step, explaining the purpose, logic, and any important details.</think>",
291 language_hint, code
292 );
293
294 let config = self.default_config.clone();
295 let reasoning_output = self.generate_with_reasoning_config(&code_prompt, &config)?;
296
297 let summary = self.extract_code_summary(&reasoning_output.final_answer);
299
300 Ok(CodeExplanationOutput {
301 original_code: code.to_string(),
302 language: language.map(|s| s.to_string()),
303 reasoning_steps: reasoning_output.thinking_chain,
304 summary,
305 confidence: reasoning_output.confidence,
306 })
307 }
308
309 pub fn solve_logical_problem(&mut self, problem: &str) -> Result<ReasoningOutput> {
311 let logic_prompt = format!(
312 "Solve this logical reasoning problem step by step: {}\n\n<think>Let me work through this logic problem systematically, considering all the given information and constraints.</think>",
313 problem
314 );
315
316 let mut config = self.default_config.clone();
317 #[cfg(test)]
318 {
319 config.max_tokens = config.max_tokens.min(8);
320 }
321 #[cfg(not(test))]
322 {
323 config.max_tokens = 512;
324 }
325
326 self.generate_with_reasoning_config(&logic_prompt, &config)
327 }
328
329 pub fn solve_logical_problem_detailed(
331 &mut self,
332 problem: &str,
333 ) -> Result<LogicalSolutionOutput> {
334 let reasoning_output = self.solve_logical_problem(problem)?;
335
336 let conclusion = self.extract_logical_conclusion(&reasoning_output.final_answer);
338
339 Ok(LogicalSolutionOutput {
340 problem: problem.to_string(),
341 reasoning_steps: reasoning_output.thinking_chain,
342 conclusion,
343 confidence: reasoning_output.confidence,
344 })
345 }
346
347 pub fn solve_problem(
349 &mut self,
350 problem: &str,
351 problem_type: ProblemType,
352 ) -> Result<ReasoningOutput> {
353 match problem_type {
354 ProblemType::Mathematical => self.solve_math_problem(problem),
355 ProblemType::Logical => self.solve_logical_problem(problem),
356 ProblemType::CodeAnalysis => self.explain_code(problem),
357 ProblemType::General => {
358 let general_prompt = format!(
359 "Analyze and solve this problem: {}\n\n<think>Let me think about this problem carefully and work through it step by step.</think>",
360 problem
361 );
362 let config = self.default_config.clone();
363 self.generate_with_reasoning_config(&general_prompt, &config)
364 }
365 }
366 }
367
368 fn extract_final_answer(&self, text: &str) -> Option<String> {
372 let lower_text = text.to_lowercase();
373
374 if let Some(pos) = lower_text.find("answer is ") {
376 let after_answer = &text[pos + 10..];
377 if let Some(number) = self.extract_first_number(after_answer) {
378 return Some(number);
379 }
380 }
381
382 if let Some(pos) = lower_text.find("result: ") {
383 let after_result = &text[pos + 8..];
384 if let Some(number) = self.extract_first_number(after_result) {
385 return Some(number);
386 }
387 }
388
389 if let Some(pos) = text.rfind('=') {
391 let after_equals = &text[pos + 1..];
392 if let Some(number) = self.extract_first_number(after_equals) {
393 return Some(number);
394 }
395 }
396
397 None
398 }
399
400 fn extract_first_number(&self, text: &str) -> Option<String> {
402 let mut number_str = String::new();
403 let mut found_digit = false;
404
405 for ch in text.trim().chars() {
406 if ch.is_ascii_digit() || (ch == '.' && found_digit && !number_str.contains('.')) {
407 number_str.push(ch);
408 found_digit = true;
409 } else if found_digit || !ch.is_whitespace() {
410 break;
411 }
412 }
413
414 if found_digit && !number_str.is_empty() {
415 Some(number_str)
416 } else {
417 None
418 }
419 }
420
421 fn extract_code_summary(&self, text: &str) -> String {
423 if let Some(period_pos) = text.find('.') {
425 text[..period_pos + 1].trim().to_string()
426 } else {
427 text.trim().to_string()
428 }
429 }
430
431 fn extract_logical_conclusion(&self, text: &str) -> String {
433 if let Some(therefore_pos) = text.to_lowercase().find("therefore") {
435 text[therefore_pos..].trim().to_string()
436 } else if let Some(conclusion_pos) = text.to_lowercase().find("conclusion") {
437 text[conclusion_pos..].trim().to_string()
438 } else {
439 text.trim().to_string()
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::model::{DeepSeekR1Model, ModelConfig};
448
449 #[test]
450 fn test_inference_engine_creation() {
451 let config = ModelConfig::default();
452 let model = DeepSeekR1Model::new(config).unwrap();
453 let engine = InferenceEngine::new(model);
454 assert!(engine.is_ok());
455 }
456
457 #[test]
458 fn test_inference_engine_with_configs() {
459 let model_config = ModelConfig::default();
460 let model = DeepSeekR1Model::new(model_config).unwrap();
461
462 let tokenizer_config = TokenizerConfig::default();
463 let sampling_config = SamplingConfig::default();
464 let generation_config = GenerationConfig::default();
465
466 let engine = InferenceEngine::with_configs(
467 model,
468 tokenizer_config,
469 sampling_config,
470 generation_config,
471 );
472 assert!(engine.is_ok());
473 }
474
475 #[test]
476 fn test_generation_config_management() {
477 let config = ModelConfig::default();
478 let model = DeepSeekR1Model::new(config).unwrap();
479 let mut engine = InferenceEngine::new(model).unwrap();
480
481 let default_config = engine.generation_config();
483 assert_eq!(default_config.max_tokens, 256);
484
485 let mut new_config = GenerationConfig::default();
487 new_config.max_tokens = 512;
488 engine.set_generation_config(new_config);
489
490 let updated_config = engine.generation_config();
491 assert_eq!(updated_config.max_tokens, 512);
492 }
493
494 #[test]
495 fn test_cache_management() {
496 let config = ModelConfig::default();
497 let model = DeepSeekR1Model::new(config).unwrap();
498 let mut engine = InferenceEngine::new(model).unwrap();
499
500 engine.clear_cache();
502 }
503
504 #[test]
505 fn test_tokenizer_access() {
506 let config = ModelConfig::default();
507 let model = DeepSeekR1Model::new(config).unwrap();
508 let engine = InferenceEngine::new(model).unwrap();
509
510 let tokenizer = engine.tokenizer();
511 assert!(tokenizer.vocab_size() > 0);
512 }
513
514 #[test]
515 fn test_problem_solving_methods_exist() {
516 let config = ModelConfig::default();
517 let model = DeepSeekR1Model::new(config).unwrap();
518 let engine = InferenceEngine::new(model).unwrap();
519
520 assert!(engine.tokenizer().vocab_size() > 0);
522 assert_eq!(engine.generation_config().max_tokens, 256);
523 }
524
525 #[test]
526 fn test_extract_first_number() {
527 let config = ModelConfig::default();
528 let model = DeepSeekR1Model::new(config).unwrap();
529 let engine = InferenceEngine::new(model).unwrap();
530
531 assert_eq!(engine.extract_first_number("42"), Some("42".to_string()));
532 assert_eq!(
533 engine.extract_first_number("3.14"),
534 Some("3.14".to_string())
535 );
536 assert_eq!(
537 engine.extract_first_number(" 123 "),
538 Some("123".to_string())
539 );
540 assert_eq!(engine.extract_first_number("no numbers here"), None);
541 }
542
543 #[test]
544 fn test_extract_final_answer() {
545 let config = ModelConfig::default();
546 let model = DeepSeekR1Model::new(config).unwrap();
547 let engine = InferenceEngine::new(model).unwrap();
548
549 assert_eq!(
550 engine.extract_final_answer("The answer is 42"),
551 Some("42".to_string())
552 );
553 assert_eq!(
554 engine.extract_final_answer("2 + 2 = 4"),
555 Some("4".to_string())
556 );
557 assert_eq!(
558 engine.extract_final_answer("result: 3.14"),
559 Some("3.14".to_string())
560 );
561 }
562
563 }