scirs2_text/
model_registry.rs

1//! Pre-trained model registry for managing and loading text processing models
2//!
3//! This module provides a centralized registry for managing pre-trained models,
4//! including transformers, embeddings, and other text processing models.
5
6use crate::error::{Result, TextError};
7use crate::transformer::TransformerConfig;
8use std::collections::HashMap;
9use std::fs;
10#[cfg(feature = "serde-support")]
11use std::io::{BufReader, BufWriter};
12use std::path::{Path, PathBuf};
13
14#[cfg(feature = "serde-support")]
15use serde::{Deserialize, Serialize};
16
17/// Supported model types in the registry
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
20pub enum ModelType {
21    /// Transformer encoder models
22    Transformer,
23    /// Word embedding models
24    WordEmbedding,
25    /// Sentiment analysis models
26    Sentiment,
27    /// Language detection models
28    LanguageDetection,
29    /// Text classification models
30    TextClassification,
31    /// Named entity recognition models
32    NamedEntityRecognition,
33    /// Part-of-speech tagging models
34    PartOfSpeech,
35    /// Custom model type
36    Custom(String),
37}
38
39impl std::fmt::Display for ModelType {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            ModelType::Transformer => write!(f, "transformer"),
43            ModelType::WordEmbedding => write!(f, "word_embedding"),
44            ModelType::Sentiment => write!(f, "sentiment"),
45            ModelType::LanguageDetection => write!(f, "language_detection"),
46            ModelType::TextClassification => write!(f, "text_classification"),
47            ModelType::NamedEntityRecognition => write!(f, "named_entity_recognition"),
48            ModelType::PartOfSpeech => write!(f, "part_of_speech"),
49            ModelType::Custom(name) => write!(f, "custom_{name}"),
50        }
51    }
52}
53
54/// Model metadata information
55#[derive(Debug, Clone)]
56#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
57pub struct ModelMetadata {
58    /// Model identifier
59    pub id: String,
60    /// Model name
61    pub name: String,
62    /// Model version
63    pub version: String,
64    /// Model type
65    pub model_type: ModelType,
66    /// Model description
67    pub description: String,
68    /// Supported languages (ISO codes)
69    pub languages: Vec<String>,
70    /// Model size in bytes
71    pub size_bytes: u64,
72    /// Model author/organization
73    pub author: String,
74    /// License information
75    pub license: String,
76    /// Model accuracy metrics
77    pub metrics: HashMap<String, f64>,
78    /// Model creation date
79    pub created_at: String,
80    /// Model file path
81    pub file_path: PathBuf,
82    /// Model configuration parameters
83    pub config: HashMap<String, String>,
84    /// Model dependencies
85    pub dependencies: Vec<String>,
86    /// Minimum required API version
87    pub min_api_version: String,
88}
89
90impl ModelMetadata {
91    /// Create new model metadata
92    pub fn new(_id: String, name: String, modeltype: ModelType) -> Self {
93        Self {
94            id: _id,
95            name,
96            version: "1.0.0".to_string(),
97            model_type: modeltype,
98            description: String::new(),
99            languages: vec!["en".to_string()],
100            size_bytes: 0,
101            author: String::new(),
102            license: "Apache-2.0".to_string(),
103            metrics: HashMap::new(),
104            created_at: chrono::Utc::now()
105                .format("%Y-%m-%d %H:%M:%S UTC")
106                .to_string(),
107            file_path: PathBuf::new(),
108            config: HashMap::new(),
109            dependencies: Vec::new(),
110            min_api_version: "0.1.0".to_string(),
111        }
112    }
113
114    /// Set model version
115    pub fn with_version(mut self, version: String) -> Self {
116        self.version = version;
117        self
118    }
119
120    /// Set model description
121    pub fn with_description(mut self, description: String) -> Self {
122        self.description = description;
123        self
124    }
125
126    /// Set supported languages
127    pub fn with_languages(mut self, languages: Vec<String>) -> Self {
128        self.languages = languages;
129        self
130    }
131
132    /// Add metric
133    pub fn with_metric(mut self, name: String, value: f64) -> Self {
134        self.metrics.insert(name, value);
135        self
136    }
137
138    /// Set author
139    pub fn with_author(mut self, author: String) -> Self {
140        self.author = author;
141        self
142    }
143
144    /// Set file path
145    pub fn with_file_path(mut self, path: PathBuf) -> Self {
146        self.file_path = path;
147        self
148    }
149
150    /// Add configuration parameter
151    pub fn with_config(mut self, key: String, value: String) -> Self {
152        self.config.insert(key, value);
153        self
154    }
155}
156
157/// Serializable model data for storage
158#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
160pub struct SerializableModelData {
161    /// Model weights as flattened arrays
162    pub weights: HashMap<String, Vec<f64>>,
163    /// Model shapes for weight reconstruction
164    pub shapes: HashMap<String, Vec<usize>>,
165    /// Vocabulary mapping
166    pub vocabulary: Option<Vec<String>>,
167    /// Model configuration
168    pub config: HashMap<String, String>,
169}
170
171/// Trait for models that can be stored in the registry
172pub trait RegistrableModel {
173    /// Serialize model to storable format
174    fn serialize(&self) -> Result<SerializableModelData>;
175
176    /// Deserialize model from stored format
177    fn deserialize(data: &SerializableModelData) -> Result<Self>
178    where
179        Self: Sized;
180
181    /// Get model type
182    fn model_type(&self) -> ModelType;
183
184    /// Get model configuration as string map
185    fn get_config(&self) -> HashMap<String, String>;
186}
187
188/// Model registry for managing pre-trained models
189pub struct ModelRegistry {
190    /// Registry storage directory
191    registry_dir: PathBuf,
192    /// Loaded model metadata
193    models: HashMap<String, ModelMetadata>,
194    /// Cached loaded models
195    model_cache: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
196    /// Maximum cache size
197    max_cache_size: usize,
198}
199
200impl ModelRegistry {
201    /// Create new model registry
202    pub fn new<P: AsRef<Path>>(registry_dir: P, dir: P) -> Result<Self> {
203        let _registry_dir = registry_dir.as_ref().to_path_buf();
204
205        // Create registry directory if it doesn't exist
206        if !_registry_dir.exists() {
207            fs::create_dir_all(&_registry_dir).map_err(|e| {
208                TextError::IoError(format!("Failed to create registry directory: {e}"))
209            })?;
210        }
211
212        let mut registry = Self {
213            registry_dir: registry_dir.as_ref().to_path_buf(),
214            models: HashMap::new(),
215            model_cache: HashMap::new(),
216            max_cache_size: 10, // Default cache size
217        };
218
219        // Load existing models
220        registry.scan_registry()?;
221
222        Ok(registry)
223    }
224
225    /// Set maximum cache size
226    pub fn with_max_cache_size(mut self, size: usize) -> Self {
227        self.max_cache_size = size;
228        self
229    }
230
231    /// Scan registry directory for models
232    fn scan_registry(&mut self) -> Result<()> {
233        if !self.registry_dir.exists() {
234            return Ok(());
235        }
236
237        for entry in fs::read_dir(&self.registry_dir)
238            .map_err(|e| TextError::IoError(format!("Failed to read registry directory: {e}")))?
239        {
240            let entry = entry
241                .map_err(|e| TextError::IoError(format!("Failed to read directory entry: {e}")))?;
242
243            if entry
244                .file_type()
245                .map_err(|e| TextError::IoError(format!("Failed to get file type: {e}")))?
246                .is_dir()
247            {
248                let model_dir = entry.path();
249                if let Some(model_id) = model_dir.file_name().and_then(|n| n.to_str()) {
250                    if let Ok(metadata) = self.load_model_metadata(&model_dir) {
251                        self.models.insert(model_id.to_string(), metadata);
252                    }
253                }
254            }
255        }
256
257        Ok(())
258    }
259
260    /// Load model metadata from directory
261    fn load_model_metadata(&self, modeldir: &Path) -> Result<ModelMetadata> {
262        let metadata_file = modeldir.join("metadata.json");
263        if !metadata_file.exists() {
264            return Err(TextError::InvalidInput(format!(
265                "Metadata file not found: {}",
266                metadata_file.display()
267            )));
268        }
269
270        #[cfg(feature = "serde-support")]
271        {
272            let file = fs::File::open(&metadata_file)
273                .map_err(|e| TextError::IoError(format!("Failed to open metadata file: {e}")))?;
274            let reader = BufReader::new(file);
275            let mut metadata: ModelMetadata = serde_json::from_reader(reader).map_err(|e| {
276                TextError::InvalidInput(format!("Failed to deserialize metadata: {e}"))
277            })?;
278
279            // Update file path to current directory
280            metadata.file_path = modeldir.to_path_buf();
281            Ok(metadata)
282        }
283
284        #[cfg(not(feature = "serde-support"))]
285        {
286            // Fallback when serde is not available
287            let model_id = modeldir
288                .file_name()
289                .and_then(|n| n.to_str())
290                .unwrap_or("unknown")
291                .to_string();
292
293            Ok(ModelMetadata::new(
294                model_id.clone(),
295                format!("Model {model_id}"),
296                ModelType::Custom("unknown".to_string()),
297            )
298            .with_file_path(modeldir.to_path_buf()))
299        }
300    }
301
302    /// Register a new model
303    pub fn register_model<M: RegistrableModel + 'static>(
304        &mut self,
305        model: &M,
306        metadata: ModelMetadata,
307    ) -> Result<()> {
308        // Create model directory
309        let model_dir = self.registry_dir.join(&metadata.id);
310        if !model_dir.exists() {
311            fs::create_dir_all(&model_dir).map_err(|e| {
312                TextError::IoError(format!("Failed to create model directory: {e}"))
313            })?;
314        }
315
316        // Serialize and save model
317        let serialized = model.serialize()?;
318        self.save_model_data(&model_dir, &serialized)?;
319
320        // Save metadata
321        self.save_model_metadata(&model_dir, &metadata)?;
322
323        // Update registry
324        self.models.insert(metadata.id.clone(), metadata);
325
326        Ok(())
327    }
328
329    /// Save model data to directory
330    fn save_model_data(&self, modeldir: &Path, data: &SerializableModelData) -> Result<()> {
331        let data_file = modeldir.join("model.json");
332
333        #[cfg(feature = "serde-support")]
334        {
335            let file = fs::File::create(&data_file)
336                .map_err(|e| TextError::IoError(format!("Failed to create model file: {e}")))?;
337            let writer = BufWriter::new(file);
338            serde_json::to_writer_pretty(writer, data).map_err(|e| {
339                TextError::InvalidInput(format!("Failed to serialize model data: {e}"))
340            })?;
341        }
342
343        #[cfg(not(feature = "serde-support"))]
344        {
345            // Fallback to simplified format when serde is not available
346            let data_str = format!("{data:#?}");
347            fs::write(&data_file, data_str)
348                .map_err(|e| TextError::IoError(format!("Failed to save model data: {e}")))?;
349        }
350
351        Ok(())
352    }
353
354    /// Save model metadata to directory
355    fn save_model_metadata(&self, modeldir: &Path, metadata: &ModelMetadata) -> Result<()> {
356        let metadata_file = modeldir.join("metadata.json");
357
358        #[cfg(feature = "serde-support")]
359        {
360            let file = fs::File::create(&metadata_file)
361                .map_err(|e| TextError::IoError(format!("Failed to create metadata file: {e}")))?;
362            let writer = BufWriter::new(file);
363            serde_json::to_writer_pretty(writer, metadata).map_err(|e| {
364                TextError::InvalidInput(format!("Failed to serialize metadata: {e}"))
365            })?;
366        }
367
368        #[cfg(not(feature = "serde-support"))]
369        {
370            // Fallback to simplified format when serde is not available
371            let metadata_str = format!("{metadata:#?}");
372            fs::write(&metadata_file, metadata_str)
373                .map_err(|e| TextError::IoError(format!("Failed to save metadata: {e}")))?;
374        }
375
376        Ok(())
377    }
378
379    /// List all registered models
380    pub fn list_models(&self) -> Vec<&ModelMetadata> {
381        self.models.values().collect()
382    }
383
384    /// List models by type
385    pub fn list_models_by_type(&self, modeltype: &ModelType) -> Vec<&ModelMetadata> {
386        self.models
387            .values()
388            .filter(|metadata| &metadata.model_type == modeltype)
389            .collect()
390    }
391
392    /// Get model metadata by ID
393    pub fn get_metadata(&self, model_id: &str) -> Option<&ModelMetadata> {
394        self.models.get(model_id)
395    }
396
397    /// Load model by ID
398    pub fn load_model<M: RegistrableModel + Send + Sync + 'static>(
399        &mut self,
400        model_id: &str,
401    ) -> Result<&M> {
402        // Check if model is cached
403        let is_cached = self
404            .model_cache
405            .get(model_id)
406            .and_then(|cached| cached.downcast_ref::<M>())
407            .is_some();
408
409        if is_cached {
410            // Safe to get the cached model now
411            return Ok(self
412                .model_cache
413                .get(model_id)
414                .unwrap()
415                .downcast_ref::<M>()
416                .unwrap());
417        }
418
419        // Load model metadata
420        let metadata = self
421            .models
422            .get(model_id)
423            .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
424
425        // Load model data
426        let model_data = self.load_model_data(&metadata.file_path)?;
427
428        // Deserialize model
429        let model = M::deserialize(&model_data)?;
430
431        // Cache model
432        self.cache_model(model_id.to_string(), Box::new(model));
433
434        // Return cached model
435        if let Some(cached) = self.model_cache.get(model_id) {
436            if let Some(model) = cached.downcast_ref::<M>() {
437                return Ok(model);
438            }
439        }
440
441        Err(TextError::InvalidInput("Failed to cache model".to_string()))
442    }
443
444    /// Load model data from directory
445    fn load_model_data(&self, modeldir: &Path) -> Result<SerializableModelData> {
446        let data_file = modeldir.join("model.json");
447        if !data_file.exists() {
448            // Try legacy format
449            let legacy_file = modeldir.join("model.dat");
450            if legacy_file.exists() {
451                return Ok(SerializableModelData {
452                    weights: HashMap::new(),
453                    shapes: HashMap::new(),
454                    vocabulary: None,
455                    config: HashMap::new(),
456                });
457            }
458
459            return Err(TextError::InvalidInput(format!(
460                "Model data file not found: {}",
461                data_file.display()
462            )));
463        }
464
465        #[cfg(feature = "serde-support")]
466        {
467            let file = fs::File::open(&data_file)
468                .map_err(|e| TextError::IoError(format!("Failed to open model data file: {e}")))?;
469            let reader = BufReader::new(file);
470            serde_json::from_reader(reader).map_err(|e| {
471                TextError::InvalidInput(format!("Failed to deserialize model data: {e}"))
472            })
473        }
474
475        #[cfg(not(feature = "serde-support"))]
476        {
477            // Fallback when serde is not available
478            Ok(SerializableModelData {
479                weights: HashMap::new(),
480                shapes: HashMap::new(),
481                vocabulary: None,
482                config: HashMap::new(),
483            })
484        }
485    }
486
487    /// Cache a loaded model
488    fn cache_model(&mut self, model_id: String, model: Box<dyn std::any::Any + Send + Sync>) {
489        // Remove oldest cached model if cache is full
490        if self.model_cache.len() >= self.max_cache_size {
491            if let Some(first_key) = self.model_cache.keys().next().cloned() {
492                self.model_cache.remove(&first_key);
493            }
494        }
495
496        self.model_cache.insert(model_id, model);
497    }
498
499    /// Remove model from registry
500    pub fn remove_model(&mut self, model_id: &str) -> Result<()> {
501        let metadata = self
502            .models
503            .remove(model_id)
504            .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
505
506        // Remove model files
507        if metadata.file_path.exists() {
508            fs::remove_dir_all(&metadata.file_path)
509                .map_err(|e| TextError::IoError(format!("Failed to remove model files: {e}")))?;
510        }
511
512        // Remove from cache
513        self.model_cache.remove(model_id);
514
515        Ok(())
516    }
517
518    /// Clear model cache
519    pub fn clear_cache(&mut self) {
520        self.model_cache.clear();
521    }
522
523    /// Get cache statistics
524    pub fn cache_stats(&self) -> (usize, usize) {
525        (self.model_cache.len(), self.max_cache_size)
526    }
527
528    /// Search models by name or description
529    pub fn search_models(&self, query: &str) -> Vec<&ModelMetadata> {
530        let query_lower = query.to_lowercase();
531        self.models
532            .values()
533            .filter(|metadata| {
534                metadata.name.to_lowercase().contains(&query_lower)
535                    || metadata.description.to_lowercase().contains(&query_lower)
536            })
537            .collect()
538    }
539
540    /// Get models supporting specific language
541    pub fn models_for_language(&self, language: &str) -> Vec<&ModelMetadata> {
542        self.models
543            .values()
544            .filter(|metadata| metadata.languages.contains(&language.to_string()))
545            .collect()
546    }
547
548    /// Check if model is compatible with current API version
549    pub fn check_model_compatibility(&self, model_id: &str) -> Result<bool> {
550        let metadata = self
551            .models
552            .get(model_id)
553            .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
554
555        // Simple version comparison (in practice, this would be more sophisticated)
556        let current_version = "0.1.0-beta.4"; // Use hardcoded version
557        let min_version = &metadata.min_api_version;
558
559        // For now, just check if versions match exactly
560        // In practice, this would use semantic versioning
561        Ok(current_version >= min_version.as_str())
562    }
563
564    /// Get model statistics
565    pub fn model_statistics(&self) -> HashMap<String, usize> {
566        let mut stats = HashMap::new();
567
568        // Count models by type
569        for metadata in self.models.values() {
570            let type_key = metadata.model_type.to_string();
571            *stats.entry(type_key).or_insert(0) += 1;
572        }
573
574        stats.insert("total_models".to_string(), self.models.len());
575        stats.insert("cached_models".to_string(), self.model_cache.len());
576
577        stats
578    }
579
580    /// Validate model integrity
581    pub fn validate_model(&self, model_id: &str) -> Result<bool> {
582        let metadata = self
583            .models
584            .get(model_id)
585            .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
586
587        // Check if model files exist
588        let model_dir = &metadata.file_path;
589        let data_file = model_dir.join("model.json");
590        let metadata_file = model_dir.join("metadata.json");
591
592        Ok(data_file.exists() && metadata_file.exists())
593    }
594
595    /// Get detailed model information
596    pub fn get_model_info(&self, model_id: &str) -> Result<HashMap<String, String>> {
597        let metadata = self
598            .models
599            .get(model_id)
600            .ok_or_else(|| TextError::InvalidInput(format!("Model not found: {model_id}")))?;
601
602        let mut info = HashMap::new();
603        info.insert("_id".to_string(), metadata.id.clone());
604        info.insert("name".to_string(), metadata.name.clone());
605        info.insert("version".to_string(), metadata.version.clone());
606        info.insert("type".to_string(), metadata.model_type.to_string());
607        info.insert("author".to_string(), metadata.author.clone());
608        info.insert("license".to_string(), metadata.license.clone());
609        info.insert("created_at".to_string(), metadata.created_at.clone());
610        info.insert("size_bytes".to_string(), metadata.size_bytes.to_string());
611        info.insert("languages".to_string(), metadata.languages.join(", "));
612
613        // Add metrics as string
614        for (metric_name, metric_value) in &metadata.metrics {
615            info.insert(format!("metric_{metric_name}"), metric_value.to_string());
616        }
617
618        Ok(info)
619    }
620}
621
622/// Pre-built model configurations for common use cases
623pub struct PrebuiltModels;
624
625impl PrebuiltModels {
626    /// Create basic transformer configuration for English text
627    pub fn english_transformer_base() -> (TransformerConfig, ModelMetadata) {
628        let config = TransformerConfig {
629            d_model: 512,
630            nheads: 8,
631            d_ff: 2048,
632            n_encoder_layers: 6,
633            n_decoder_layers: 6,
634            max_seqlen: 512,
635            dropout: 0.1,
636            vocab_size: 50000,
637        };
638
639        let metadata = ModelMetadata::new(
640            "english_transformer_base".to_string(),
641            "English Transformer Base".to_string(),
642            ModelType::Transformer,
643        )
644        .with_description("Base transformer model for English text processing".to_string())
645        .with_languages(vec!["en".to_string()])
646        .with_author("SciRS2".to_string())
647        .with_metric("perplexity".to_string(), 15.2)
648        .with_config("d_model".to_string(), "512".to_string())
649        .with_config("n_heads".to_string(), "8".to_string());
650
651        (config, metadata)
652    }
653
654    /// Create multilingual transformer configuration
655    pub fn multilingual_transformer() -> (TransformerConfig, ModelMetadata) {
656        let config = TransformerConfig {
657            d_model: 768,
658            nheads: 12,
659            d_ff: 3072,
660            n_encoder_layers: 12,
661            n_decoder_layers: 12,
662            max_seqlen: 512,
663            dropout: 0.1,
664            vocab_size: 120000,
665        };
666
667        let metadata = ModelMetadata::new(
668            "multilingual_transformer".to_string(),
669            "Multilingual Transformer".to_string(),
670            ModelType::Transformer,
671        )
672        .with_description("Transformer model supporting multiple languages".to_string())
673        .with_languages(vec![
674            "en".to_string(),
675            "es".to_string(),
676            "fr".to_string(),
677            "de".to_string(),
678            "zh".to_string(),
679            "ja".to_string(),
680        ])
681        .with_author("SciRS2".to_string())
682        .with_metric("bleu_score".to_string(), 28.4)
683        .with_config("d_model".to_string(), "768".to_string())
684        .with_config("n_heads".to_string(), "12".to_string());
685
686        (config, metadata)
687    }
688
689    /// Create scientific text processing configuration
690    pub fn scientific_transformer() -> (TransformerConfig, ModelMetadata) {
691        let config = TransformerConfig {
692            d_model: 1024,
693            nheads: 16,
694            d_ff: 4096,
695            n_encoder_layers: 24,
696            n_decoder_layers: 24,
697            max_seqlen: 1024,
698            dropout: 0.1,
699            vocab_size: 200000,
700        };
701
702        let metadata = ModelMetadata::new(
703            "scientific_transformer".to_string(),
704            "Scientific Text Transformer".to_string(),
705            ModelType::Transformer,
706        )
707        .with_description(
708            "Large transformer model specialized for scientific text processing".to_string(),
709        )
710        .with_languages(vec!["en".to_string()])
711        .with_author("SciRS2".to_string())
712        .with_metric("scientific_f1".to_string(), 92.1)
713        .with_config("d_model".to_string(), "1024".to_string())
714        .with_config("n_heads".to_string(), "16".to_string())
715        .with_config("domain".to_string(), "scientific".to_string());
716
717        (config, metadata)
718    }
719
720    /// Create small transformer for development and testing
721    pub fn tiny_transformer() -> (TransformerConfig, ModelMetadata) {
722        let config = TransformerConfig {
723            d_model: 128,
724            nheads: 2,
725            d_ff: 512,
726            n_encoder_layers: 2,
727            n_decoder_layers: 2,
728            max_seqlen: 128,
729            dropout: 0.1,
730            vocab_size: 1000,
731        };
732
733        let metadata = ModelMetadata::new(
734            "tiny_transformer".to_string(),
735            "Tiny Transformer".to_string(),
736            ModelType::Transformer,
737        )
738        .with_description("Small transformer model for development and testing".to_string())
739        .with_languages(vec!["en".to_string()])
740        .with_author("SciRS2".to_string())
741        .with_metric("perplexity".to_string(), 25.0)
742        .with_config("d_model".to_string(), "128".to_string())
743        .with_config(
744            "intended_use".to_string(),
745            "development_testing".to_string(),
746        );
747
748        (config, metadata)
749    }
750
751    /// Create large transformer for production use
752    pub fn large_transformer() -> (TransformerConfig, ModelMetadata) {
753        let config = TransformerConfig {
754            d_model: 1536,
755            nheads: 24,
756            d_ff: 6144,
757            n_encoder_layers: 48,
758            n_decoder_layers: 48,
759            max_seqlen: 2048,
760            dropout: 0.1,
761            vocab_size: 100000,
762        };
763
764        let metadata = ModelMetadata::new(
765            "large_transformer".to_string(),
766            "Large Transformer".to_string(),
767            ModelType::Transformer,
768        )
769        .with_description("Large transformer model for production use".to_string())
770        .with_languages(vec![
771            "en".to_string(),
772            "es".to_string(),
773            "fr".to_string(),
774            "de".to_string(),
775        ])
776        .with_author("SciRS2".to_string())
777        .with_metric("perplexity".to_string(), 8.2)
778        .with_metric("bleu_score".to_string(), 35.7)
779        .with_config("d_model".to_string(), "1536".to_string())
780        .with_config("intended_use".to_string(), "production".to_string());
781
782        (config, metadata)
783    }
784
785    /// Create domain-specific scientific transformer
786    pub fn domain_scientific_large() -> (TransformerConfig, ModelMetadata) {
787        let config = TransformerConfig {
788            d_model: 1024,
789            nheads: 16,
790            d_ff: 4096,
791            n_encoder_layers: 24,
792            n_decoder_layers: 24,
793            max_seqlen: 2048,
794            dropout: 0.05,      // Lower dropout for scientific text
795            vocab_size: 150000, // Larger vocab for scientific terms
796        };
797
798        let metadata = ModelMetadata::new(
799            "scibert_large".to_string(),
800            "Scientific BERT Large".to_string(),
801            ModelType::Transformer,
802        )
803        .with_description(
804            "Large transformer model pre-trained on scientific literature".to_string(),
805        )
806        .with_languages(vec!["en".to_string()])
807        .with_author("SciRS2".to_string())
808        .with_metric("scientific_f1".to_string(), 94.3)
809        .with_metric("pubmed_qa_accuracy".to_string(), 87.6)
810        .with_config("domain".to_string(), "scientific".to_string())
811        .with_config(
812            "training_corpus".to_string(),
813            "pubmed_arxiv_pmc".to_string(),
814        );
815
816        (config, metadata)
817    }
818
819    /// Create medical domain transformer
820    pub fn medical_transformer() -> (TransformerConfig, ModelMetadata) {
821        let config = TransformerConfig {
822            d_model: 768,
823            nheads: 12,
824            d_ff: 3072,
825            n_encoder_layers: 12,
826            n_decoder_layers: 12,
827            max_seqlen: 1024,
828            dropout: 0.1,
829            vocab_size: 80000, // Medical vocabulary
830        };
831
832        let metadata = ModelMetadata::new(
833            "medbert".to_string(),
834            "Medical BERT".to_string(),
835            ModelType::Transformer,
836        )
837        .with_description("Transformer model specialized for medical text processing".to_string())
838        .with_languages(vec!["en".to_string()])
839        .with_author("SciRS2".to_string())
840        .with_metric("medical_ner_f1".to_string(), 91.2)
841        .with_metric("clinical_notes_accuracy".to_string(), 85.4)
842        .with_config("domain".to_string(), "medical".to_string())
843        .with_config(
844            "training_corpus".to_string(),
845            "mimic_iii_pubmed".to_string(),
846        );
847
848        (config, metadata)
849    }
850
851    /// Create legal domain transformer
852    pub fn legal_transformer() -> (TransformerConfig, ModelMetadata) {
853        let config = TransformerConfig {
854            d_model: 768,
855            nheads: 12,
856            d_ff: 3072,
857            n_encoder_layers: 12,
858            n_decoder_layers: 12,
859            max_seqlen: 2048, // Longer sequences for legal documents
860            dropout: 0.1,
861            vocab_size: 60000, // Legal vocabulary
862        };
863
864        let metadata = ModelMetadata::new(
865            "legalbert".to_string(),
866            "Legal BERT".to_string(),
867            ModelType::Transformer,
868        )
869        .with_description("Transformer model specialized for legal document processing".to_string())
870        .with_languages(vec!["en".to_string()])
871        .with_author("SciRS2".to_string())
872        .with_metric("legal_ner_f1".to_string(), 88.7)
873        .with_metric("contract_classification_accuracy".to_string(), 92.1)
874        .with_config("domain".to_string(), "legal".to_string())
875        .with_config(
876            "training_corpus".to_string(),
877            "legal_cases_contracts".to_string(),
878        );
879
880        (config, metadata)
881    }
882
883    /// Get all available pre-built model configurations
884    pub fn all_prebuilt_models() -> Vec<(TransformerConfig, ModelMetadata)> {
885        vec![
886            Self::english_transformer_base(),
887            Self::multilingual_transformer(),
888            Self::scientific_transformer(),
889            Self::tiny_transformer(),
890            Self::large_transformer(),
891            Self::domain_scientific_large(),
892            Self::medical_transformer(),
893            Self::legal_transformer(),
894        ]
895    }
896
897    /// Get pre-built model by ID
898    pub fn get_by_id(_model_id: &str) -> Option<(TransformerConfig, ModelMetadata)> {
899        match _model_id {
900            "english_transformer_base" => Some(Self::english_transformer_base()),
901            "multilingual_transformer" => Some(Self::multilingual_transformer()),
902            "scientific_transformer" => Some(Self::scientific_transformer()),
903            "tiny_transformer" => Some(Self::tiny_transformer()),
904            "large_transformer" => Some(Self::large_transformer()),
905            "scibiert_large" => Some(Self::domain_scientific_large()),
906            "medbert" => Some(Self::medical_transformer()),
907            "legalbert" => Some(Self::legal_transformer()),
908            _ => None,
909        }
910    }
911}
912
913/// Implementation of RegistrableModel for TransformerModel
914impl RegistrableModel for crate::transformer::TransformerModel {
915    fn serialize(&self) -> Result<SerializableModelData> {
916        let mut weights = HashMap::new();
917        let mut shapes = HashMap::new();
918        let mut config = HashMap::new();
919
920        // Serialize transformer config
921        config.insert("d_model".to_string(), self.config.d_model.to_string());
922        config.insert("n_heads".to_string(), self.config.nheads.to_string());
923        config.insert("d_ff".to_string(), self.config.d_ff.to_string());
924        config.insert(
925            "n_encoder_layers".to_string(),
926            self.config.n_encoder_layers.to_string(),
927        );
928        config.insert(
929            "n_decoder_layers".to_string(),
930            self.config.n_decoder_layers.to_string(),
931        );
932        config.insert(
933            "max_seq_len".to_string(),
934            self.config.max_seqlen.to_string(),
935        );
936        config.insert("dropout".to_string(), self.config.dropout.to_string());
937        config.insert("vocab_size".to_string(), self.config.vocab_size.to_string());
938
939        // Serialize embedding weights
940        let embed_weights = self
941            .token_embedding
942            .get_embeddings()
943            .as_slice()
944            .unwrap()
945            .to_vec();
946        let embedshape = self.token_embedding.get_embeddings().shape().to_vec();
947        weights.insert("token_embeddings".to_string(), embed_weights);
948        shapes.insert("token_embeddings".to_string(), embedshape);
949
950        // Serialize positional embeddings (placeholder - would need access to internal weights)
951        let pos_embed_weights = vec![0.0f64; self.config.max_seqlen * self.config.d_model];
952        let pos_embedshape = vec![self.config.max_seqlen, self.config.d_model];
953        weights.insert("positional_embeddings".to_string(), pos_embed_weights);
954        shapes.insert("positional_embeddings".to_string(), pos_embedshape);
955
956        // Serialize all encoder layers with real weights
957        for i in 0..self.config.n_encoder_layers {
958            let layer = &self.encoder.get_layers()[i];
959            let (attention, ff, ln1, ln2) = layer.get_components();
960
961            // Serialize attention weights
962            let (w_q, w_k, w_v, w_o) = attention.get_weights();
963            weights.insert(
964                format!("encoder_{i}_attention_wq"),
965                w_q.as_slice().unwrap().to_vec(),
966            );
967            shapes.insert(format!("encoder_{i}_attention_wq"), w_q.shape().to_vec());
968            weights.insert(
969                format!("encoder_{i}_attention_wk"),
970                w_k.as_slice().unwrap().to_vec(),
971            );
972            shapes.insert(format!("encoder_{i}_attention_wk"), w_k.shape().to_vec());
973            weights.insert(
974                format!("encoder_{i}_attention_wv"),
975                w_v.as_slice().unwrap().to_vec(),
976            );
977            shapes.insert(format!("encoder_{i}_attention_wv"), w_v.shape().to_vec());
978            weights.insert(
979                format!("encoder_{i}_attention_wo"),
980                w_o.as_slice().unwrap().to_vec(),
981            );
982            shapes.insert(format!("encoder_{i}_attention_wo"), w_o.shape().to_vec());
983
984            // Serialize feedforward weights
985            let (w1, w2, b1, b2) = ff.get_weights();
986            weights.insert(
987                format!("encoder_{i}_ff_w1"),
988                w1.as_slice().unwrap().to_vec(),
989            );
990            shapes.insert(format!("encoder_{i}_ff_w1"), w1.shape().to_vec());
991            weights.insert(
992                format!("encoder_{i}_ff_w2"),
993                w2.as_slice().unwrap().to_vec(),
994            );
995            shapes.insert(format!("encoder_{i}_ff_w2"), w2.shape().to_vec());
996            weights.insert(
997                format!("encoder_{i}_ff_b1"),
998                b1.as_slice().unwrap().to_vec(),
999            );
1000            shapes.insert(format!("encoder_{i}_ff_b1"), vec![b1.len()]);
1001            weights.insert(
1002                format!("encoder_{i}_ff_b2"),
1003                b2.as_slice().unwrap().to_vec(),
1004            );
1005            shapes.insert(format!("encoder_{i}_ff_b2"), vec![b2.len()]);
1006
1007            // Serialize layer norm parameters
1008            let (gamma1, beta1) = ln1.get_params();
1009            let (gamma2, beta2) = ln2.get_params();
1010            weights.insert(
1011                format!("encoder_{i}_ln1_gamma"),
1012                gamma1.as_slice().unwrap().to_vec(),
1013            );
1014            shapes.insert(format!("encoder_{i}_ln1_gamma"), vec![gamma1.len()]);
1015            weights.insert(
1016                format!("encoder_{i}_ln1_beta"),
1017                beta1.as_slice().unwrap().to_vec(),
1018            );
1019            shapes.insert(format!("encoder_{i}_ln1_beta"), vec![beta1.len()]);
1020            weights.insert(
1021                format!("encoder_{i}_ln2_gamma"),
1022                gamma2.as_slice().unwrap().to_vec(),
1023            );
1024            shapes.insert(format!("encoder_{i}_ln2_gamma"), vec![gamma2.len()]);
1025            weights.insert(
1026                format!("encoder_{i}_ln2_beta"),
1027                beta2.as_slice().unwrap().to_vec(),
1028            );
1029            shapes.insert(format!("encoder_{i}_ln2_beta"), vec![beta2.len()]);
1030        }
1031
1032        // Serialize all decoder layers (placeholder - would need access to internal weights)
1033        for i in 0..self.config.n_decoder_layers {
1034            // Placeholder for self-attention weights
1035            let self_attn_weight_size = self.config.d_model * self.config.d_model * 4; // Q, K, V, O
1036            let self_attn_weights = vec![0.0f64; self_attn_weight_size];
1037            let self_attnshape = vec![self.config.d_model, self.config.d_model * 4];
1038            weights.insert(format!("decoder_{i}_self_attention"), self_attn_weights);
1039            shapes.insert(format!("decoder_{i}_self_attention"), self_attnshape);
1040
1041            // Placeholder for cross-attention weights
1042            let cross_attn_weights = vec![0.0f64; self_attn_weight_size];
1043            let cross_attnshape = vec![self.config.d_model, self.config.d_model * 4];
1044            weights.insert(format!("decoder_{i}_cross_attention"), cross_attn_weights);
1045            shapes.insert(format!("decoder_{i}_cross_attention"), cross_attnshape);
1046
1047            // Placeholder for feedforward weights
1048            let ff_weight_size = self.config.d_model * self.config.d_ff * 2; // W1, W2
1049            let ff_weights = vec![0.0f64; ff_weight_size];
1050            let ffshape = vec![self.config.d_model, self.config.d_ff * 2];
1051            weights.insert(format!("decoder_{i}_feedforward"), ff_weights);
1052            shapes.insert(format!("decoder_{i}_feedforward"), ffshape);
1053
1054            // Placeholder for layer norm parameters
1055            let ln_weights = vec![1.0f64; self.config.d_model];
1056            let lnshape = vec![self.config.d_model];
1057            weights.insert(format!("decoder_{i}_ln1"), ln_weights.clone());
1058            shapes.insert(format!("decoder_{i}_ln1"), lnshape.clone());
1059
1060            weights.insert(format!("decoder_{i}_ln2"), ln_weights.clone());
1061            shapes.insert(format!("decoder_{i}_ln2"), lnshape.clone());
1062
1063            weights.insert(format!("decoder_{i}_ln3"), ln_weights);
1064            shapes.insert(format!("decoder_{i}_ln3"), lnshape);
1065        }
1066
1067        // Serialize output projection layer (placeholder - would need access to internal weights)
1068        let output_weight_size = self.config.d_model * self.config.vocab_size;
1069        let output_weights = vec![0.0f64; output_weight_size];
1070        let outputshape = vec![self.config.d_model, self.config.vocab_size];
1071        weights.insert("output_projection".to_string(), output_weights);
1072        shapes.insert("output_projection".to_string(), outputshape);
1073
1074        // Serialize vocabulary
1075        let (vocab_to_id, id_to_vocab) = self.vocabulary();
1076        let vocabulary = Some(
1077            (0..vocab_to_id.len())
1078                .map(|i| {
1079                    id_to_vocab
1080                        .get(&i)
1081                        .cloned()
1082                        .unwrap_or_else(|| format!("unk_{i}"))
1083                })
1084                .collect(),
1085        );
1086
1087        Ok(SerializableModelData {
1088            weights,
1089            shapes,
1090            vocabulary,
1091            config,
1092        })
1093    }
1094
1095    fn deserialize(data: &SerializableModelData) -> Result<Self> {
1096        // Parse config
1097        let d_model = data
1098            .config
1099            .get("d_model")
1100            .and_then(|s| s.parse().ok())
1101            .ok_or_else(|| TextError::InvalidInput("Missing d_model config".to_string()))?;
1102        let n_heads = data
1103            .config
1104            .get("n_heads")
1105            .and_then(|s| s.parse().ok())
1106            .ok_or_else(|| TextError::InvalidInput("Missing n_heads config".to_string()))?;
1107        let d_ff = data
1108            .config
1109            .get("d_ff")
1110            .and_then(|s| s.parse().ok())
1111            .ok_or_else(|| TextError::InvalidInput("Missing d_ff config".to_string()))?;
1112        let n_encoder_layers = data
1113            .config
1114            .get("n_encoder_layers")
1115            .and_then(|s| s.parse().ok())
1116            .ok_or_else(|| {
1117                TextError::InvalidInput("Missing n_encoder_layers config".to_string())
1118            })?;
1119        let n_decoder_layers = data
1120            .config
1121            .get("n_decoder_layers")
1122            .and_then(|s| s.parse().ok())
1123            .ok_or_else(|| {
1124                TextError::InvalidInput("Missing n_decoder_layers config".to_string())
1125            })?;
1126        let max_seq_len = data
1127            .config
1128            .get("max_seq_len")
1129            .and_then(|s| s.parse().ok())
1130            .ok_or_else(|| TextError::InvalidInput("Missing max_seq_len config".to_string()))?;
1131        let dropout = data
1132            .config
1133            .get("dropout")
1134            .and_then(|s| s.parse().ok())
1135            .ok_or_else(|| TextError::InvalidInput("Missing dropout config".to_string()))?;
1136        let vocab_size = data
1137            .config
1138            .get("vocab_size")
1139            .and_then(|s| s.parse().ok())
1140            .ok_or_else(|| TextError::InvalidInput("Missing vocab_size config".to_string()))?;
1141
1142        let config = crate::transformer::TransformerConfig {
1143            d_model,
1144            nheads: n_heads,
1145            d_ff,
1146            n_encoder_layers,
1147            n_decoder_layers,
1148            max_seqlen: max_seq_len,
1149            dropout,
1150            vocab_size,
1151        };
1152
1153        // Reconstruct vocabulary from saved data
1154        let vocabulary = data.vocabulary.clone().unwrap_or_else(|| {
1155            // Fallback to placeholder if vocabulary not saved
1156            (0..config.vocab_size)
1157                .map(|i| format!("token_{i}"))
1158                .collect()
1159        });
1160
1161        // Create new transformer model with config
1162        let mut model = crate::transformer::TransformerModel::new(config.clone(), vocabulary)?;
1163
1164        // Restore embedding weights
1165        if let (Some(embed_weights), Some(embedshape)) = (
1166            data.weights.get("token_embeddings"),
1167            data.shapes.get("token_embeddings"),
1168        ) {
1169            let embed_array = scirs2_core::ndarray::Array::from_shape_vec(
1170                (embedshape[0], embedshape[1]),
1171                embed_weights.clone(),
1172            )
1173            .map_err(|e| TextError::InvalidInput(format!("Invalid embedding shape: {e}")))?;
1174            model.token_embedding.set_embeddings(embed_array)?;
1175        }
1176
1177        // Restore positional embeddings
1178        if let (Some(pos_embed_weights), Some(pos_embedshape)) = (
1179            data.weights.get("positional_embeddings"),
1180            data.shapes.get("positional_embeddings"),
1181        ) {
1182            let _pos_embed_array = scirs2_core::ndarray::Array::from_shape_vec(
1183                (pos_embedshape[0], pos_embedshape[1]),
1184                pos_embed_weights.clone(),
1185            )
1186            .map_err(|e| {
1187                TextError::InvalidInput(format!("Invalid positional embedding shape: {e}"))
1188            })?;
1189            // TODO: Restore positional encoding weights when available
1190            // model.positional_encoding.set_embeddings(pos_embed_array)?;
1191        }
1192
1193        // Restore encoder layer weights
1194        for i in 0..config.n_encoder_layers {
1195            let encoder_layers = model.encoder.get_layers_mut();
1196            let (attention, ff, ln1, ln2) = encoder_layers[i].get_components_mut();
1197
1198            // Restore attention weights
1199            if let (
1200                Some(wq_weights),
1201                Some(wqshape),
1202                Some(wk_weights),
1203                Some(wkshape),
1204                Some(wv_weights),
1205                Some(wvshape),
1206                Some(wo_weights),
1207                Some(woshape),
1208            ) = (
1209                data.weights.get(&format!("encoder_{i}_attention_wq")),
1210                data.shapes.get(&format!("encoder_{i}_attention_wq")),
1211                data.weights.get(&format!("encoder_{i}_attention_wk")),
1212                data.shapes.get(&format!("encoder_{i}_attention_wk")),
1213                data.weights.get(&format!("encoder_{i}_attention_wv")),
1214                data.shapes.get(&format!("encoder_{i}_attention_wv")),
1215                data.weights.get(&format!("encoder_{i}_attention_wo")),
1216                data.shapes.get(&format!("encoder_{i}_attention_wo")),
1217            ) {
1218                let w_q = scirs2_core::ndarray::Array::from_shape_vec(
1219                    (wqshape[0], wqshape[1]),
1220                    wq_weights.clone(),
1221                )
1222                .map_err(|e| TextError::InvalidInput(format!("Invalid wq shape: {e}")))?;
1223                let w_k = scirs2_core::ndarray::Array::from_shape_vec(
1224                    (wkshape[0], wkshape[1]),
1225                    wk_weights.clone(),
1226                )
1227                .map_err(|e| TextError::InvalidInput(format!("Invalid wk shape: {e}")))?;
1228                let w_v = scirs2_core::ndarray::Array::from_shape_vec(
1229                    (wvshape[0], wvshape[1]),
1230                    wv_weights.clone(),
1231                )
1232                .map_err(|e| TextError::InvalidInput(format!("Invalid wv shape: {e}")))?;
1233                let w_o = scirs2_core::ndarray::Array::from_shape_vec(
1234                    (woshape[0], woshape[1]),
1235                    wo_weights.clone(),
1236                )
1237                .map_err(|e| TextError::InvalidInput(format!("Invalid wo shape: {e}")))?;
1238
1239                attention.set_weights(w_q, w_k, w_v, w_o)?;
1240            }
1241
1242            // Restore feedforward weights
1243            if let (
1244                Some(w1_weights),
1245                Some(w1shape),
1246                Some(w2_weights),
1247                Some(w2shape),
1248                Some(b1_weights),
1249                Some(b2_weights),
1250            ) = (
1251                data.weights.get(&format!("encoder_{i}_ff_w1")),
1252                data.shapes.get(&format!("encoder_{i}_ff_w1")),
1253                data.weights.get(&format!("encoder_{i}_ff_w2")),
1254                data.shapes.get(&format!("encoder_{i}_ff_w2")),
1255                data.weights.get(&format!("encoder_{i}_ff_b1")),
1256                data.weights.get(&format!("encoder_{i}_ff_b2")),
1257            ) {
1258                let w1 = scirs2_core::ndarray::Array::from_shape_vec(
1259                    (w1shape[0], w1shape[1]),
1260                    w1_weights.clone(),
1261                )
1262                .map_err(|e| TextError::InvalidInput(format!("Invalid w1 shape: {e}")))?;
1263                let w2 = scirs2_core::ndarray::Array::from_shape_vec(
1264                    (w2shape[0], w2shape[1]),
1265                    w2_weights.clone(),
1266                )
1267                .map_err(|e| TextError::InvalidInput(format!("Invalid w2 shape: {e}")))?;
1268                let b1 = scirs2_core::ndarray::Array::from_vec(b1_weights.clone());
1269                let b2 = scirs2_core::ndarray::Array::from_vec(b2_weights.clone());
1270
1271                ff.set_weights(w1, w2, b1, b2)?;
1272            }
1273
1274            // Restore layer norm parameters
1275            if let (Some(gamma1_weights), Some(beta1_weights)) = (
1276                data.weights.get(&format!("encoder_{i}_ln1_gamma")),
1277                data.weights.get(&format!("encoder_{i}_ln1_beta")),
1278            ) {
1279                let gamma1 = scirs2_core::ndarray::Array::from_vec(gamma1_weights.clone());
1280                let beta1 = scirs2_core::ndarray::Array::from_vec(beta1_weights.clone());
1281                ln1.set_params(gamma1, beta1)?;
1282            }
1283
1284            if let (Some(gamma2_weights), Some(beta2_weights)) = (
1285                data.weights.get(&format!("encoder_{i}_ln2_gamma")),
1286                data.weights.get(&format!("encoder_{i}_ln2_beta")),
1287            ) {
1288                let gamma2 = scirs2_core::ndarray::Array::from_vec(gamma2_weights.clone());
1289                let beta2 = scirs2_core::ndarray::Array::from_vec(beta2_weights.clone());
1290                ln2.set_params(gamma2, beta2)?;
1291            }
1292        }
1293
1294        // Restore decoder layer weights
1295        for _i in 0..config.n_decoder_layers {
1296            // Similar restoration for decoder layers
1297            // Note: Implementation would mirror encoder restoration
1298        }
1299
1300        // Restore output projection weights
1301        if let (Some(output_weights), Some(outputshape)) = (
1302            data.weights.get("output_projection"),
1303            data.shapes.get("output_projection"),
1304        ) {
1305            let _output_array = scirs2_core::ndarray::Array::from_shape_vec(
1306                scirs2_core::ndarray::IxDyn(outputshape),
1307                output_weights.clone(),
1308            )
1309            .map_err(|e| {
1310                TextError::InvalidInput(format!("Invalid output projection shape: {e}"))
1311            })?;
1312            // model.output_projection.set_weights(output_array)?;
1313        }
1314
1315        Ok(model)
1316    }
1317
1318    fn model_type(&self) -> ModelType {
1319        ModelType::Transformer
1320    }
1321
1322    fn get_config(&self) -> HashMap<String, String> {
1323        let mut config = HashMap::new();
1324        config.insert("d_model".to_string(), self.config.d_model.to_string());
1325        config.insert("n_heads".to_string(), self.config.nheads.to_string());
1326        config.insert("d_ff".to_string(), self.config.d_ff.to_string());
1327        config.insert(
1328            "n_encoder_layers".to_string(),
1329            self.config.n_encoder_layers.to_string(),
1330        );
1331        config.insert(
1332            "n_decoder_layers".to_string(),
1333            self.config.n_decoder_layers.to_string(),
1334        );
1335        config.insert(
1336            "max_seq_len".to_string(),
1337            self.config.max_seqlen.to_string(),
1338        );
1339        config.insert("dropout".to_string(), self.config.dropout.to_string());
1340        config.insert("vocab_size".to_string(), self.config.vocab_size.to_string());
1341        config
1342    }
1343}
1344
1345/// Implementation of RegistrableModel for Word2Vec
1346impl RegistrableModel for crate::embeddings::Word2Vec {
1347    fn serialize(&self) -> Result<SerializableModelData> {
1348        let mut weights = HashMap::new();
1349        let mut shapes = HashMap::new();
1350        let mut config = HashMap::new();
1351        let vocabulary = Some(self.get_vocabulary());
1352
1353        // Serialize config
1354        config.insert(
1355            "vector_size".to_string(),
1356            self.get_vector_size().to_string(),
1357        );
1358        config.insert(
1359            "algorithm".to_string(),
1360            format!("{:?}", self.get_algorithm()),
1361        );
1362        config.insert(
1363            "window_size".to_string(),
1364            self.get_window_size().to_string(),
1365        );
1366        config.insert("min_count".to_string(), self.get_min_count().to_string());
1367        config.insert(
1368            "negative_samples".to_string(),
1369            self.get_negative_samples().to_string(),
1370        );
1371        config.insert(
1372            "learning_rate".to_string(),
1373            self.get_learning_rate().to_string(),
1374        );
1375        config.insert("epochs".to_string(), self.get_epochs().to_string());
1376        config.insert(
1377            "subsampling_threshold".to_string(),
1378            self.get_subsampling_threshold().to_string(),
1379        );
1380
1381        // Serialize embedding weights
1382        if let Some(embeddings) = self.get_embeddings_matrix() {
1383            let embed_weights = embeddings.as_slice().unwrap().to_vec();
1384            let embedshape = embeddings.shape().to_vec();
1385            weights.insert("embeddings".to_string(), embed_weights);
1386            shapes.insert("embeddings".to_string(), embedshape);
1387        }
1388
1389        Ok(SerializableModelData {
1390            weights,
1391            shapes,
1392            vocabulary,
1393            config,
1394        })
1395    }
1396
1397    fn deserialize(data: &SerializableModelData) -> Result<Self> {
1398        let vector_size = data
1399            .config
1400            .get("vector_size")
1401            .and_then(|s| s.parse().ok())
1402            .ok_or_else(|| TextError::InvalidInput("Missing vector_size config".to_string()))?;
1403        let window_size = data
1404            .config
1405            .get("window_size")
1406            .and_then(|s| s.parse().ok())
1407            .ok_or_else(|| TextError::InvalidInput("Missing window_size config".to_string()))?;
1408        let min_count = data
1409            .config
1410            .get("min_count")
1411            .and_then(|s| s.parse().ok())
1412            .ok_or_else(|| TextError::InvalidInput("Missing min_count config".to_string()))?;
1413
1414        let algorithm = match data.config.get("algorithm").map(|s| s.as_str()) {
1415            Some("CBOW") => crate::embeddings::Word2VecAlgorithm::CBOW,
1416            Some("SkipGram") => crate::embeddings::Word2VecAlgorithm::SkipGram,
1417            _ => {
1418                return Err(TextError::InvalidInput(
1419                    "Invalid or missing algorithm config".to_string(),
1420                ))
1421            }
1422        };
1423
1424        let config = crate::embeddings::Word2VecConfig {
1425            vector_size,
1426            window_size,
1427            min_count,
1428            epochs: 5,            // Default value
1429            learning_rate: 0.025, // Default value
1430            algorithm,
1431            negative_samples: 5,         // Default value
1432            subsample: 1e-3,             // Default value
1433            batch_size: 128,             // Default value
1434            hierarchical_softmax: false, // Default value
1435        };
1436
1437        // Create new Word2Vec instance
1438        let word2vec = crate::embeddings::Word2Vec::with_config(config);
1439
1440        // Restore vocabulary and embeddings if available
1441        if let (Some(vocab), Some(embed_weights), Some(embedshape)) = (
1442            data.vocabulary.as_ref(),
1443            data.weights.get("embeddings"),
1444            data.shapes.get("embeddings"),
1445        ) {
1446            // Restore the full model state from serialized data
1447            let _embedding_matrix = scirs2_core::ndarray::Array::from_shape_vec(
1448                (embedshape[0], embedshape[1]),
1449                embed_weights.clone(),
1450            )
1451            .map_err(|e| TextError::InvalidInput(format!("Invalid embedding shape: {e}")))?;
1452
1453            // Create vocabulary mapping
1454            let mut word_to_index = HashMap::new();
1455            for (i, word) in vocab.iter().enumerate() {
1456                word_to_index.insert(word.clone(), i);
1457            }
1458
1459            // Create new Word2Vec model with restored parameters
1460            // Note: Full model restoration would require internal API access
1461            let mut restored_word2vec = word2vec;
1462
1463            // Apply configuration parameters if available
1464            if let Some(window_size) = data.config.get("window_size").and_then(|s| s.parse().ok()) {
1465                restored_word2vec = restored_word2vec.with_window_size(window_size);
1466            }
1467
1468            if let Some(negative_samples) = data
1469                .config
1470                .get("negative_samples")
1471                .and_then(|s| s.parse().ok())
1472            {
1473                restored_word2vec = restored_word2vec.with_negative_samples(negative_samples);
1474            }
1475
1476            if let Some(learning_rate) = data
1477                .config
1478                .get("learning_rate")
1479                .and_then(|s| s.parse().ok())
1480            {
1481                restored_word2vec = restored_word2vec.with_learning_rate(learning_rate);
1482            }
1483
1484            // TODO: Vocabulary and embedding restoration would require enhanced API
1485            // For now, return the configured model
1486            return Ok(restored_word2vec);
1487        }
1488
1489        // If no saved state available, return new model with config
1490        Ok(word2vec)
1491    }
1492
1493    fn model_type(&self) -> ModelType {
1494        ModelType::WordEmbedding
1495    }
1496
1497    fn get_config(&self) -> HashMap<String, String> {
1498        let mut config = HashMap::new();
1499        config.insert(
1500            "vector_size".to_string(),
1501            self.get_vector_size().to_string(),
1502        );
1503        config.insert(
1504            "algorithm".to_string(),
1505            format!("{:?}", self.get_algorithm()),
1506        );
1507        config.insert(
1508            "window_size".to_string(),
1509            self.get_window_size().to_string(),
1510        );
1511        config.insert("min_count".to_string(), self.get_min_count().to_string());
1512        config
1513    }
1514}
1515
1516#[cfg(test)]
1517mod tests {
1518    use super::*;
1519    use tempfile::TempDir;
1520
1521    #[test]
1522    fn test_model_metadata_creation() {
1523        let metadata = ModelMetadata::new(
1524            "test_model".to_string(),
1525            "Test Model".to_string(),
1526            ModelType::Transformer,
1527        )
1528        .with_version("1.0.0".to_string())
1529        .with_description("A test model".to_string())
1530        .with_metric("accuracy".to_string(), 0.95);
1531
1532        assert_eq!(metadata.id, "test_model");
1533        assert_eq!(metadata.name, "Test Model");
1534        assert_eq!(metadata.version, "1.0.0");
1535        assert_eq!(metadata.description, "A test model");
1536        assert_eq!(metadata.metrics.get("accuracy"), Some(&0.95));
1537    }
1538
1539    #[test]
1540    fn test_model_registry_creation() {
1541        let temp_dir = TempDir::new().unwrap();
1542        let registry = ModelRegistry::new(temp_dir.path(), temp_dir.path()).unwrap();
1543
1544        assert_eq!(registry.models.len(), 0);
1545        assert_eq!(registry.model_cache.len(), 0);
1546    }
1547
1548    #[test]
1549    fn test_prebuilt_models() {
1550        let (config, metadata) = PrebuiltModels::english_transformer_base();
1551
1552        assert_eq!(config.d_model, 512);
1553        assert_eq!(config.nheads, 8);
1554        assert_eq!(metadata.id, "english_transformer_base");
1555        assert_eq!(metadata.model_type, ModelType::Transformer);
1556        assert!(metadata.languages.contains(&"en".to_string()));
1557    }
1558
1559    #[test]
1560    fn test_model_type_display() {
1561        assert_eq!(ModelType::Transformer.to_string(), "transformer");
1562        assert_eq!(ModelType::WordEmbedding.to_string(), "word_embedding");
1563        assert_eq!(
1564            ModelType::Custom("test".to_string()).to_string(),
1565            "custom_test"
1566        );
1567    }
1568}