Skip to main content

reasoning_parser/
factory.rs

1// Factory and registry for creating model-specific reasoning parsers.
2// Now with parser pooling support for efficient reuse across requests.
3
4use std::{collections::HashMap, sync::Arc};
5
6use parking_lot::RwLock;
7use tokio::sync::Mutex;
8
9use crate::{
10    parsers::{
11        BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
12        MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
13    },
14    traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE},
15};
16
17/// Type alias for pooled parser instances.
18/// Uses tokio::Mutex to avoid blocking the async executor.
19pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
20
21/// Type alias for parser creator functions.
22type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
23
24/// Registry for model-specific parsers with pooling support.
25#[derive(Clone)]
26pub struct ParserRegistry {
27    /// Creator functions for parsers (used when pool is empty)
28    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29    /// Pooled parser instances for reuse
30    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31    /// Model pattern to parser name mappings
32    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
33}
34
35impl ParserRegistry {
36    /// Create a new empty registry.
37    pub fn new() -> Self {
38        Self {
39            creators: Arc::new(RwLock::new(HashMap::new())),
40            pool: Arc::new(RwLock::new(HashMap::new())),
41            patterns: Arc::new(RwLock::new(Vec::new())),
42        }
43    }
44
45    /// Register a parser creator for a given parser type.
46    pub fn register_parser<F>(&self, name: &str, creator: F)
47    where
48        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
49    {
50        let mut creators = self.creators.write();
51        creators.insert(name.to_string(), Arc::new(creator));
52    }
53
54    /// Register a model pattern to parser mapping.
55    /// Patterns are checked in order, first match wins.
56    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
57        let mut patterns = self.patterns.write();
58        patterns.push((pattern.to_string(), parser_name.to_string()));
59    }
60
61    /// Get a pooled parser by exact name.
62    /// Returns a shared parser instance from the pool, creating one if needed.
63    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
64        // First check if we have a pooled instance
65        {
66            let pool = self.pool.read();
67            if let Some(parser) = pool.get(name) {
68                return Some(Arc::clone(parser));
69            }
70        }
71
72        // If not in pool, create one and add to pool
73        let creators = self.creators.read();
74        if let Some(creator) = creators.get(name) {
75            let parser = Arc::new(Mutex::new(creator()));
76
77            // Add to pool for future use
78            let mut pool = self.pool.write();
79            pool.insert(name.to_string(), Arc::clone(&parser));
80
81            Some(parser)
82        } else {
83            None
84        }
85    }
86
87    /// Check if a parser with the given name is registered.
88    pub fn has_parser(&self, name: &str) -> bool {
89        let creators = self.creators.read();
90        creators.contains_key(name)
91    }
92
93    /// Create a fresh parser instance by exact name (not pooled).
94    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
95    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
96        let creators = self.creators.read();
97        creators.get(name).map(|creator| creator())
98    }
99
100    /// Find a pooled parser for a given model ID by pattern matching.
101    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
102        let patterns = self.patterns.read();
103        let model_lower = model_id.to_lowercase();
104
105        for (pattern, parser_name) in patterns.iter() {
106            if model_lower.contains(&pattern.to_lowercase()) {
107                return self.get_pooled_parser(parser_name);
108            }
109        }
110        None
111    }
112
113    /// Check if a parser can be created for a specific model without actually creating it.
114    /// Returns true if a parser is available (registered) for this model.
115    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
116        let patterns = self.patterns.read();
117        let model_lower = model_id.to_lowercase();
118
119        for (pattern, parser_name) in patterns.iter() {
120            if model_lower.contains(&pattern.to_lowercase()) {
121                let creators = self.creators.read();
122                return creators.contains_key(parser_name);
123            }
124        }
125        false
126    }
127
128    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
129    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
130    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
131        let patterns = self.patterns.read();
132        let model_lower = model_id.to_lowercase();
133
134        for (pattern, parser_name) in patterns.iter() {
135            if model_lower.contains(&pattern.to_lowercase()) {
136                return self.create_parser(parser_name);
137            }
138        }
139        None
140    }
141
142    /// Clear the parser pool, forcing new instances to be created.
143    /// Useful for testing or when parsers need to be reset globally.
144    pub fn clear_pool(&self) {
145        let mut pool = self.pool.write();
146        pool.clear();
147    }
148}
149
150impl Default for ParserRegistry {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Factory for creating reasoning parsers based on model type.
157#[derive(Clone)]
158pub struct ParserFactory {
159    registry: ParserRegistry,
160}
161
162impl ParserFactory {
163    /// Create a new factory with default parsers registered.
164    pub fn new() -> Self {
165        let registry = ParserRegistry::new();
166
167        // Register base parser
168        registry.register_parser("base", || {
169            Box::new(BaseReasoningParser::new(ParserConfig::default()))
170        });
171
172        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
173        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
174
175        // Register Qwen3 parser (starts with in_reasoning=false)
176        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
177
178        // Register Qwen3-thinking parser (starts with in_reasoning=true)
179        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
180
181        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
182        registry.register_parser("kimi", || Box::new(KimiParser::new()));
183
184        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
185        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
186
187        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
188        registry.register_parser("step3", || Box::new(Step3Parser::new()));
189
190        // Register MiniMax parser (appends <think> token at the beginning)
191        registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
192
193        // Register Cohere Command parser (uses <|START_THINKING|> / <|END_THINKING|>)
194        registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
195
196        // Register NanoV3 parser (same format as DeepSeek-R1)
197        registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
198
199        // Register model patterns
200        registry.register_pattern("deepseek-r1", "deepseek_r1");
201        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
202        registry.register_pattern("qwen-thinking", "qwen3_thinking");
203        registry.register_pattern("qwen3", "qwen3");
204        registry.register_pattern("qwen", "qwen3");
205        registry.register_pattern("glm45", "glm45");
206        registry.register_pattern("glm47", "glm45"); // glm47 uses same reasoning format as glm45
207        registry.register_pattern("kimi", "kimi");
208        registry.register_pattern("step3", "step3");
209        registry.register_pattern("minimax", "minimax");
210        registry.register_pattern("minimax-m2", "minimax");
211        registry.register_pattern("mm-m2", "minimax");
212
213        // Cohere Command models use <|START_THINKING|> / <|END_THINKING|>
214        registry.register_pattern("command-r", "cohere_cmd");
215        registry.register_pattern("command-a", "cohere_cmd");
216        registry.register_pattern("c4ai-command", "cohere_cmd");
217        registry.register_pattern("cohere", "cohere_cmd");
218
219        // Nano V3 / Nemotron uses same format as DeepSeek-R1 (initial_in_reasoning=true)
220        registry.register_pattern("nemotron-nano", "nano_v3");
221        registry.register_pattern("nemotron-super", "nano_v3");
222        registry.register_pattern("nano-v3", "nano_v3");
223
224        Self { registry }
225    }
226
227    /// Get a pooled parser for the given model ID.
228    /// Returns a shared instance that can be used concurrently.
229    /// Falls back to a passthrough parser if model is not recognized.
230    #[expect(
231        clippy::expect_used,
232        reason = "passthrough parser is registered on the line above; None indicates a bug in registration logic"
233    )]
234    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
235        // First try to find by pattern
236        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
237            return parser;
238        }
239
240        // Fall back to no-op parser (get or create passthrough in pool)
241        self.registry
242            .get_pooled_parser("passthrough")
243            .unwrap_or_else(|| {
244                // Register passthrough if not already registered
245                self.registry.register_parser("passthrough", || {
246                    let config = ParserConfig {
247                        think_start_token: String::new(),
248                        think_end_token: String::new(),
249                        stream_reasoning: true,
250                        max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
251                        initial_in_reasoning: false,
252                    };
253                    Box::new(
254                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
255                    )
256                });
257                self.registry
258                    .get_pooled_parser("passthrough")
259                    .expect("passthrough parser was just registered")
260            })
261    }
262
263    /// Create a new parser instance for the given model ID.
264    /// Returns a fresh instance (not pooled).
265    /// Use this when you need an isolated parser instance.
266    pub fn create(&self, model_id: &str) -> Box<dyn ReasoningParser> {
267        // First try to find by pattern
268        if let Some(parser) = self.registry.create_for_model(model_id) {
269            return parser;
270        }
271
272        // Fall back to no-op parser (base parser without reasoning detection)
273        let config = ParserConfig {
274            think_start_token: String::new(),
275            think_end_token: String::new(),
276            stream_reasoning: true,
277            max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
278            initial_in_reasoning: false,
279        };
280        Box::new(BaseReasoningParser::new(config).with_model_type("passthrough".to_string()))
281    }
282
283    /// Get the internal registry for custom registration.
284    pub fn registry(&self) -> &ParserRegistry {
285        &self.registry
286    }
287
288    /// Clear the parser pool.
289    /// Useful for testing or when parsers need to be reset globally.
290    pub fn clear_pool(&self) {
291        self.registry.clear_pool();
292    }
293}
294
295impl Default for ParserFactory {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301#[cfg(test)]
302#[expect(
303    clippy::disallowed_methods,
304    reason = "tokio::spawn is fine in unit tests that await all handles"
305)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_factory_creates_deepseek_r1() {
311        let factory = ParserFactory::new();
312        let parser = factory.create("deepseek-r1-distill");
313        assert_eq!(parser.model_type(), "deepseek_r1");
314    }
315
316    #[test]
317    fn test_factory_creates_qwen3() {
318        let factory = ParserFactory::new();
319        let parser = factory.create("qwen3-7b");
320        assert_eq!(parser.model_type(), "qwen3");
321    }
322
323    #[test]
324    fn test_factory_creates_kimi() {
325        let factory = ParserFactory::new();
326        let parser = factory.create("kimi-chat");
327        assert_eq!(parser.model_type(), "kimi");
328    }
329
330    #[test]
331    fn test_factory_fallback_to_passthrough() {
332        let factory = ParserFactory::new();
333        let parser = factory.create("unknown-model");
334        assert_eq!(parser.model_type(), "passthrough");
335    }
336
337    #[test]
338    fn test_case_insensitive_matching() {
339        let factory = ParserFactory::new();
340        let parser1 = factory.create("DeepSeek-R1");
341        let parser2 = factory.create("QWEN3");
342        let parser3 = factory.create("Kimi");
343
344        assert_eq!(parser1.model_type(), "deepseek_r1");
345        assert_eq!(parser2.model_type(), "qwen3");
346        assert_eq!(parser3.model_type(), "kimi");
347    }
348
349    #[test]
350    fn test_step3_model() {
351        let factory = ParserFactory::new();
352        let step3 = factory.create("step3-model");
353        assert_eq!(step3.model_type(), "step3");
354    }
355
356    #[test]
357    fn test_glm45_model() {
358        let factory = ParserFactory::new();
359        let glm45 = factory.create("glm45-v2");
360        assert_eq!(glm45.model_type(), "glm45");
361    }
362
363    #[test]
364    fn test_minimax_model() {
365        let factory = ParserFactory::new();
366        let minimax = factory.create("minimax-m2");
367        assert_eq!(minimax.model_type(), "minimax");
368
369        // Also test alternate patterns
370        let mm = factory.create("mm-m2-chat");
371        assert_eq!(mm.model_type(), "minimax");
372    }
373
374    #[test]
375    fn test_nano_v3_model() {
376        let factory = ParserFactory::new();
377
378        let nano = factory.create("nano-v3-chat");
379        assert_eq!(nano.model_type(), "nano_v3");
380
381        let nemotron_nano = factory.create("nemotron-nano-4b");
382        assert_eq!(nemotron_nano.model_type(), "nano_v3");
383
384        let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super");
385        assert_eq!(nemotron_super.model_type(), "nano_v3");
386    }
387
388    #[test]
389    fn test_cohere_cmd_model() {
390        let factory = ParserFactory::new();
391
392        // Test various Cohere model patterns
393        let command_r = factory.create("command-r-plus");
394        assert_eq!(command_r.model_type(), "cohere_cmd");
395
396        let command_a = factory.create("command-a-03-2025");
397        assert_eq!(command_a.model_type(), "cohere_cmd");
398
399        let c4ai = factory.create("c4ai-command-r-v01");
400        assert_eq!(c4ai.model_type(), "cohere_cmd");
401
402        let cohere = factory.create("cohere-embed");
403        assert_eq!(cohere.model_type(), "cohere_cmd");
404    }
405
406    #[tokio::test]
407    async fn test_pooled_parser_reuse() {
408        let factory = ParserFactory::new();
409
410        // Get the same parser twice - should be the same instance
411        let parser1 = factory.get_pooled("deepseek-r1");
412        let parser2 = factory.get_pooled("deepseek-r1");
413
414        // Both should point to the same Arc
415        assert!(Arc::ptr_eq(&parser1, &parser2));
416
417        // Different models should get different parsers
418        let parser3 = factory.get_pooled("qwen3");
419        assert!(!Arc::ptr_eq(&parser1, &parser3));
420    }
421
422    #[tokio::test]
423    async fn test_pooled_parser_concurrent_access() {
424        let factory = ParserFactory::new();
425        let parser = factory.get_pooled("deepseek-r1");
426
427        // Spawn multiple async tasks that use the same parser
428        let mut handles = vec![];
429
430        for i in 0..3 {
431            let parser_clone = Arc::clone(&parser);
432            let handle = tokio::spawn(async move {
433                let mut parser = parser_clone.lock().await;
434                let input = format!("thread {i} reasoning</think>answer");
435                let result = parser.detect_and_parse_reasoning(&input).unwrap();
436                assert_eq!(result.normal_text, "answer");
437                assert!(result.reasoning_text.contains("reasoning"));
438            });
439            handles.push(handle);
440        }
441
442        // Wait for all tasks to complete
443        for handle in handles {
444            handle.await.unwrap();
445        }
446    }
447
448    #[tokio::test]
449    async fn test_pool_clearing() {
450        let factory = ParserFactory::new();
451
452        // Get a pooled parser
453        let parser1 = factory.get_pooled("deepseek-r1");
454
455        // Clear the pool
456        factory.clear_pool();
457
458        // Get another parser - should be a new instance
459        let parser2 = factory.get_pooled("deepseek-r1");
460
461        // They should be different instances (different Arc pointers)
462        assert!(!Arc::ptr_eq(&parser1, &parser2));
463    }
464
465    #[tokio::test]
466    async fn test_passthrough_parser_pooling() {
467        let factory = ParserFactory::new();
468
469        // Unknown models should get passthrough parser
470        let parser1 = factory.get_pooled("unknown-model-1");
471        let parser2 = factory.get_pooled("unknown-model-2");
472
473        // Both should use the same passthrough parser instance
474        assert!(Arc::ptr_eq(&parser1, &parser2));
475
476        let parser = parser1.lock().await;
477        assert_eq!(parser.model_type(), "passthrough");
478    }
479
480    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
481    async fn test_high_concurrency_parser_access() {
482        use std::{
483            sync::atomic::{AtomicUsize, Ordering},
484            time::Instant,
485        };
486
487        let factory = ParserFactory::new();
488        let num_tasks = 100;
489        let requests_per_task = 50;
490        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
491
492        // Track successful operations
493        let success_count = Arc::new(AtomicUsize::new(0));
494        let error_count = Arc::new(AtomicUsize::new(0));
495
496        let start = Instant::now();
497        let mut handles = vec![];
498
499        for task_id in 0..num_tasks {
500            let factory = factory.clone();
501            let models = models.clone();
502            let success_count = Arc::clone(&success_count);
503            let error_count = Arc::clone(&error_count);
504
505            let handle = tokio::spawn(async move {
506                for request_id in 0..requests_per_task {
507                    // Rotate through different models
508                    let model = &models[(task_id + request_id) % models.len()];
509                    let parser = factory.get_pooled(model);
510
511                    // Use async lock - tokio::Mutex doesn't poison
512                    let mut p = parser.lock().await;
513
514                    // Simulate realistic parsing work with substantial text
515                    // Typical reasoning can be 500-5000 tokens
516                    let product = task_id * request_id;
517                    let reasoning_text = format!(
518                        "Task {task_id} is processing request {request_id}. Let me think through this step by step. \
519                        First, I need to understand the problem. The problem involves analyzing data \
520                        and making calculations. Let me break this down: \n\
521                        1. Initial analysis shows that we have multiple variables to consider. \
522                        2. The data suggests a pattern that needs further investigation. \
523                        3. Computing the values: {task_id} * {request_id} = {product}. \
524                        4. Cross-referencing with previous results indicates consistency. \
525                        5. The mathematical proof follows from the axioms... \
526                        6. Considering edge cases and boundary conditions... \
527                        7. Validating against known constraints... \
528                        8. The conclusion follows logically from premises A, B, and C. \
529                        This reasoning chain demonstrates the validity of our approach.",
530                    );
531
532                    let answer_text = format!(
533                        "Based on my analysis, the answer for task {task_id} request {request_id} is: \
534                        The solution involves multiple steps as outlined in the reasoning. \
535                        The final result is {product} with confidence level high. \
536                        This conclusion is supported by rigorous mathematical analysis \
537                        and has been validated against multiple test cases. \
538                        The implementation should handle edge cases appropriately.",
539                    );
540
541                    let input = format!("<think>{reasoning_text}</think>{answer_text}");
542
543                    match p.detect_and_parse_reasoning(&input) {
544                        Ok(result) => {
545                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
546                            assert!(result.normal_text.contains(&format!("task {task_id}")));
547
548                            // For parsers that accumulate reasoning (stream_reasoning=false)
549                            // the reasoning_text should be populated
550                            if !result.reasoning_text.is_empty() {
551                                assert!(result.reasoning_text.contains(&format!("Task {task_id}")));
552                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
553                            }
554
555                            // Normal text should always be present
556                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
557                            success_count.fetch_add(1, Ordering::Relaxed);
558                        }
559                        Err(e) => {
560                            #[expect(clippy::print_stderr, reason = "test diagnostic output")]
561                            {
562                                eprintln!("Parse error: {e:?}");
563                            }
564                            error_count.fetch_add(1, Ordering::Relaxed);
565                        }
566                    }
567
568                    // Explicitly drop the lock to release it quickly
569                    drop(p);
570                }
571            });
572            handles.push(handle);
573        }
574
575        // Wait for all tasks
576        for handle in handles {
577            handle.await.unwrap();
578        }
579
580        let duration = start.elapsed();
581        let total_requests = num_tasks * requests_per_task;
582        let successes = success_count.load(Ordering::Relaxed);
583        let errors = error_count.load(Ordering::Relaxed);
584
585        // Print stats for debugging
586        #[expect(clippy::print_stdout, reason = "test diagnostic output")]
587        {
588            println!("High concurrency test: {num_tasks} tasks, {requests_per_task} requests each");
589            println!("Completed in {duration:?}, {successes} successes, {errors} errors");
590            println!(
591                "Throughput: {:.0} requests/sec",
592                (total_requests as f64) / duration.as_secs_f64()
593            );
594        }
595
596        // All requests should succeed
597        assert_eq!(successes, total_requests);
598        assert_eq!(errors, 0);
599
600        // Performance check: should handle at least 1000 req/sec
601        let throughput = (total_requests as f64) / duration.as_secs_f64();
602        assert!(
603            throughput > 1000.0,
604            "Throughput too low: {throughput:.0} req/sec",
605        );
606    }
607
608    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
609    async fn test_concurrent_pool_modifications() {
610        let factory = ParserFactory::new();
611        let mut handles = vec![];
612
613        // Task 1: Continuously get parsers
614        let factory1 = factory.clone();
615        handles.push(tokio::spawn(async move {
616            for _ in 0..100 {
617                let _parser = factory1.get_pooled("deepseek-r1");
618            }
619        }));
620
621        // Task 2: Continuously clear pool
622        let factory2 = factory.clone();
623        handles.push(tokio::spawn(async move {
624            for _ in 0..10 {
625                factory2.clear_pool();
626                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
627            }
628        }));
629
630        // Task 3: Get different parsers
631        let factory3 = factory.clone();
632        handles.push(tokio::spawn(async move {
633            for i in 0..100 {
634                let models = ["qwen3", "kimi", "unknown"];
635                let _parser = factory3.get_pooled(models[i % 3]);
636            }
637        }));
638
639        // Wait for all tasks - should not deadlock or panic
640        for handle in handles {
641            handle.await.unwrap();
642        }
643    }
644}