1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12 parsers::{
13 BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
14 MiniMaxParser, Qwen3Parser, QwenThinkingParser, Step3Parser,
15 },
16 traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26#[derive(Clone)]
28pub struct ParserRegistry {
29 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33 patterns: Arc<RwLock<Vec<(String, String)>>>, }
36
37impl ParserRegistry {
38 pub fn new() -> Self {
40 Self {
41 creators: Arc::new(RwLock::new(HashMap::new())),
42 pool: Arc::new(RwLock::new(HashMap::new())),
43 patterns: Arc::new(RwLock::new(Vec::new())),
44 }
45 }
46
47 pub fn register_parser<F>(&self, name: &str, creator: F)
49 where
50 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51 {
52 let mut creators = self.creators.write().unwrap();
53 creators.insert(name.to_string(), Arc::new(creator));
54 }
55
56 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59 let mut patterns = self.patterns.write().unwrap();
60 patterns.push((pattern.to_string(), parser_name.to_string()));
61 }
62
63 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66 {
68 let pool = self.pool.read().unwrap();
69 if let Some(parser) = pool.get(name) {
70 return Some(Arc::clone(parser));
71 }
72 }
73
74 let creators = self.creators.read().unwrap();
76 if let Some(creator) = creators.get(name) {
77 let parser = Arc::new(Mutex::new(creator()));
78
79 let mut pool = self.pool.write().unwrap();
81 pool.insert(name.to_string(), Arc::clone(&parser));
82
83 Some(parser)
84 } else {
85 None
86 }
87 }
88
89 pub fn has_parser(&self, name: &str) -> bool {
91 let creators = self.creators.read().unwrap();
92 creators.contains_key(name)
93 }
94
95 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98 let creators = self.creators.read().unwrap();
99 creators.get(name).map(|creator| creator())
100 }
101
102 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104 let patterns = self.patterns.read().unwrap();
105 let model_lower = model_id.to_lowercase();
106
107 for (pattern, parser_name) in patterns.iter() {
108 if model_lower.contains(&pattern.to_lowercase()) {
109 return self.get_pooled_parser(parser_name);
110 }
111 }
112 None
113 }
114
115 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118 let patterns = self.patterns.read().unwrap();
119 let model_lower = model_id.to_lowercase();
120
121 for (pattern, parser_name) in patterns.iter() {
122 if model_lower.contains(&pattern.to_lowercase()) {
123 let creators = self.creators.read().unwrap();
124 return creators.contains_key(parser_name);
125 }
126 }
127 false
128 }
129
130 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133 let patterns = self.patterns.read().unwrap();
134 let model_lower = model_id.to_lowercase();
135
136 for (pattern, parser_name) in patterns.iter() {
137 if model_lower.contains(&pattern.to_lowercase()) {
138 return self.create_parser(parser_name);
139 }
140 }
141 None
142 }
143
144 pub fn clear_pool(&self) {
147 let mut pool = self.pool.write().unwrap();
148 pool.clear();
149 }
150}
151
152impl Default for ParserRegistry {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Clone)]
160pub struct ParserFactory {
161 registry: ParserRegistry,
162}
163
164impl ParserFactory {
165 pub fn new() -> Self {
167 let registry = ParserRegistry::new();
168
169 registry.register_parser("base", || {
171 Box::new(BaseReasoningParser::new(ParserConfig::default()))
172 });
173
174 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183 registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189 registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195 registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
197
198 registry.register_pattern("deepseek-r1", "deepseek_r1");
200 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
201 registry.register_pattern("qwen-thinking", "qwen3_thinking");
202 registry.register_pattern("qwen3", "qwen3");
203 registry.register_pattern("qwen", "qwen3");
204 registry.register_pattern("glm45", "glm45");
205 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi", "kimi");
207 registry.register_pattern("step3", "step3");
208 registry.register_pattern("minimax", "minimax");
209 registry.register_pattern("minimax-m2", "minimax");
210 registry.register_pattern("mm-m2", "minimax");
211
212 registry.register_pattern("command-r", "cohere_cmd");
214 registry.register_pattern("command-a", "cohere_cmd");
215 registry.register_pattern("c4ai-command", "cohere_cmd");
216 registry.register_pattern("cohere", "cohere_cmd");
217
218 registry.register_pattern("nemotron-nano", "qwen3");
220 registry.register_pattern("nano-v3", "qwen3");
221
222 Self { registry }
223 }
224
225 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
229 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
231 return parser;
232 }
233
234 self.registry
236 .get_pooled_parser("passthrough")
237 .unwrap_or_else(|| {
238 self.registry.register_parser("passthrough", || {
240 let config = ParserConfig {
241 think_start_token: "".to_string(),
242 think_end_token: "".to_string(),
243 stream_reasoning: true,
244 max_buffer_size: 65536,
245 initial_in_reasoning: false,
246 };
247 Box::new(
248 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
249 )
250 });
251 self.registry.get_pooled_parser("passthrough").unwrap()
252 })
253 }
254
255 pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
259 if let Some(parser) = self.registry.create_for_model(model_id) {
261 return Ok(parser);
262 }
263
264 let config = ParserConfig {
266 think_start_token: "".to_string(),
267 think_end_token: "".to_string(),
268 stream_reasoning: true,
269 max_buffer_size: 65536,
270 initial_in_reasoning: false,
271 };
272 Ok(Box::new(
273 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
274 ))
275 }
276
277 pub fn registry(&self) -> &ParserRegistry {
279 &self.registry
280 }
281
282 pub fn clear_pool(&self) {
285 self.registry.clear_pool();
286 }
287}
288
289impl Default for ParserFactory {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_factory_creates_deepseek_r1() {
301 let factory = ParserFactory::new();
302 let parser = factory.create("deepseek-r1-distill").unwrap();
303 assert_eq!(parser.model_type(), "deepseek_r1");
304 }
305
306 #[test]
307 fn test_factory_creates_qwen3() {
308 let factory = ParserFactory::new();
309 let parser = factory.create("qwen3-7b").unwrap();
310 assert_eq!(parser.model_type(), "qwen3");
311 }
312
313 #[test]
314 fn test_factory_creates_kimi() {
315 let factory = ParserFactory::new();
316 let parser = factory.create("kimi-chat").unwrap();
317 assert_eq!(parser.model_type(), "kimi");
318 }
319
320 #[test]
321 fn test_factory_fallback_to_passthrough() {
322 let factory = ParserFactory::new();
323 let parser = factory.create("unknown-model").unwrap();
324 assert_eq!(parser.model_type(), "passthrough");
325 }
326
327 #[test]
328 fn test_case_insensitive_matching() {
329 let factory = ParserFactory::new();
330 let parser1 = factory.create("DeepSeek-R1").unwrap();
331 let parser2 = factory.create("QWEN3").unwrap();
332 let parser3 = factory.create("Kimi").unwrap();
333
334 assert_eq!(parser1.model_type(), "deepseek_r1");
335 assert_eq!(parser2.model_type(), "qwen3");
336 assert_eq!(parser3.model_type(), "kimi");
337 }
338
339 #[test]
340 fn test_step3_model() {
341 let factory = ParserFactory::new();
342 let step3 = factory.create("step3-model").unwrap();
343 assert_eq!(step3.model_type(), "step3");
344 }
345
346 #[test]
347 fn test_glm45_model() {
348 let factory = ParserFactory::new();
349 let glm45 = factory.create("glm45-v2").unwrap();
350 assert_eq!(glm45.model_type(), "glm45");
351 }
352
353 #[test]
354 fn test_minimax_model() {
355 let factory = ParserFactory::new();
356 let minimax = factory.create("minimax-m2").unwrap();
357 assert_eq!(minimax.model_type(), "minimax");
358
359 let mm = factory.create("mm-m2-chat").unwrap();
361 assert_eq!(mm.model_type(), "minimax");
362 }
363
364 #[test]
365 fn test_cohere_cmd_model() {
366 let factory = ParserFactory::new();
367
368 let command_r = factory.create("command-r-plus").unwrap();
370 assert_eq!(command_r.model_type(), "cohere_cmd");
371
372 let command_a = factory.create("command-a-03-2025").unwrap();
373 assert_eq!(command_a.model_type(), "cohere_cmd");
374
375 let c4ai = factory.create("c4ai-command-r-v01").unwrap();
376 assert_eq!(c4ai.model_type(), "cohere_cmd");
377
378 let cohere = factory.create("cohere-embed").unwrap();
379 assert_eq!(cohere.model_type(), "cohere_cmd");
380 }
381
382 #[tokio::test]
383 async fn test_pooled_parser_reuse() {
384 let factory = ParserFactory::new();
385
386 let parser1 = factory.get_pooled("deepseek-r1");
388 let parser2 = factory.get_pooled("deepseek-r1");
389
390 assert!(Arc::ptr_eq(&parser1, &parser2));
392
393 let parser3 = factory.get_pooled("qwen3");
395 assert!(!Arc::ptr_eq(&parser1, &parser3));
396 }
397
398 #[tokio::test]
399 async fn test_pooled_parser_concurrent_access() {
400 let factory = ParserFactory::new();
401 let parser = factory.get_pooled("deepseek-r1");
402
403 let mut handles = vec![];
405
406 for i in 0..3 {
407 let parser_clone = Arc::clone(&parser);
408 let handle = tokio::spawn(async move {
409 let mut parser = parser_clone.lock().await;
410 let input = format!("thread {} reasoning</think>answer", i);
411 let result = parser.detect_and_parse_reasoning(&input).unwrap();
412 assert_eq!(result.normal_text, "answer");
413 assert!(result.reasoning_text.contains("reasoning"));
414 });
415 handles.push(handle);
416 }
417
418 for handle in handles {
420 handle.await.unwrap();
421 }
422 }
423
424 #[tokio::test]
425 async fn test_pool_clearing() {
426 let factory = ParserFactory::new();
427
428 let parser1 = factory.get_pooled("deepseek-r1");
430
431 factory.clear_pool();
433
434 let parser2 = factory.get_pooled("deepseek-r1");
436
437 assert!(!Arc::ptr_eq(&parser1, &parser2));
439 }
440
441 #[tokio::test]
442 async fn test_passthrough_parser_pooling() {
443 let factory = ParserFactory::new();
444
445 let parser1 = factory.get_pooled("unknown-model-1");
447 let parser2 = factory.get_pooled("unknown-model-2");
448
449 assert!(Arc::ptr_eq(&parser1, &parser2));
451
452 let parser = parser1.lock().await;
453 assert_eq!(parser.model_type(), "passthrough");
454 }
455
456 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
457 async fn test_high_concurrency_parser_access() {
458 use std::{
459 sync::atomic::{AtomicUsize, Ordering},
460 time::Instant,
461 };
462
463 let factory = ParserFactory::new();
464 let num_tasks = 100;
465 let requests_per_task = 50;
466 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
467
468 let success_count = Arc::new(AtomicUsize::new(0));
470 let error_count = Arc::new(AtomicUsize::new(0));
471
472 let start = Instant::now();
473 let mut handles = vec![];
474
475 for task_id in 0..num_tasks {
476 let factory = factory.clone();
477 let models = models.clone();
478 let success_count = Arc::clone(&success_count);
479 let error_count = Arc::clone(&error_count);
480
481 let handle = tokio::spawn(async move {
482 for request_id in 0..requests_per_task {
483 let model = &models[(task_id + request_id) % models.len()];
485 let parser = factory.get_pooled(model);
486
487 let mut p = parser.lock().await;
489
490 let reasoning_text = format!(
493 "Task {} is processing request {}. Let me think through this step by step. \
494 First, I need to understand the problem. The problem involves analyzing data \
495 and making calculations. Let me break this down: \n\
496 1. Initial analysis shows that we have multiple variables to consider. \
497 2. The data suggests a pattern that needs further investigation. \
498 3. Computing the values: {} * {} = {}. \
499 4. Cross-referencing with previous results indicates consistency. \
500 5. The mathematical proof follows from the axioms... \
501 6. Considering edge cases and boundary conditions... \
502 7. Validating against known constraints... \
503 8. The conclusion follows logically from premises A, B, and C. \
504 This reasoning chain demonstrates the validity of our approach.",
505 task_id, request_id, task_id, request_id, task_id * request_id
506 );
507
508 let answer_text = format!(
509 "Based on my analysis, the answer for task {} request {} is: \
510 The solution involves multiple steps as outlined in the reasoning. \
511 The final result is {} with confidence level high. \
512 This conclusion is supported by rigorous mathematical analysis \
513 and has been validated against multiple test cases. \
514 The implementation should handle edge cases appropriately.",
515 task_id,
516 request_id,
517 task_id * request_id
518 );
519
520 let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
521
522 match p.detect_and_parse_reasoning(&input) {
523 Ok(result) => {
524 assert!(result.normal_text.contains(&format!("task {}", task_id)));
526
527 if !result.reasoning_text.is_empty() {
530 assert!(result
531 .reasoning_text
532 .contains(&format!("Task {}", task_id)));
533 assert!(result.reasoning_text.len() > 500); }
535
536 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
539 }
540 Err(e) => {
541 eprintln!("Parse error: {:?}", e);
542 error_count.fetch_add(1, Ordering::Relaxed);
543 }
544 }
545
546 drop(p);
548 }
549 });
550 handles.push(handle);
551 }
552
553 for handle in handles {
555 handle.await.unwrap();
556 }
557
558 let duration = start.elapsed();
559 let total_requests = num_tasks * requests_per_task;
560 let successes = success_count.load(Ordering::Relaxed);
561 let errors = error_count.load(Ordering::Relaxed);
562
563 println!(
565 "High concurrency test: {} tasks, {} requests each",
566 num_tasks, requests_per_task
567 );
568 println!(
569 "Completed in {:?}, {} successes, {} errors",
570 duration, successes, errors
571 );
572 println!(
573 "Throughput: {:.0} requests/sec",
574 (total_requests as f64) / duration.as_secs_f64()
575 );
576
577 assert_eq!(successes, total_requests);
579 assert_eq!(errors, 0);
580
581 let throughput = (total_requests as f64) / duration.as_secs_f64();
583 assert!(
584 throughput > 1000.0,
585 "Throughput too low: {:.0} req/sec",
586 throughput
587 );
588 }
589
590 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
591 async fn test_concurrent_pool_modifications() {
592 let factory = ParserFactory::new();
593 let mut handles = vec![];
594
595 let factory1 = factory.clone();
597 handles.push(tokio::spawn(async move {
598 for _ in 0..100 {
599 let _parser = factory1.get_pooled("deepseek-r1");
600 }
601 }));
602
603 let factory2 = factory.clone();
605 handles.push(tokio::spawn(async move {
606 for _ in 0..10 {
607 factory2.clear_pool();
608 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
609 }
610 }));
611
612 let factory3 = factory.clone();
614 handles.push(tokio::spawn(async move {
615 for i in 0..100 {
616 let models = ["qwen3", "kimi", "unknown"];
617 let _parser = factory3.get_pooled(models[i % 3]);
618 }
619 }));
620
621 for handle in handles {
623 handle.await.unwrap();
624 }
625 }
626}