1use std::{collections::HashMap, sync::Arc};
5
6use parking_lot::RwLock;
7use tokio::sync::Mutex;
8
9use crate::{
10 parsers::{
11 BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser,
12 MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser,
13 },
14 traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE},
15};
16
17pub type PooledParser = Arc<Mutex<Box<dyn ReasoningParser>>>;
20
21type ParserCreator = Arc<dyn Fn() -> Box<dyn ReasoningParser> + Send + Sync>;
23
24#[derive(Clone)]
26pub struct ParserRegistry {
27 creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29 pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31 patterns: Arc<RwLock<Vec<(String, String)>>>, }
34
35impl ParserRegistry {
36 pub fn new() -> Self {
38 Self {
39 creators: Arc::new(RwLock::new(HashMap::new())),
40 pool: Arc::new(RwLock::new(HashMap::new())),
41 patterns: Arc::new(RwLock::new(Vec::new())),
42 }
43 }
44
45 pub fn register_parser<F>(&self, name: &str, creator: F)
47 where
48 F: Fn() -> Box<dyn ReasoningParser> + Send + Sync + 'static,
49 {
50 let mut creators = self.creators.write();
51 creators.insert(name.to_string(), Arc::new(creator));
52 }
53
54 pub fn register_pattern(&self, pattern: &str, parser_name: &str) {
57 let mut patterns = self.patterns.write();
58 patterns.push((pattern.to_string(), parser_name.to_string()));
59 }
60
61 pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
64 {
66 let pool = self.pool.read();
67 if let Some(parser) = pool.get(name) {
68 return Some(Arc::clone(parser));
69 }
70 }
71
72 let creators = self.creators.read();
74 if let Some(creator) = creators.get(name) {
75 let parser = Arc::new(Mutex::new(creator()));
76
77 let mut pool = self.pool.write();
79 pool.insert(name.to_string(), Arc::clone(&parser));
80
81 Some(parser)
82 } else {
83 None
84 }
85 }
86
87 pub fn has_parser(&self, name: &str) -> bool {
89 let creators = self.creators.read();
90 creators.contains_key(name)
91 }
92
93 pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
96 let creators = self.creators.read();
97 creators.get(name).map(|creator| creator())
98 }
99
100 pub fn find_pooled_parser_for_model(&self, model_id: &str) -> Option<PooledParser> {
102 let patterns = self.patterns.read();
103 let model_lower = model_id.to_lowercase();
104
105 for (pattern, parser_name) in patterns.iter() {
106 if model_lower.contains(&pattern.to_lowercase()) {
107 return self.get_pooled_parser(parser_name);
108 }
109 }
110 None
111 }
112
113 pub fn has_parser_for_model(&self, model_id: &str) -> bool {
116 let patterns = self.patterns.read();
117 let model_lower = model_id.to_lowercase();
118
119 for (pattern, parser_name) in patterns.iter() {
120 if model_lower.contains(&pattern.to_lowercase()) {
121 let creators = self.creators.read();
122 return creators.contains_key(parser_name);
123 }
124 }
125 false
126 }
127
128 pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
131 let patterns = self.patterns.read();
132 let model_lower = model_id.to_lowercase();
133
134 for (pattern, parser_name) in patterns.iter() {
135 if model_lower.contains(&pattern.to_lowercase()) {
136 return self.create_parser(parser_name);
137 }
138 }
139 None
140 }
141
142 pub fn clear_pool(&self) {
145 let mut pool = self.pool.write();
146 pool.clear();
147 }
148}
149
150impl Default for ParserRegistry {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[derive(Clone)]
158pub struct ParserFactory {
159 registry: ParserRegistry,
160}
161
162impl ParserFactory {
163 pub fn new() -> Self {
165 let registry = ParserRegistry::new();
166
167 registry.register_parser("base", || {
169 Box::new(BaseReasoningParser::new(ParserConfig::default()))
170 });
171
172 registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new()));
174
175 registry.register_parser("qwen3", || Box::new(Qwen3Parser::new()));
177
178 registry.register_parser("qwen3_thinking", || Box::new(QwenThinkingParser::new()));
180
181 registry.register_parser("kimi", || Box::new(KimiParser::new()));
183
184 registry.register_parser("glm45", || Box::new(Glm45Parser::new()));
186
187 registry.register_parser("step3", || Box::new(Step3Parser::new()));
189
190 registry.register_parser("minimax", || Box::new(MiniMaxParser::new()));
192
193 registry.register_parser("cohere_cmd", || Box::new(CohereCmdParser::new()));
195
196 registry.register_parser("nano_v3", || Box::new(NanoV3Parser::new()));
198
199 registry.register_pattern("deepseek-r1", "deepseek_r1");
201 registry.register_pattern("qwen3-thinking", "qwen3_thinking");
202 registry.register_pattern("qwen-thinking", "qwen3_thinking");
203 registry.register_pattern("qwen3", "qwen3");
204 registry.register_pattern("qwen", "qwen3");
205 registry.register_pattern("glm45", "glm45");
206 registry.register_pattern("glm47", "glm45"); registry.register_pattern("kimi", "kimi");
208 registry.register_pattern("step3", "step3");
209 registry.register_pattern("minimax", "minimax");
210 registry.register_pattern("minimax-m2", "minimax");
211 registry.register_pattern("mm-m2", "minimax");
212
213 registry.register_pattern("command-r", "cohere_cmd");
215 registry.register_pattern("command-a", "cohere_cmd");
216 registry.register_pattern("c4ai-command", "cohere_cmd");
217 registry.register_pattern("cohere", "cohere_cmd");
218
219 registry.register_pattern("nemotron-nano", "nano_v3");
221 registry.register_pattern("nemotron-super", "nano_v3");
222 registry.register_pattern("nano-v3", "nano_v3");
223
224 Self { registry }
225 }
226
227 #[expect(
231 clippy::expect_used,
232 reason = "passthrough parser is registered on the line above; None indicates a bug in registration logic"
233 )]
234 pub fn get_pooled(&self, model_id: &str) -> PooledParser {
235 if let Some(parser) = self.registry.find_pooled_parser_for_model(model_id) {
237 return parser;
238 }
239
240 self.registry
242 .get_pooled_parser("passthrough")
243 .unwrap_or_else(|| {
244 self.registry.register_parser("passthrough", || {
246 let config = ParserConfig {
247 think_start_token: String::new(),
248 think_end_token: String::new(),
249 stream_reasoning: true,
250 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
251 initial_in_reasoning: false,
252 };
253 Box::new(
254 BaseReasoningParser::new(config).with_model_type("passthrough".to_string()),
255 )
256 });
257 self.registry
258 .get_pooled_parser("passthrough")
259 .expect("passthrough parser was just registered")
260 })
261 }
262
263 pub fn create(&self, model_id: &str) -> Box<dyn ReasoningParser> {
267 if let Some(parser) = self.registry.create_for_model(model_id) {
269 return parser;
270 }
271
272 let config = ParserConfig {
274 think_start_token: String::new(),
275 think_end_token: String::new(),
276 stream_reasoning: true,
277 max_buffer_size: DEFAULT_MAX_BUFFER_SIZE,
278 initial_in_reasoning: false,
279 };
280 Box::new(BaseReasoningParser::new(config).with_model_type("passthrough".to_string()))
281 }
282
283 pub fn registry(&self) -> &ParserRegistry {
285 &self.registry
286 }
287
288 pub fn clear_pool(&self) {
291 self.registry.clear_pool();
292 }
293}
294
295impl Default for ParserFactory {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301#[cfg(test)]
302#[expect(
303 clippy::disallowed_methods,
304 reason = "tokio::spawn is fine in unit tests that await all handles"
305)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_factory_creates_deepseek_r1() {
311 let factory = ParserFactory::new();
312 let parser = factory.create("deepseek-r1-distill");
313 assert_eq!(parser.model_type(), "deepseek_r1");
314 }
315
316 #[test]
317 fn test_factory_creates_qwen3() {
318 let factory = ParserFactory::new();
319 let parser = factory.create("qwen3-7b");
320 assert_eq!(parser.model_type(), "qwen3");
321 }
322
323 #[test]
324 fn test_factory_creates_kimi() {
325 let factory = ParserFactory::new();
326 let parser = factory.create("kimi-chat");
327 assert_eq!(parser.model_type(), "kimi");
328 }
329
330 #[test]
331 fn test_factory_fallback_to_passthrough() {
332 let factory = ParserFactory::new();
333 let parser = factory.create("unknown-model");
334 assert_eq!(parser.model_type(), "passthrough");
335 }
336
337 #[test]
338 fn test_case_insensitive_matching() {
339 let factory = ParserFactory::new();
340 let parser1 = factory.create("DeepSeek-R1");
341 let parser2 = factory.create("QWEN3");
342 let parser3 = factory.create("Kimi");
343
344 assert_eq!(parser1.model_type(), "deepseek_r1");
345 assert_eq!(parser2.model_type(), "qwen3");
346 assert_eq!(parser3.model_type(), "kimi");
347 }
348
349 #[test]
350 fn test_step3_model() {
351 let factory = ParserFactory::new();
352 let step3 = factory.create("step3-model");
353 assert_eq!(step3.model_type(), "step3");
354 }
355
356 #[test]
357 fn test_glm45_model() {
358 let factory = ParserFactory::new();
359 let glm45 = factory.create("glm45-v2");
360 assert_eq!(glm45.model_type(), "glm45");
361 }
362
363 #[test]
364 fn test_minimax_model() {
365 let factory = ParserFactory::new();
366 let minimax = factory.create("minimax-m2");
367 assert_eq!(minimax.model_type(), "minimax");
368
369 let mm = factory.create("mm-m2-chat");
371 assert_eq!(mm.model_type(), "minimax");
372 }
373
374 #[test]
375 fn test_nano_v3_model() {
376 let factory = ParserFactory::new();
377
378 let nano = factory.create("nano-v3-chat");
379 assert_eq!(nano.model_type(), "nano_v3");
380
381 let nemotron_nano = factory.create("nemotron-nano-4b");
382 assert_eq!(nemotron_nano.model_type(), "nano_v3");
383
384 let nemotron_super = factory.create("NVIDIA-Nemotron/nemotron-super");
385 assert_eq!(nemotron_super.model_type(), "nano_v3");
386 }
387
388 #[test]
389 fn test_cohere_cmd_model() {
390 let factory = ParserFactory::new();
391
392 let command_r = factory.create("command-r-plus");
394 assert_eq!(command_r.model_type(), "cohere_cmd");
395
396 let command_a = factory.create("command-a-03-2025");
397 assert_eq!(command_a.model_type(), "cohere_cmd");
398
399 let c4ai = factory.create("c4ai-command-r-v01");
400 assert_eq!(c4ai.model_type(), "cohere_cmd");
401
402 let cohere = factory.create("cohere-embed");
403 assert_eq!(cohere.model_type(), "cohere_cmd");
404 }
405
406 #[tokio::test]
407 async fn test_pooled_parser_reuse() {
408 let factory = ParserFactory::new();
409
410 let parser1 = factory.get_pooled("deepseek-r1");
412 let parser2 = factory.get_pooled("deepseek-r1");
413
414 assert!(Arc::ptr_eq(&parser1, &parser2));
416
417 let parser3 = factory.get_pooled("qwen3");
419 assert!(!Arc::ptr_eq(&parser1, &parser3));
420 }
421
422 #[tokio::test]
423 async fn test_pooled_parser_concurrent_access() {
424 let factory = ParserFactory::new();
425 let parser = factory.get_pooled("deepseek-r1");
426
427 let mut handles = vec![];
429
430 for i in 0..3 {
431 let parser_clone = Arc::clone(&parser);
432 let handle = tokio::spawn(async move {
433 let mut parser = parser_clone.lock().await;
434 let input = format!("thread {i} reasoning</think>answer");
435 let result = parser.detect_and_parse_reasoning(&input).unwrap();
436 assert_eq!(result.normal_text, "answer");
437 assert!(result.reasoning_text.contains("reasoning"));
438 });
439 handles.push(handle);
440 }
441
442 for handle in handles {
444 handle.await.unwrap();
445 }
446 }
447
448 #[tokio::test]
449 async fn test_pool_clearing() {
450 let factory = ParserFactory::new();
451
452 let parser1 = factory.get_pooled("deepseek-r1");
454
455 factory.clear_pool();
457
458 let parser2 = factory.get_pooled("deepseek-r1");
460
461 assert!(!Arc::ptr_eq(&parser1, &parser2));
463 }
464
465 #[tokio::test]
466 async fn test_passthrough_parser_pooling() {
467 let factory = ParserFactory::new();
468
469 let parser1 = factory.get_pooled("unknown-model-1");
471 let parser2 = factory.get_pooled("unknown-model-2");
472
473 assert!(Arc::ptr_eq(&parser1, &parser2));
475
476 let parser = parser1.lock().await;
477 assert_eq!(parser.model_type(), "passthrough");
478 }
479
480 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
481 async fn test_high_concurrency_parser_access() {
482 use std::{
483 sync::atomic::{AtomicUsize, Ordering},
484 time::Instant,
485 };
486
487 let factory = ParserFactory::new();
488 let num_tasks = 100;
489 let requests_per_task = 50;
490 let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
491
492 let success_count = Arc::new(AtomicUsize::new(0));
494 let error_count = Arc::new(AtomicUsize::new(0));
495
496 let start = Instant::now();
497 let mut handles = vec![];
498
499 for task_id in 0..num_tasks {
500 let factory = factory.clone();
501 let models = models.clone();
502 let success_count = Arc::clone(&success_count);
503 let error_count = Arc::clone(&error_count);
504
505 let handle = tokio::spawn(async move {
506 for request_id in 0..requests_per_task {
507 let model = &models[(task_id + request_id) % models.len()];
509 let parser = factory.get_pooled(model);
510
511 let mut p = parser.lock().await;
513
514 let product = task_id * request_id;
517 let reasoning_text = format!(
518 "Task {task_id} is processing request {request_id}. Let me think through this step by step. \
519 First, I need to understand the problem. The problem involves analyzing data \
520 and making calculations. Let me break this down: \n\
521 1. Initial analysis shows that we have multiple variables to consider. \
522 2. The data suggests a pattern that needs further investigation. \
523 3. Computing the values: {task_id} * {request_id} = {product}. \
524 4. Cross-referencing with previous results indicates consistency. \
525 5. The mathematical proof follows from the axioms... \
526 6. Considering edge cases and boundary conditions... \
527 7. Validating against known constraints... \
528 8. The conclusion follows logically from premises A, B, and C. \
529 This reasoning chain demonstrates the validity of our approach.",
530 );
531
532 let answer_text = format!(
533 "Based on my analysis, the answer for task {task_id} request {request_id} is: \
534 The solution involves multiple steps as outlined in the reasoning. \
535 The final result is {product} with confidence level high. \
536 This conclusion is supported by rigorous mathematical analysis \
537 and has been validated against multiple test cases. \
538 The implementation should handle edge cases appropriately.",
539 );
540
541 let input = format!("<think>{reasoning_text}</think>{answer_text}");
542
543 match p.detect_and_parse_reasoning(&input) {
544 Ok(result) => {
545 assert!(result.normal_text.contains(&format!("task {task_id}")));
547
548 if !result.reasoning_text.is_empty() {
551 assert!(result.reasoning_text.contains(&format!("Task {task_id}")));
552 assert!(result.reasoning_text.len() > 500); }
554
555 assert!(result.normal_text.len() > 100); success_count.fetch_add(1, Ordering::Relaxed);
558 }
559 Err(e) => {
560 #[expect(clippy::print_stderr, reason = "test diagnostic output")]
561 {
562 eprintln!("Parse error: {e:?}");
563 }
564 error_count.fetch_add(1, Ordering::Relaxed);
565 }
566 }
567
568 drop(p);
570 }
571 });
572 handles.push(handle);
573 }
574
575 for handle in handles {
577 handle.await.unwrap();
578 }
579
580 let duration = start.elapsed();
581 let total_requests = num_tasks * requests_per_task;
582 let successes = success_count.load(Ordering::Relaxed);
583 let errors = error_count.load(Ordering::Relaxed);
584
585 #[expect(clippy::print_stdout, reason = "test diagnostic output")]
587 {
588 println!("High concurrency test: {num_tasks} tasks, {requests_per_task} requests each");
589 println!("Completed in {duration:?}, {successes} successes, {errors} errors");
590 println!(
591 "Throughput: {:.0} requests/sec",
592 (total_requests as f64) / duration.as_secs_f64()
593 );
594 }
595
596 assert_eq!(successes, total_requests);
598 assert_eq!(errors, 0);
599
600 let throughput = (total_requests as f64) / duration.as_secs_f64();
602 assert!(
603 throughput > 1000.0,
604 "Throughput too low: {throughput:.0} req/sec",
605 );
606 }
607
608 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
609 async fn test_concurrent_pool_modifications() {
610 let factory = ParserFactory::new();
611 let mut handles = vec![];
612
613 let factory1 = factory.clone();
615 handles.push(tokio::spawn(async move {
616 for _ in 0..100 {
617 let _parser = factory1.get_pooled("deepseek-r1");
618 }
619 }));
620
621 let factory2 = factory.clone();
623 handles.push(tokio::spawn(async move {
624 for _ in 0..10 {
625 factory2.clear_pool();
626 tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
627 }
628 }));
629
630 let factory3 = factory.clone();
632 handles.push(tokio::spawn(async move {
633 for i in 0..100 {
634 let models = ["qwen3", "kimi", "unknown"];
635 let _parser = factory3.get_pooled(models[i % 3]);
636 }
637 }));
638
639 for handle in handles {
641 handle.await.unwrap();
642 }
643 }
644}