1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12 parsers::{
13 BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
14 MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
15 },
16 traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26#[derive(Clone)]
28pub struct ParserRegistry {
29 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33 patterns: Arc<RwLock<Vec<(String, String)>>>, }
36
37impl ParserRegistry {
38 pub fn new() -> Self {
40 Self {
41 creators: Arc::new(RwLock::new(HashMap::new())),
42 pool: Arc::new(RwLock::new(HashMap::new())),
43 patterns: Arc::new(RwLock::new(Vec::new())),
44 }
45 }
46
47 pub fn register_parser<F>(&self, name: &str, creator: F)
49 where
50 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51 {
52 let mut creators = self.creators.write().unwrap();
53 creators.insert(name.to_string(), Arc::new(creator));
54 }
55
56 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59 let mut patterns = self.patterns.write().unwrap();
60 patterns.push((pattern.to_string(), parser_name.to_string()));
61 }
62
63 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66 {
68 let pool = self.pool.read().unwrap();
69 if let Some(parser) = pool.get(name) {
70 return Some(Arc::clone(parser));
71 }
72 }
73
74 let creators = self.creators.read().unwrap();
76 if let Some(creator) = creators.get(name) {
77 let parser = Arc::new(Mutex::new(creator()));
78
79 let mut pool = self.pool.write().unwrap();
81 pool.insert(name.to_string(), Arc::clone(&parser));
82
83 Some(parser)
84 } else {
85 None
86 }
87 }
88
89 pub fn has_parser(&self, name: &str) -> bool {
91 let creators = self.creators.read().unwrap();
92 creators.contains_key(name)
93 }
94
95 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98 let creators = self.creators.read().unwrap();
99 creators.get(name).map(|creator| creator())
100 }
101
102 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104 let patterns = self.patterns.read().unwrap();
105 let model_lower = model_id.to_lowercase();
106
107 for (pattern, parser_name) in patterns.iter() {
108 if model_lower.contains(&pattern.to_lowercase()) {
109 return self.get_pooled_parser(parser_name);
110 }
111 }
112 None
113 }
114
115 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118 let patterns = self.patterns.read().unwrap();
119 let model_lower = model_id.to_lowercase();
120
121 for (pattern, parser_name) in patterns.iter() {
122 if model_lower.contains(&pattern.to_lowercase()) {
123 let creators = self.creators.read().unwrap();
124 return creators.contains_key(parser_name);
125 }
126 }
127 false
128 }
129
130 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133 let patterns = self.patterns.read().unwrap();
134 let model_lower = model_id.to_lowercase();
135
136 for (pattern, parser_name) in patterns.iter() {
137 if model_lower.contains(&pattern.to_lowercase()) {
138 return self.create_parser(parser_name);
139 }
140 }
141 None
142 }
143
144 pub fn clear_pool(&self) {
147 let mut pool = self.pool.write().unwrap();
148 pool.clear();
149 }
150}
151
152impl Default for ParserRegistry {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Clone)]
160pub struct ParserFactory {
161 registry: ParserRegistry,
162}
163
164impl ParserFactory {
165 pub fn new() -> Self {
167 let registry = ParserRegistry::new();
168
169 registry.register_parser("base", || {
171 Box::new(BaseReasoningParser::new(ParserConfig::default()))
172 });
173
174 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183 registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189 registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195 registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
197
198 registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
200
201 registry.register_pattern("deepseek-r1", "deepseek_r1");
203 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
204 registry.register_pattern("qwen-thinking", "qwen3_thinking");
205 registry.register_pattern("qwen3", "qwen3");
206 registry.register_pattern("qwen", "qwen3");
207 registry.register_pattern("glm45", "glm45");
208 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi", "kimi");
210 registry.register_pattern("step3", "step3");
211 registry.register_pattern("minimax", "minimax");
212 registry.register_pattern("minimax-m2", "minimax");
213 registry.register_pattern("mm-m2", "minimax");
214
215 registry.register_pattern("command-r", "cohere_cmd");
217 registry.register_pattern("command-a", "cohere_cmd");
218 registry.register_pattern("c4ai-command", "cohere_cmd");
219 registry.register_pattern("cohere", "cohere_cmd");
220
221 registry.register_pattern("nemotron-nano", "nano_v3");
223 registry.register_pattern("nemotron-super", "nano_v3");
224 registry.register_pattern("nano-v3", "nano_v3");
225
226 Self { registry }
227 }
228
229 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
233 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
235 return parser;
236 }
237
238 self.registry
240 .get_pooled_parser("passthrough")
241 .unwrap_or_else(|| {
242 self.registry.register_parser("passthrough", || {
244 let config = ParserConfig {
245 think_start_token: "".to_string(),
246 think_end_token: "".to_string(),
247 stream_reasoning: true,
248 max_buffer_size: 65536,
249 initial_in_reasoning: false,
250 };
251 Box::new(
252 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
253 )
254 });
255 self.registry.get_pooled_parser("passthrough").unwrap()
256 })
257 }
258
259 pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
263 if let Some(parser) = self.registry.create_for_model(model_id) {
265 return Ok(parser);
266 }
267
268 let config = ParserConfig {
270 think_start_token: "".to_string(),
271 think_end_token: "".to_string(),
272 stream_reasoning: true,
273 max_buffer_size: 65536,
274 initial_in_reasoning: false,
275 };
276 Ok(Box::new(
277 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
278 ))
279 }
280
281 pub fn registry(&self) -> &ParserRegistry {
283 &self.registry
284 }
285
286 pub fn clear_pool(&self) {
289 self.registry.clear_pool();
290 }
291}
292
293impl Default for ParserFactory {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_factory_creates_deepseek_r1() {
305 let factory = ParserFactory::new();
306 let parser = factory.create("deepseek-r1-distill").unwrap();
307 assert_eq!(parser.model_type(), "deepseek_r1");
308 }
309
310 #[test]
311 fn test_factory_creates_qwen3() {
312 let factory = ParserFactory::new();
313 let parser = factory.create("qwen3-7b").unwrap();
314 assert_eq!(parser.model_type(), "qwen3");
315 }
316
317 #[test]
318 fn test_factory_creates_kimi() {
319 let factory = ParserFactory::new();
320 let parser = factory.create("kimi-chat").unwrap();
321 assert_eq!(parser.model_type(), "kimi");
322 }
323
324 #[test]
325 fn test_factory_fallback_to_passthrough() {
326 let factory = ParserFactory::new();
327 let parser = factory.create("unknown-model").unwrap();
328 assert_eq!(parser.model_type(), "passthrough");
329 }
330
331 #[test]
332 fn test_case_insensitive_matching() {
333 let factory = ParserFactory::new();
334 let parser1 = factory.create("DeepSeek-R1").unwrap();
335 let parser2 = factory.create("QWEN3").unwrap();
336 let parser3 = factory.create("Kimi").unwrap();
337
338 assert_eq!(parser1.model_type(), "deepseek_r1");
339 assert_eq!(parser2.model_type(), "qwen3");
340 assert_eq!(parser3.model_type(), "kimi");
341 }
342
343 #[test]
344 fn test_step3_model() {
345 let factory = ParserFactory::new();
346 let step3 = factory.create("step3-model").unwrap();
347 assert_eq!(step3.model_type(), "step3");
348 }
349
350 #[test]
351 fn test_glm45_model() {
352 let factory = ParserFactory::new();
353 let glm45 = factory.create("glm45-v2").unwrap();
354 assert_eq!(glm45.model_type(), "glm45");
355 }
356
357 #[test]
358 fn test_minimax_model() {
359 let factory = ParserFactory::new();
360 let minimax = factory.create("minimax-m2").unwrap();
361 assert_eq!(minimax.model_type(), "minimax");
362
363 let mm = factory.create("mm-m2-chat").unwrap();
365 assert_eq!(mm.model_type(), "minimax");
366 }
367
368 #[test]
369 fn test_nano_v3_model() {
370 let factory = ParserFactory::new();
371
372 let nano = factory.create("nano-v3-chat").unwrap();
373 assert_eq!(nano.model_type(), "nano_v3");
374
375 let nemotron_nano = factory.create("nemotron-nano-4b").unwrap();
376 assert_eq!(nemotron_nano.model_type(), "nano_v3");
377
378 let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super").unwrap();
379 assert_eq!(nemotron_super.model_type(), "nano_v3");
380 }
381
382 #[test]
383 fn test_cohere_cmd_model() {
384 let factory = ParserFactory::new();
385
386 let command_r = factory.create("command-r-plus").unwrap();
388 assert_eq!(command_r.model_type(), "cohere_cmd");
389
390 let command_a = factory.create("command-a-03-2025").unwrap();
391 assert_eq!(command_a.model_type(), "cohere_cmd");
392
393 let c4ai = factory.create("c4ai-command-r-v01").unwrap();
394 assert_eq!(c4ai.model_type(), "cohere_cmd");
395
396 let cohere = factory.create("cohere-embed").unwrap();
397 assert_eq!(cohere.model_type(), "cohere_cmd");
398 }
399
400 #[tokio::test]
401 async fn test_pooled_parser_reuse() {
402 let factory = ParserFactory::new();
403
404 let parser1 = factory.get_pooled("deepseek-r1");
406 let parser2 = factory.get_pooled("deepseek-r1");
407
408 assert!(Arc::ptr_eq(&parser1, &parser2));
410
411 let parser3 = factory.get_pooled("qwen3");
413 assert!(!Arc::ptr_eq(&parser1, &parser3));
414 }
415
416 #[tokio::test]
417 async fn test_pooled_parser_concurrent_access() {
418 let factory = ParserFactory::new();
419 let parser = factory.get_pooled("deepseek-r1");
420
421 let mut handles = vec![];
423
424 for i in 0..3 {
425 let parser_clone = Arc::clone(&parser);
426 let handle = tokio::spawn(async move {
427 let mut parser = parser_clone.lock().await;
428 let input = format!("thread {} reasoning</think>answer", i);
429 let result = parser.detect_and_parse_reasoning(&input).unwrap();
430 assert_eq!(result.normal_text, "answer");
431 assert!(result.reasoning_text.contains("reasoning"));
432 });
433 handles.push(handle);
434 }
435
436 for handle in handles {
438 handle.await.unwrap();
439 }
440 }
441
442 #[tokio::test]
443 async fn test_pool_clearing() {
444 let factory = ParserFactory::new();
445
446 let parser1 = factory.get_pooled("deepseek-r1");
448
449 factory.clear_pool();
451
452 let parser2 = factory.get_pooled("deepseek-r1");
454
455 assert!(!Arc::ptr_eq(&parser1, &parser2));
457 }
458
459 #[tokio::test]
460 async fn test_passthrough_parser_pooling() {
461 let factory = ParserFactory::new();
462
463 let parser1 = factory.get_pooled("unknown-model-1");
465 let parser2 = factory.get_pooled("unknown-model-2");
466
467 assert!(Arc::ptr_eq(&parser1, &parser2));
469
470 let parser = parser1.lock().await;
471 assert_eq!(parser.model_type(), "passthrough");
472 }
473
474 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
475 async fn test_high_concurrency_parser_access() {
476 use std::{
477 sync::atomic::{AtomicUsize, Ordering},
478 time::Instant,
479 };
480
481 let factory = ParserFactory::new();
482 let num_tasks = 100;
483 let requests_per_task = 50;
484 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
485
486 let success_count = Arc::new(AtomicUsize::new(0));
488 let error_count = Arc::new(AtomicUsize::new(0));
489
490 let start = Instant::now();
491 let mut handles = vec![];
492
493 for task_id in 0..num_tasks {
494 let factory = factory.clone();
495 let models = models.clone();
496 let success_count = Arc::clone(&success_count);
497 let error_count = Arc::clone(&error_count);
498
499 let handle = tokio::spawn(async move {
500 for request_id in 0..requests_per_task {
501 let model = &models[(task_id + request_id) % models.len()];
503 let parser = factory.get_pooled(model);
504
505 let mut p = parser.lock().await;
507
508 let reasoning_text = format!(
511 "Task {} is processing request {}. Let me think through this step by step. \
512 First, I need to understand the problem. The problem involves analyzing data \
513 and making calculations. Let me break this down: \n\
514 1. Initial analysis shows that we have multiple variables to consider. \
515 2. The data suggests a pattern that needs further investigation. \
516 3. Computing the values: {} * {} = {}. \
517 4. Cross-referencing with previous results indicates consistency. \
518 5. The mathematical proof follows from the axioms... \
519 6. Considering edge cases and boundary conditions... \
520 7. Validating against known constraints... \
521 8. The conclusion follows logically from premises A, B, and C. \
522 This reasoning chain demonstrates the validity of our approach.",
523 task_id, request_id, task_id, request_id, task_id * request_id
524 );
525
526 let answer_text = format!(
527 "Based on my analysis, the answer for task {} request {} is: \
528 The solution involves multiple steps as outlined in the reasoning. \
529 The final result is {} with confidence level high. \
530 This conclusion is supported by rigorous mathematical analysis \
531 and has been validated against multiple test cases. \
532 The implementation should handle edge cases appropriately.",
533 task_id,
534 request_id,
535 task_id * request_id
536 );
537
538 let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
539
540 match p.detect_and_parse_reasoning(&input) {
541 Ok(result) => {
542 assert!(result.normal_text.contains(&format!("task {}", task_id)));
544
545 if !result.reasoning_text.is_empty() {
548 assert!(result
549 .reasoning_text
550 .contains(&format!("Task {}", task_id)));
551 assert!(result.reasoning_text.len() > 500); }
553
554 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
557 }
558 Err(e) => {
559 eprintln!("Parse error: {:?}", e);
560 error_count.fetch_add(1, Ordering::Relaxed);
561 }
562 }
563
564 drop(p);
566 }
567 });
568 handles.push(handle);
569 }
570
571 for handle in handles {
573 handle.await.unwrap();
574 }
575
576 let duration = start.elapsed();
577 let total_requests = num_tasks * requests_per_task;
578 let successes = success_count.load(Ordering::Relaxed);
579 let errors = error_count.load(Ordering::Relaxed);
580
581 println!(
583 "High concurrency test: {} tasks, {} requests each",
584 num_tasks, requests_per_task
585 );
586 println!(
587 "Completed in {:?}, {} successes, {} errors",
588 duration, successes, errors
589 );
590 println!(
591 "Throughput: {:.0} requests/sec",
592 (total_requests as f64) / duration.as_secs_f64()
593 );
594
595 assert_eq!(successes, total_requests);
597 assert_eq!(errors, 0);
598
599 let throughput = (total_requests as f64) / duration.as_secs_f64();
601 assert!(
602 throughput > 1000.0,
603 "Throughput too low: {:.0} req/sec",
604 throughput
605 );
606 }
607
608 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
609 async fn test_concurrent_pool_modifications() {
610 let factory = ParserFactory::new();
611 let mut handles = vec![];
612
613 let factory1 = factory.clone();
615 handles.push(tokio::spawn(async move {
616 for _ in 0..100 {
617 let _parser = factory1.get_pooled("deepseek-r1");
618 }
619 }));
620
621 let factory2 = factory.clone();
623 handles.push(tokio::spawn(async move {
624 for _ in 0..10 {
625 factory2.clear_pool();
626 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
627 }
628 }));
629
630 let factory3 = factory.clone();
632 handles.push(tokio::spawn(async move {
633 for i in 0..100 {
634 let models = ["qwen3", "kimi", "unknown"];
635 let _parser = factory3.get_pooled(models[i % 3]);
636 }
637 }));
638
639 for handle in handles {
641 handle.await.unwrap();
642 }
643 }
644}