1use std::{collections::HashMap, sync::Arc};
5
6use parking_lot::RwLock;
7use tokio::sync::Mutex;
8
9use crate::{
10 parsers::{
11 BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
12 MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
13 },
14 traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE},
15};
16
17pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
20
21type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
23
24#[derive(Clone)]
26pub struct ParserRegistry {
27 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31 patterns: Arc<RwLock<Vec<(String, String)>>>, }
34
35impl ParserRegistry {
36 pub fn new() -> Self {
38 Self {
39 creators: Arc::new(RwLock::new(HashMap::new())),
40 pool: Arc::new(RwLock::new(HashMap::new())),
41 patterns: Arc::new(RwLock::new(Vec::new())),
42 }
43 }
44
45 pub fn register_parser<F>(&self, name: &str, creator: F)
47 where
48 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
49 {
50 let mut creators = self.creators.write();
51 creators.insert(name.to_string(), Arc::new(creator));
52 }
53
54 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
57 let mut patterns = self.patterns.write();
58 patterns.push((pattern.to_string(), parser_name.to_string()));
59 }
60
61 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
64 {
66 let pool = self.pool.read();
67 if let Some(parser) = pool.get(name) {
68 return Some(Arc::clone(parser));
69 }
70 }
71
72 let creators = self.creators.read();
74 if let Some(creator) = creators.get(name) {
75 let parser = Arc::new(Mutex::new(creator()));
76
77 let mut pool = self.pool.write();
79 pool.insert(name.to_string(), Arc::clone(&parser));
80
81 Some(parser)
82 } else {
83 None
84 }
85 }
86
87 pub fn has_parser(&self, name: &str) -> bool {
89 let creators = self.creators.read();
90 creators.contains_key(name)
91 }
92
93 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
96 let creators = self.creators.read();
97 creators.get(name).map(|creator| creator())
98 }
99
100 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
102 let patterns = self.patterns.read();
103 let model_lower = model_id.to_lowercase();
104
105 for (pattern, parser_name) in patterns.iter() {
106 if model_lower.contains(&pattern.to_lowercase()) {
107 return self.get_pooled_parser(parser_name);
108 }
109 }
110 None
111 }
112
113 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
116 let patterns = self.patterns.read();
117 let model_lower = model_id.to_lowercase();
118
119 for (pattern, parser_name) in patterns.iter() {
120 if model_lower.contains(&pattern.to_lowercase()) {
121 let creators = self.creators.read();
122 return creators.contains_key(parser_name);
123 }
124 }
125 false
126 }
127
128 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
131 let patterns = self.patterns.read();
132 let model_lower = model_id.to_lowercase();
133
134 for (pattern, parser_name) in patterns.iter() {
135 if model_lower.contains(&pattern.to_lowercase()) {
136 return self.create_parser(parser_name);
137 }
138 }
139 None
140 }
141
142 pub fn list_parsers(&self) -> Vec<String> {
144 let mut parsers: Vec<_> = self.creators.read().keys().cloned().collect();
145 parsers.sort_unstable();
146 parsers
147 }
148
149 pub fn clear_pool(&self) {
152 let mut pool = self.pool.write();
153 pool.clear();
154 }
155}
156
157impl Default for ParserRegistry {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163#[derive(Clone)]
165pub struct ParserFactory {
166 registry: ParserRegistry,
167}
168
169impl ParserFactory {
170 pub fn new() -> Self {
172 let registry = ParserRegistry::new();
173
174 registry.register_parser("base", || {
176 Box::new(BaseReasoningParser::new(ParserConfig::default()))
177 });
178
179 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
181
182 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
184
185 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
187
188 registry.register_parser("kimi", || Box::new(KimiParser::new()));
190
191 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
193
194 registry.register_parser("step3", || Box::new(Step3Parser::new()));
196
197 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
199
200 registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
202
203 registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
205
206 registry.register_parser("deepseek_v31", || {
208 let config = ParserConfig {
209 think_start_token: "<think>".to_string(),
210 think_end_token: "</think>".to_string(),
211 stream_reasoning: true,
212 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
213 always_in_reasoning: false,
214 };
215 Box::new(BaseReasoningParser::new(config).with_model_type("deepseek_v31".to_string()))
216 });
217
218 registry.register_parser("kimi_k25", || {
220 let config = ParserConfig {
221 think_start_token: "<think>".to_string(),
222 think_end_token: "</think>".to_string(),
223 stream_reasoning: true,
224 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
225 always_in_reasoning: false,
226 };
227 Box::new(BaseReasoningParser::new(config).with_model_type("kimi_k25".to_string()))
228 });
229
230 registry.register_parser("kimi_thinking", || {
232 let config = ParserConfig {
233 think_start_token: "<think>".to_string(),
234 think_end_token: "</think>".to_string(),
235 stream_reasoning: true,
236 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
237 always_in_reasoning: true,
238 };
239 Box::new(BaseReasoningParser::new(config).with_model_type("kimi_thinking".to_string()))
240 });
241
242 registry.register_pattern("deepseek-r1", "deepseek_r1");
244 registry.register_pattern("deepseek-v3.1", "deepseek_v31");
245 registry.register_pattern("deepseek-v3-1", "deepseek_v31");
246 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
247 registry.register_pattern("qwen-thinking", "qwen3_thinking");
248 registry.register_pattern("qwen3", "qwen3");
249 registry.register_pattern("qwen", "qwen3");
250 registry.register_pattern("glm45", "glm45");
251 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi-k2-thinking", "kimi_thinking");
253 registry.register_pattern("kimi-k2.5", "kimi_k25");
254 registry.register_pattern("kimi", "kimi"); registry.register_pattern("step3", "step3");
256 registry.register_pattern("minimax", "minimax");
257 registry.register_pattern("minimax-m2", "minimax");
258 registry.register_pattern("mm-m2", "minimax");
259
260 registry.register_pattern("command-r", "cohere_cmd");
262 registry.register_pattern("command-a", "cohere_cmd");
263 registry.register_pattern("c4ai-command", "cohere_cmd");
264 registry.register_pattern("cohere", "cohere_cmd");
265
266 registry.register_pattern("nemotron-nano", "nano_v3");
268 registry.register_pattern("nemotron-super", "nano_v3");
269 registry.register_pattern("nano-v3", "nano_v3");
270
271 Self { registry }
272 }
273
274 #[expect(
278 clippy::expect_used,
279 reason = "passthrough parser is registered on the line above; None indicates a bug in registration logic"
280 )]
281 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
282 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
284 return parser;
285 }
286
287 self.registry
289 .get_pooled_parser("passthrough")
290 .unwrap_or_else(|| {
291 self.registry.register_parser("passthrough", || {
293 let config = ParserConfig {
294 think_start_token: String::new(),
295 think_end_token: String::new(),
296 stream_reasoning: true,
297 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
298 always_in_reasoning: false,
299 };
300 Box::new(
301 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
302 )
303 });
304 self.registry
305 .get_pooled_parser("passthrough")
306 .expect("passthrough parser was just registered")
307 })
308 }
309
310 pub fn create(&self, model_id: &str) -> Box<dyn ReasoningParser> {
314 if let Some(parser) = self.registry.create_for_model(model_id) {
316 return parser;
317 }
318
319 let config = ParserConfig {
321 think_start_token: String::new(),
322 think_end_token: String::new(),
323 stream_reasoning: true,
324 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
325 always_in_reasoning: false,
326 };
327 Box::new(BaseReasoningParser::new(config).with_model_type("passthrough".to_string()))
328 }
329
330 pub fn registry(&self) -> &ParserRegistry {
332 &self.registry
333 }
334
335 pub fn list_parsers(&self) -> Vec<String> {
337 self.registry.list_parsers()
338 }
339
340 pub fn clear_pool(&self) {
343 self.registry.clear_pool();
344 }
345}
346
347impl Default for ParserFactory {
348 fn default() -> Self {
349 Self::new()
350 }
351}
352
353#[cfg(test)]
354#[expect(
355 clippy::disallowed_methods,
356 reason = "tokio::spawn is fine in unit tests that await all handles"
357)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_factory_creates_deepseek_r1() {
363 let factory = ParserFactory::new();
364 let parser = factory.create("deepseek-r1-distill");
365 assert_eq!(parser.model_type(), "deepseek_r1");
366 }
367
368 #[test]
369 fn test_factory_creates_qwen3() {
370 let factory = ParserFactory::new();
371 let parser = factory.create("qwen3-7b");
372 assert_eq!(parser.model_type(), "qwen3");
373 }
374
375 #[test]
376 fn test_factory_creates_kimi() {
377 let factory = ParserFactory::new();
378 let parser = factory.create("kimi-chat");
379 assert_eq!(parser.model_type(), "kimi");
380 }
381
382 #[test]
383 fn test_factory_fallback_to_passthrough() {
384 let factory = ParserFactory::new();
385 let parser = factory.create("unknown-model");
386 assert_eq!(parser.model_type(), "passthrough");
387 }
388
389 #[test]
390 fn test_case_insensitive_matching() {
391 let factory = ParserFactory::new();
392 let parser1 = factory.create("DeepSeek-R1");
393 let parser2 = factory.create("QWEN3");
394 let parser3 = factory.create("Kimi");
395
396 assert_eq!(parser1.model_type(), "deepseek_r1");
397 assert_eq!(parser2.model_type(), "qwen3");
398 assert_eq!(parser3.model_type(), "kimi");
399 }
400
401 #[test]
402 fn test_step3_model() {
403 let factory = ParserFactory::new();
404 let step3 = factory.create("step3-model");
405 assert_eq!(step3.model_type(), "step3");
406 }
407
408 #[test]
409 fn test_glm45_model() {
410 let factory = ParserFactory::new();
411 let glm45 = factory.create("glm45-v2");
412 assert_eq!(glm45.model_type(), "glm45");
413 }
414
415 #[test]
416 fn test_minimax_model() {
417 let factory = ParserFactory::new();
418 let minimax = factory.create("minimax-m2");
419 assert_eq!(minimax.model_type(), "minimax");
420
421 let mm = factory.create("mm-m2-chat");
423 assert_eq!(mm.model_type(), "minimax");
424 }
425
426 #[test]
427 fn test_nano_v3_model() {
428 let factory = ParserFactory::new();
429
430 let nano = factory.create("nano-v3-chat");
431 assert_eq!(nano.model_type(), "nano_v3");
432
433 let nemotron_nano = factory.create("nemotron-nano-4b");
434 assert_eq!(nemotron_nano.model_type(), "nano_v3");
435
436 let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super");
437 assert_eq!(nemotron_super.model_type(), "nano_v3");
438 }
439
440 #[test]
441 fn test_cohere_cmd_model() {
442 let factory = ParserFactory::new();
443
444 let command_r = factory.create("command-r-plus");
446 assert_eq!(command_r.model_type(), "cohere_cmd");
447
448 let command_a = factory.create("command-a-03-2025");
449 assert_eq!(command_a.model_type(), "cohere_cmd");
450
451 let c4ai = factory.create("c4ai-command-r-v01");
452 assert_eq!(c4ai.model_type(), "cohere_cmd");
453
454 let cohere = factory.create("cohere-embed");
455 assert_eq!(cohere.model_type(), "cohere_cmd");
456 }
457
458 #[tokio::test]
459 async fn test_pooled_parser_reuse() {
460 let factory = ParserFactory::new();
461
462 let parser1 = factory.get_pooled("deepseek-r1");
464 let parser2 = factory.get_pooled("deepseek-r1");
465
466 assert!(Arc::ptr_eq(&parser1, &parser2));
468
469 let parser3 = factory.get_pooled("qwen3");
471 assert!(!Arc::ptr_eq(&parser1, &parser3));
472 }
473
474 #[tokio::test]
475 async fn test_pooled_parser_concurrent_access() {
476 let factory = ParserFactory::new();
477 let parser = factory.get_pooled("deepseek-r1");
478
479 let mut handles = vec![];
481
482 for i in 0..3 {
483 let parser_clone = Arc::clone(&parser);
484 let handle = tokio::spawn(async move {
485 let mut parser = parser_clone.lock().await;
486 let input = format!("thread {i} reasoning</think>answer");
487 let result = parser.detect_and_parse_reasoning(&input).unwrap();
488 assert_eq!(result.normal_text, "answer");
489 assert!(result.reasoning_text.contains("reasoning"));
490 });
491 handles.push(handle);
492 }
493
494 for handle in handles {
496 handle.await.unwrap();
497 }
498 }
499
500 #[tokio::test]
501 async fn test_pool_clearing() {
502 let factory = ParserFactory::new();
503
504 let parser1 = factory.get_pooled("deepseek-r1");
506
507 factory.clear_pool();
509
510 let parser2 = factory.get_pooled("deepseek-r1");
512
513 assert!(!Arc::ptr_eq(&parser1, &parser2));
515 }
516
517 #[tokio::test]
518 async fn test_passthrough_parser_pooling() {
519 let factory = ParserFactory::new();
520
521 let parser1 = factory.get_pooled("unknown-model-1");
523 let parser2 = factory.get_pooled("unknown-model-2");
524
525 assert!(Arc::ptr_eq(&parser1, &parser2));
527
528 let parser = parser1.lock().await;
529 assert_eq!(parser.model_type(), "passthrough");
530 }
531
532 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
533 async fn test_high_concurrency_parser_access() {
534 use std::{
535 sync::atomic::{AtomicUsize, Ordering},
536 time::Instant,
537 };
538
539 let factory = ParserFactory::new();
540 let num_tasks = 100;
541 let requests_per_task = 50;
542 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
543
544 let success_count = Arc::new(AtomicUsize::new(0));
546 let error_count = Arc::new(AtomicUsize::new(0));
547
548 let start = Instant::now();
549 let mut handles = vec![];
550
551 for task_id in 0..num_tasks {
552 let factory = factory.clone();
553 let models = models.clone();
554 let success_count = Arc::clone(&success_count);
555 let error_count = Arc::clone(&error_count);
556
557 let handle = tokio::spawn(async move {
558 for request_id in 0..requests_per_task {
559 let model = &models[(task_id + request_id) % models.len()];
561 let parser = factory.get_pooled(model);
562
563 let mut p = parser.lock().await;
565
566 let product = task_id * request_id;
569 let reasoning_text = format!(
570 "Task {task_id} is processing request {request_id}. Let me think through this step by step. \
571 First, I need to understand the problem. The problem involves analyzing data \
572 and making calculations. Let me break this down: \n\
573 1. Initial analysis shows that we have multiple variables to consider. \
574 2. The data suggests a pattern that needs further investigation. \
575 3. Computing the values: {task_id} * {request_id} = {product}. \
576 4. Cross-referencing with previous results indicates consistency. \
577 5. The mathematical proof follows from the axioms... \
578 6. Considering edge cases and boundary conditions... \
579 7. Validating against known constraints... \
580 8. The conclusion follows logically from premises A, B, and C. \
581 This reasoning chain demonstrates the validity of our approach.",
582 );
583
584 let answer_text = format!(
585 "Based on my analysis, the answer for task {task_id} request {request_id} is: \
586 The solution involves multiple steps as outlined in the reasoning. \
587 The final result is {product} with confidence level high. \
588 This conclusion is supported by rigorous mathematical analysis \
589 and has been validated against multiple test cases. \
590 The implementation should handle edge cases appropriately.",
591 );
592
593 let input = format!("<think>{reasoning_text}</think>{answer_text}");
594
595 match p.detect_and_parse_reasoning(&input) {
596 Ok(result) => {
597 assert!(result.normal_text.contains(&format!("task {task_id}")));
599
600 if !result.reasoning_text.is_empty() {
603 assert!(result.reasoning_text.contains(&format!("Task {task_id}")));
604 assert!(result.reasoning_text.len() > 500); }
606
607 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
610 }
611 Err(e) => {
612 #[expect(clippy::print_stderr, reason = "test diagnostic output")]
613 {
614 eprintln!("Parse error: {e:?}");
615 }
616 error_count.fetch_add(1, Ordering::Relaxed);
617 }
618 }
619
620 drop(p);
622 }
623 });
624 handles.push(handle);
625 }
626
627 for handle in handles {
629 handle.await.unwrap();
630 }
631
632 let duration = start.elapsed();
633 let total_requests = num_tasks * requests_per_task;
634 let successes = success_count.load(Ordering::Relaxed);
635 let errors = error_count.load(Ordering::Relaxed);
636
637 #[expect(clippy::print_stdout, reason = "test diagnostic output")]
639 {
640 println!("High concurrency test: {num_tasks} tasks, {requests_per_task} requests each");
641 println!("Completed in {duration:?}, {successes} successes, {errors} errors");
642 println!(
643 "Throughput: {:.0} requests/sec",
644 (total_requests as f64) / duration.as_secs_f64()
645 );
646 }
647
648 assert_eq!(successes, total_requests);
650 assert_eq!(errors, 0);
651
652 let throughput = (total_requests as f64) / duration.as_secs_f64();
654 assert!(
655 throughput > 1000.0,
656 "Throughput too low: {throughput:.0} req/sec",
657 );
658 }
659
660 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
661 async fn test_concurrent_pool_modifications() {
662 let factory = ParserFactory::new();
663 let mut handles = vec![];
664
665 let factory1 = factory.clone();
667 handles.push(tokio::spawn(async move {
668 for _ in 0..100 {
669 let _parser = factory1.get_pooled("deepseek-r1");
670 }
671 }));
672
673 let factory2 = factory.clone();
675 handles.push(tokio::spawn(async move {
676 for _ in 0..10 {
677 factory2.clear_pool();
678 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
679 }
680 }));
681
682 let factory3 = factory.clone();
684 handles.push(tokio::spawn(async move {
685 for i in 0..100 {
686 let models = ["qwen3", "kimi", "unknown"];
687 let _parser = factory3.get_pooled(models[i % 3]);
688 }
689 }));
690
691 for handle in handles {
693 handle.await.unwrap();
694 }
695 }
696}