Skip to main content

reasoning_parser/
factory.rs

1// Factory and registry for creating model-specific reasoning parsers.
2// Now with parser pooling support for efficient reuse across requests.
3
4use std::{collections::HashMap, sync::Arc};
5
6use parking_lot::RwLock;
7use tokio::sync::Mutex;
8
9use crate::{
10    parsers::{
11        BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
12        MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
13    },
14    traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE},
15};
16
17/// Type alias for pooled parser instances.
18/// Uses tokio::Mutex to avoid blocking the async executor.
19pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
20
21/// Type alias for parser creator functions.
22type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
23
24/// Registry for model-specific parsers with pooling support.
25#[derive(Clone)]
26pub struct ParserRegistry {
27    /// Creator functions for parsers (used when pool is empty)
28    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29    /// Pooled parser instances for reuse
30    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31    /// Model pattern to parser name mappings
32    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
33}
34
35impl ParserRegistry {
36    /// Create a new empty registry.
37    pub fn new() -> Self {
38        Self {
39            creators: Arc::new(RwLock::new(HashMap::new())),
40            pool: Arc::new(RwLock::new(HashMap::new())),
41            patterns: Arc::new(RwLock::new(Vec::new())),
42        }
43    }
44
45    /// Register a parser creator for a given parser type.
46    pub fn register_parser<F>(&self, name: &str, creator: F)
47    where
48        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
49    {
50        let mut creators = self.creators.write();
51        creators.insert(name.to_string(), Arc::new(creator));
52    }
53
54    /// Register a model pattern to parser mapping.
55    /// Patterns are checked in order, first match wins.
56    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
57        let mut patterns = self.patterns.write();
58        patterns.push((pattern.to_string(), parser_name.to_string()));
59    }
60
61    /// Get a pooled parser by exact name.
62    /// Returns a shared parser instance from the pool, creating one if needed.
63    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
64        // First check if we have a pooled instance
65        {
66            let pool = self.pool.read();
67            if let Some(parser) = pool.get(name) {
68                return Some(Arc::clone(parser));
69            }
70        }
71
72        // If not in pool, create one and add to pool
73        let creators = self.creators.read();
74        if let Some(creator) = creators.get(name) {
75            let parser = Arc::new(Mutex::new(creator()));
76
77            // Add to pool for future use
78            let mut pool = self.pool.write();
79            pool.insert(name.to_string(), Arc::clone(&parser));
80
81            Some(parser)
82        } else {
83            None
84        }
85    }
86
87    /// Check if a parser with the given name is registered.
88    pub fn has_parser(&self, name: &str) -> bool {
89        let creators = self.creators.read();
90        creators.contains_key(name)
91    }
92
93    /// Create a fresh parser instance by exact name (not pooled).
94    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
95    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
96        let creators = self.creators.read();
97        creators.get(name).map(|creator| creator())
98    }
99
100    /// Find a pooled parser for a given model ID by pattern matching.
101    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
102        let patterns = self.patterns.read();
103        let model_lower = model_id.to_lowercase();
104
105        for (pattern, parser_name) in patterns.iter() {
106            if model_lower.contains(&pattern.to_lowercase()) {
107                return self.get_pooled_parser(parser_name);
108            }
109        }
110        None
111    }
112
113    /// Check if a parser can be created for a specific model without actually creating it.
114    /// Returns true if a parser is available (registered) for this model.
115    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
116        let patterns = self.patterns.read();
117        let model_lower = model_id.to_lowercase();
118
119        for (pattern, parser_name) in patterns.iter() {
120            if model_lower.contains(&pattern.to_lowercase()) {
121                let creators = self.creators.read();
122                return creators.contains_key(parser_name);
123            }
124        }
125        false
126    }
127
128    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
129    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
130    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
131        let patterns = self.patterns.read();
132        let model_lower = model_id.to_lowercase();
133
134        for (pattern, parser_name) in patterns.iter() {
135            if model_lower.contains(&pattern.to_lowercase()) {
136                return self.create_parser(parser_name);
137            }
138        }
139        None
140    }
141
142    /// List all registered parser names in sorted order.
143    pub fn list_parsers(&self) -> Vec<String> {
144        let mut parsers: Vec<_> = self.creators.read().keys().cloned().collect();
145        parsers.sort_unstable();
146        parsers
147    }
148
149    /// Clear the parser pool, forcing new instances to be created.
150    /// Useful for testing or when parsers need to be reset globally.
151    pub fn clear_pool(&self) {
152        let mut pool = self.pool.write();
153        pool.clear();
154    }
155}
156
157impl Default for ParserRegistry {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163/// Factory for creating reasoning parsers based on model type.
164#[derive(Clone)]
165pub struct ParserFactory {
166    registry: ParserRegistry,
167}
168
169impl ParserFactory {
170    /// Create a new factory with default parsers registered.
171    pub fn new() -> Self {
172        let registry = ParserRegistry::new();
173
174        // Register base parser
175        registry.register_parser("base", || {
176            Box::new(BaseReasoningParser::new(ParserConfig::default()))
177        });
178
179        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
180        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
181
182        // Register Qwen3 parser (starts with in_reasoning=false)
183        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
184
185        // Register Qwen3-thinking parser (starts with in_reasoning=true)
186        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
187
188        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
189        registry.register_parser("kimi", || Box::new(KimiParser::new()));
190
191        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
192        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
193
194        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
195        registry.register_parser("step3", || Box::new(Step3Parser::new()));
196
197        // Register MiniMax parser (appends <think> token at the beginning)
198        registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
199
200        // Register Cohere Command parser (uses <|START_THINKING|> / <|END_THINKING|>)
201        registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
202
203        // Register NanoV3 parser (same format as DeepSeek-R1)
204        registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
205
206        // Register model patterns
207        registry.register_pattern("deepseek-r1", "deepseek_r1");
208        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
209        registry.register_pattern("qwen-thinking", "qwen3_thinking");
210        registry.register_pattern("qwen3", "qwen3");
211        registry.register_pattern("qwen", "qwen3");
212        registry.register_pattern("glm45", "glm45");
213        registry.register_pattern("glm47", "glm45"); // glm47 uses same reasoning format as glm45
214        registry.register_pattern("kimi", "kimi");
215        registry.register_pattern("step3", "step3");
216        registry.register_pattern("minimax", "minimax");
217        registry.register_pattern("minimax-m2", "minimax");
218        registry.register_pattern("mm-m2", "minimax");
219
220        // Cohere Command models use <|START_THINKING|> / <|END_THINKING|>
221        registry.register_pattern("command-r", "cohere_cmd");
222        registry.register_pattern("command-a", "cohere_cmd");
223        registry.register_pattern("c4ai-command", "cohere_cmd");
224        registry.register_pattern("cohere", "cohere_cmd");
225
226        // Nano V3 / Nemotron uses same format as DeepSeek-R1 (initial_in_reasoning=true)
227        registry.register_pattern("nemotron-nano", "nano_v3");
228        registry.register_pattern("nemotron-super", "nano_v3");
229        registry.register_pattern("nano-v3", "nano_v3");
230
231        Self { registry }
232    }
233
234    /// Get a pooled parser for the given model ID.
235    /// Returns a shared instance that can be used concurrently.
236    /// Falls back to a passthrough parser if model is not recognized.
237    #[expect(
238        clippy::expect_used,
239        reason = "passthrough parser is registered on the line above; None indicates a bug in registration logic"
240    )]
241    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
242        // First try to find by pattern
243        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
244            return parser;
245        }
246
247        // Fall back to no-op parser (get or create passthrough in pool)
248        self.registry
249            .get_pooled_parser("passthrough")
250            .unwrap_or_else(|| {
251                // Register passthrough if not already registered
252                self.registry.register_parser("passthrough", || {
253                    let config = ParserConfig {
254                        think_start_token: String::new(),
255                        think_end_token: String::new(),
256                        stream_reasoning: true,
257                        max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
258                        initial_in_reasoning: false,
259                    };
260                    Box::new(
261                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
262                    )
263                });
264                self.registry
265                    .get_pooled_parser("passthrough")
266                    .expect("passthrough parser was just registered")
267            })
268    }
269
270    /// Create a new parser instance for the given model ID.
271    /// Returns a fresh instance (not pooled).
272    /// Use this when you need an isolated parser instance.
273    pub fn create(&self, model_id: &str) -> Box<dyn ReasoningParser> {
274        // First try to find by pattern
275        if let Some(parser) = self.registry.create_for_model(model_id) {
276            return parser;
277        }
278
279        // Fall back to no-op parser (base parser without reasoning detection)
280        let config = ParserConfig {
281            think_start_token: String::new(),
282            think_end_token: String::new(),
283            stream_reasoning: true,
284            max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
285            initial_in_reasoning: false,
286        };
287        Box::new(BaseReasoningParser::new(config).with_model_type("passthrough".to_string()))
288    }
289
290    /// Get the internal registry for custom registration.
291    pub fn registry(&self) -> &ParserRegistry {
292        &self.registry
293    }
294
295    /// List all registered parser names.
296    pub fn list_parsers(&self) -> Vec<String> {
297        self.registry.list_parsers()
298    }
299
300    /// Clear the parser pool.
301    /// Useful for testing or when parsers need to be reset globally.
302    pub fn clear_pool(&self) {
303        self.registry.clear_pool();
304    }
305}
306
307impl Default for ParserFactory {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314#[expect(
315    clippy::disallowed_methods,
316    reason = "tokio::spawn is fine in unit tests that await all handles"
317)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_factory_creates_deepseek_r1() {
323        let factory = ParserFactory::new();
324        let parser = factory.create("deepseek-r1-distill");
325        assert_eq!(parser.model_type(), "deepseek_r1");
326    }
327
328    #[test]
329    fn test_factory_creates_qwen3() {
330        let factory = ParserFactory::new();
331        let parser = factory.create("qwen3-7b");
332        assert_eq!(parser.model_type(), "qwen3");
333    }
334
335    #[test]
336    fn test_factory_creates_kimi() {
337        let factory = ParserFactory::new();
338        let parser = factory.create("kimi-chat");
339        assert_eq!(parser.model_type(), "kimi");
340    }
341
342    #[test]
343    fn test_factory_fallback_to_passthrough() {
344        let factory = ParserFactory::new();
345        let parser = factory.create("unknown-model");
346        assert_eq!(parser.model_type(), "passthrough");
347    }
348
349    #[test]
350    fn test_case_insensitive_matching() {
351        let factory = ParserFactory::new();
352        let parser1 = factory.create("DeepSeek-R1");
353        let parser2 = factory.create("QWEN3");
354        let parser3 = factory.create("Kimi");
355
356        assert_eq!(parser1.model_type(), "deepseek_r1");
357        assert_eq!(parser2.model_type(), "qwen3");
358        assert_eq!(parser3.model_type(), "kimi");
359    }
360
361    #[test]
362    fn test_step3_model() {
363        let factory = ParserFactory::new();
364        let step3 = factory.create("step3-model");
365        assert_eq!(step3.model_type(), "step3");
366    }
367
368    #[test]
369    fn test_glm45_model() {
370        let factory = ParserFactory::new();
371        let glm45 = factory.create("glm45-v2");
372        assert_eq!(glm45.model_type(), "glm45");
373    }
374
375    #[test]
376    fn test_minimax_model() {
377        let factory = ParserFactory::new();
378        let minimax = factory.create("minimax-m2");
379        assert_eq!(minimax.model_type(), "minimax");
380
381        // Also test alternate patterns
382        let mm = factory.create("mm-m2-chat");
383        assert_eq!(mm.model_type(), "minimax");
384    }
385
386    #[test]
387    fn test_nano_v3_model() {
388        let factory = ParserFactory::new();
389
390        let nano = factory.create("nano-v3-chat");
391        assert_eq!(nano.model_type(), "nano_v3");
392
393        let nemotron_nano = factory.create("nemotron-nano-4b");
394        assert_eq!(nemotron_nano.model_type(), "nano_v3");
395
396        let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super");
397        assert_eq!(nemotron_super.model_type(), "nano_v3");
398    }
399
400    #[test]
401    fn test_cohere_cmd_model() {
402        let factory = ParserFactory::new();
403
404        // Test various Cohere model patterns
405        let command_r = factory.create("command-r-plus");
406        assert_eq!(command_r.model_type(), "cohere_cmd");
407
408        let command_a = factory.create("command-a-03-2025");
409        assert_eq!(command_a.model_type(), "cohere_cmd");
410
411        let c4ai = factory.create("c4ai-command-r-v01");
412        assert_eq!(c4ai.model_type(), "cohere_cmd");
413
414        let cohere = factory.create("cohere-embed");
415        assert_eq!(cohere.model_type(), "cohere_cmd");
416    }
417
418    #[tokio::test]
419    async fn test_pooled_parser_reuse() {
420        let factory = ParserFactory::new();
421
422        // Get the same parser twice - should be the same instance
423        let parser1 = factory.get_pooled("deepseek-r1");
424        let parser2 = factory.get_pooled("deepseek-r1");
425
426        // Both should point to the same Arc
427        assert!(Arc::ptr_eq(&parser1, &parser2));
428
429        // Different models should get different parsers
430        let parser3 = factory.get_pooled("qwen3");
431        assert!(!Arc::ptr_eq(&parser1, &parser3));
432    }
433
434    #[tokio::test]
435    async fn test_pooled_parser_concurrent_access() {
436        let factory = ParserFactory::new();
437        let parser = factory.get_pooled("deepseek-r1");
438
439        // Spawn multiple async tasks that use the same parser
440        let mut handles = vec![];
441
442        for i in 0..3 {
443            let parser_clone = Arc::clone(&parser);
444            let handle = tokio::spawn(async move {
445                let mut parser = parser_clone.lock().await;
446                let input = format!("thread {i} reasoning</think>answer");
447                let result = parser.detect_and_parse_reasoning(&input).unwrap();
448                assert_eq!(result.normal_text, "answer");
449                assert!(result.reasoning_text.contains("reasoning"));
450            });
451            handles.push(handle);
452        }
453
454        // Wait for all tasks to complete
455        for handle in handles {
456            handle.await.unwrap();
457        }
458    }
459
460    #[tokio::test]
461    async fn test_pool_clearing() {
462        let factory = ParserFactory::new();
463
464        // Get a pooled parser
465        let parser1 = factory.get_pooled("deepseek-r1");
466
467        // Clear the pool
468        factory.clear_pool();
469
470        // Get another parser - should be a new instance
471        let parser2 = factory.get_pooled("deepseek-r1");
472
473        // They should be different instances (different Arc pointers)
474        assert!(!Arc::ptr_eq(&parser1, &parser2));
475    }
476
477    #[tokio::test]
478    async fn test_passthrough_parser_pooling() {
479        let factory = ParserFactory::new();
480
481        // Unknown models should get passthrough parser
482        let parser1 = factory.get_pooled("unknown-model-1");
483        let parser2 = factory.get_pooled("unknown-model-2");
484
485        // Both should use the same passthrough parser instance
486        assert!(Arc::ptr_eq(&parser1, &parser2));
487
488        let parser = parser1.lock().await;
489        assert_eq!(parser.model_type(), "passthrough");
490    }
491
492    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
493    async fn test_high_concurrency_parser_access() {
494        use std::{
495            sync::atomic::{AtomicUsize, Ordering},
496            time::Instant,
497        };
498
499        let factory = ParserFactory::new();
500        let num_tasks = 100;
501        let requests_per_task = 50;
502        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
503
504        // Track successful operations
505        let success_count = Arc::new(AtomicUsize::new(0));
506        let error_count = Arc::new(AtomicUsize::new(0));
507
508        let start = Instant::now();
509        let mut handles = vec![];
510
511        for task_id in 0..num_tasks {
512            let factory = factory.clone();
513            let models = models.clone();
514            let success_count = Arc::clone(&success_count);
515            let error_count = Arc::clone(&error_count);
516
517            let handle = tokio::spawn(async move {
518                for request_id in 0..requests_per_task {
519                    // Rotate through different models
520                    let model = &models[(task_id + request_id) % models.len()];
521                    let parser = factory.get_pooled(model);
522
523                    // Use async lock - tokio::Mutex doesn't poison
524                    let mut p = parser.lock().await;
525
526                    // Simulate realistic parsing work with substantial text
527                    // Typical reasoning can be 500-5000 tokens
528                    let product = task_id * request_id;
529                    let reasoning_text = format!(
530                        "Task {task_id} is processing request {request_id}. Let me think through this step by step. \
531                        First, I need to understand the problem. The problem involves analyzing data \
532                        and making calculations. Let me break this down: \n\
533                        1. Initial analysis shows that we have multiple variables to consider. \
534                        2. The data suggests a pattern that needs further investigation. \
535                        3. Computing the values: {task_id} * {request_id} = {product}. \
536                        4. Cross-referencing with previous results indicates consistency. \
537                        5. The mathematical proof follows from the axioms... \
538                        6. Considering edge cases and boundary conditions... \
539                        7. Validating against known constraints... \
540                        8. The conclusion follows logically from premises A, B, and C. \
541                        This reasoning chain demonstrates the validity of our approach.",
542                    );
543
544                    let answer_text = format!(
545                        "Based on my analysis, the answer for task {task_id} request {request_id} is: \
546                        The solution involves multiple steps as outlined in the reasoning. \
547                        The final result is {product} with confidence level high. \
548                        This conclusion is supported by rigorous mathematical analysis \
549                        and has been validated against multiple test cases. \
550                        The implementation should handle edge cases appropriately.",
551                    );
552
553                    let input = format!("<think>{reasoning_text}</think>{answer_text}");
554
555                    match p.detect_and_parse_reasoning(&input) {
556                        Ok(result) => {
557                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
558                            assert!(result.normal_text.contains(&format!("task {task_id}")));
559
560                            // For parsers that accumulate reasoning (stream_reasoning=false)
561                            // the reasoning_text should be populated
562                            if !result.reasoning_text.is_empty() {
563                                assert!(result.reasoning_text.contains(&format!("Task {task_id}")));
564                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
565                            }
566
567                            // Normal text should always be present
568                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
569                            success_count.fetch_add(1, Ordering::Relaxed);
570                        }
571                        Err(e) => {
572                            #[expect(clippy::print_stderr, reason = "test diagnostic output")]
573                            {
574                                eprintln!("Parse error: {e:?}");
575                            }
576                            error_count.fetch_add(1, Ordering::Relaxed);
577                        }
578                    }
579
580                    // Explicitly drop the lock to release it quickly
581                    drop(p);
582                }
583            });
584            handles.push(handle);
585        }
586
587        // Wait for all tasks
588        for handle in handles {
589            handle.await.unwrap();
590        }
591
592        let duration = start.elapsed();
593        let total_requests = num_tasks * requests_per_task;
594        let successes = success_count.load(Ordering::Relaxed);
595        let errors = error_count.load(Ordering::Relaxed);
596
597        // Print stats for debugging
598        #[expect(clippy::print_stdout, reason = "test diagnostic output")]
599        {
600            println!("High concurrency test: {num_tasks} tasks, {requests_per_task} requests each");
601            println!("Completed in {duration:?}, {successes} successes, {errors} errors");
602            println!(
603                "Throughput: {:.0} requests/sec",
604                (total_requests as f64) / duration.as_secs_f64()
605            );
606        }
607
608        // All requests should succeed
609        assert_eq!(successes, total_requests);
610        assert_eq!(errors, 0);
611
612        // Performance check: should handle at least 1000 req/sec
613        let throughput = (total_requests as f64) / duration.as_secs_f64();
614        assert!(
615            throughput > 1000.0,
616            "Throughput too low: {throughput:.0} req/sec",
617        );
618    }
619
620    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
621    async fn test_concurrent_pool_modifications() {
622        let factory = ParserFactory::new();
623        let mut handles = vec![];
624
625        // Task 1: Continuously get parsers
626        let factory1 = factory.clone();
627        handles.push(tokio::spawn(async move {
628            for _ in 0..100 {
629                let _parser = factory1.get_pooled("deepseek-r1");
630            }
631        }));
632
633        // Task 2: Continuously clear pool
634        let factory2 = factory.clone();
635        handles.push(tokio::spawn(async move {
636            for _ in 0..10 {
637                factory2.clear_pool();
638                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
639            }
640        }));
641
642        // Task 3: Get different parsers
643        let factory3 = factory.clone();
644        handles.push(tokio::spawn(async move {
645            for i in 0..100 {
646                let models = ["qwen3", "kimi", "unknown"];
647                let _parser = factory3.get_pooled(models[i % 3]);
648            }
649        }));
650
651        // Wait for all tasks - should not deadlock or panic
652        for handle in handles {
653            handle.await.unwrap();
654        }
655    }
656}