Skip to main content

reasoning_parser/
factory.rs

1// Factory and registry for creating model-specific reasoning parsers.
2// Now with parser pooling support for efficient reuse across requests.
3
4use std::{
5    collections::HashMap,
6    sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12    parsers::{
13        BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
14        MiniMaxParser, Qwen3Parser, QwenThinkingParser, Step3Parser,
15    },
16    traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19/// Type alias for pooled parser instances.
20/// Uses tokio::Mutex to avoid blocking the async executor.
21pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23/// Type alias for parser creator functions.
24type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26/// Registry for model-specific parsers with pooling support.
27#[derive(Clone)]
28pub struct ParserRegistry {
29    /// Creator functions for parsers (used when pool is empty)
30    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31    /// Pooled parser instances for reuse
32    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33    /// Model pattern to parser name mappings
34    patterns: Arc<RwLock<Vec<(String, String)>>>, // (pattern, parser_name)
35}
36
37impl ParserRegistry {
38    /// Create a new empty registry.
39    pub fn new() -> Self {
40        Self {
41            creators: Arc::new(RwLock::new(HashMap::new())),
42            pool: Arc::new(RwLock::new(HashMap::new())),
43            patterns: Arc::new(RwLock::new(Vec::new())),
44        }
45    }
46
47    /// Register a parser creator for a given parser type.
48    pub fn register_parser<F>(&self, name: &str, creator: F)
49    where
50        F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51    {
52        let mut creators = self.creators.write().unwrap();
53        creators.insert(name.to_string(), Arc::new(creator));
54    }
55
56    /// Register a model pattern to parser mapping.
57    /// Patterns are checked in order, first match wins.
58    pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59        let mut patterns = self.patterns.write().unwrap();
60        patterns.push((pattern.to_string(), parser_name.to_string()));
61    }
62
63    /// Get a pooled parser by exact name.
64    /// Returns a shared parser instance from the pool, creating one if needed.
65    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66        // First check if we have a pooled instance
67        {
68            let pool = self.pool.read().unwrap();
69            if let Some(parser) = pool.get(name) {
70                return Some(Arc::clone(parser));
71            }
72        }
73
74        // If not in pool, create one and add to pool
75        let creators = self.creators.read().unwrap();
76        if let Some(creator) = creators.get(name) {
77            let parser = Arc::new(Mutex::new(creator()));
78
79            // Add to pool for future use
80            let mut pool = self.pool.write().unwrap();
81            pool.insert(name.to_string(), Arc::clone(&parser));
82
83            Some(parser)
84        } else {
85            None
86        }
87    }
88
89    /// Check if a parser with the given name is registered.
90    pub fn has_parser(&self, name: &str) -> bool {
91        let creators = self.creators.read().unwrap();
92        creators.contains_key(name)
93    }
94
95    /// Create a fresh parser instance by exact name (not pooled).
96    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
97    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98        let creators = self.creators.read().unwrap();
99        creators.get(name).map(|creator| creator())
100    }
101
102    /// Find a pooled parser for a given model ID by pattern matching.
103    pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104        let patterns = self.patterns.read().unwrap();
105        let model_lower = model_id.to_lowercase();
106
107        for (pattern, parser_name) in patterns.iter() {
108            if model_lower.contains(&pattern.to_lowercase()) {
109                return self.get_pooled_parser(parser_name);
110            }
111        }
112        None
113    }
114
115    /// Check if a parser can be created for a specific model without actually creating it.
116    /// Returns true if a parser is available (registered) for this model.
117    pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118        let patterns = self.patterns.read().unwrap();
119        let model_lower = model_id.to_lowercase();
120
121        for (pattern, parser_name) in patterns.iter() {
122            if model_lower.contains(&pattern.to_lowercase()) {
123                let creators = self.creators.read().unwrap();
124                return creators.contains_key(parser_name);
125            }
126        }
127        false
128    }
129
130    /// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
131    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
132    pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133        let patterns = self.patterns.read().unwrap();
134        let model_lower = model_id.to_lowercase();
135
136        for (pattern, parser_name) in patterns.iter() {
137            if model_lower.contains(&pattern.to_lowercase()) {
138                return self.create_parser(parser_name);
139            }
140        }
141        None
142    }
143
144    /// Clear the parser pool, forcing new instances to be created.
145    /// Useful for testing or when parsers need to be reset globally.
146    pub fn clear_pool(&self) {
147        let mut pool = self.pool.write().unwrap();
148        pool.clear();
149    }
150}
151
152impl Default for ParserRegistry {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Factory for creating reasoning parsers based on model type.
159#[derive(Clone)]
160pub struct ParserFactory {
161    registry: ParserRegistry,
162}
163
164impl ParserFactory {
165    /// Create a new factory with default parsers registered.
166    pub fn new() -> Self {
167        let registry = ParserRegistry::new();
168
169        // Register base parser
170        registry.register_parser("base", || {
171            Box::new(BaseReasoningParser::new(ParserConfig::default()))
172        });
173
174        // Register DeepSeek-R1 parser (starts with in_reasoning=true)
175        registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177        // Register Qwen3 parser (starts with in_reasoning=false)
178        registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180        // Register Qwen3-thinking parser (starts with in_reasoning=true)
181        registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183        // Register Kimi parser with Unicode tokens (starts with in_reasoning=false)
184        registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186        // Register GLM45 parser (same format as Qwen3 but separate for debugging)
187        registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189        // Register Step3 parser (same format as DeepSeek-R1 but separate for debugging)
190        registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192        // Register MiniMax parser (appends <think> token at the beginning)
193        registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195        // Register Cohere Command parser (uses <|START_THINKING|> / <|END_THINKING|>)
196        registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
197
198        // Register model patterns
199        registry.register_pattern("deepseek-r1", "deepseek_r1");
200        registry.register_pattern("qwen3-thinking", "qwen3_thinking");
201        registry.register_pattern("qwen-thinking", "qwen3_thinking");
202        registry.register_pattern("qwen3", "qwen3");
203        registry.register_pattern("qwen", "qwen3");
204        registry.register_pattern("glm45", "glm45");
205        registry.register_pattern("glm47", "glm45"); // glm47 uses same reasoning format as glm45
206        registry.register_pattern("kimi", "kimi");
207        registry.register_pattern("step3", "step3");
208        registry.register_pattern("minimax", "minimax");
209        registry.register_pattern("minimax-m2", "minimax");
210        registry.register_pattern("mm-m2", "minimax");
211
212        // Cohere Command models use <|START_THINKING|> / <|END_THINKING|>
213        registry.register_pattern("command-r", "cohere_cmd");
214        registry.register_pattern("command-a", "cohere_cmd");
215        registry.register_pattern("c4ai-command", "cohere_cmd");
216        registry.register_pattern("cohere", "cohere_cmd");
217
218        // Nano V3 uses same format as Qwen3 (requires explicit <think> token)
219        registry.register_pattern("nemotron-nano", "qwen3");
220        registry.register_pattern("nano-v3", "qwen3");
221
222        Self { registry }
223    }
224
225    /// Get a pooled parser for the given model ID.
226    /// Returns a shared instance that can be used concurrently.
227    /// Falls back to a passthrough parser if model is not recognized.
228    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
229        // First try to find by pattern
230        if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
231            return parser;
232        }
233
234        // Fall back to no-op parser (get or create passthrough in pool)
235        self.registry
236            .get_pooled_parser("passthrough")
237            .unwrap_or_else(|| {
238                // Register passthrough if not already registered
239                self.registry.register_parser("passthrough", || {
240                    let config = ParserConfig {
241                        think_start_token: "".to_string(),
242                        think_end_token: "".to_string(),
243                        stream_reasoning: true,
244                        max_buffer_size: 65536,
245                        initial_in_reasoning: false,
246                    };
247                    Box::new(
248                        BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
249                    )
250                });
251                self.registry.get_pooled_parser("passthrough").unwrap()
252            })
253    }
254
255    /// Create a new parser instance for the given model ID.
256    /// Returns a fresh instance (not pooled).
257    /// Use this when you need an isolated parser instance.
258    pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
259        // First try to find by pattern
260        if let Some(parser) = self.registry.create_for_model(model_id) {
261            return Ok(parser);
262        }
263
264        // Fall back to no-op parser (base parser without reasoning detection)
265        let config = ParserConfig {
266            think_start_token: "".to_string(),
267            think_end_token: "".to_string(),
268            stream_reasoning: true,
269            max_buffer_size: 65536,
270            initial_in_reasoning: false,
271        };
272        Ok(Box::new(
273            BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
274        ))
275    }
276
277    /// Get the internal registry for custom registration.
278    pub fn registry(&self) -> &ParserRegistry {
279        &self.registry
280    }
281
282    /// Clear the parser pool.
283    /// Useful for testing or when parsers need to be reset globally.
284    pub fn clear_pool(&self) {
285        self.registry.clear_pool();
286    }
287}
288
289impl Default for ParserFactory {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_factory_creates_deepseek_r1() {
301        let factory = ParserFactory::new();
302        let parser = factory.create("deepseek-r1-distill").unwrap();
303        assert_eq!(parser.model_type(), "deepseek_r1");
304    }
305
306    #[test]
307    fn test_factory_creates_qwen3() {
308        let factory = ParserFactory::new();
309        let parser = factory.create("qwen3-7b").unwrap();
310        assert_eq!(parser.model_type(), "qwen3");
311    }
312
313    #[test]
314    fn test_factory_creates_kimi() {
315        let factory = ParserFactory::new();
316        let parser = factory.create("kimi-chat").unwrap();
317        assert_eq!(parser.model_type(), "kimi");
318    }
319
320    #[test]
321    fn test_factory_fallback_to_passthrough() {
322        let factory = ParserFactory::new();
323        let parser = factory.create("unknown-model").unwrap();
324        assert_eq!(parser.model_type(), "passthrough");
325    }
326
327    #[test]
328    fn test_case_insensitive_matching() {
329        let factory = ParserFactory::new();
330        let parser1 = factory.create("DeepSeek-R1").unwrap();
331        let parser2 = factory.create("QWEN3").unwrap();
332        let parser3 = factory.create("Kimi").unwrap();
333
334        assert_eq!(parser1.model_type(), "deepseek_r1");
335        assert_eq!(parser2.model_type(), "qwen3");
336        assert_eq!(parser3.model_type(), "kimi");
337    }
338
339    #[test]
340    fn test_step3_model() {
341        let factory = ParserFactory::new();
342        let step3 = factory.create("step3-model").unwrap();
343        assert_eq!(step3.model_type(), "step3");
344    }
345
346    #[test]
347    fn test_glm45_model() {
348        let factory = ParserFactory::new();
349        let glm45 = factory.create("glm45-v2").unwrap();
350        assert_eq!(glm45.model_type(), "glm45");
351    }
352
353    #[test]
354    fn test_minimax_model() {
355        let factory = ParserFactory::new();
356        let minimax = factory.create("minimax-m2").unwrap();
357        assert_eq!(minimax.model_type(), "minimax");
358
359        // Also test alternate patterns
360        let mm = factory.create("mm-m2-chat").unwrap();
361        assert_eq!(mm.model_type(), "minimax");
362    }
363
364    #[test]
365    fn test_cohere_cmd_model() {
366        let factory = ParserFactory::new();
367
368        // Test various Cohere model patterns
369        let command_r = factory.create("command-r-plus").unwrap();
370        assert_eq!(command_r.model_type(), "cohere_cmd");
371
372        let command_a = factory.create("command-a-03-2025").unwrap();
373        assert_eq!(command_a.model_type(), "cohere_cmd");
374
375        let c4ai = factory.create("c4ai-command-r-v01").unwrap();
376        assert_eq!(c4ai.model_type(), "cohere_cmd");
377
378        let cohere = factory.create("cohere-embed").unwrap();
379        assert_eq!(cohere.model_type(), "cohere_cmd");
380    }
381
382    #[tokio::test]
383    async fn test_pooled_parser_reuse() {
384        let factory = ParserFactory::new();
385
386        // Get the same parser twice - should be the same instance
387        let parser1 = factory.get_pooled("deepseek-r1");
388        let parser2 = factory.get_pooled("deepseek-r1");
389
390        // Both should point to the same Arc
391        assert!(Arc::ptr_eq(&parser1, &parser2));
392
393        // Different models should get different parsers
394        let parser3 = factory.get_pooled("qwen3");
395        assert!(!Arc::ptr_eq(&parser1, &parser3));
396    }
397
398    #[tokio::test]
399    async fn test_pooled_parser_concurrent_access() {
400        let factory = ParserFactory::new();
401        let parser = factory.get_pooled("deepseek-r1");
402
403        // Spawn multiple async tasks that use the same parser
404        let mut handles = vec![];
405
406        for i in 0..3 {
407            let parser_clone = Arc::clone(&parser);
408            let handle = tokio::spawn(async move {
409                let mut parser = parser_clone.lock().await;
410                let input = format!("thread {} reasoning</think>answer", i);
411                let result = parser.detect_and_parse_reasoning(&input).unwrap();
412                assert_eq!(result.normal_text, "answer");
413                assert!(result.reasoning_text.contains("reasoning"));
414            });
415            handles.push(handle);
416        }
417
418        // Wait for all tasks to complete
419        for handle in handles {
420            handle.await.unwrap();
421        }
422    }
423
424    #[tokio::test]
425    async fn test_pool_clearing() {
426        let factory = ParserFactory::new();
427
428        // Get a pooled parser
429        let parser1 = factory.get_pooled("deepseek-r1");
430
431        // Clear the pool
432        factory.clear_pool();
433
434        // Get another parser - should be a new instance
435        let parser2 = factory.get_pooled("deepseek-r1");
436
437        // They should be different instances (different Arc pointers)
438        assert!(!Arc::ptr_eq(&parser1, &parser2));
439    }
440
441    #[tokio::test]
442    async fn test_passthrough_parser_pooling() {
443        let factory = ParserFactory::new();
444
445        // Unknown models should get passthrough parser
446        let parser1 = factory.get_pooled("unknown-model-1");
447        let parser2 = factory.get_pooled("unknown-model-2");
448
449        // Both should use the same passthrough parser instance
450        assert!(Arc::ptr_eq(&parser1, &parser2));
451
452        let parser = parser1.lock().await;
453        assert_eq!(parser.model_type(), "passthrough");
454    }
455
456    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
457    async fn test_high_concurrency_parser_access() {
458        use std::{
459            sync::atomic::{AtomicUsize, Ordering},
460            time::Instant,
461        };
462
463        let factory = ParserFactory::new();
464        let num_tasks = 100;
465        let requests_per_task = 50;
466        let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
467
468        // Track successful operations
469        let success_count = Arc::new(AtomicUsize::new(0));
470        let error_count = Arc::new(AtomicUsize::new(0));
471
472        let start = Instant::now();
473        let mut handles = vec![];
474
475        for task_id in 0..num_tasks {
476            let factory = factory.clone();
477            let models = models.clone();
478            let success_count = Arc::clone(&success_count);
479            let error_count = Arc::clone(&error_count);
480
481            let handle = tokio::spawn(async move {
482                for request_id in 0..requests_per_task {
483                    // Rotate through different models
484                    let model = &models[(task_id + request_id) % models.len()];
485                    let parser = factory.get_pooled(model);
486
487                    // Use async lock - tokio::Mutex doesn't poison
488                    let mut p = parser.lock().await;
489
490                    // Simulate realistic parsing work with substantial text
491                    // Typical reasoning can be 500-5000 tokens
492                    let reasoning_text = format!(
493                        "Task {} is processing request {}. Let me think through this step by step. \
494                        First, I need to understand the problem. The problem involves analyzing data \
495                        and making calculations. Let me break this down: \n\
496                        1. Initial analysis shows that we have multiple variables to consider. \
497                        2. The data suggests a pattern that needs further investigation. \
498                        3. Computing the values: {} * {} = {}. \
499                        4. Cross-referencing with previous results indicates consistency. \
500                        5. The mathematical proof follows from the axioms... \
501                        6. Considering edge cases and boundary conditions... \
502                        7. Validating against known constraints... \
503                        8. The conclusion follows logically from premises A, B, and C. \
504                        This reasoning chain demonstrates the validity of our approach.",
505                        task_id, request_id, task_id, request_id, task_id * request_id
506                    );
507
508                    let answer_text = format!(
509                        "Based on my analysis, the answer for task {} request {} is: \
510                        The solution involves multiple steps as outlined in the reasoning. \
511                        The final result is {} with confidence level high. \
512                        This conclusion is supported by rigorous mathematical analysis \
513                        and has been validated against multiple test cases. \
514                        The implementation should handle edge cases appropriately.",
515                        task_id,
516                        request_id,
517                        task_id * request_id
518                    );
519
520                    let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
521
522                    match p.detect_and_parse_reasoning(&input) {
523                        Ok(result) => {
524                            // Note: Some parsers with stream_reasoning=true won't accumulate reasoning text
525                            assert!(result.normal_text.contains(&format!("task {}", task_id)));
526
527                            // For parsers that accumulate reasoning (stream_reasoning=false)
528                            // the reasoning_text should be populated
529                            if !result.reasoning_text.is_empty() {
530                                assert!(result
531                                    .reasoning_text
532                                    .contains(&format!("Task {}", task_id)));
533                                assert!(result.reasoning_text.len() > 500); // Ensure substantial reasoning
534                            }
535
536                            // Normal text should always be present
537                            assert!(result.normal_text.len() > 100); // Ensure substantial answer
538                            success_count.fetch_add(1, Ordering::Relaxed);
539                        }
540                        Err(e) => {
541                            eprintln!("Parse error: {:?}", e);
542                            error_count.fetch_add(1, Ordering::Relaxed);
543                        }
544                    }
545
546                    // Explicitly drop the lock to release it quickly
547                    drop(p);
548                }
549            });
550            handles.push(handle);
551        }
552
553        // Wait for all tasks
554        for handle in handles {
555            handle.await.unwrap();
556        }
557
558        let duration = start.elapsed();
559        let total_requests = num_tasks * requests_per_task;
560        let successes = success_count.load(Ordering::Relaxed);
561        let errors = error_count.load(Ordering::Relaxed);
562
563        // Print stats for debugging
564        println!(
565            "High concurrency test: {} tasks, {} requests each",
566            num_tasks, requests_per_task
567        );
568        println!(
569            "Completed in {:?}, {} successes, {} errors",
570            duration, successes, errors
571        );
572        println!(
573            "Throughput: {:.0} requests/sec",
574            (total_requests as f64) / duration.as_secs_f64()
575        );
576
577        // All requests should succeed
578        assert_eq!(successes, total_requests);
579        assert_eq!(errors, 0);
580
581        // Performance check: should handle at least 1000 req/sec
582        let throughput = (total_requests as f64) / duration.as_secs_f64();
583        assert!(
584            throughput > 1000.0,
585            "Throughput too low: {:.0} req/sec",
586            throughput
587        );
588    }
589
590    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
591    async fn test_concurrent_pool_modifications() {
592        let factory = ParserFactory::new();
593        let mut handles = vec![];
594
595        // Task 1: Continuously get parsers
596        let factory1 = factory.clone();
597        handles.push(tokio::spawn(async move {
598            for _ in 0..100 {
599                let _parser = factory1.get_pooled("deepseek-r1");
600            }
601        }));
602
603        // Task 2: Continuously clear pool
604        let factory2 = factory.clone();
605        handles.push(tokio::spawn(async move {
606            for _ in 0..10 {
607                factory2.clear_pool();
608                tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
609            }
610        }));
611
612        // Task 3: Get different parsers
613        let factory3 = factory.clone();
614        handles.push(tokio::spawn(async move {
615            for i in 0..100 {
616                let models = ["qwen3", "kimi", "unknown"];
617                let _parser = factory3.get_pooled(models[i % 3]);
618            }
619        }));
620
621        // Wait for all tasks - should not deadlock or panic
622        for handle in handles {
623            handle.await.unwrap();
624        }
625    }
626}