ds_r1_rs/inference/
engine.rs

1//! # Inference Engine
2//!
3//! Main inference engine for text generation and reasoning.
4
5use crate::inference::generation::{
6    GenerationCache, GenerationConfig, GenerationOutput, TextGenerator,
7};
8use crate::inference::reasoning::ReasoningOutput;
9use crate::inference::sampling::SamplingConfig;
10use crate::model::DeepSeekR1Model;
11use crate::utils::error::Result;
12use crate::utils::tokenizer::{Tokenizer, TokenizerConfig};
13use serde::{Deserialize, Serialize};
14
15/// Types of problems that can be solved
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum ProblemType {
18    Mathematical,
19    Logical,
20    CodeAnalysis,
21    General,
22}
23
24/// Mathematical solution output with structured steps
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MathSolutionOutput {
27    pub problem: String,
28    pub reasoning_steps: Vec<String>,
29    pub final_answer: Option<String>,
30    pub confidence: f32,
31}
32
33/// Code explanation output with structured analysis
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CodeExplanationOutput {
36    pub original_code: String,
37    pub language: Option<String>,
38    pub reasoning_steps: Vec<String>,
39    pub summary: String,
40    pub confidence: f32,
41}
42
43/// Logical reasoning solution output
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct LogicalSolutionOutput {
46    pub problem: String,
47    pub reasoning_steps: Vec<String>,
48    pub conclusion: String,
49    pub confidence: f32,
50}
51
52/// Main inference engine
53pub struct InferenceEngine {
54    model: DeepSeekR1Model,
55    tokenizer: Tokenizer,
56    text_generator: TextGenerator,
57    generation_cache: GenerationCache,
58    default_config: GenerationConfig,
59}
60
61impl InferenceEngine {
62    /// Create a new inference engine
63    pub fn new(model: DeepSeekR1Model) -> Result<Self> {
64        let tokenizer_config = TokenizerConfig {
65            vocab_size: model.config().vocab_size,
66            ..TokenizerConfig::default()
67        };
68        let tokenizer = Tokenizer::new(tokenizer_config)?;
69
70        let sampling_config = SamplingConfig::default();
71        let text_generator = TextGenerator::new(sampling_config);
72
73        let generation_cache = GenerationCache::new();
74        let default_config = GenerationConfig::default();
75
76        Ok(Self {
77            model,
78            tokenizer,
79            text_generator,
80            generation_cache,
81            default_config,
82        })
83    }
84
85    /// Create a new inference engine with custom configurations
86    pub fn with_configs(
87        model: DeepSeekR1Model,
88        tokenizer_config: TokenizerConfig,
89        sampling_config: SamplingConfig,
90        generation_config: GenerationConfig,
91    ) -> Result<Self> {
92        // Ensure tokenizer vocab size matches model vocab size
93        let tokenizer_config = TokenizerConfig {
94            vocab_size: model.config().vocab_size,
95            ..tokenizer_config
96        };
97        let tokenizer = Tokenizer::new(tokenizer_config)?;
98        let text_generator = TextGenerator::new(sampling_config);
99        let generation_cache = GenerationCache::new();
100
101        Ok(Self {
102            model,
103            tokenizer,
104            text_generator,
105            generation_cache,
106            default_config: generation_config,
107        })
108    }
109
110    /// Generate text from a prompt using default configuration
111    pub fn generate_text(&mut self, prompt: &str) -> Result<String> {
112        let output = self.generate_text_with_config(prompt, &self.default_config.clone())?;
113        Ok(output.text)
114    }
115
116    /// Generate text from a prompt with custom configuration
117    pub fn generate_text_with_config(
118        &mut self,
119        prompt: &str,
120        config: &GenerationConfig,
121    ) -> Result<GenerationOutput> {
122        self.text_generator.generate_with_cache(
123            &mut self.model,
124            &self.tokenizer,
125            prompt,
126            config,
127            &mut self.generation_cache,
128        )
129    }
130
131    /// Generate text with streaming (returns tokens as they are generated)
132    pub fn generate_text_streaming<F>(
133        &mut self,
134        prompt: &str,
135        config: &GenerationConfig,
136        mut callback: F,
137    ) -> Result<GenerationOutput>
138    where
139        F: FnMut(&str) -> Result<bool>, // Returns false to stop generation
140    {
141        // For now, implement as non-streaming but call callback with full result
142        // TODO: Implement true streaming in future iterations
143        let output = self.generate_text_with_config(prompt, config)?;
144
145        // Call callback with the generated text
146        let should_continue = callback(&output.text)?;
147        if !should_continue {
148            return Ok(GenerationOutput::new(
149                output.text,
150                output.tokens_generated,
151                crate::inference::generation::StopReason::Error("Stopped by callback".to_string()),
152            ));
153        }
154
155        Ok(output)
156    }
157
158    /// Set generation configuration
159    pub fn set_generation_config(&mut self, config: GenerationConfig) {
160        self.default_config = config;
161    }
162
163    /// Get current generation configuration
164    pub fn generation_config(&self) -> &GenerationConfig {
165        &self.default_config
166    }
167
168    /// Clear generation cache
169    pub fn clear_cache(&mut self) {
170        self.generation_cache.clear();
171    }
172
173    /// Get tokenizer reference
174    pub fn tokenizer(&self) -> &Tokenizer {
175        &self.tokenizer
176    }
177
178    /// Generate text with reasoning awareness
179    pub fn generate_with_reasoning(&mut self, prompt: &str) -> Result<ReasoningOutput> {
180        self.text_generator.generate_with_reasoning(
181            &mut self.model,
182            &self.tokenizer,
183            prompt,
184            &self.default_config,
185        )
186    }
187
188    /// Generate text with reasoning awareness and custom config
189    pub fn generate_with_reasoning_config(
190        &mut self,
191        prompt: &str,
192        config: &GenerationConfig,
193    ) -> Result<ReasoningOutput> {
194        self.text_generator.generate_with_reasoning(
195            &mut self.model,
196            &self.tokenizer,
197            prompt,
198            config,
199        )
200    }
201
202    /// Generate text with automatic reasoning detection
203    pub fn generate_with_reasoning_detection(
204        &mut self,
205        prompt: &str,
206    ) -> Result<(GenerationOutput, Option<ReasoningOutput>)> {
207        self.text_generator.generate_with_reasoning_detection(
208            &mut self.model,
209            &self.tokenizer,
210            prompt,
211            &self.default_config,
212        )
213    }
214
215    /// Generate structured reasoning for a given prompt
216    pub fn generate_structured_reasoning(&mut self, prompt: &str) -> Result<ReasoningOutput> {
217        self.text_generator.generate_structured_reasoning(
218            &mut self.model,
219            &self.tokenizer,
220            prompt,
221            &self.default_config,
222        )
223    }
224
225    /// Solve a mathematical problem with reasoning
226    pub fn solve_math_problem(&mut self, problem: &str) -> Result<ReasoningOutput> {
227        let math_prompt = format!(
228            "Solve this mathematical problem step by step: {}\n\n<think>Let me break this down step by step and show my reasoning.</think>",
229            problem
230        );
231
232        let mut config = self.default_config.clone();
233        #[cfg(test)]
234        {
235            config.max_tokens = config.max_tokens.min(8);
236        }
237        #[cfg(not(test))]
238        {
239            config.max_tokens = 512;
240        }
241
242        self.generate_with_reasoning_config(&math_prompt, &config)
243    }
244
245    /// Solve a mathematical problem with detailed step-by-step reasoning
246    pub fn solve_math_problem_detailed(&mut self, problem: &str) -> Result<MathSolutionOutput> {
247        let reasoning_output = self.solve_math_problem(problem)?;
248
249        // Extract final answer using simple string matching
250        let final_answer = self.extract_final_answer(&reasoning_output.final_answer);
251
252        Ok(MathSolutionOutput {
253            problem: problem.to_string(),
254            reasoning_steps: reasoning_output.thinking_chain,
255            final_answer,
256            confidence: reasoning_output.confidence,
257        })
258    }
259
260    /// Explain code with reasoning
261    pub fn explain_code(&mut self, code: &str) -> Result<ReasoningOutput> {
262        let code_prompt = format!(
263            "Explain this code step by step:\n\n```\n{}\n```\n\n<think>Let me analyze this code line by line and explain what it does.</think>",
264            code
265        );
266
267        let mut config = self.default_config.clone();
268        #[cfg(test)]
269        {
270            config.max_tokens = config.max_tokens.min(8);
271        }
272        #[cfg(not(test))]
273        {
274            config.max_tokens = 512; // Allow more tokens for detailed explanations
275        }
276
277        self.generate_with_reasoning_config(&code_prompt, &config)
278    }
279
280    /// Explain code with detailed analysis
281    pub fn explain_code_detailed(
282        &mut self,
283        code: &str,
284        language: Option<&str>,
285    ) -> Result<CodeExplanationOutput> {
286        let language_hint = language
287            .map(|lang| format!(" ({})", lang))
288            .unwrap_or_default();
289        let code_prompt = format!(
290            "Analyze and explain this{} code in detail:\n\n```\n{}\n```\n\n<think>Let me break down this code step by step, explaining the purpose, logic, and any important details.</think>",
291            language_hint, code
292        );
293
294        let config = self.default_config.clone();
295        let reasoning_output = self.generate_with_reasoning_config(&code_prompt, &config)?;
296
297        // Extract code summary from final answer
298        let summary = self.extract_code_summary(&reasoning_output.final_answer);
299
300        Ok(CodeExplanationOutput {
301            original_code: code.to_string(),
302            language: language.map(|s| s.to_string()),
303            reasoning_steps: reasoning_output.thinking_chain,
304            summary,
305            confidence: reasoning_output.confidence,
306        })
307    }
308
309    /// Solve logical reasoning problems
310    pub fn solve_logical_problem(&mut self, problem: &str) -> Result<ReasoningOutput> {
311        let logic_prompt = format!(
312            "Solve this logical reasoning problem step by step: {}\n\n<think>Let me work through this logic problem systematically, considering all the given information and constraints.</think>",
313            problem
314        );
315
316        let mut config = self.default_config.clone();
317        #[cfg(test)]
318        {
319            config.max_tokens = config.max_tokens.min(8);
320        }
321        #[cfg(not(test))]
322        {
323            config.max_tokens = 512;
324        }
325
326        self.generate_with_reasoning_config(&logic_prompt, &config)
327    }
328
329    /// Solve logical reasoning problems with detailed analysis
330    pub fn solve_logical_problem_detailed(
331        &mut self,
332        problem: &str,
333    ) -> Result<LogicalSolutionOutput> {
334        let reasoning_output = self.solve_logical_problem(problem)?;
335
336        // Extract logical conclusion from final answer
337        let conclusion = self.extract_logical_conclusion(&reasoning_output.final_answer);
338
339        Ok(LogicalSolutionOutput {
340            problem: problem.to_string(),
341            reasoning_steps: reasoning_output.thinking_chain,
342            conclusion,
343            confidence: reasoning_output.confidence,
344        })
345    }
346
347    /// General problem solving with adaptive prompting
348    pub fn solve_problem(
349        &mut self,
350        problem: &str,
351        problem_type: ProblemType,
352    ) -> Result<ReasoningOutput> {
353        match problem_type {
354            ProblemType::Mathematical => self.solve_math_problem(problem),
355            ProblemType::Logical => self.solve_logical_problem(problem),
356            ProblemType::CodeAnalysis => self.explain_code(problem),
357            ProblemType::General => {
358                let general_prompt = format!(
359                    "Analyze and solve this problem: {}\n\n<think>Let me think about this problem carefully and work through it step by step.</think>",
360                    problem
361                );
362                let config = self.default_config.clone();
363                self.generate_with_reasoning_config(&general_prompt, &config)
364            }
365        }
366    }
367
368    // Helper methods for extracting structured information from reasoning
369
370    /// Extract final numerical answer from text
371    fn extract_final_answer(&self, text: &str) -> Option<String> {
372        let lower_text = text.to_lowercase();
373
374        // Look for common answer patterns
375        if let Some(pos) = lower_text.find("answer is ") {
376            let after_answer = &text[pos + 10..];
377            if let Some(number) = self.extract_first_number(after_answer) {
378                return Some(number);
379            }
380        }
381
382        if let Some(pos) = lower_text.find("result: ") {
383            let after_result = &text[pos + 8..];
384            if let Some(number) = self.extract_first_number(after_result) {
385                return Some(number);
386            }
387        }
388
389        // Look for equals sign
390        if let Some(pos) = text.rfind('=') {
391            let after_equals = &text[pos + 1..];
392            if let Some(number) = self.extract_first_number(after_equals) {
393                return Some(number);
394            }
395        }
396
397        None
398    }
399
400    /// Extract the first number from a string
401    fn extract_first_number(&self, text: &str) -> Option<String> {
402        let mut number_str = String::new();
403        let mut found_digit = false;
404
405        for ch in text.trim().chars() {
406            if ch.is_ascii_digit() || (ch == '.' && found_digit && !number_str.contains('.')) {
407                number_str.push(ch);
408                found_digit = true;
409            } else if found_digit || !ch.is_whitespace() {
410                break;
411            }
412        }
413
414        if found_digit && !number_str.is_empty() {
415            Some(number_str)
416        } else {
417            None
418        }
419    }
420
421    /// Extract code summary from final answer
422    fn extract_code_summary(&self, text: &str) -> String {
423        // Extract the first sentence or paragraph as summary
424        if let Some(period_pos) = text.find('.') {
425            text[..period_pos + 1].trim().to_string()
426        } else {
427            text.trim().to_string()
428        }
429    }
430
431    /// Extract logical conclusion from final answer
432    fn extract_logical_conclusion(&self, text: &str) -> String {
433        // Look for conclusion indicators
434        if let Some(therefore_pos) = text.to_lowercase().find("therefore") {
435            text[therefore_pos..].trim().to_string()
436        } else if let Some(conclusion_pos) = text.to_lowercase().find("conclusion") {
437            text[conclusion_pos..].trim().to_string()
438        } else {
439            text.trim().to_string()
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::model::{DeepSeekR1Model, ModelConfig};
448
449    #[test]
450    fn test_inference_engine_creation() {
451        let config = ModelConfig::default();
452        let model = DeepSeekR1Model::new(config).unwrap();
453        let engine = InferenceEngine::new(model);
454        assert!(engine.is_ok());
455    }
456
457    #[test]
458    fn test_inference_engine_with_configs() {
459        let model_config = ModelConfig::default();
460        let model = DeepSeekR1Model::new(model_config).unwrap();
461
462        let tokenizer_config = TokenizerConfig::default();
463        let sampling_config = SamplingConfig::default();
464        let generation_config = GenerationConfig::default();
465
466        let engine = InferenceEngine::with_configs(
467            model,
468            tokenizer_config,
469            sampling_config,
470            generation_config,
471        );
472        assert!(engine.is_ok());
473    }
474
475    #[test]
476    fn test_generation_config_management() {
477        let config = ModelConfig::default();
478        let model = DeepSeekR1Model::new(config).unwrap();
479        let mut engine = InferenceEngine::new(model).unwrap();
480
481        // Test default config
482        let default_config = engine.generation_config();
483        assert_eq!(default_config.max_tokens, 256);
484
485        // Test setting new config
486        let mut new_config = GenerationConfig::default();
487        new_config.max_tokens = 512;
488        engine.set_generation_config(new_config);
489
490        let updated_config = engine.generation_config();
491        assert_eq!(updated_config.max_tokens, 512);
492    }
493
494    #[test]
495    fn test_cache_management() {
496        let config = ModelConfig::default();
497        let model = DeepSeekR1Model::new(config).unwrap();
498        let mut engine = InferenceEngine::new(model).unwrap();
499
500        // Should not panic
501        engine.clear_cache();
502    }
503
504    #[test]
505    fn test_tokenizer_access() {
506        let config = ModelConfig::default();
507        let model = DeepSeekR1Model::new(config).unwrap();
508        let engine = InferenceEngine::new(model).unwrap();
509
510        let tokenizer = engine.tokenizer();
511        assert!(tokenizer.vocab_size() > 0);
512    }
513
514    #[test]
515    fn test_problem_solving_methods_exist() {
516        let config = ModelConfig::default();
517        let model = DeepSeekR1Model::new(config).unwrap();
518        let engine = InferenceEngine::new(model).unwrap();
519
520        // Just verify the engine was created successfully - don't run actual inference
521        assert!(engine.tokenizer().vocab_size() > 0);
522        assert_eq!(engine.generation_config().max_tokens, 256);
523    }
524
525    #[test]
526    fn test_extract_first_number() {
527        let config = ModelConfig::default();
528        let model = DeepSeekR1Model::new(config).unwrap();
529        let engine = InferenceEngine::new(model).unwrap();
530
531        assert_eq!(engine.extract_first_number("42"), Some("42".to_string()));
532        assert_eq!(
533            engine.extract_first_number("3.14"),
534            Some("3.14".to_string())
535        );
536        assert_eq!(
537            engine.extract_first_number("  123  "),
538            Some("123".to_string())
539        );
540        assert_eq!(engine.extract_first_number("no numbers here"), None);
541    }
542
543    #[test]
544    fn test_extract_final_answer() {
545        let config = ModelConfig::default();
546        let model = DeepSeekR1Model::new(config).unwrap();
547        let engine = InferenceEngine::new(model).unwrap();
548
549        assert_eq!(
550            engine.extract_final_answer("The answer is 42"),
551            Some("42".to_string())
552        );
553        assert_eq!(
554            engine.extract_final_answer("2 + 2 = 4"),
555            Some("4".to_string())
556        );
557        assert_eq!(
558            engine.extract_final_answer("result: 3.14"),
559            Some("3.14".to_string())
560        );
561    }
562
563    // Note: Full generation testing requires a working model implementation
564    // which will be available in later tasks
565}