Skip to main content

reasoning_parser/
factory.rs

1// Factory and registry for creating model-specific reasoning parsers.
2// Now with parser pooling support for efficient reuse across requests.
3
4use std::{
5    collections::HashMap,
6    sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12    parsers::{
13        BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
14        MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
15    },
16    traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19/// Type alias for pooled parser instances.
20/// Uses tokio::Mutex to avoid blocking the async executor.
21pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23/// Type alias for parser creator functions.
24type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26/// Registry for model-specific parsers with pooling support.
27#[derive(Clone)]
28pub struct ParserRegistry {
29    /// Creator functions for parsers (used when pool is empty)
30    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31    /// Pooled parser instances for reuse
32    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33    /// Model pattern to parser name mappings
34    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
35}
36
37impl ParserRegistry {
38    /// Create a new empty registry.
39    pub fn new() -> Self {
40        Self {
41            creators: Arc::new(RwLock::new(HashMap::new())),
42            pool: Arc::new(RwLock::new(HashMap::new())),
43            patterns: Arc::new(RwLock::new(Vec::new())),
44        }
45    }
46
47    /// Register a parser creator for a given parser type.
48    pub fn register_parser<F>(&self, name: &str, creator: F)
49    where
50        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51    {
52        let mut creators = self.creators.write().unwrap();
53        creators.insert(name.to_string(), Arc::new(creator));
54    }
55
56    /// Register a model pattern to parser mapping.
57    /// Patterns are checked in order, first match wins.
58    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59        let mut patterns = self.patterns.write().unwrap();
60        patterns.push((pattern.to_string(), parser_name.to_string()));
61    }
62
63    /// Get a pooled parser by exact name.
64    /// Returns a shared parser instance from the pool, creating one if needed.
65    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66        // First check if we have a pooled instance
67        {
68            let pool = self.pool.read().unwrap();
69            if let Some(parser) = pool.get(name) {
70                return Some(Arc::clone(parser));
71            }
72        }
73
74        // If not in pool, create one and add to pool
75        let creators = self.creators.read().unwrap();
76        if let Some(creator) = creators.get(name) {
77            let parser = Arc::new(Mutex::new(creator()));
78
79            // Add to pool for future use
80            let mut pool = self.pool.write().unwrap();
81            pool.insert(name.to_string(), Arc::clone(&parser));
82
83            Some(parser)
84        } else {
85            None
86        }
87    }
88
89    /// Check if a parser with the given name is registered.
90    pub fn has_parser(&self, name: &str) -> bool {
91        let creators = self.creators.read().unwrap();
92        creators.contains_key(name)
93    }
94
95    /// Create a fresh parser instance by exact name (not pooled).
96    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
97    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98        let creators = self.creators.read().unwrap();
99        creators.get(name).map(|creator| creator())
100    }
101
102    /// Find a pooled parser for a given model ID by pattern matching.
103    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104        let patterns = self.patterns.read().unwrap();
105        let model_lower = model_id.to_lowercase();
106
107        for (pattern, parser_name) in patterns.iter() {
108            if model_lower.contains(&pattern.to_lowercase()) {
109                return self.get_pooled_parser(parser_name);
110            }
111        }
112        None
113    }
114
115    /// Check if a parser can be created for a specific model without actually creating it.
116    /// Returns true if a parser is available (registered) for this model.
117    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118        let patterns = self.patterns.read().unwrap();
119        let model_lower = model_id.to_lowercase();
120
121        for (pattern, parser_name) in patterns.iter() {
122            if model_lower.contains(&pattern.to_lowercase()) {
123                let creators = self.creators.read().unwrap();
124                return creators.contains_key(parser_name);
125            }
126        }
127        false
128    }
129
130    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
131    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
132    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133        let patterns = self.patterns.read().unwrap();
134        let model_lower = model_id.to_lowercase();
135
136        for (pattern, parser_name) in patterns.iter() {
137            if model_lower.contains(&pattern.to_lowercase()) {
138                return self.create_parser(parser_name);
139            }
140        }
141        None
142    }
143
144    /// Clear the parser pool, forcing new instances to be created.
145    /// Useful for testing or when parsers need to be reset globally.
146    pub fn clear_pool(&self) {
147        let mut pool = self.pool.write().unwrap();
148        pool.clear();
149    }
150}
151
152impl Default for ParserRegistry {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Factory for creating reasoning parsers based on model type.
159#[derive(Clone)]
160pub struct ParserFactory {
161    registry: ParserRegistry,
162}
163
164impl ParserFactory {
165    /// Create a new factory with default parsers registered.
166    pub fn new() -> Self {
167        let registry = ParserRegistry::new();
168
169        // Register base parser
170        registry.register_parser("base", || {
171            Box::new(BaseReasoningParser::new(ParserConfig::default()))
172        });
173
174        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
175        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177        // Register Qwen3 parser (starts with in_reasoning=false)
178        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180        // Register Qwen3-thinking parser (starts with in_reasoning=true)
181        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
184        registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
187        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
190        registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192        // Register MiniMax parser (appends <think> token at the beginning)
193        registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195        // Register Cohere Command parser (uses <|START_THINKING|> / <|END_THINKING|>)
196        registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
197
198        // Register NanoV3 parser (same format as DeepSeek-R1)
199        registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
200
201        // Register model patterns
202        registry.register_pattern("deepseek-r1", "deepseek_r1");
203        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
204        registry.register_pattern("qwen-thinking", "qwen3_thinking");
205        registry.register_pattern("qwen3", "qwen3");
206        registry.register_pattern("qwen", "qwen3");
207        registry.register_pattern("glm45", "glm45");
208        registry.register_pattern("glm47", "glm45"); // glm47 uses same reasoning format as glm45
209        registry.register_pattern("kimi", "kimi");
210        registry.register_pattern("step3", "step3");
211        registry.register_pattern("minimax", "minimax");
212        registry.register_pattern("minimax-m2", "minimax");
213        registry.register_pattern("mm-m2", "minimax");
214
215        // Cohere Command models use <|START_THINKING|> / <|END_THINKING|>
216        registry.register_pattern("command-r", "cohere_cmd");
217        registry.register_pattern("command-a", "cohere_cmd");
218        registry.register_pattern("c4ai-command", "cohere_cmd");
219        registry.register_pattern("cohere", "cohere_cmd");
220
221        // Nano V3 / Nemotron uses same format as DeepSeek-R1 (initial_in_reasoning=true)
222        registry.register_pattern("nemotron-nano", "nano_v3");
223        registry.register_pattern("nemotron-super", "nano_v3");
224        registry.register_pattern("nano-v3", "nano_v3");
225
226        Self { registry }
227    }
228
229    /// Get a pooled parser for the given model ID.
230    /// Returns a shared instance that can be used concurrently.
231    /// Falls back to a passthrough parser if model is not recognized.
232    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
233        // First try to find by pattern
234        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
235            return parser;
236        }
237
238        // Fall back to no-op parser (get or create passthrough in pool)
239        self.registry
240            .get_pooled_parser("passthrough")
241            .unwrap_or_else(|| {
242                // Register passthrough if not already registered
243                self.registry.register_parser("passthrough", || {
244                    let config = ParserConfig {
245                        think_start_token: "".to_string(),
246                        think_end_token: "".to_string(),
247                        stream_reasoning: true,
248                        max_buffer_size: 65536,
249                        initial_in_reasoning: false,
250                    };
251                    Box::new(
252                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
253                    )
254                });
255                self.registry.get_pooled_parser("passthrough").unwrap()
256            })
257    }
258
259    /// Create a new parser instance for the given model ID.
260    /// Returns a fresh instance (not pooled).
261    /// Use this when you need an isolated parser instance.
262    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
263        // First try to find by pattern
264        if let Some(parser) = self.registry.create_for_model(model_id) {
265            return Ok(parser);
266        }
267
268        // Fall back to no-op parser (base parser without reasoning detection)
269        let config = ParserConfig {
270            think_start_token: "".to_string(),
271            think_end_token: "".to_string(),
272            stream_reasoning: true,
273            max_buffer_size: 65536,
274            initial_in_reasoning: false,
275        };
276        Ok(Box::new(
277            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
278        ))
279    }
280
281    /// Get the internal registry for custom registration.
282    pub fn registry(&self) -> &ParserRegistry {
283        &self.registry
284    }
285
286    /// Clear the parser pool.
287    /// Useful for testing or when parsers need to be reset globally.
288    pub fn clear_pool(&self) {
289        self.registry.clear_pool();
290    }
291}
292
293impl Default for ParserFactory {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_factory_creates_deepseek_r1() {
305        let factory = ParserFactory::new();
306        let parser = factory.create("deepseek-r1-distill").unwrap();
307        assert_eq!(parser.model_type(), "deepseek_r1");
308    }
309
310    #[test]
311    fn test_factory_creates_qwen3() {
312        let factory = ParserFactory::new();
313        let parser = factory.create("qwen3-7b").unwrap();
314        assert_eq!(parser.model_type(), "qwen3");
315    }
316
317    #[test]
318    fn test_factory_creates_kimi() {
319        let factory = ParserFactory::new();
320        let parser = factory.create("kimi-chat").unwrap();
321        assert_eq!(parser.model_type(), "kimi");
322    }
323
324    #[test]
325    fn test_factory_fallback_to_passthrough() {
326        let factory = ParserFactory::new();
327        let parser = factory.create("unknown-model").unwrap();
328        assert_eq!(parser.model_type(), "passthrough");
329    }
330
331    #[test]
332    fn test_case_insensitive_matching() {
333        let factory = ParserFactory::new();
334        let parser1 = factory.create("DeepSeek-R1").unwrap();
335        let parser2 = factory.create("QWEN3").unwrap();
336        let parser3 = factory.create("Kimi").unwrap();
337
338        assert_eq!(parser1.model_type(), "deepseek_r1");
339        assert_eq!(parser2.model_type(), "qwen3");
340        assert_eq!(parser3.model_type(), "kimi");
341    }
342
343    #[test]
344    fn test_step3_model() {
345        let factory = ParserFactory::new();
346        let step3 = factory.create("step3-model").unwrap();
347        assert_eq!(step3.model_type(), "step3");
348    }
349
350    #[test]
351    fn test_glm45_model() {
352        let factory = ParserFactory::new();
353        let glm45 = factory.create("glm45-v2").unwrap();
354        assert_eq!(glm45.model_type(), "glm45");
355    }
356
357    #[test]
358    fn test_minimax_model() {
359        let factory = ParserFactory::new();
360        let minimax = factory.create("minimax-m2").unwrap();
361        assert_eq!(minimax.model_type(), "minimax");
362
363        // Also test alternate patterns
364        let mm = factory.create("mm-m2-chat").unwrap();
365        assert_eq!(mm.model_type(), "minimax");
366    }
367
368    #[test]
369    fn test_nano_v3_model() {
370        let factory = ParserFactory::new();
371
372        let nano = factory.create("nano-v3-chat").unwrap();
373        assert_eq!(nano.model_type(), "nano_v3");
374
375        let nemotron_nano = factory.create("nemotron-nano-4b").unwrap();
376        assert_eq!(nemotron_nano.model_type(), "nano_v3");
377
378        let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super").unwrap();
379        assert_eq!(nemotron_super.model_type(), "nano_v3");
380    }
381
382    #[test]
383    fn test_cohere_cmd_model() {
384        let factory = ParserFactory::new();
385
386        // Test various Cohere model patterns
387        let command_r = factory.create("command-r-plus").unwrap();
388        assert_eq!(command_r.model_type(), "cohere_cmd");
389
390        let command_a = factory.create("command-a-03-2025").unwrap();
391        assert_eq!(command_a.model_type(), "cohere_cmd");
392
393        let c4ai = factory.create("c4ai-command-r-v01").unwrap();
394        assert_eq!(c4ai.model_type(), "cohere_cmd");
395
396        let cohere = factory.create("cohere-embed").unwrap();
397        assert_eq!(cohere.model_type(), "cohere_cmd");
398    }
399
400    #[tokio::test]
401    async fn test_pooled_parser_reuse() {
402        let factory = ParserFactory::new();
403
404        // Get the same parser twice - should be the same instance
405        let parser1 = factory.get_pooled("deepseek-r1");
406        let parser2 = factory.get_pooled("deepseek-r1");
407
408        // Both should point to the same Arc
409        assert!(Arc::ptr_eq(&parser1, &parser2));
410
411        // Different models should get different parsers
412        let parser3 = factory.get_pooled("qwen3");
413        assert!(!Arc::ptr_eq(&parser1, &parser3));
414    }
415
416    #[tokio::test]
417    async fn test_pooled_parser_concurrent_access() {
418        let factory = ParserFactory::new();
419        let parser = factory.get_pooled("deepseek-r1");
420
421        // Spawn multiple async tasks that use the same parser
422        let mut handles = vec![];
423
424        for i in 0..3 {
425            let parser_clone = Arc::clone(&parser);
426            let handle = tokio::spawn(async move {
427                let mut parser = parser_clone.lock().await;
428                let input = format!("thread {} reasoning</think>answer", i);
429                let result = parser.detect_and_parse_reasoning(&input).unwrap();
430                assert_eq!(result.normal_text, "answer");
431                assert!(result.reasoning_text.contains("reasoning"));
432            });
433            handles.push(handle);
434        }
435
436        // Wait for all tasks to complete
437        for handle in handles {
438            handle.await.unwrap();
439        }
440    }
441
442    #[tokio::test]
443    async fn test_pool_clearing() {
444        let factory = ParserFactory::new();
445
446        // Get a pooled parser
447        let parser1 = factory.get_pooled("deepseek-r1");
448
449        // Clear the pool
450        factory.clear_pool();
451
452        // Get another parser - should be a new instance
453        let parser2 = factory.get_pooled("deepseek-r1");
454
455        // They should be different instances (different Arc pointers)
456        assert!(!Arc::ptr_eq(&parser1, &parser2));
457    }
458
459    #[tokio::test]
460    async fn test_passthrough_parser_pooling() {
461        let factory = ParserFactory::new();
462
463        // Unknown models should get passthrough parser
464        let parser1 = factory.get_pooled("unknown-model-1");
465        let parser2 = factory.get_pooled("unknown-model-2");
466
467        // Both should use the same passthrough parser instance
468        assert!(Arc::ptr_eq(&parser1, &parser2));
469
470        let parser = parser1.lock().await;
471        assert_eq!(parser.model_type(), "passthrough");
472    }
473
474    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
475    async fn test_high_concurrency_parser_access() {
476        use std::{
477            sync::atomic::{AtomicUsize, Ordering},
478            time::Instant,
479        };
480
481        let factory = ParserFactory::new();
482        let num_tasks = 100;
483        let requests_per_task = 50;
484        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
485
486        // Track successful operations
487        let success_count = Arc::new(AtomicUsize::new(0));
488        let error_count = Arc::new(AtomicUsize::new(0));
489
490        let start = Instant::now();
491        let mut handles = vec![];
492
493        for task_id in 0..num_tasks {
494            let factory = factory.clone();
495            let models = models.clone();
496            let success_count = Arc::clone(&success_count);
497            let error_count = Arc::clone(&error_count);
498
499            let handle = tokio::spawn(async move {
500                for request_id in 0..requests_per_task {
501                    // Rotate through different models
502                    let model = &models[(task_id + request_id) % models.len()];
503                    let parser = factory.get_pooled(model);
504
505                    // Use async lock - tokio::Mutex doesn't poison
506                    let mut p = parser.lock().await;
507
508                    // Simulate realistic parsing work with substantial text
509                    // Typical reasoning can be 500-5000 tokens
510                    let reasoning_text = format!(
511                        "Task {} is processing request {}. Let me think through this step by step. \
512                        First, I need to understand the problem. The problem involves analyzing data \
513                        and making calculations. Let me break this down: \n\
514                        1. Initial analysis shows that we have multiple variables to consider. \
515                        2. The data suggests a pattern that needs further investigation. \
516                        3. Computing the values: {} * {} = {}. \
517                        4. Cross-referencing with previous results indicates consistency. \
518                        5. The mathematical proof follows from the axioms... \
519                        6. Considering edge cases and boundary conditions... \
520                        7. Validating against known constraints... \
521                        8. The conclusion follows logically from premises A, B, and C. \
522                        This reasoning chain demonstrates the validity of our approach.",
523                        task_id, request_id, task_id, request_id, task_id * request_id
524                    );
525
526                    let answer_text = format!(
527                        "Based on my analysis, the answer for task {} request {} is: \
528                        The solution involves multiple steps as outlined in the reasoning. \
529                        The final result is {} with confidence level high. \
530                        This conclusion is supported by rigorous mathematical analysis \
531                        and has been validated against multiple test cases. \
532                        The implementation should handle edge cases appropriately.",
533                        task_id,
534                        request_id,
535                        task_id * request_id
536                    );
537
538                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
539
540                    match p.detect_and_parse_reasoning(&input) {
541                        Ok(result) => {
542                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
543                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
544
545                            // For parsers that accumulate reasoning (stream_reasoning=false)
546                            // the reasoning_text should be populated
547                            if !result.reasoning_text.is_empty() {
548                                assert!(result
549                                    .reasoning_text
550                                    .contains(&format!("Task {}", task_id)));
551                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
552                            }
553
554                            // Normal text should always be present
555                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
556                            success_count.fetch_add(1, Ordering::Relaxed);
557                        }
558                        Err(e) => {
559                            eprintln!("Parse error: {:?}", e);
560                            error_count.fetch_add(1, Ordering::Relaxed);
561                        }
562                    }
563
564                    // Explicitly drop the lock to release it quickly
565                    drop(p);
566                }
567            });
568            handles.push(handle);
569        }
570
571        // Wait for all tasks
572        for handle in handles {
573            handle.await.unwrap();
574        }
575
576        let duration = start.elapsed();
577        let total_requests = num_tasks * requests_per_task;
578        let successes = success_count.load(Ordering::Relaxed);
579        let errors = error_count.load(Ordering::Relaxed);
580
581        // Print stats for debugging
582        println!(
583            "High concurrency test: {} tasks, {} requests each",
584            num_tasks, requests_per_task
585        );
586        println!(
587            "Completed in {:?}, {} successes, {} errors",
588            duration, successes, errors
589        );
590        println!(
591            "Throughput: {:.0} requests/sec",
592            (total_requests as f64) / duration.as_secs_f64()
593        );
594
595        // All requests should succeed
596        assert_eq!(successes, total_requests);
597        assert_eq!(errors, 0);
598
599        // Performance check: should handle at least 1000 req/sec
600        let throughput = (total_requests as f64) / duration.as_secs_f64();
601        assert!(
602            throughput > 1000.0,
603            "Throughput too low: {:.0} req/sec",
604            throughput
605        );
606    }
607
608    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
609    async fn test_concurrent_pool_modifications() {
610        let factory = ParserFactory::new();
611        let mut handles = vec![];
612
613        // Task 1: Continuously get parsers
614        let factory1 = factory.clone();
615        handles.push(tokio::spawn(async move {
616            for _ in 0..100 {
617                let _parser = factory1.get_pooled("deepseek-r1");
618            }
619        }));
620
621        // Task 2: Continuously clear pool
622        let factory2 = factory.clone();
623        handles.push(tokio::spawn(async move {
624            for _ in 0..10 {
625                factory2.clear_pool();
626                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
627            }
628        }));
629
630        // Task 3: Get different parsers
631        let factory3 = factory.clone();
632        handles.push(tokio::spawn(async move {
633            for i in 0..100 {
634                let models = ["qwen3", "kimi", "unknown"];
635                let _parser = factory3.get_pooled(models[i % 3]);
636            }
637        }));
638
639        // Wait for all tasks - should not deadlock or panic
640        for handle in handles {
641            handle.await.unwrap();
642        }
643    }
644}