reasoning_parser/
factory.rs

1// Factory and registry for creating model-specific reasoning parsers.
2// Now with parser pooling support for efficient reuse across requests.
3
4use std::{
5    collections::HashMap,
6    sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12    parsers::{
13        BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, MiniMaxParser, Qwen3Parser,
14        QwenThinkingParser, Step3Parser,
15    },
16    traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19/// Type alias for pooled parser instances.
20/// Uses tokio::Mutex to avoid blocking the async executor.
21pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23/// Type alias for parser creator functions.
24type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26/// Registry for model-specific parsers with pooling support.
27#[derive(Clone)]
28pub struct ParserRegistry {
29    /// Creator functions for parsers (used when pool is empty)
30    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31    /// Pooled parser instances for reuse
32    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33    /// Model pattern to parser name mappings
34    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
35}
36
37impl ParserRegistry {
38    /// Create a new empty registry.
39    pub fn new() -> Self {
40        Self {
41            creators: Arc::new(RwLock::new(HashMap::new())),
42            pool: Arc::new(RwLock::new(HashMap::new())),
43            patterns: Arc::new(RwLock::new(Vec::new())),
44        }
45    }
46
47    /// Register a parser creator for a given parser type.
48    pub fn register_parser<F>(&self, name: &str, creator: F)
49    where
50        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51    {
52        let mut creators = self.creators.write().unwrap();
53        creators.insert(name.to_string(), Arc::new(creator));
54    }
55
56    /// Register a model pattern to parser mapping.
57    /// Patterns are checked in order, first match wins.
58    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59        let mut patterns = self.patterns.write().unwrap();
60        patterns.push((pattern.to_string(), parser_name.to_string()));
61    }
62
63    /// Get a pooled parser by exact name.
64    /// Returns a shared parser instance from the pool, creating one if needed.
65    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66        // First check if we have a pooled instance
67        {
68            let pool = self.pool.read().unwrap();
69            if let Some(parser) = pool.get(name) {
70                return Some(Arc::clone(parser));
71            }
72        }
73
74        // If not in pool, create one and add to pool
75        let creators = self.creators.read().unwrap();
76        if let Some(creator) = creators.get(name) {
77            let parser = Arc::new(Mutex::new(creator()));
78
79            // Add to pool for future use
80            let mut pool = self.pool.write().unwrap();
81            pool.insert(name.to_string(), Arc::clone(&parser));
82
83            Some(parser)
84        } else {
85            None
86        }
87    }
88
89    /// Check if a parser with the given name is registered.
90    pub fn has_parser(&self, name: &str) -> bool {
91        let creators = self.creators.read().unwrap();
92        creators.contains_key(name)
93    }
94
95    /// Create a fresh parser instance by exact name (not pooled).
96    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
97    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98        let creators = self.creators.read().unwrap();
99        creators.get(name).map(|creator| creator())
100    }
101
102    /// Find a pooled parser for a given model ID by pattern matching.
103    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104        let patterns = self.patterns.read().unwrap();
105        let model_lower = model_id.to_lowercase();
106
107        for (pattern, parser_name) in patterns.iter() {
108            if model_lower.contains(&pattern.to_lowercase()) {
109                return self.get_pooled_parser(parser_name);
110            }
111        }
112        None
113    }
114
115    /// Check if a parser can be created for a specific model without actually creating it.
116    /// Returns true if a parser is available (registered) for this model.
117    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118        let patterns = self.patterns.read().unwrap();
119        let model_lower = model_id.to_lowercase();
120
121        for (pattern, parser_name) in patterns.iter() {
122            if model_lower.contains(&pattern.to_lowercase()) {
123                let creators = self.creators.read().unwrap();
124                return creators.contains_key(parser_name);
125            }
126        }
127        false
128    }
129
130    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
131    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
132    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133        let patterns = self.patterns.read().unwrap();
134        let model_lower = model_id.to_lowercase();
135
136        for (pattern, parser_name) in patterns.iter() {
137            if model_lower.contains(&pattern.to_lowercase()) {
138                return self.create_parser(parser_name);
139            }
140        }
141        None
142    }
143
144    /// Clear the parser pool, forcing new instances to be created.
145    /// Useful for testing or when parsers need to be reset globally.
146    pub fn clear_pool(&self) {
147        let mut pool = self.pool.write().unwrap();
148        pool.clear();
149    }
150}
151
152impl Default for ParserRegistry {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Factory for creating reasoning parsers based on model type.
159#[derive(Clone)]
160pub struct ParserFactory {
161    registry: ParserRegistry,
162}
163
164impl ParserFactory {
165    /// Create a new factory with default parsers registered.
166    pub fn new() -> Self {
167        let registry = ParserRegistry::new();
168
169        // Register base parser
170        registry.register_parser("base", || {
171            Box::new(BaseReasoningParser::new(ParserConfig::default()))
172        });
173
174        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
175        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177        // Register Qwen3 parser (starts with in_reasoning=false)
178        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180        // Register Qwen3-thinking parser (starts with in_reasoning=true)
181        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
184        registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
187        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
190        registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192        // Register MiniMax parser (appends <think> token at the beginning)
193        registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195        // Register model patterns
196        registry.register_pattern("deepseek-r1", "deepseek_r1");
197        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
198        registry.register_pattern("qwen-thinking", "qwen3_thinking");
199        registry.register_pattern("qwen3", "qwen3");
200        registry.register_pattern("qwen", "qwen3");
201        registry.register_pattern("glm45", "glm45");
202        registry.register_pattern("glm47", "glm45"); // glm47 uses same reasoning format as glm45
203        registry.register_pattern("kimi", "kimi");
204        registry.register_pattern("step3", "step3");
205        registry.register_pattern("minimax", "minimax");
206        registry.register_pattern("minimax-m2", "minimax");
207        registry.register_pattern("mm-m2", "minimax");
208
209        // Nano V3 uses same format as Qwen3 (requires explicit <think> token)
210        registry.register_pattern("nemotron-nano", "qwen3");
211        registry.register_pattern("nano-v3", "qwen3");
212
213        Self { registry }
214    }
215
216    /// Get a pooled parser for the given model ID.
217    /// Returns a shared instance that can be used concurrently.
218    /// Falls back to a passthrough parser if model is not recognized.
219    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
220        // First try to find by pattern
221        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
222            return parser;
223        }
224
225        // Fall back to no-op parser (get or create passthrough in pool)
226        self.registry
227            .get_pooled_parser("passthrough")
228            .unwrap_or_else(|| {
229                // Register passthrough if not already registered
230                self.registry.register_parser("passthrough", || {
231                    let config = ParserConfig {
232                        think_start_token: "".to_string(),
233                        think_end_token: "".to_string(),
234                        stream_reasoning: true,
235                        max_buffer_size: 65536,
236                        initial_in_reasoning: false,
237                    };
238                    Box::new(
239                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
240                    )
241                });
242                self.registry.get_pooled_parser("passthrough").unwrap()
243            })
244    }
245
246    /// Create a new parser instance for the given model ID.
247    /// Returns a fresh instance (not pooled).
248    /// Use this when you need an isolated parser instance.
249    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
250        // First try to find by pattern
251        if let Some(parser) = self.registry.create_for_model(model_id) {
252            return Ok(parser);
253        }
254
255        // Fall back to no-op parser (base parser without reasoning detection)
256        let config = ParserConfig {
257            think_start_token: "".to_string(),
258            think_end_token: "".to_string(),
259            stream_reasoning: true,
260            max_buffer_size: 65536,
261            initial_in_reasoning: false,
262        };
263        Ok(Box::new(
264            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
265        ))
266    }
267
268    /// Get the internal registry for custom registration.
269    pub fn registry(&self) -> &ParserRegistry {
270        &self.registry
271    }
272
273    /// Clear the parser pool.
274    /// Useful for testing or when parsers need to be reset globally.
275    pub fn clear_pool(&self) {
276        self.registry.clear_pool();
277    }
278}
279
280impl Default for ParserFactory {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_factory_creates_deepseek_r1() {
292        let factory = ParserFactory::new();
293        let parser = factory.create("deepseek-r1-distill").unwrap();
294        assert_eq!(parser.model_type(), "deepseek_r1");
295    }
296
297    #[test]
298    fn test_factory_creates_qwen3() {
299        let factory = ParserFactory::new();
300        let parser = factory.create("qwen3-7b").unwrap();
301        assert_eq!(parser.model_type(), "qwen3");
302    }
303
304    #[test]
305    fn test_factory_creates_kimi() {
306        let factory = ParserFactory::new();
307        let parser = factory.create("kimi-chat").unwrap();
308        assert_eq!(parser.model_type(), "kimi");
309    }
310
311    #[test]
312    fn test_factory_fallback_to_passthrough() {
313        let factory = ParserFactory::new();
314        let parser = factory.create("unknown-model").unwrap();
315        assert_eq!(parser.model_type(), "passthrough");
316    }
317
318    #[test]
319    fn test_case_insensitive_matching() {
320        let factory = ParserFactory::new();
321        let parser1 = factory.create("DeepSeek-R1").unwrap();
322        let parser2 = factory.create("QWEN3").unwrap();
323        let parser3 = factory.create("Kimi").unwrap();
324
325        assert_eq!(parser1.model_type(), "deepseek_r1");
326        assert_eq!(parser2.model_type(), "qwen3");
327        assert_eq!(parser3.model_type(), "kimi");
328    }
329
330    #[test]
331    fn test_step3_model() {
332        let factory = ParserFactory::new();
333        let step3 = factory.create("step3-model").unwrap();
334        assert_eq!(step3.model_type(), "step3");
335    }
336
337    #[test]
338    fn test_glm45_model() {
339        let factory = ParserFactory::new();
340        let glm45 = factory.create("glm45-v2").unwrap();
341        assert_eq!(glm45.model_type(), "glm45");
342    }
343
344    #[test]
345    fn test_minimax_model() {
346        let factory = ParserFactory::new();
347        let minimax = factory.create("minimax-m2").unwrap();
348        assert_eq!(minimax.model_type(), "minimax");
349
350        // Also test alternate patterns
351        let mm = factory.create("mm-m2-chat").unwrap();
352        assert_eq!(mm.model_type(), "minimax");
353    }
354
355    #[tokio::test]
356    async fn test_pooled_parser_reuse() {
357        let factory = ParserFactory::new();
358
359        // Get the same parser twice - should be the same instance
360        let parser1 = factory.get_pooled("deepseek-r1");
361        let parser2 = factory.get_pooled("deepseek-r1");
362
363        // Both should point to the same Arc
364        assert!(Arc::ptr_eq(&parser1, &parser2));
365
366        // Different models should get different parsers
367        let parser3 = factory.get_pooled("qwen3");
368        assert!(!Arc::ptr_eq(&parser1, &parser3));
369    }
370
371    #[tokio::test]
372    async fn test_pooled_parser_concurrent_access() {
373        let factory = ParserFactory::new();
374        let parser = factory.get_pooled("deepseek-r1");
375
376        // Spawn multiple async tasks that use the same parser
377        let mut handles = vec![];
378
379        for i in 0..3 {
380            let parser_clone = Arc::clone(&parser);
381            let handle = tokio::spawn(async move {
382                let mut parser = parser_clone.lock().await;
383                let input = format!("thread {} reasoning</think>answer", i);
384                let result = parser.detect_and_parse_reasoning(&input).unwrap();
385                assert_eq!(result.normal_text, "answer");
386                assert!(result.reasoning_text.contains("reasoning"));
387            });
388            handles.push(handle);
389        }
390
391        // Wait for all tasks to complete
392        for handle in handles {
393            handle.await.unwrap();
394        }
395    }
396
397    #[tokio::test]
398    async fn test_pool_clearing() {
399        let factory = ParserFactory::new();
400
401        // Get a pooled parser
402        let parser1 = factory.get_pooled("deepseek-r1");
403
404        // Clear the pool
405        factory.clear_pool();
406
407        // Get another parser - should be a new instance
408        let parser2 = factory.get_pooled("deepseek-r1");
409
410        // They should be different instances (different Arc pointers)
411        assert!(!Arc::ptr_eq(&parser1, &parser2));
412    }
413
414    #[tokio::test]
415    async fn test_passthrough_parser_pooling() {
416        let factory = ParserFactory::new();
417
418        // Unknown models should get passthrough parser
419        let parser1 = factory.get_pooled("unknown-model-1");
420        let parser2 = factory.get_pooled("unknown-model-2");
421
422        // Both should use the same passthrough parser instance
423        assert!(Arc::ptr_eq(&parser1, &parser2));
424
425        let parser = parser1.lock().await;
426        assert_eq!(parser.model_type(), "passthrough");
427    }
428
429    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
430    async fn test_high_concurrency_parser_access() {
431        use std::{
432            sync::atomic::{AtomicUsize, Ordering},
433            time::Instant,
434        };
435
436        let factory = ParserFactory::new();
437        let num_tasks = 100;
438        let requests_per_task = 50;
439        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
440
441        // Track successful operations
442        let success_count = Arc::new(AtomicUsize::new(0));
443        let error_count = Arc::new(AtomicUsize::new(0));
444
445        let start = Instant::now();
446        let mut handles = vec![];
447
448        for task_id in 0..num_tasks {
449            let factory = factory.clone();
450            let models = models.clone();
451            let success_count = Arc::clone(&success_count);
452            let error_count = Arc::clone(&error_count);
453
454            let handle = tokio::spawn(async move {
455                for request_id in 0..requests_per_task {
456                    // Rotate through different models
457                    let model = &models[(task_id + request_id) % models.len()];
458                    let parser = factory.get_pooled(model);
459
460                    // Use async lock - tokio::Mutex doesn't poison
461                    let mut p = parser.lock().await;
462
463                    // Simulate realistic parsing work with substantial text
464                    // Typical reasoning can be 500-5000 tokens
465                    let reasoning_text = format!(
466                        "Task {} is processing request {}. Let me think through this step by step. \
467                        First, I need to understand the problem. The problem involves analyzing data \
468                        and making calculations. Let me break this down: \n\
469                        1. Initial analysis shows that we have multiple variables to consider. \
470                        2. The data suggests a pattern that needs further investigation. \
471                        3. Computing the values: {} * {} = {}. \
472                        4. Cross-referencing with previous results indicates consistency. \
473                        5. The mathematical proof follows from the axioms... \
474                        6. Considering edge cases and boundary conditions... \
475                        7. Validating against known constraints... \
476                        8. The conclusion follows logically from premises A, B, and C. \
477                        This reasoning chain demonstrates the validity of our approach.",
478                        task_id, request_id, task_id, request_id, task_id * request_id
479                    );
480
481                    let answer_text = format!(
482                        "Based on my analysis, the answer for task {} request {} is: \
483                        The solution involves multiple steps as outlined in the reasoning. \
484                        The final result is {} with confidence level high. \
485                        This conclusion is supported by rigorous mathematical analysis \
486                        and has been validated against multiple test cases. \
487                        The implementation should handle edge cases appropriately.",
488                        task_id,
489                        request_id,
490                        task_id * request_id
491                    );
492
493                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
494
495                    match p.detect_and_parse_reasoning(&input) {
496                        Ok(result) => {
497                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
498                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
499
500                            // For parsers that accumulate reasoning (stream_reasoning=false)
501                            // the reasoning_text should be populated
502                            if !result.reasoning_text.is_empty() {
503                                assert!(result
504                                    .reasoning_text
505                                    .contains(&format!("Task {}", task_id)));
506                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
507                            }
508
509                            // Normal text should always be present
510                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
511                            success_count.fetch_add(1, Ordering::Relaxed);
512                        }
513                        Err(e) => {
514                            eprintln!("Parse error: {:?}", e);
515                            error_count.fetch_add(1, Ordering::Relaxed);
516                        }
517                    }
518
519                    // Explicitly drop the lock to release it quickly
520                    drop(p);
521                }
522            });
523            handles.push(handle);
524        }
525
526        // Wait for all tasks
527        for handle in handles {
528            handle.await.unwrap();
529        }
530
531        let duration = start.elapsed();
532        let total_requests = num_tasks * requests_per_task;
533        let successes = success_count.load(Ordering::Relaxed);
534        let errors = error_count.load(Ordering::Relaxed);
535
536        // Print stats for debugging
537        println!(
538            "High concurrency test: {} tasks, {} requests each",
539            num_tasks, requests_per_task
540        );
541        println!(
542            "Completed in {:?}, {} successes, {} errors",
543            duration, successes, errors
544        );
545        println!(
546            "Throughput: {:.0} requests/sec",
547            (total_requests as f64) / duration.as_secs_f64()
548        );
549
550        // All requests should succeed
551        assert_eq!(successes, total_requests);
552        assert_eq!(errors, 0);
553
554        // Performance check: should handle at least 1000 req/sec
555        let throughput = (total_requests as f64) / duration.as_secs_f64();
556        assert!(
557            throughput > 1000.0,
558            "Throughput too low: {:.0} req/sec",
559            throughput
560        );
561    }
562
563    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
564    async fn test_concurrent_pool_modifications() {
565        let factory = ParserFactory::new();
566        let mut handles = vec![];
567
568        // Task 1: Continuously get parsers
569        let factory1 = factory.clone();
570        handles.push(tokio::spawn(async move {
571            for _ in 0..100 {
572                let _parser = factory1.get_pooled("deepseek-r1");
573            }
574        }));
575
576        // Task 2: Continuously clear pool
577        let factory2 = factory.clone();
578        handles.push(tokio::spawn(async move {
579            for _ in 0..10 {
580                factory2.clear_pool();
581                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
582            }
583        }));
584
585        // Task 3: Get different parsers
586        let factory3 = factory.clone();
587        handles.push(tokio::spawn(async move {
588            for i in 0..100 {
589                let models = ["qwen3", "kimi", "unknown"];
590                let _parser = factory3.get_pooled(models[i % 3]);
591            }
592        }));
593
594        // Wait for all tasks - should not deadlock or panic
595        for handle in handles {
596            handle.await.unwrap();
597        }
598    }
599}