1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock},
7};
8
9use tokio::sync::Mutex;
10
11use crate::{
12 parsers::{
13 BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, MiniMaxParser, Qwen3Parser,
14 QwenThinkingParser, Step3Parser,
15 },
16 traits::{ParseError, ParserConfig, ReasoningParser},
17};
18
19pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
22
23type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
25
26#[derive(Clone)]
28pub struct ParserRegistry {
29 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
31 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
33 patterns: Arc<RwLock<Vec<(String, String)>>>, }
36
37impl ParserRegistry {
38 pub fn new() -> Self {
40 Self {
41 creators: Arc::new(RwLock::new(HashMap::new())),
42 pool: Arc::new(RwLock::new(HashMap::new())),
43 patterns: Arc::new(RwLock::new(Vec::new())),
44 }
45 }
46
47 pub fn register_parser<F>(&self, name: &str, creator: F)
49 where
50 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
51 {
52 let mut creators = self.creators.write().unwrap();
53 creators.insert(name.to_string(), Arc::new(creator));
54 }
55
56 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
59 let mut patterns = self.patterns.write().unwrap();
60 patterns.push((pattern.to_string(), parser_name.to_string()));
61 }
62
63 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66 {
68 let pool = self.pool.read().unwrap();
69 if let Some(parser) = pool.get(name) {
70 return Some(Arc::clone(parser));
71 }
72 }
73
74 let creators = self.creators.read().unwrap();
76 if let Some(creator) = creators.get(name) {
77 let parser = Arc::new(Mutex::new(creator()));
78
79 let mut pool = self.pool.write().unwrap();
81 pool.insert(name.to_string(), Arc::clone(&parser));
82
83 Some(parser)
84 } else {
85 None
86 }
87 }
88
89 pub fn has_parser(&self, name: &str) -> bool {
91 let creators = self.creators.read().unwrap();
92 creators.contains_key(name)
93 }
94
95 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
98 let creators = self.creators.read().unwrap();
99 creators.get(name).map(|creator| creator())
100 }
101
102 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
104 let patterns = self.patterns.read().unwrap();
105 let model_lower = model_id.to_lowercase();
106
107 for (pattern, parser_name) in patterns.iter() {
108 if model_lower.contains(&pattern.to_lowercase()) {
109 return self.get_pooled_parser(parser_name);
110 }
111 }
112 None
113 }
114
115 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
118 let patterns = self.patterns.read().unwrap();
119 let model_lower = model_id.to_lowercase();
120
121 for (pattern, parser_name) in patterns.iter() {
122 if model_lower.contains(&pattern.to_lowercase()) {
123 let creators = self.creators.read().unwrap();
124 return creators.contains_key(parser_name);
125 }
126 }
127 false
128 }
129
130 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
133 let patterns = self.patterns.read().unwrap();
134 let model_lower = model_id.to_lowercase();
135
136 for (pattern, parser_name) in patterns.iter() {
137 if model_lower.contains(&pattern.to_lowercase()) {
138 return self.create_parser(parser_name);
139 }
140 }
141 None
142 }
143
144 pub fn clear_pool(&self) {
147 let mut pool = self.pool.write().unwrap();
148 pool.clear();
149 }
150}
151
152impl Default for ParserRegistry {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Clone)]
160pub struct ParserFactory {
161 registry: ParserRegistry,
162}
163
164impl ParserFactory {
165 pub fn new() -> Self {
167 let registry = ParserRegistry::new();
168
169 registry.register_parser("base", || {
171 Box::new(BaseReasoningParser::new(ParserConfig::default()))
172 });
173
174 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
176
177 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
179
180 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
182
183 registry.register_parser("kimi", || Box::new(KimiParser::new()));
185
186 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
188
189 registry.register_parser("step3", || Box::new(Step3Parser::new()));
191
192 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
194
195 registry.register_pattern("deepseek-r1", "deepseek_r1");
197 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
198 registry.register_pattern("qwen-thinking", "qwen3_thinking");
199 registry.register_pattern("qwen3", "qwen3");
200 registry.register_pattern("qwen", "qwen3");
201 registry.register_pattern("glm45", "glm45");
202 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi", "kimi");
204 registry.register_pattern("step3", "step3");
205 registry.register_pattern("minimax", "minimax");
206 registry.register_pattern("minimax-m2", "minimax");
207 registry.register_pattern("mm-m2", "minimax");
208
209 registry.register_pattern("nemotron-nano", "qwen3");
211 registry.register_pattern("nano-v3", "qwen3");
212
213 Self { registry }
214 }
215
216 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
220 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
222 return parser;
223 }
224
225 self.registry
227 .get_pooled_parser("passthrough")
228 .unwrap_or_else(|| {
229 self.registry.register_parser("passthrough", || {
231 let config = ParserConfig {
232 think_start_token: "".to_string(),
233 think_end_token: "".to_string(),
234 stream_reasoning: true,
235 max_buffer_size: 65536,
236 initial_in_reasoning: false,
237 };
238 Box::new(
239 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
240 )
241 });
242 self.registry.get_pooled_parser("passthrough").unwrap()
243 })
244 }
245
246 pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
250 if let Some(parser) = self.registry.create_for_model(model_id) {
252 return Ok(parser);
253 }
254
255 let config = ParserConfig {
257 think_start_token: "".to_string(),
258 think_end_token: "".to_string(),
259 stream_reasoning: true,
260 max_buffer_size: 65536,
261 initial_in_reasoning: false,
262 };
263 Ok(Box::new(
264 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
265 ))
266 }
267
268 pub fn registry(&self) -> &ParserRegistry {
270 &self.registry
271 }
272
273 pub fn clear_pool(&self) {
276 self.registry.clear_pool();
277 }
278}
279
280impl Default for ParserFactory {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_factory_creates_deepseek_r1() {
292 let factory = ParserFactory::new();
293 let parser = factory.create("deepseek-r1-distill").unwrap();
294 assert_eq!(parser.model_type(), "deepseek_r1");
295 }
296
297 #[test]
298 fn test_factory_creates_qwen3() {
299 let factory = ParserFactory::new();
300 let parser = factory.create("qwen3-7b").unwrap();
301 assert_eq!(parser.model_type(), "qwen3");
302 }
303
304 #[test]
305 fn test_factory_creates_kimi() {
306 let factory = ParserFactory::new();
307 let parser = factory.create("kimi-chat").unwrap();
308 assert_eq!(parser.model_type(), "kimi");
309 }
310
311 #[test]
312 fn test_factory_fallback_to_passthrough() {
313 let factory = ParserFactory::new();
314 let parser = factory.create("unknown-model").unwrap();
315 assert_eq!(parser.model_type(), "passthrough");
316 }
317
318 #[test]
319 fn test_case_insensitive_matching() {
320 let factory = ParserFactory::new();
321 let parser1 = factory.create("DeepSeek-R1").unwrap();
322 let parser2 = factory.create("QWEN3").unwrap();
323 let parser3 = factory.create("Kimi").unwrap();
324
325 assert_eq!(parser1.model_type(), "deepseek_r1");
326 assert_eq!(parser2.model_type(), "qwen3");
327 assert_eq!(parser3.model_type(), "kimi");
328 }
329
330 #[test]
331 fn test_step3_model() {
332 let factory = ParserFactory::new();
333 let step3 = factory.create("step3-model").unwrap();
334 assert_eq!(step3.model_type(), "step3");
335 }
336
337 #[test]
338 fn test_glm45_model() {
339 let factory = ParserFactory::new();
340 let glm45 = factory.create("glm45-v2").unwrap();
341 assert_eq!(glm45.model_type(), "glm45");
342 }
343
344 #[test]
345 fn test_minimax_model() {
346 let factory = ParserFactory::new();
347 let minimax = factory.create("minimax-m2").unwrap();
348 assert_eq!(minimax.model_type(), "minimax");
349
350 let mm = factory.create("mm-m2-chat").unwrap();
352 assert_eq!(mm.model_type(), "minimax");
353 }
354
355 #[tokio::test]
356 async fn test_pooled_parser_reuse() {
357 let factory = ParserFactory::new();
358
359 let parser1 = factory.get_pooled("deepseek-r1");
361 let parser2 = factory.get_pooled("deepseek-r1");
362
363 assert!(Arc::ptr_eq(&parser1, &parser2));
365
366 let parser3 = factory.get_pooled("qwen3");
368 assert!(!Arc::ptr_eq(&parser1, &parser3));
369 }
370
371 #[tokio::test]
372 async fn test_pooled_parser_concurrent_access() {
373 let factory = ParserFactory::new();
374 let parser = factory.get_pooled("deepseek-r1");
375
376 let mut handles = vec![];
378
379 for i in 0..3 {
380 let parser_clone = Arc::clone(&parser);
381 let handle = tokio::spawn(async move {
382 let mut parser = parser_clone.lock().await;
383 let input = format!("thread {} reasoning</think>answer", i);
384 let result = parser.detect_and_parse_reasoning(&input).unwrap();
385 assert_eq!(result.normal_text, "answer");
386 assert!(result.reasoning_text.contains("reasoning"));
387 });
388 handles.push(handle);
389 }
390
391 for handle in handles {
393 handle.await.unwrap();
394 }
395 }
396
397 #[tokio::test]
398 async fn test_pool_clearing() {
399 let factory = ParserFactory::new();
400
401 let parser1 = factory.get_pooled("deepseek-r1");
403
404 factory.clear_pool();
406
407 let parser2 = factory.get_pooled("deepseek-r1");
409
410 assert!(!Arc::ptr_eq(&parser1, &parser2));
412 }
413
414 #[tokio::test]
415 async fn test_passthrough_parser_pooling() {
416 let factory = ParserFactory::new();
417
418 let parser1 = factory.get_pooled("unknown-model-1");
420 let parser2 = factory.get_pooled("unknown-model-2");
421
422 assert!(Arc::ptr_eq(&parser1, &parser2));
424
425 let parser = parser1.lock().await;
426 assert_eq!(parser.model_type(), "passthrough");
427 }
428
429 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
430 async fn test_high_concurrency_parser_access() {
431 use std::{
432 sync::atomic::{AtomicUsize, Ordering},
433 time::Instant,
434 };
435
436 let factory = ParserFactory::new();
437 let num_tasks = 100;
438 let requests_per_task = 50;
439 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
440
441 let success_count = Arc::new(AtomicUsize::new(0));
443 let error_count = Arc::new(AtomicUsize::new(0));
444
445 let start = Instant::now();
446 let mut handles = vec![];
447
448 for task_id in 0..num_tasks {
449 let factory = factory.clone();
450 let models = models.clone();
451 let success_count = Arc::clone(&success_count);
452 let error_count = Arc::clone(&error_count);
453
454 let handle = tokio::spawn(async move {
455 for request_id in 0..requests_per_task {
456 let model = &models[(task_id + request_id) % models.len()];
458 let parser = factory.get_pooled(model);
459
460 let mut p = parser.lock().await;
462
463 let reasoning_text = format!(
466 "Task {} is processing request {}. Let me think through this step by step. \
467 First, I need to understand the problem. The problem involves analyzing data \
468 and making calculations. Let me break this down: \n\
469 1. Initial analysis shows that we have multiple variables to consider. \
470 2. The data suggests a pattern that needs further investigation. \
471 3. Computing the values: {} * {} = {}. \
472 4. Cross-referencing with previous results indicates consistency. \
473 5. The mathematical proof follows from the axioms... \
474 6. Considering edge cases and boundary conditions... \
475 7. Validating against known constraints... \
476 8. The conclusion follows logically from premises A, B, and C. \
477 This reasoning chain demonstrates the validity of our approach.",
478 task_id, request_id, task_id, request_id, task_id * request_id
479 );
480
481 let answer_text = format!(
482 "Based on my analysis, the answer for task {} request {} is: \
483 The solution involves multiple steps as outlined in the reasoning. \
484 The final result is {} with confidence level high. \
485 This conclusion is supported by rigorous mathematical analysis \
486 and has been validated against multiple test cases. \
487 The implementation should handle edge cases appropriately.",
488 task_id,
489 request_id,
490 task_id * request_id
491 );
492
493 let input = format!("<think>{}</think>{}", reasoning_text, answer_text);
494
495 match p.detect_and_parse_reasoning(&input) {
496 Ok(result) => {
497 assert!(result.normal_text.contains(&format!("task {}", task_id)));
499
500 if !result.reasoning_text.is_empty() {
503 assert!(result
504 .reasoning_text
505 .contains(&format!("Task {}", task_id)));
506 assert!(result.reasoning_text.len() > 500); }
508
509 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
512 }
513 Err(e) => {
514 eprintln!("Parse error: {:?}", e);
515 error_count.fetch_add(1, Ordering::Relaxed);
516 }
517 }
518
519 drop(p);
521 }
522 });
523 handles.push(handle);
524 }
525
526 for handle in handles {
528 handle.await.unwrap();
529 }
530
531 let duration = start.elapsed();
532 let total_requests = num_tasks * requests_per_task;
533 let successes = success_count.load(Ordering::Relaxed);
534 let errors = error_count.load(Ordering::Relaxed);
535
536 println!(
538 "High concurrency test: {} tasks, {} requests each",
539 num_tasks, requests_per_task
540 );
541 println!(
542 "Completed in {:?}, {} successes, {} errors",
543 duration, successes, errors
544 );
545 println!(
546 "Throughput: {:.0} requests/sec",
547 (total_requests as f64) / duration.as_secs_f64()
548 );
549
550 assert_eq!(successes, total_requests);
552 assert_eq!(errors, 0);
553
554 let throughput = (total_requests as f64) / duration.as_secs_f64();
556 assert!(
557 throughput > 1000.0,
558 "Throughput too low: {:.0} req/sec",
559 throughput
560 );
561 }
562
563 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
564 async fn test_concurrent_pool_modifications() {
565 let factory = ParserFactory::new();
566 let mut handles = vec![];
567
568 let factory1 = factory.clone();
570 handles.push(tokio::spawn(async move {
571 for _ in 0..100 {
572 let _parser = factory1.get_pooled("deepseek-r1");
573 }
574 }));
575
576 let factory2 = factory.clone();
578 handles.push(tokio::spawn(async move {
579 for _ in 0..10 {
580 factory2.clear_pool();
581 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
582 }
583 }));
584
585 let factory3 = factory.clone();
587 handles.push(tokio::spawn(async move {
588 for i in 0..100 {
589 let models = ["qwen3", "kimi", "unknown"];
590 let _parser = factory3.get_pooled(models[i % 3]);
591 }
592 }));
593
594 for handle in handles {
596 handle.await.unwrap();
597 }
598 }
599}