1use std::{collections::HashMap, sync::Arc};
5
6use parking_lot::RwLock;
7use tokio::sync::Mutex;
8
9use crate::{
10 parsers::{
11 BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
12 MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
13 },
14 traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE},
15};
16
17pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
20
21type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
23
24#[derive(Clone)]
26pub struct ParserRegistry {
27 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31 patterns: Arc<RwLock<Vec<(String, String)>>>, }
34
35impl ParserRegistry {
36 pub fn new() -> Self {
38 Self {
39 creators: Arc::new(RwLock::new(HashMap::new())),
40 pool: Arc::new(RwLock::new(HashMap::new())),
41 patterns: Arc::new(RwLock::new(Vec::new())),
42 }
43 }
44
45 pub fn register_parser<F>(&self, name: &str, creator: F)
47 where
48 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
49 {
50 let mut creators = self.creators.write();
51 creators.insert(name.to_string(), Arc::new(creator));
52 }
53
54 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
57 let mut patterns = self.patterns.write();
58 patterns.push((pattern.to_string(), parser_name.to_string()));
59 }
60
61 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
64 {
66 let pool = self.pool.read();
67 if let Some(parser) = pool.get(name) {
68 return Some(Arc::clone(parser));
69 }
70 }
71
72 let creators = self.creators.read();
74 if let Some(creator) = creators.get(name) {
75 let parser = Arc::new(Mutex::new(creator()));
76
77 let mut pool = self.pool.write();
79 pool.insert(name.to_string(), Arc::clone(&parser));
80
81 Some(parser)
82 } else {
83 None
84 }
85 }
86
87 pub fn has_parser(&self, name: &str) -> bool {
89 let creators = self.creators.read();
90 creators.contains_key(name)
91 }
92
93 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
96 let creators = self.creators.read();
97 creators.get(name).map(|creator| creator())
98 }
99
100 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
102 let patterns = self.patterns.read();
103 let model_lower = model_id.to_lowercase();
104
105 for (pattern, parser_name) in patterns.iter() {
106 if model_lower.contains(&pattern.to_lowercase()) {
107 return self.get_pooled_parser(parser_name);
108 }
109 }
110 None
111 }
112
113 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
116 let patterns = self.patterns.read();
117 let model_lower = model_id.to_lowercase();
118
119 for (pattern, parser_name) in patterns.iter() {
120 if model_lower.contains(&pattern.to_lowercase()) {
121 let creators = self.creators.read();
122 return creators.contains_key(parser_name);
123 }
124 }
125 false
126 }
127
128 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
131 let patterns = self.patterns.read();
132 let model_lower = model_id.to_lowercase();
133
134 for (pattern, parser_name) in patterns.iter() {
135 if model_lower.contains(&pattern.to_lowercase()) {
136 return self.create_parser(parser_name);
137 }
138 }
139 None
140 }
141
142 pub fn list_parsers(&self) -> Vec<String> {
144 let mut parsers: Vec<_> = self.creators.read().keys().cloned().collect();
145 parsers.sort_unstable();
146 parsers
147 }
148
149 pub fn clear_pool(&self) {
152 let mut pool = self.pool.write();
153 pool.clear();
154 }
155}
156
157impl Default for ParserRegistry {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163#[derive(Clone)]
165pub struct ParserFactory {
166 registry: ParserRegistry,
167}
168
169impl ParserFactory {
170 pub fn new() -> Self {
172 let registry = ParserRegistry::new();
173
174 registry.register_parser("base", || {
176 Box::new(BaseReasoningParser::new(ParserConfig::default()))
177 });
178
179 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
181
182 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
184
185 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
187
188 registry.register_parser("kimi", || Box::new(KimiParser::new()));
190
191 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
193
194 registry.register_parser("step3", || Box::new(Step3Parser::new()));
196
197 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
199
200 registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
202
203 registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
205
206 registry.register_pattern("deepseek-r1", "deepseek_r1");
208 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
209 registry.register_pattern("qwen-thinking", "qwen3_thinking");
210 registry.register_pattern("qwen3", "qwen3");
211 registry.register_pattern("qwen", "qwen3");
212 registry.register_pattern("glm45", "glm45");
213 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi", "kimi");
215 registry.register_pattern("step3", "step3");
216 registry.register_pattern("minimax", "minimax");
217 registry.register_pattern("minimax-m2", "minimax");
218 registry.register_pattern("mm-m2", "minimax");
219
220 registry.register_pattern("command-r", "cohere_cmd");
222 registry.register_pattern("command-a", "cohere_cmd");
223 registry.register_pattern("c4ai-command", "cohere_cmd");
224 registry.register_pattern("cohere", "cohere_cmd");
225
226 registry.register_pattern("nemotron-nano", "nano_v3");
228 registry.register_pattern("nemotron-super", "nano_v3");
229 registry.register_pattern("nano-v3", "nano_v3");
230
231 Self { registry }
232 }
233
234 #[expect(
238 clippy::expect_used,
239 reason = "passthrough parser is registered on the line above; None indicates a bug in registration logic"
240 )]
241 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
242 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
244 return parser;
245 }
246
247 self.registry
249 .get_pooled_parser("passthrough")
250 .unwrap_or_else(|| {
251 self.registry.register_parser("passthrough", || {
253 let config = ParserConfig {
254 think_start_token: String::new(),
255 think_end_token: String::new(),
256 stream_reasoning: true,
257 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
258 initial_in_reasoning: false,
259 };
260 Box::new(
261 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
262 )
263 });
264 self.registry
265 .get_pooled_parser("passthrough")
266 .expect("passthrough parser was just registered")
267 })
268 }
269
270 pub fn create(&self, model_id: &str) -> Box<dyn ReasoningParser> {
274 if let Some(parser) = self.registry.create_for_model(model_id) {
276 return parser;
277 }
278
279 let config = ParserConfig {
281 think_start_token: String::new(),
282 think_end_token: String::new(),
283 stream_reasoning: true,
284 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
285 initial_in_reasoning: false,
286 };
287 Box::new(BaseReasoningParser::new(config).with_model_type("passthrough".to_string()))
288 }
289
290 pub fn registry(&self) -> &ParserRegistry {
292 &self.registry
293 }
294
295 pub fn list_parsers(&self) -> Vec<String> {
297 self.registry.list_parsers()
298 }
299
300 pub fn clear_pool(&self) {
303 self.registry.clear_pool();
304 }
305}
306
307impl Default for ParserFactory {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
314#[expect(
315 clippy::disallowed_methods,
316 reason = "tokio::spawn is fine in unit tests that await all handles"
317)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_factory_creates_deepseek_r1() {
323 let factory = ParserFactory::new();
324 let parser = factory.create("deepseek-r1-distill");
325 assert_eq!(parser.model_type(), "deepseek_r1");
326 }
327
328 #[test]
329 fn test_factory_creates_qwen3() {
330 let factory = ParserFactory::new();
331 let parser = factory.create("qwen3-7b");
332 assert_eq!(parser.model_type(), "qwen3");
333 }
334
335 #[test]
336 fn test_factory_creates_kimi() {
337 let factory = ParserFactory::new();
338 let parser = factory.create("kimi-chat");
339 assert_eq!(parser.model_type(), "kimi");
340 }
341
342 #[test]
343 fn test_factory_fallback_to_passthrough() {
344 let factory = ParserFactory::new();
345 let parser = factory.create("unknown-model");
346 assert_eq!(parser.model_type(), "passthrough");
347 }
348
349 #[test]
350 fn test_case_insensitive_matching() {
351 let factory = ParserFactory::new();
352 let parser1 = factory.create("DeepSeek-R1");
353 let parser2 = factory.create("QWEN3");
354 let parser3 = factory.create("Kimi");
355
356 assert_eq!(parser1.model_type(), "deepseek_r1");
357 assert_eq!(parser2.model_type(), "qwen3");
358 assert_eq!(parser3.model_type(), "kimi");
359 }
360
361 #[test]
362 fn test_step3_model() {
363 let factory = ParserFactory::new();
364 let step3 = factory.create("step3-model");
365 assert_eq!(step3.model_type(), "step3");
366 }
367
368 #[test]
369 fn test_glm45_model() {
370 let factory = ParserFactory::new();
371 let glm45 = factory.create("glm45-v2");
372 assert_eq!(glm45.model_type(), "glm45");
373 }
374
375 #[test]
376 fn test_minimax_model() {
377 let factory = ParserFactory::new();
378 let minimax = factory.create("minimax-m2");
379 assert_eq!(minimax.model_type(), "minimax");
380
381 let mm = factory.create("mm-m2-chat");
383 assert_eq!(mm.model_type(), "minimax");
384 }
385
386 #[test]
387 fn test_nano_v3_model() {
388 let factory = ParserFactory::new();
389
390 let nano = factory.create("nano-v3-chat");
391 assert_eq!(nano.model_type(), "nano_v3");
392
393 let nemotron_nano = factory.create("nemotron-nano-4b");
394 assert_eq!(nemotron_nano.model_type(), "nano_v3");
395
396 let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super");
397 assert_eq!(nemotron_super.model_type(), "nano_v3");
398 }
399
400 #[test]
401 fn test_cohere_cmd_model() {
402 let factory = ParserFactory::new();
403
404 let command_r = factory.create("command-r-plus");
406 assert_eq!(command_r.model_type(), "cohere_cmd");
407
408 let command_a = factory.create("command-a-03-2025");
409 assert_eq!(command_a.model_type(), "cohere_cmd");
410
411 let c4ai = factory.create("c4ai-command-r-v01");
412 assert_eq!(c4ai.model_type(), "cohere_cmd");
413
414 let cohere = factory.create("cohere-embed");
415 assert_eq!(cohere.model_type(), "cohere_cmd");
416 }
417
418 #[tokio::test]
419 async fn test_pooled_parser_reuse() {
420 let factory = ParserFactory::new();
421
422 let parser1 = factory.get_pooled("deepseek-r1");
424 let parser2 = factory.get_pooled("deepseek-r1");
425
426 assert!(Arc::ptr_eq(&parser1, &parser2));
428
429 let parser3 = factory.get_pooled("qwen3");
431 assert!(!Arc::ptr_eq(&parser1, &parser3));
432 }
433
434 #[tokio::test]
435 async fn test_pooled_parser_concurrent_access() {
436 let factory = ParserFactory::new();
437 let parser = factory.get_pooled("deepseek-r1");
438
439 let mut handles = vec![];
441
442 for i in 0..3 {
443 let parser_clone = Arc::clone(&parser);
444 let handle = tokio::spawn(async move {
445 let mut parser = parser_clone.lock().await;
446 let input = format!("thread {i} reasoning</think>answer");
447 let result = parser.detect_and_parse_reasoning(&input).unwrap();
448 assert_eq!(result.normal_text, "answer");
449 assert!(result.reasoning_text.contains("reasoning"));
450 });
451 handles.push(handle);
452 }
453
454 for handle in handles {
456 handle.await.unwrap();
457 }
458 }
459
460 #[tokio::test]
461 async fn test_pool_clearing() {
462 let factory = ParserFactory::new();
463
464 let parser1 = factory.get_pooled("deepseek-r1");
466
467 factory.clear_pool();
469
470 let parser2 = factory.get_pooled("deepseek-r1");
472
473 assert!(!Arc::ptr_eq(&parser1, &parser2));
475 }
476
477 #[tokio::test]
478 async fn test_passthrough_parser_pooling() {
479 let factory = ParserFactory::new();
480
481 let parser1 = factory.get_pooled("unknown-model-1");
483 let parser2 = factory.get_pooled("unknown-model-2");
484
485 assert!(Arc::ptr_eq(&parser1, &parser2));
487
488 let parser = parser1.lock().await;
489 assert_eq!(parser.model_type(), "passthrough");
490 }
491
492 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
493 async fn test_high_concurrency_parser_access() {
494 use std::{
495 sync::atomic::{AtomicUsize, Ordering},
496 time::Instant,
497 };
498
499 let factory = ParserFactory::new();
500 let num_tasks = 100;
501 let requests_per_task = 50;
502 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
503
504 let success_count = Arc::new(AtomicUsize::new(0));
506 let error_count = Arc::new(AtomicUsize::new(0));
507
508 let start = Instant::now();
509 let mut handles = vec![];
510
511 for task_id in 0..num_tasks {
512 let factory = factory.clone();
513 let models = models.clone();
514 let success_count = Arc::clone(&success_count);
515 let error_count = Arc::clone(&error_count);
516
517 let handle = tokio::spawn(async move {
518 for request_id in 0..requests_per_task {
519 let model = &models[(task_id + request_id) % models.len()];
521 let parser = factory.get_pooled(model);
522
523 let mut p = parser.lock().await;
525
526 let product = task_id * request_id;
529 let reasoning_text = format!(
530 "Task {task_id} is processing request {request_id}. Let me think through this step by step. \
531 First, I need to understand the problem. The problem involves analyzing data \
532 and making calculations. Let me break this down: \n\
533 1. Initial analysis shows that we have multiple variables to consider. \
534 2. The data suggests a pattern that needs further investigation. \
535 3. Computing the values: {task_id} * {request_id} = {product}. \
536 4. Cross-referencing with previous results indicates consistency. \
537 5. The mathematical proof follows from the axioms... \
538 6. Considering edge cases and boundary conditions... \
539 7. Validating against known constraints... \
540 8. The conclusion follows logically from premises A, B, and C. \
541 This reasoning chain demonstrates the validity of our approach.",
542 );
543
544 let answer_text = format!(
545 "Based on my analysis, the answer for task {task_id} request {request_id} is: \
546 The solution involves multiple steps as outlined in the reasoning. \
547 The final result is {product} with confidence level high. \
548 This conclusion is supported by rigorous mathematical analysis \
549 and has been validated against multiple test cases. \
550 The implementation should handle edge cases appropriately.",
551 );
552
553 let input = format!("<think>{reasoning_text}</think>{answer_text}");
554
555 match p.detect_and_parse_reasoning(&input) {
556 Ok(result) => {
557 assert!(result.normal_text.contains(&format!("task {task_id}")));
559
560 if !result.reasoning_text.is_empty() {
563 assert!(result.reasoning_text.contains(&format!("Task {task_id}")));
564 assert!(result.reasoning_text.len() > 500); }
566
567 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
570 }
571 Err(e) => {
572 #[expect(clippy::print_stderr, reason = "test diagnostic output")]
573 {
574 eprintln!("Parse error: {e:?}");
575 }
576 error_count.fetch_add(1, Ordering::Relaxed);
577 }
578 }
579
580 drop(p);
582 }
583 });
584 handles.push(handle);
585 }
586
587 for handle in handles {
589 handle.await.unwrap();
590 }
591
592 let duration = start.elapsed();
593 let total_requests = num_tasks * requests_per_task;
594 let successes = success_count.load(Ordering::Relaxed);
595 let errors = error_count.load(Ordering::Relaxed);
596
597 #[expect(clippy::print_stdout, reason = "test diagnostic output")]
599 {
600 println!("High concurrency test: {num_tasks} tasks, {requests_per_task} requests each");
601 println!("Completed in {duration:?}, {successes} successes, {errors} errors");
602 println!(
603 "Throughput: {:.0} requests/sec",
604 (total_requests as f64) / duration.as_secs_f64()
605 );
606 }
607
608 assert_eq!(successes, total_requests);
610 assert_eq!(errors, 0);
611
612 let throughput = (total_requests as f64) / duration.as_secs_f64();
614 assert!(
615 throughput > 1000.0,
616 "Throughput too low: {throughput:.0} req/sec",
617 );
618 }
619
620 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
621 async fn test_concurrent_pool_modifications() {
622 let factory = ParserFactory::new();
623 let mut handles = vec![];
624
625 let factory1 = factory.clone();
627 handles.push(tokio::spawn(async move {
628 for _ in 0..100 {
629 let _parser = factory1.get_pooled("deepseek-r1");
630 }
631 }));
632
633 let factory2 = factory.clone();
635 handles.push(tokio::spawn(async move {
636 for _ in 0..10 {
637 factory2.clear_pool();
638 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
639 }
640 }));
641
642 let factory3 = factory.clone();
644 handles.push(tokio::spawn(async move {
645 for i in 0..100 {
646 let models = ["qwen3", "kimi", "unknown"];
647 let _parser = factory3.get_pooled(models[i % 3]);
648 }
649 }));
650
651 for handle in handles {
653 handle.await.unwrap();
654 }
655 }
656}