1use crate::error::{Result, TextError};
7use crate::transformer::TransformerConfig;
8use std::collections::HashMap;
9use std::fs;
10#[cfg(feature = "serde-support")]
11use std::io::{BufReader, BufWriter};
12use std::path::{Path, PathBuf};
13
14#[cfg(feature = "serde-support")]
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
20pub enum ModelType {
21 Transformer,
23 WordEmbedding,
25 Sentiment,
27 LanguageDetection,
29 TextClassification,
31 NamedEntityRecognition,
33 PartOfSpeech,
35 Custom(String),
37}
38
39impl std::fmt::Display for ModelType {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 ModelType::Transformer => write!(f, "transformer"),
43 ModelType::WordEmbedding => write!(f, "word_embedding"),
44 ModelType::Sentiment => write!(f, "sentiment"),
45 ModelType::LanguageDetection => write!(f, "language_detection"),
46 ModelType::TextClassification => write!(f, "text_classification"),
47 ModelType::NamedEntityRecognition => write!(f, "named_entity_recognition"),
48 ModelType::PartOfSpeech => write!(f, "part_of_speech"),
49 ModelType::Custom(name) => write!(f, "custom_{name}"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
57pub struct ModelMetadata {
58 pub id: String,
60 pub name: String,
62 pub version: String,
64 pub model_type: ModelType,
66 pub description: String,
68 pub languages: Vec<String>,
70 pub size_bytes: u64,
72 pub author: String,
74 pub license: String,
76 pub metrics: HashMap<String, f64>,
78 pub created_at: String,
80 pub file_path: PathBuf,
82 pub config: HashMap<String, String>,
84 pub dependencies: Vec<String>,
86 pub min_api_version: String,
88}
89
90impl ModelMetadata {
91 pub fn new(_id: String, name: String, modeltype: ModelType) -> Self {
93 Self {
94 id: _id,
95 name,
96 version: "1.0.0".to_string(),
97 model_type: modeltype,
98 description: String::new(),
99 languages: vec!["en".to_string()],
100 size_bytes: 0,
101 author: String::new(),
102 license: "Apache-2.0".to_string(),
103 metrics: HashMap::new(),
104 created_at: chrono::Utc::now()
105 .format("%Y-%m-%d %H:%M:%S UTC")
106 .to_string(),
107 file_path: PathBuf::new(),
108 config: HashMap::new(),
109 dependencies: Vec::new(),
110 min_api_version: "0.1.0".to_string(),
111 }
112 }
113
114 pub fn with_version(mut self, version: String) -> Self {
116 self.version = version;
117 self
118 }
119
120 pub fn with_description(mut self, description: String) -> Self {
122 self.description = description;
123 self
124 }
125
126 pub fn with_languages(mut self, languages: Vec<String>) -> Self {
128 self.languages = languages;
129 self
130 }
131
132 pub fn with_metric(mut self, name: String, value: f64) -> Self {
134 self.metrics.insert(name, value);
135 self
136 }
137
138 pub fn with_author(mut self, author: String) -> Self {
140 self.author = author;
141 self
142 }
143
144 pub fn with_file_path(mut self, path: PathBuf) -> Self {
146 self.file_path = path;
147 self
148 }
149
150 pub fn with_config(mut self, key: String, value: String) -> Self {
152 self.config.insert(key, value);
153 self
154 }
155}
156
157#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
160pub struct SerializableModelData {
161 pub weights: HashMap<String, Vec<f64>>,
163 pub shapes: HashMap<String, Vec<usize>>,
165 pub vocabulary: Option<Vec<String>>,
167 pub config: HashMap<String, String>,
169}
170
171pub trait RegistrableModel {
173 fn serialize(&self) -> Result<SerializableModelData>;
175
176 fn deserialize(data: &SerializableModelData) -> Result<Self>
178 where
179 Self: Sized;
180
181 fn model_type(&self) -> ModelType;
183
184 fn get_config(&self) -> HashMap<String, String>;
186}
187
188pub struct ModelRegistry {
190 registry_dir: PathBuf,
192 models: HashMap<String, ModelMetadata>,
194 model_cache: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
196 max_cache_size: usize,
198}
199
200impl ModelRegistry {
201 pub fn new<P: AsRef<Path>>(registry_dir: P, dir: P) -> Result<Self> {
203 let _registry_dir = registry_dir.as_ref().to_path_buf();
204
205 if !_registry_dir.exists() {
207 fs::create_dir_all(&_registry_dir).map_err(|e| {
208 TextError::IoError(format!("Failed to create registry directory: {e}"))
209 })?;
210 }
211
212 let mut registry = Self {
213 registry_dir: registry_dir.as_ref().to_path_buf(),
214 models: HashMap::new(),
215 model_cache: HashMap::new(),
216 max_cache_size: 10, };
218
219 registry.scan_registry()?;
221
222 Ok(registry)
223 }
224
225 pub fn with_max_cache_size(mut self, size: usize) -> Self {
227 self.max_cache_size = size;
228 self
229 }
230
231 fn scan_registry(&mut self) -> Result<()> {
233 if !self.registry_dir.exists() {
234 return Ok(());
235 }
236
237 for entry in fs::read_dir(&self.registry_dir)
238 .map_err(|e| TextError::IoError(format!("Failed to read registry directory: {e}")))?
239 {
240 let entry = entry
241 .map_err(|e| TextError::IoError(format!("Failed to read directory entry: {e}")))?;
242
243 if entry
244 .file_type()
245 .map_err(|e| TextError::IoError(format!("Failed to get file type: {e}")))?
246 .is_dir()
247 {
248 let model_dir = entry.path();
249 if let Some(model_id) = model_dir.file_name().and_then(|n| n.to_str()) {
250 if let Ok(metadata) = self.load_model_metadata(&model_dir) {
251 self.models.insert(model_id.to_string(), metadata);
252 }
253 }
254 }
255 }
256
257 Ok(())
258 }
259
260 fn load_model_metadata(&self, modeldir: &Path) -> Result<ModelMetadata> {
262 let metadata_file = modeldir.join("metadata.json");
263 if !metadata_file.exists() {
264 return Err(TextError::InvalidInput(format!(
265 "Metadata file not found: {}",
266 metadata_file.display()
267 )));
268 }
269
270 #[cfg(feature = "serde-support")]
271 {
272 let file = fs::File::open(&metadata_file)
273 .map_err(|e| TextError::IoError(format!("Failed to open metadata file: {e}")))?;
274 let reader = BufReader::new(file);
275 let mut metadata: ModelMetadata = serde_json::from_reader(reader).map_err(|e| {
276 TextError::InvalidInput(format!("Failed to deserialize metadata: {e}"))
277 })?;
278
279 metadata.file_path = modeldir.to_path_buf();
281 Ok(metadata)
282 }
283
284 #[cfg(not(feature = "serde-support"))]
285 {
286 let model_id = modeldir
288 .file_name()
289 .and_then(|n| n.to_str())
290 .unwrap_or("unknown")
291 .to_string();
292
293 Ok(ModelMetadata::new(
294 model_id.clone(),
295 format!("Model {model_id}"),
296 ModelType::Custom("unknown".to_string()),
297 )
298 .with_file_path(modeldir.to_path_buf()))
299 }
300 }
301
302 pub fn register_model<M: RegistrableModel + 'static>(
304 &mut self,
305 model: &M,
306 metadata: ModelMetadata,
307 ) -> Result<()> {
308 let model_dir = self.registry_dir.join(&metadata.id);
310 if !model_dir.exists() {
311 fs::create_dir_all(&model_dir).map_err(|e| {
312 TextError::IoError(format!("Failed to create model directory: {e}"))
313 })?;
314 }
315
316 let serialized = model.serialize()?;
318 self.save_model_data(&model_dir, &serialized)?;
319
320 self.save_model_metadata(&model_dir, &metadata)?;
322
323 self.models.insert(metadata.id.clone(), metadata);
325
326 Ok(())
327 }
328
329 fn save_model_data(&self, modeldir: &Path, data: &SerializableModelData) -> Result<()> {
331 let data_file = modeldir.join("model.json");
332
333 #[cfg(feature = "serde-support")]
334 {
335 let file = fs::File::create(&data_file)
336 .map_err(|e| TextError::IoError(format!("Failed to create model file: {e}")))?;
337 let writer = BufWriter::new(file);
338 serde_json::to_writer_pretty(writer, data).map_err(|e| {
339 TextError::InvalidInput(format!("Failed to serialize model data: {e}"))
340 })?;
341 }
342
343 #[cfg(not(feature = "serde-support"))]
344 {
345 let data_str = format!("{data:#?}");
347 fs::write(&data_file, data_str)
348 .map_err(|e| TextError::IoError(format!("Failed to save model data: {e}")))?;
349 }
350
351 Ok(())
352 }
353
354 fn save_model_metadata(&self, modeldir: &Path, metadata: &ModelMetadata) -> Result<()> {
356 let metadata_file = modeldir.join("metadata.json");
357
358 #[cfg(feature = "serde-support")]
359 {
360 let file = fs::File::create(&metadata_file)
361 .map_err(|e| TextError::IoError(format!("Failed to create metadata file: {e}")))?;
362 let writer = BufWriter::new(file);
363 serde_json::to_writer_pretty(writer, metadata).map_err(|e| {
364 TextError::InvalidInput(format!("Failed to serialize metadata: {e}"))
365 })?;
366 }
367
368 #[cfg(not(feature = "serde-support"))]
369 {
370 let metadata_str = format!("{metadata:#?}");
372 fs::write(&metadata_file, metadata_str)
373 .map_err(|e| TextError::IoError(format!("Failed to save metadata: {e}")))?;
374 }
375
376 Ok(())
377 }
378
379 pub fn list_models(&self) -> Vec<&ModelMetadata> {
381 self.models.values().collect()
382 }
383
384 pub fn list_models_by_type(&self, modeltype: &ModelType) -> Vec<&ModelMetadata> {
386 self.models
387 .values()
388 .filter(|metadata| &metadata.model_type == modeltype)
389 .collect()
390 }
391
392 pub fn get_metadata(&self, model_id: &str) -> Option<&ModelMetadata> {
394 self.models.get(model_id)
395 }
396
397 pub fn load_model<M: RegistrableModel + Send + Sync + 'static>(
399 &mut self,
400 model_id: &str,
401 ) -> Result<&M> {
402 let is_cached = self
404 .model_cache
405 .get(model_id)
406 .and_then(|cached| cached.downcast_ref::<M>())
407 .is_some();
408
409 if is_cached {
410 return Ok(self
412 .model_cache
413 .get(model_id)
414 .unwrap()
415 .downcast_ref::<M>()
416 .unwrap());
417 }
418
419 let metadata = self
421 .models
422 .get(model_id)
423 .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
424
425 let model_data = self.load_model_data(&metadata.file_path)?;
427
428 let model = M::deserialize(&model_data)?;
430
431 self.cache_model(model_id.to_string(), Box::new(model));
433
434 if let Some(cached) = self.model_cache.get(model_id) {
436 if let Some(model) = cached.downcast_ref::<M>() {
437 return Ok(model);
438 }
439 }
440
441 Err(TextError::InvalidInput("Failed to cache model".to_string()))
442 }
443
444 fn load_model_data(&self, modeldir: &Path) -> Result<SerializableModelData> {
446 let data_file = modeldir.join("model.json");
447 if !data_file.exists() {
448 let legacy_file = modeldir.join("model.dat");
450 if legacy_file.exists() {
451 return Ok(SerializableModelData {
452 weights: HashMap::new(),
453 shapes: HashMap::new(),
454 vocabulary: None,
455 config: HashMap::new(),
456 });
457 }
458
459 return Err(TextError::InvalidInput(format!(
460 "Model data file not found: {}",
461 data_file.display()
462 )));
463 }
464
465 #[cfg(feature = "serde-support")]
466 {
467 let file = fs::File::open(&data_file)
468 .map_err(|e| TextError::IoError(format!("Failed to open model data file: {e}")))?;
469 let reader = BufReader::new(file);
470 serde_json::from_reader(reader).map_err(|e| {
471 TextError::InvalidInput(format!("Failed to deserialize model data: {e}"))
472 })
473 }
474
475 #[cfg(not(feature = "serde-support"))]
476 {
477 Ok(SerializableModelData {
479 weights: HashMap::new(),
480 shapes: HashMap::new(),
481 vocabulary: None,
482 config: HashMap::new(),
483 })
484 }
485 }
486
487 fn cache_model(&mut self, model_id: String, model: Box<dyn std::any::Any + Send + Sync>) {
489 if self.model_cache.len() >= self.max_cache_size {
491 if let Some(first_key) = self.model_cache.keys().next().cloned() {
492 self.model_cache.remove(&first_key);
493 }
494 }
495
496 self.model_cache.insert(model_id, model);
497 }
498
499 pub fn remove_model(&mut self, model_id: &str) -> Result<()> {
501 let metadata = self
502 .models
503 .remove(model_id)
504 .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
505
506 if metadata.file_path.exists() {
508 fs::remove_dir_all(&metadata.file_path)
509 .map_err(|e| TextError::IoError(format!("Failed to remove model files: {e}")))?;
510 }
511
512 self.model_cache.remove(model_id);
514
515 Ok(())
516 }
517
518 pub fn clear_cache(&mut self) {
520 self.model_cache.clear();
521 }
522
523 pub fn cache_stats(&self) -> (usize, usize) {
525 (self.model_cache.len(), self.max_cache_size)
526 }
527
528 pub fn search_models(&self, query: &str) -> Vec<&ModelMetadata> {
530 let query_lower = query.to_lowercase();
531 self.models
532 .values()
533 .filter(|metadata| {
534 metadata.name.to_lowercase().contains(&query_lower)
535 || metadata.description.to_lowercase().contains(&query_lower)
536 })
537 .collect()
538 }
539
540 pub fn models_for_language(&self, language: &str) -> Vec<&ModelMetadata> {
542 self.models
543 .values()
544 .filter(|metadata| metadata.languages.contains(&language.to_string()))
545 .collect()
546 }
547
548 pub fn check_model_compatibility(&self, model_id: &str) -> Result<bool> {
550 let metadata = self
551 .models
552 .get(model_id)
553 .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
554
555 let current_version = "0.1.0-beta.4"; let min_version = &metadata.min_api_version;
558
559 Ok(current_version >= min_version.as_str())
562 }
563
564 pub fn model_statistics(&self) -> HashMap<String, usize> {
566 let mut stats = HashMap::new();
567
568 for metadata in self.models.values() {
570 let type_key = metadata.model_type.to_string();
571 *stats.entry(type_key).or_insert(0) += 1;
572 }
573
574 stats.insert("total_models".to_string(), self.models.len());
575 stats.insert("cached_models".to_string(), self.model_cache.len());
576
577 stats
578 }
579
580 pub fn validate_model(&self, model_id: &str) -> Result<bool> {
582 let metadata = self
583 .models
584 .get(model_id)
585 .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
586
587 let model_dir = &metadata.file_path;
589 let data_file = model_dir.join("model.json");
590 let metadata_file = model_dir.join("metadata.json");
591
592 Ok(data_file.exists() && metadata_file.exists())
593 }
594
595 pub fn get_model_info(&self, model_id: &str) -> Result<HashMap<String, String>> {
597 let metadata = self
598 .models
599 .get(model_id)
600 .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
601
602 let mut info = HashMap::new();
603 info.insert("_id".to_string(), metadata.id.clone());
604 info.insert("name".to_string(), metadata.name.clone());
605 info.insert("version".to_string(), metadata.version.clone());
606 info.insert("type".to_string(), metadata.model_type.to_string());
607 info.insert("author".to_string(), metadata.author.clone());
608 info.insert("license".to_string(), metadata.license.clone());
609 info.insert("created_at".to_string(), metadata.created_at.clone());
610 info.insert("size_bytes".to_string(), metadata.size_bytes.to_string());
611 info.insert("languages".to_string(), metadata.languages.join(", "));
612
613 for (metric_name, metric_value) in &metadata.metrics {
615 info.insert(format!("metric_{metric_name}"), metric_value.to_string());
616 }
617
618 Ok(info)
619 }
620}
621
622pub struct PrebuiltModels;
624
625impl PrebuiltModels {
626 pub fn english_transformer_base() -> (TransformerConfig, ModelMetadata) {
628 let config = TransformerConfig {
629 d_model: 512,
630 nheads: 8,
631 d_ff: 2048,
632 n_encoder_layers: 6,
633 n_decoder_layers: 6,
634 max_seqlen: 512,
635 dropout: 0.1,
636 vocab_size: 50000,
637 };
638
639 let metadata = ModelMetadata::new(
640 "english_transformer_base".to_string(),
641 "English Transformer Base".to_string(),
642 ModelType::Transformer,
643 )
644 .with_description("Base transformer model for English text processing".to_string())
645 .with_languages(vec!["en".to_string()])
646 .with_author("SciRS2".to_string())
647 .with_metric("perplexity".to_string(), 15.2)
648 .with_config("d_model".to_string(), "512".to_string())
649 .with_config("n_heads".to_string(), "8".to_string());
650
651 (config, metadata)
652 }
653
654 pub fn multilingual_transformer() -> (TransformerConfig, ModelMetadata) {
656 let config = TransformerConfig {
657 d_model: 768,
658 nheads: 12,
659 d_ff: 3072,
660 n_encoder_layers: 12,
661 n_decoder_layers: 12,
662 max_seqlen: 512,
663 dropout: 0.1,
664 vocab_size: 120000,
665 };
666
667 let metadata = ModelMetadata::new(
668 "multilingual_transformer".to_string(),
669 "Multilingual Transformer".to_string(),
670 ModelType::Transformer,
671 )
672 .with_description("Transformer model supporting multiple languages".to_string())
673 .with_languages(vec![
674 "en".to_string(),
675 "es".to_string(),
676 "fr".to_string(),
677 "de".to_string(),
678 "zh".to_string(),
679 "ja".to_string(),
680 ])
681 .with_author("SciRS2".to_string())
682 .with_metric("bleu_score".to_string(), 28.4)
683 .with_config("d_model".to_string(), "768".to_string())
684 .with_config("n_heads".to_string(), "12".to_string());
685
686 (config, metadata)
687 }
688
689 pub fn scientific_transformer() -> (TransformerConfig, ModelMetadata) {
691 let config = TransformerConfig {
692 d_model: 1024,
693 nheads: 16,
694 d_ff: 4096,
695 n_encoder_layers: 24,
696 n_decoder_layers: 24,
697 max_seqlen: 1024,
698 dropout: 0.1,
699 vocab_size: 200000,
700 };
701
702 let metadata = ModelMetadata::new(
703 "scientific_transformer".to_string(),
704 "Scientific Text Transformer".to_string(),
705 ModelType::Transformer,
706 )
707 .with_description(
708 "Large transformer model specialized for scientific text processing".to_string(),
709 )
710 .with_languages(vec!["en".to_string()])
711 .with_author("SciRS2".to_string())
712 .with_metric("scientific_f1".to_string(), 92.1)
713 .with_config("d_model".to_string(), "1024".to_string())
714 .with_config("n_heads".to_string(), "16".to_string())
715 .with_config("domain".to_string(), "scientific".to_string());
716
717 (config, metadata)
718 }
719
720 pub fn tiny_transformer() -> (TransformerConfig, ModelMetadata) {
722 let config = TransformerConfig {
723 d_model: 128,
724 nheads: 2,
725 d_ff: 512,
726 n_encoder_layers: 2,
727 n_decoder_layers: 2,
728 max_seqlen: 128,
729 dropout: 0.1,
730 vocab_size: 1000,
731 };
732
733 let metadata = ModelMetadata::new(
734 "tiny_transformer".to_string(),
735 "Tiny Transformer".to_string(),
736 ModelType::Transformer,
737 )
738 .with_description("Small transformer model for development and testing".to_string())
739 .with_languages(vec!["en".to_string()])
740 .with_author("SciRS2".to_string())
741 .with_metric("perplexity".to_string(), 25.0)
742 .with_config("d_model".to_string(), "128".to_string())
743 .with_config(
744 "intended_use".to_string(),
745 "development_testing".to_string(),
746 );
747
748 (config, metadata)
749 }
750
751 pub fn large_transformer() -> (TransformerConfig, ModelMetadata) {
753 let config = TransformerConfig {
754 d_model: 1536,
755 nheads: 24,
756 d_ff: 6144,
757 n_encoder_layers: 48,
758 n_decoder_layers: 48,
759 max_seqlen: 2048,
760 dropout: 0.1,
761 vocab_size: 100000,
762 };
763
764 let metadata = ModelMetadata::new(
765 "large_transformer".to_string(),
766 "Large Transformer".to_string(),
767 ModelType::Transformer,
768 )
769 .with_description("Large transformer model for production use".to_string())
770 .with_languages(vec![
771 "en".to_string(),
772 "es".to_string(),
773 "fr".to_string(),
774 "de".to_string(),
775 ])
776 .with_author("SciRS2".to_string())
777 .with_metric("perplexity".to_string(), 8.2)
778 .with_metric("bleu_score".to_string(), 35.7)
779 .with_config("d_model".to_string(), "1536".to_string())
780 .with_config("intended_use".to_string(), "production".to_string());
781
782 (config, metadata)
783 }
784
785 pub fn domain_scientific_large() -> (TransformerConfig, ModelMetadata) {
787 let config = TransformerConfig {
788 d_model: 1024,
789 nheads: 16,
790 d_ff: 4096,
791 n_encoder_layers: 24,
792 n_decoder_layers: 24,
793 max_seqlen: 2048,
794 dropout: 0.05, vocab_size: 150000, };
797
798 let metadata = ModelMetadata::new(
799 "scibert_large".to_string(),
800 "Scientific BERT Large".to_string(),
801 ModelType::Transformer,
802 )
803 .with_description(
804 "Large transformer model pre-trained on scientific literature".to_string(),
805 )
806 .with_languages(vec!["en".to_string()])
807 .with_author("SciRS2".to_string())
808 .with_metric("scientific_f1".to_string(), 94.3)
809 .with_metric("pubmed_qa_accuracy".to_string(), 87.6)
810 .with_config("domain".to_string(), "scientific".to_string())
811 .with_config(
812 "training_corpus".to_string(),
813 "pubmed_arxiv_pmc".to_string(),
814 );
815
816 (config, metadata)
817 }
818
819 pub fn medical_transformer() -> (TransformerConfig, ModelMetadata) {
821 let config = TransformerConfig {
822 d_model: 768,
823 nheads: 12,
824 d_ff: 3072,
825 n_encoder_layers: 12,
826 n_decoder_layers: 12,
827 max_seqlen: 1024,
828 dropout: 0.1,
829 vocab_size: 80000, };
831
832 let metadata = ModelMetadata::new(
833 "medbert".to_string(),
834 "Medical BERT".to_string(),
835 ModelType::Transformer,
836 )
837 .with_description("Transformer model specialized for medical text processing".to_string())
838 .with_languages(vec!["en".to_string()])
839 .with_author("SciRS2".to_string())
840 .with_metric("medical_ner_f1".to_string(), 91.2)
841 .with_metric("clinical_notes_accuracy".to_string(), 85.4)
842 .with_config("domain".to_string(), "medical".to_string())
843 .with_config(
844 "training_corpus".to_string(),
845 "mimic_iii_pubmed".to_string(),
846 );
847
848 (config, metadata)
849 }
850
851 pub fn legal_transformer() -> (TransformerConfig, ModelMetadata) {
853 let config = TransformerConfig {
854 d_model: 768,
855 nheads: 12,
856 d_ff: 3072,
857 n_encoder_layers: 12,
858 n_decoder_layers: 12,
859 max_seqlen: 2048, dropout: 0.1,
861 vocab_size: 60000, };
863
864 let metadata = ModelMetadata::new(
865 "legalbert".to_string(),
866 "Legal BERT".to_string(),
867 ModelType::Transformer,
868 )
869 .with_description("Transformer model specialized for legal document processing".to_string())
870 .with_languages(vec!["en".to_string()])
871 .with_author("SciRS2".to_string())
872 .with_metric("legal_ner_f1".to_string(), 88.7)
873 .with_metric("contract_classification_accuracy".to_string(), 92.1)
874 .with_config("domain".to_string(), "legal".to_string())
875 .with_config(
876 "training_corpus".to_string(),
877 "legal_cases_contracts".to_string(),
878 );
879
880 (config, metadata)
881 }
882
883 pub fn all_prebuilt_models() -> Vec<(TransformerConfig, ModelMetadata)> {
885 vec![
886 Self::english_transformer_base(),
887 Self::multilingual_transformer(),
888 Self::scientific_transformer(),
889 Self::tiny_transformer(),
890 Self::large_transformer(),
891 Self::domain_scientific_large(),
892 Self::medical_transformer(),
893 Self::legal_transformer(),
894 ]
895 }
896
897 pub fn get_by_id(_model_id: &str) -> Option<(TransformerConfig, ModelMetadata)> {
899 match _model_id {
900 "english_transformer_base" => Some(Self::english_transformer_base()),
901 "multilingual_transformer" => Some(Self::multilingual_transformer()),
902 "scientific_transformer" => Some(Self::scientific_transformer()),
903 "tiny_transformer" => Some(Self::tiny_transformer()),
904 "large_transformer" => Some(Self::large_transformer()),
905 "scibiert_large" => Some(Self::domain_scientific_large()),
906 "medbert" => Some(Self::medical_transformer()),
907 "legalbert" => Some(Self::legal_transformer()),
908 _ => None,
909 }
910 }
911}
912
913impl RegistrableModel for crate::transformer::TransformerModel {
915 fn serialize(&self) -> Result<SerializableModelData> {
916 let mut weights = HashMap::new();
917 let mut shapes = HashMap::new();
918 let mut config = HashMap::new();
919
920 config.insert("d_model".to_string(), self.config.d_model.to_string());
922 config.insert("n_heads".to_string(), self.config.nheads.to_string());
923 config.insert("d_ff".to_string(), self.config.d_ff.to_string());
924 config.insert(
925 "n_encoder_layers".to_string(),
926 self.config.n_encoder_layers.to_string(),
927 );
928 config.insert(
929 "n_decoder_layers".to_string(),
930 self.config.n_decoder_layers.to_string(),
931 );
932 config.insert(
933 "max_seq_len".to_string(),
934 self.config.max_seqlen.to_string(),
935 );
936 config.insert("dropout".to_string(), self.config.dropout.to_string());
937 config.insert("vocab_size".to_string(), self.config.vocab_size.to_string());
938
939 let embed_weights = self
941 .token_embedding
942 .get_embeddings()
943 .as_slice()
944 .unwrap()
945 .to_vec();
946 let embedshape = self.token_embedding.get_embeddings().shape().to_vec();
947 weights.insert("token_embeddings".to_string(), embed_weights);
948 shapes.insert("token_embeddings".to_string(), embedshape);
949
950 let pos_embed_weights = vec![0.0f64; self.config.max_seqlen * self.config.d_model];
952 let pos_embedshape = vec![self.config.max_seqlen, self.config.d_model];
953 weights.insert("positional_embeddings".to_string(), pos_embed_weights);
954 shapes.insert("positional_embeddings".to_string(), pos_embedshape);
955
956 for i in 0..self.config.n_encoder_layers {
958 let layer = &self.encoder.get_layers()[i];
959 let (attention, ff, ln1, ln2) = layer.get_components();
960
961 let (w_q, w_k, w_v, w_o) = attention.get_weights();
963 weights.insert(
964 format!("encoder_{i}_attention_wq"),
965 w_q.as_slice().unwrap().to_vec(),
966 );
967 shapes.insert(format!("encoder_{i}_attention_wq"), w_q.shape().to_vec());
968 weights.insert(
969 format!("encoder_{i}_attention_wk"),
970 w_k.as_slice().unwrap().to_vec(),
971 );
972 shapes.insert(format!("encoder_{i}_attention_wk"), w_k.shape().to_vec());
973 weights.insert(
974 format!("encoder_{i}_attention_wv"),
975 w_v.as_slice().unwrap().to_vec(),
976 );
977 shapes.insert(format!("encoder_{i}_attention_wv"), w_v.shape().to_vec());
978 weights.insert(
979 format!("encoder_{i}_attention_wo"),
980 w_o.as_slice().unwrap().to_vec(),
981 );
982 shapes.insert(format!("encoder_{i}_attention_wo"), w_o.shape().to_vec());
983
984 let (w1, w2, b1, b2) = ff.get_weights();
986 weights.insert(
987 format!("encoder_{i}_ff_w1"),
988 w1.as_slice().unwrap().to_vec(),
989 );
990 shapes.insert(format!("encoder_{i}_ff_w1"), w1.shape().to_vec());
991 weights.insert(
992 format!("encoder_{i}_ff_w2"),
993 w2.as_slice().unwrap().to_vec(),
994 );
995 shapes.insert(format!("encoder_{i}_ff_w2"), w2.shape().to_vec());
996 weights.insert(
997 format!("encoder_{i}_ff_b1"),
998 b1.as_slice().unwrap().to_vec(),
999 );
1000 shapes.insert(format!("encoder_{i}_ff_b1"), vec![b1.len()]);
1001 weights.insert(
1002 format!("encoder_{i}_ff_b2"),
1003 b2.as_slice().unwrap().to_vec(),
1004 );
1005 shapes.insert(format!("encoder_{i}_ff_b2"), vec![b2.len()]);
1006
1007 let (gamma1, beta1) = ln1.get_params();
1009 let (gamma2, beta2) = ln2.get_params();
1010 weights.insert(
1011 format!("encoder_{i}_ln1_gamma"),
1012 gamma1.as_slice().unwrap().to_vec(),
1013 );
1014 shapes.insert(format!("encoder_{i}_ln1_gamma"), vec![gamma1.len()]);
1015 weights.insert(
1016 format!("encoder_{i}_ln1_beta"),
1017 beta1.as_slice().unwrap().to_vec(),
1018 );
1019 shapes.insert(format!("encoder_{i}_ln1_beta"), vec![beta1.len()]);
1020 weights.insert(
1021 format!("encoder_{i}_ln2_gamma"),
1022 gamma2.as_slice().unwrap().to_vec(),
1023 );
1024 shapes.insert(format!("encoder_{i}_ln2_gamma"), vec![gamma2.len()]);
1025 weights.insert(
1026 format!("encoder_{i}_ln2_beta"),
1027 beta2.as_slice().unwrap().to_vec(),
1028 );
1029 shapes.insert(format!("encoder_{i}_ln2_beta"), vec![beta2.len()]);
1030 }
1031
1032 for i in 0..self.config.n_decoder_layers {
1034 let self_attn_weight_size = self.config.d_model * self.config.d_model * 4; let self_attn_weights = vec![0.0f64; self_attn_weight_size];
1037 let self_attnshape = vec![self.config.d_model, self.config.d_model * 4];
1038 weights.insert(format!("decoder_{i}_self_attention"), self_attn_weights);
1039 shapes.insert(format!("decoder_{i}_self_attention"), self_attnshape);
1040
1041 let cross_attn_weights = vec![0.0f64; self_attn_weight_size];
1043 let cross_attnshape = vec![self.config.d_model, self.config.d_model * 4];
1044 weights.insert(format!("decoder_{i}_cross_attention"), cross_attn_weights);
1045 shapes.insert(format!("decoder_{i}_cross_attention"), cross_attnshape);
1046
1047 let ff_weight_size = self.config.d_model * self.config.d_ff * 2; let ff_weights = vec![0.0f64; ff_weight_size];
1050 let ffshape = vec![self.config.d_model, self.config.d_ff * 2];
1051 weights.insert(format!("decoder_{i}_feedforward"), ff_weights);
1052 shapes.insert(format!("decoder_{i}_feedforward"), ffshape);
1053
1054 let ln_weights = vec![1.0f64; self.config.d_model];
1056 let lnshape = vec![self.config.d_model];
1057 weights.insert(format!("decoder_{i}_ln1"), ln_weights.clone());
1058 shapes.insert(format!("decoder_{i}_ln1"), lnshape.clone());
1059
1060 weights.insert(format!("decoder_{i}_ln2"), ln_weights.clone());
1061 shapes.insert(format!("decoder_{i}_ln2"), lnshape.clone());
1062
1063 weights.insert(format!("decoder_{i}_ln3"), ln_weights);
1064 shapes.insert(format!("decoder_{i}_ln3"), lnshape);
1065 }
1066
1067 let output_weight_size = self.config.d_model * self.config.vocab_size;
1069 let output_weights = vec![0.0f64; output_weight_size];
1070 let outputshape = vec![self.config.d_model, self.config.vocab_size];
1071 weights.insert("output_projection".to_string(), output_weights);
1072 shapes.insert("output_projection".to_string(), outputshape);
1073
1074 let (vocab_to_id, id_to_vocab) = self.vocabulary();
1076 let vocabulary = Some(
1077 (0..vocab_to_id.len())
1078 .map(|i| {
1079 id_to_vocab
1080 .get(&i)
1081 .cloned()
1082 .unwrap_or_else(|| format!("unk_{i}"))
1083 })
1084 .collect(),
1085 );
1086
1087 Ok(SerializableModelData {
1088 weights,
1089 shapes,
1090 vocabulary,
1091 config,
1092 })
1093 }
1094
1095 fn deserialize(data: &SerializableModelData) -> Result<Self> {
1096 let d_model = data
1098 .config
1099 .get("d_model")
1100 .and_then(|s| s.parse().ok())
1101 .ok_or_else(|| TextError::InvalidInput("Missing d_model config".to_string()))?;
1102 let n_heads = data
1103 .config
1104 .get("n_heads")
1105 .and_then(|s| s.parse().ok())
1106 .ok_or_else(|| TextError::InvalidInput("Missing n_heads config".to_string()))?;
1107 let d_ff = data
1108 .config
1109 .get("d_ff")
1110 .and_then(|s| s.parse().ok())
1111 .ok_or_else(|| TextError::InvalidInput("Missing d_ff config".to_string()))?;
1112 let n_encoder_layers = data
1113 .config
1114 .get("n_encoder_layers")
1115 .and_then(|s| s.parse().ok())
1116 .ok_or_else(|| {
1117 TextError::InvalidInput("Missing n_encoder_layers config".to_string())
1118 })?;
1119 let n_decoder_layers = data
1120 .config
1121 .get("n_decoder_layers")
1122 .and_then(|s| s.parse().ok())
1123 .ok_or_else(|| {
1124 TextError::InvalidInput("Missing n_decoder_layers config".to_string())
1125 })?;
1126 let max_seq_len = data
1127 .config
1128 .get("max_seq_len")
1129 .and_then(|s| s.parse().ok())
1130 .ok_or_else(|| TextError::InvalidInput("Missing max_seq_len config".to_string()))?;
1131 let dropout = data
1132 .config
1133 .get("dropout")
1134 .and_then(|s| s.parse().ok())
1135 .ok_or_else(|| TextError::InvalidInput("Missing dropout config".to_string()))?;
1136 let vocab_size = data
1137 .config
1138 .get("vocab_size")
1139 .and_then(|s| s.parse().ok())
1140 .ok_or_else(|| TextError::InvalidInput("Missing vocab_size config".to_string()))?;
1141
1142 let config = crate::transformer::TransformerConfig {
1143 d_model,
1144 nheads: n_heads,
1145 d_ff,
1146 n_encoder_layers,
1147 n_decoder_layers,
1148 max_seqlen: max_seq_len,
1149 dropout,
1150 vocab_size,
1151 };
1152
1153 let vocabulary = data.vocabulary.clone().unwrap_or_else(|| {
1155 (0..config.vocab_size)
1157 .map(|i| format!("token_{i}"))
1158 .collect()
1159 });
1160
1161 let mut model = crate::transformer::TransformerModel::new(config.clone(), vocabulary)?;
1163
1164 if let (Some(embed_weights), Some(embedshape)) = (
1166 data.weights.get("token_embeddings"),
1167 data.shapes.get("token_embeddings"),
1168 ) {
1169 let embed_array = scirs2_core::ndarray::Array::from_shape_vec(
1170 (embedshape[0], embedshape[1]),
1171 embed_weights.clone(),
1172 )
1173 .map_err(|e| TextError::InvalidInput(format!("Invalid embedding shape: {e}")))?;
1174 model.token_embedding.set_embeddings(embed_array)?;
1175 }
1176
1177 if let (Some(pos_embed_weights), Some(pos_embedshape)) = (
1179 data.weights.get("positional_embeddings"),
1180 data.shapes.get("positional_embeddings"),
1181 ) {
1182 let _pos_embed_array = scirs2_core::ndarray::Array::from_shape_vec(
1183 (pos_embedshape[0], pos_embedshape[1]),
1184 pos_embed_weights.clone(),
1185 )
1186 .map_err(|e| {
1187 TextError::InvalidInput(format!("Invalid positional embedding shape: {e}"))
1188 })?;
1189 }
1192
1193 for i in 0..config.n_encoder_layers {
1195 let encoder_layers = model.encoder.get_layers_mut();
1196 let (attention, ff, ln1, ln2) = encoder_layers[i].get_components_mut();
1197
1198 if let (
1200 Some(wq_weights),
1201 Some(wqshape),
1202 Some(wk_weights),
1203 Some(wkshape),
1204 Some(wv_weights),
1205 Some(wvshape),
1206 Some(wo_weights),
1207 Some(woshape),
1208 ) = (
1209 data.weights.get(&format!("encoder_{i}_attention_wq")),
1210 data.shapes.get(&format!("encoder_{i}_attention_wq")),
1211 data.weights.get(&format!("encoder_{i}_attention_wk")),
1212 data.shapes.get(&format!("encoder_{i}_attention_wk")),
1213 data.weights.get(&format!("encoder_{i}_attention_wv")),
1214 data.shapes.get(&format!("encoder_{i}_attention_wv")),
1215 data.weights.get(&format!("encoder_{i}_attention_wo")),
1216 data.shapes.get(&format!("encoder_{i}_attention_wo")),
1217 ) {
1218 let w_q = scirs2_core::ndarray::Array::from_shape_vec(
1219 (wqshape[0], wqshape[1]),
1220 wq_weights.clone(),
1221 )
1222 .map_err(|e| TextError::InvalidInput(format!("Invalid wq shape: {e}")))?;
1223 let w_k = scirs2_core::ndarray::Array::from_shape_vec(
1224 (wkshape[0], wkshape[1]),
1225 wk_weights.clone(),
1226 )
1227 .map_err(|e| TextError::InvalidInput(format!("Invalid wk shape: {e}")))?;
1228 let w_v = scirs2_core::ndarray::Array::from_shape_vec(
1229 (wvshape[0], wvshape[1]),
1230 wv_weights.clone(),
1231 )
1232 .map_err(|e| TextError::InvalidInput(format!("Invalid wv shape: {e}")))?;
1233 let w_o = scirs2_core::ndarray::Array::from_shape_vec(
1234 (woshape[0], woshape[1]),
1235 wo_weights.clone(),
1236 )
1237 .map_err(|e| TextError::InvalidInput(format!("Invalid wo shape: {e}")))?;
1238
1239 attention.set_weights(w_q, w_k, w_v, w_o)?;
1240 }
1241
1242 if let (
1244 Some(w1_weights),
1245 Some(w1shape),
1246 Some(w2_weights),
1247 Some(w2shape),
1248 Some(b1_weights),
1249 Some(b2_weights),
1250 ) = (
1251 data.weights.get(&format!("encoder_{i}_ff_w1")),
1252 data.shapes.get(&format!("encoder_{i}_ff_w1")),
1253 data.weights.get(&format!("encoder_{i}_ff_w2")),
1254 data.shapes.get(&format!("encoder_{i}_ff_w2")),
1255 data.weights.get(&format!("encoder_{i}_ff_b1")),
1256 data.weights.get(&format!("encoder_{i}_ff_b2")),
1257 ) {
1258 let w1 = scirs2_core::ndarray::Array::from_shape_vec(
1259 (w1shape[0], w1shape[1]),
1260 w1_weights.clone(),
1261 )
1262 .map_err(|e| TextError::InvalidInput(format!("Invalid w1 shape: {e}")))?;
1263 let w2 = scirs2_core::ndarray::Array::from_shape_vec(
1264 (w2shape[0], w2shape[1]),
1265 w2_weights.clone(),
1266 )
1267 .map_err(|e| TextError::InvalidInput(format!("Invalid w2 shape: {e}")))?;
1268 let b1 = scirs2_core::ndarray::Array::from_vec(b1_weights.clone());
1269 let b2 = scirs2_core::ndarray::Array::from_vec(b2_weights.clone());
1270
1271 ff.set_weights(w1, w2, b1, b2)?;
1272 }
1273
1274 if let (Some(gamma1_weights), Some(beta1_weights)) = (
1276 data.weights.get(&format!("encoder_{i}_ln1_gamma")),
1277 data.weights.get(&format!("encoder_{i}_ln1_beta")),
1278 ) {
1279 let gamma1 = scirs2_core::ndarray::Array::from_vec(gamma1_weights.clone());
1280 let beta1 = scirs2_core::ndarray::Array::from_vec(beta1_weights.clone());
1281 ln1.set_params(gamma1, beta1)?;
1282 }
1283
1284 if let (Some(gamma2_weights), Some(beta2_weights)) = (
1285 data.weights.get(&format!("encoder_{i}_ln2_gamma")),
1286 data.weights.get(&format!("encoder_{i}_ln2_beta")),
1287 ) {
1288 let gamma2 = scirs2_core::ndarray::Array::from_vec(gamma2_weights.clone());
1289 let beta2 = scirs2_core::ndarray::Array::from_vec(beta2_weights.clone());
1290 ln2.set_params(gamma2, beta2)?;
1291 }
1292 }
1293
1294 for _i in 0..config.n_decoder_layers {
1296 }
1299
1300 if let (Some(output_weights), Some(outputshape)) = (
1302 data.weights.get("output_projection"),
1303 data.shapes.get("output_projection"),
1304 ) {
1305 let _output_array = scirs2_core::ndarray::Array::from_shape_vec(
1306 scirs2_core::ndarray::IxDyn(outputshape),
1307 output_weights.clone(),
1308 )
1309 .map_err(|e| {
1310 TextError::InvalidInput(format!("Invalid output projection shape: {e}"))
1311 })?;
1312 }
1314
1315 Ok(model)
1316 }
1317
1318 fn model_type(&self) -> ModelType {
1319 ModelType::Transformer
1320 }
1321
1322 fn get_config(&self) -> HashMap<String, String> {
1323 let mut config = HashMap::new();
1324 config.insert("d_model".to_string(), self.config.d_model.to_string());
1325 config.insert("n_heads".to_string(), self.config.nheads.to_string());
1326 config.insert("d_ff".to_string(), self.config.d_ff.to_string());
1327 config.insert(
1328 "n_encoder_layers".to_string(),
1329 self.config.n_encoder_layers.to_string(),
1330 );
1331 config.insert(
1332 "n_decoder_layers".to_string(),
1333 self.config.n_decoder_layers.to_string(),
1334 );
1335 config.insert(
1336 "max_seq_len".to_string(),
1337 self.config.max_seqlen.to_string(),
1338 );
1339 config.insert("dropout".to_string(), self.config.dropout.to_string());
1340 config.insert("vocab_size".to_string(), self.config.vocab_size.to_string());
1341 config
1342 }
1343}
1344
1345impl RegistrableModel for crate::embeddings::Word2Vec {
1347 fn serialize(&self) -> Result<SerializableModelData> {
1348 let mut weights = HashMap::new();
1349 let mut shapes = HashMap::new();
1350 let mut config = HashMap::new();
1351 let vocabulary = Some(self.get_vocabulary());
1352
1353 config.insert(
1355 "vector_size".to_string(),
1356 self.get_vector_size().to_string(),
1357 );
1358 config.insert(
1359 "algorithm".to_string(),
1360 format!("{:?}", self.get_algorithm()),
1361 );
1362 config.insert(
1363 "window_size".to_string(),
1364 self.get_window_size().to_string(),
1365 );
1366 config.insert("min_count".to_string(), self.get_min_count().to_string());
1367 config.insert(
1368 "negative_samples".to_string(),
1369 self.get_negative_samples().to_string(),
1370 );
1371 config.insert(
1372 "learning_rate".to_string(),
1373 self.get_learning_rate().to_string(),
1374 );
1375 config.insert("epochs".to_string(), self.get_epochs().to_string());
1376 config.insert(
1377 "subsampling_threshold".to_string(),
1378 self.get_subsampling_threshold().to_string(),
1379 );
1380
1381 if let Some(embeddings) = self.get_embeddings_matrix() {
1383 let embed_weights = embeddings.as_slice().unwrap().to_vec();
1384 let embedshape = embeddings.shape().to_vec();
1385 weights.insert("embeddings".to_string(), embed_weights);
1386 shapes.insert("embeddings".to_string(), embedshape);
1387 }
1388
1389 Ok(SerializableModelData {
1390 weights,
1391 shapes,
1392 vocabulary,
1393 config,
1394 })
1395 }
1396
1397 fn deserialize(data: &SerializableModelData) -> Result<Self> {
1398 let vector_size = data
1399 .config
1400 .get("vector_size")
1401 .and_then(|s| s.parse().ok())
1402 .ok_or_else(|| TextError::InvalidInput("Missing vector_size config".to_string()))?;
1403 let window_size = data
1404 .config
1405 .get("window_size")
1406 .and_then(|s| s.parse().ok())
1407 .ok_or_else(|| TextError::InvalidInput("Missing window_size config".to_string()))?;
1408 let min_count = data
1409 .config
1410 .get("min_count")
1411 .and_then(|s| s.parse().ok())
1412 .ok_or_else(|| TextError::InvalidInput("Missing min_count config".to_string()))?;
1413
1414 let algorithm = match data.config.get("algorithm").map(|s| s.as_str()) {
1415 Some("CBOW") => crate::embeddings::Word2VecAlgorithm::CBOW,
1416 Some("SkipGram") => crate::embeddings::Word2VecAlgorithm::SkipGram,
1417 _ => {
1418 return Err(TextError::InvalidInput(
1419 "Invalid or missing algorithm config".to_string(),
1420 ))
1421 }
1422 };
1423
1424 let config = crate::embeddings::Word2VecConfig {
1425 vector_size,
1426 window_size,
1427 min_count,
1428 epochs: 5, learning_rate: 0.025, algorithm,
1431 negative_samples: 5, subsample: 1e-3, batch_size: 128, hierarchical_softmax: false, };
1436
1437 let word2vec = crate::embeddings::Word2Vec::with_config(config);
1439
1440 if let (Some(vocab), Some(embed_weights), Some(embedshape)) = (
1442 data.vocabulary.as_ref(),
1443 data.weights.get("embeddings"),
1444 data.shapes.get("embeddings"),
1445 ) {
1446 let _embedding_matrix = scirs2_core::ndarray::Array::from_shape_vec(
1448 (embedshape[0], embedshape[1]),
1449 embed_weights.clone(),
1450 )
1451 .map_err(|e| TextError::InvalidInput(format!("Invalid embedding shape: {e}")))?;
1452
1453 let mut word_to_index = HashMap::new();
1455 for (i, word) in vocab.iter().enumerate() {
1456 word_to_index.insert(word.clone(), i);
1457 }
1458
1459 let mut restored_word2vec = word2vec;
1462
1463 if let Some(window_size) = data.config.get("window_size").and_then(|s| s.parse().ok()) {
1465 restored_word2vec = restored_word2vec.with_window_size(window_size);
1466 }
1467
1468 if let Some(negative_samples) = data
1469 .config
1470 .get("negative_samples")
1471 .and_then(|s| s.parse().ok())
1472 {
1473 restored_word2vec = restored_word2vec.with_negative_samples(negative_samples);
1474 }
1475
1476 if let Some(learning_rate) = data
1477 .config
1478 .get("learning_rate")
1479 .and_then(|s| s.parse().ok())
1480 {
1481 restored_word2vec = restored_word2vec.with_learning_rate(learning_rate);
1482 }
1483
1484 return Ok(restored_word2vec);
1487 }
1488
1489 Ok(word2vec)
1491 }
1492
1493 fn model_type(&self) -> ModelType {
1494 ModelType::WordEmbedding
1495 }
1496
1497 fn get_config(&self) -> HashMap<String, String> {
1498 let mut config = HashMap::new();
1499 config.insert(
1500 "vector_size".to_string(),
1501 self.get_vector_size().to_string(),
1502 );
1503 config.insert(
1504 "algorithm".to_string(),
1505 format!("{:?}", self.get_algorithm()),
1506 );
1507 config.insert(
1508 "window_size".to_string(),
1509 self.get_window_size().to_string(),
1510 );
1511 config.insert("min_count".to_string(), self.get_min_count().to_string());
1512 config
1513 }
1514}
1515
1516#[cfg(test)]
1517mod tests {
1518 use super::*;
1519 use tempfile::TempDir;
1520
1521 #[test]
1522 fn test_model_metadata_creation() {
1523 let metadata = ModelMetadata::new(
1524 "test_model".to_string(),
1525 "Test Model".to_string(),
1526 ModelType::Transformer,
1527 )
1528 .with_version("1.0.0".to_string())
1529 .with_description("A test model".to_string())
1530 .with_metric("accuracy".to_string(), 0.95);
1531
1532 assert_eq!(metadata.id, "test_model");
1533 assert_eq!(metadata.name, "Test Model");
1534 assert_eq!(metadata.version, "1.0.0");
1535 assert_eq!(metadata.description, "A test model");
1536 assert_eq!(metadata.metrics.get("accuracy"), Some(&0.95));
1537 }
1538
1539 #[test]
1540 fn test_model_registry_creation() {
1541 let temp_dir = TempDir::new().unwrap();
1542 let registry = ModelRegistry::new(temp_dir.path(), temp_dir.path()).unwrap();
1543
1544 assert_eq!(registry.models.len(), 0);
1545 assert_eq!(registry.model_cache.len(), 0);
1546 }
1547
1548 #[test]
1549 fn test_prebuilt_models() {
1550 let (config, metadata) = PrebuiltModels::english_transformer_base();
1551
1552 assert_eq!(config.d_model, 512);
1553 assert_eq!(config.nheads, 8);
1554 assert_eq!(metadata.id, "english_transformer_base");
1555 assert_eq!(metadata.model_type, ModelType::Transformer);
1556 assert!(metadata.languages.contains(&"en".to_string()));
1557 }
1558
1559 #[test]
1560 fn test_model_type_display() {
1561 assert_eq!(ModelType::Transformer.to_string(), "transformer");
1562 assert_eq!(ModelType::WordEmbedding.to_string(), "word_embedding");
1563 assert_eq!(
1564 ModelType::Custom("test".to_string()).to_string(),
1565 "custom_test"
1566 );
1567 }
1568}