1use crate::{Error, Result};
4use candle_core::{Device, Module, Tensor};
5use candle_nn::{linear, Linear, VarBuilder};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum ModelType {
14 NeuralVC,
16 CycleGAN,
18 AutoVC,
20 StarGAN,
22 WaveNet,
24 Transformer,
26 Custom,
28}
29
30impl ModelType {
31 pub fn default_config(&self) -> ModelConfig {
33 match self {
34 ModelType::NeuralVC => ModelConfig {
35 input_dim: 80,
36 hidden_dim: 256,
37 output_dim: 80,
38 num_layers: 4,
39 dropout: 0.1,
40 activation: ActivationType::ReLU,
41 normalization: NormalizationType::BatchNorm,
42 model_specific: HashMap::new(),
43 },
44 ModelType::CycleGAN => ModelConfig {
45 input_dim: 80,
46 hidden_dim: 512,
47 output_dim: 80,
48 num_layers: 6,
49 dropout: 0.0,
50 activation: ActivationType::LeakyReLU,
51 normalization: NormalizationType::InstanceNorm,
52 model_specific: HashMap::from([
53 ("discriminator_layers".to_string(), 3.0),
54 ("lambda_cycle".to_string(), 10.0),
55 ]),
56 },
57 ModelType::AutoVC => ModelConfig {
58 input_dim: 80,
59 hidden_dim: 512,
60 output_dim: 80,
61 num_layers: 8,
62 dropout: 0.1,
63 activation: ActivationType::ReLU,
64 normalization: NormalizationType::BatchNorm,
65 model_specific: HashMap::from([
66 ("bottleneck_dim".to_string(), 32.0),
67 ("speaker_embedding_dim".to_string(), 256.0),
68 ]),
69 },
70 ModelType::StarGAN => ModelConfig {
71 input_dim: 80,
72 hidden_dim: 512,
73 output_dim: 80,
74 num_layers: 6,
75 dropout: 0.0,
76 activation: ActivationType::ReLU,
77 normalization: NormalizationType::InstanceNorm,
78 model_specific: HashMap::from([
79 ("domain_embedding_dim".to_string(), 8.0),
80 ("num_domains".to_string(), 4.0),
81 ]),
82 },
83 ModelType::WaveNet => ModelConfig {
84 input_dim: 1,
85 hidden_dim: 256,
86 output_dim: 256,
87 num_layers: 30,
88 dropout: 0.0,
89 activation: ActivationType::Tanh,
90 normalization: NormalizationType::None,
91 model_specific: HashMap::from([
92 ("dilation_channels".to_string(), 32.0),
93 ("residual_channels".to_string(), 32.0),
94 ("skip_channels".to_string(), 256.0),
95 ]),
96 },
97 ModelType::Transformer => ModelConfig {
98 input_dim: 80,
99 hidden_dim: 512,
100 output_dim: 80,
101 num_layers: 6,
102 dropout: 0.1,
103 activation: ActivationType::GELU,
104 normalization: NormalizationType::LayerNorm,
105 model_specific: HashMap::from([
106 ("num_heads".to_string(), 8.0),
107 ("ff_dim".to_string(), 2048.0),
108 ]),
109 },
110 ModelType::Custom => ModelConfig::default(),
111 }
112 }
113
114 pub fn supports_realtime(&self) -> bool {
116 match self {
117 ModelType::NeuralVC => true,
118 ModelType::AutoVC => true,
119 ModelType::WaveNet => false, ModelType::Transformer => true,
121 ModelType::CycleGAN => false, ModelType::StarGAN => false, ModelType::Custom => false, }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ModelConfig {
131 pub input_dim: usize,
133 pub hidden_dim: usize,
135 pub output_dim: usize,
137 pub num_layers: usize,
139 pub dropout: f32,
141 pub activation: ActivationType,
143 pub normalization: NormalizationType,
145 pub model_specific: HashMap<String, f32>,
147}
148
149impl Default for ModelConfig {
150 fn default() -> Self {
151 Self {
152 input_dim: 80,
153 hidden_dim: 256,
154 output_dim: 80,
155 num_layers: 4,
156 dropout: 0.1,
157 activation: ActivationType::ReLU,
158 normalization: NormalizationType::BatchNorm,
159 model_specific: HashMap::new(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
166pub enum ActivationType {
167 ReLU,
169 LeakyReLU,
171 Tanh,
173 Sigmoid,
175 GELU,
177 Swish,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183pub enum NormalizationType {
184 None,
186 BatchNorm,
188 LayerNorm,
190 InstanceNorm,
192 GroupNorm,
194}
195
196#[derive(Debug)]
198pub struct ConversionModel {
199 pub model_type: ModelType,
201 pub config: ModelConfig,
203 network: Box<dyn NeuralNetwork>,
205 device: Device,
207 parameters_loaded: bool,
209 metadata: ModelMetadata,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ModelMetadata {
216 pub name: String,
218 pub version: String,
220 pub training_dataset: Option<String>,
222 pub training_epochs: Option<u32>,
224 pub parameter_count: Option<u64>,
226 pub sample_rate: u32,
228 pub created_at: Option<std::time::SystemTime>,
230}
231
232impl Default for ModelMetadata {
233 fn default() -> Self {
234 Self {
235 name: "Untitled Model".to_string(),
236 version: "1.0.0".to_string(),
237 training_dataset: None,
238 training_epochs: None,
239 parameter_count: None,
240 sample_rate: 22050,
241 created_at: Some(std::time::SystemTime::now()),
242 }
243 }
244}
245
246pub trait NeuralNetwork: std::fmt::Debug + Send + Sync {
248 fn forward(&self, input: &Tensor) -> Result<Tensor>;
250
251 fn input_shape(&self) -> &[usize];
253
254 fn output_shape(&self) -> &[usize];
256
257 fn load_weights(&mut self, weights: &[u8]) -> Result<()>;
259
260 fn save_weights(&self) -> Result<Vec<u8>>;
262
263 fn parameter_count(&self) -> u64;
265
266 fn set_training(&mut self, training: bool);
268
269 fn clone_network(&self) -> Box<dyn NeuralNetwork>;
271}
272
273impl ConversionModel {
274 pub fn new(model_type: ModelType) -> Self {
276 let config = model_type.default_config();
277 Self::with_config(model_type, config)
278 }
279
280 pub fn with_config(model_type: ModelType, config: ModelConfig) -> Self {
282 let device = Device::Cpu; let network =
284 Self::create_network(model_type, &config, &device).expect("operation should succeed");
285
286 Self {
287 model_type,
288 config,
289 network,
290 device,
291 parameters_loaded: false,
292 metadata: ModelMetadata::default(),
293 }
294 }
295
296 pub async fn load_from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
298 let path = path.as_ref();
299 info!("Loading conversion model from: {:?}", path);
300
301 if !path.exists() {
303 return Err(Error::model(format!("Model file not found: {path:?}")));
304 }
305
306 let model_type = ModelType::NeuralVC;
309 let mut model = Self::new(model_type);
310
311 if let Some(weights_path) = Self::find_weights_file(path) {
313 model.load_weights_file(&weights_path).await?;
314 }
315
316 if let Some(metadata_path) = Self::find_metadata_file(path) {
318 model.load_metadata_file(&metadata_path).await?;
319 }
320
321 info!("Successfully loaded model: {}", model.metadata.name);
322 Ok(model)
323 }
324
325 pub async fn load_from_bytes(bytes: &[u8], model_type: ModelType) -> Result<Self> {
327 debug!("Loading model from {} bytes", bytes.len());
328
329 let mut model = Self::new(model_type);
330 model.network.load_weights(bytes)?;
331 model.parameters_loaded = true;
332
333 Ok(model)
334 }
335
336 pub async fn process_tensor(&self, input: &Tensor) -> Result<Tensor> {
338 if !self.parameters_loaded {
339 warn!("Model parameters not loaded, using uninitialized weights");
340 }
341
342 debug!("Processing tensor with shape: {:?}", input.shape());
343
344 let expected_shape = self.network.input_shape();
346 let input_shape = input.shape().dims();
347
348 if input_shape.len() < expected_shape.len() {
349 return Err(Error::model(format!(
350 "Input tensor has insufficient dimensions: expected {expected_shape:?}, got {input_shape:?}"
351 )));
352 }
353
354 let output = self.network.forward(input)?;
356
357 debug!("Model output shape: {:?}", output.shape());
358 Ok(output)
359 }
360
361 pub async fn process(&self, input: &[f32]) -> Result<Vec<f32>> {
363 let input_tensor = self.audio_to_tensor(input)?;
365
366 let output_tensor = self.process_tensor(&input_tensor).await?;
368
369 self.tensor_to_audio(&output_tensor)
371 }
372
373 pub fn set_device(&mut self, device: Device) -> Result<()> {
375 info!("Moving model to device: {:?}", device);
376 self.device = device;
377 Ok(())
379 }
380
381 pub fn info(&self) -> ModelInfo {
383 ModelInfo {
384 model_type: self.model_type,
385 config: self.config.clone(),
386 metadata: self.metadata.clone(),
387 device: format!("{:?}", self.device),
388 parameters_loaded: self.parameters_loaded,
389 parameter_count: self.network.parameter_count(),
390 supports_realtime: self.model_type.supports_realtime(),
391 }
392 }
393
394 pub async fn save_to_path<P: AsRef<Path>>(&self, path: P) -> Result<()> {
396 let path = path.as_ref();
397 info!("Saving model to: {:?}", path);
398
399 if let Some(parent) = path.parent() {
401 std::fs::create_dir_all(parent)?;
402 }
403
404 let weights = self.network.save_weights()?;
406 let weights_path = path.with_extension("weights");
407 std::fs::write(&weights_path, weights)?;
408
409 let metadata_json = serde_json::to_string_pretty(&self.metadata)?;
411 let metadata_path = path.with_extension("json");
412 std::fs::write(&metadata_path, metadata_json)?;
413
414 let config_json = serde_json::to_string_pretty(&self.config)?;
416 let config_path = path.with_extension("config.json");
417 std::fs::write(&config_path, config_json)?;
418
419 info!("Model saved successfully");
420 Ok(())
421 }
422
423 fn audio_to_tensor(&self, audio: &[f32]) -> Result<Tensor> {
425 let input_shape = self.network.input_shape();
427
428 match input_shape.len() {
429 1 => {
430 let feature_size = input_shape[0];
432 if audio.len() != feature_size {
433 return Err(Error::model(format!(
434 "Input audio length {} doesn't match expected feature size {}",
435 audio.len(),
436 feature_size
437 )));
438 }
439 Tensor::from_vec(audio.to_vec(), (1, audio.len()), &self.device)
440 }
441 2 => {
442 let _batch_size = 1;
444 let feature_size = input_shape[1];
445 let time_steps = audio.len() / feature_size;
446
447 if !audio.len().is_multiple_of(feature_size) {
448 let mut padded_audio = audio.to_vec();
450 let padding_needed = feature_size - (audio.len() % feature_size);
451 padded_audio.extend(vec![0.0; padding_needed]);
452
453 let new_time_steps = padded_audio.len() / feature_size;
454 Tensor::from_vec(padded_audio, (new_time_steps, feature_size), &self.device)
455 } else {
456 Tensor::from_vec(audio.to_vec(), (time_steps, feature_size), &self.device)
457 }
458 }
459 3 => {
460 let batch_size = 1;
462 let feature_size = input_shape[2];
463 let time_steps = audio.len() / feature_size;
464
465 if !audio.len().is_multiple_of(feature_size) {
466 let mut padded_audio = audio.to_vec();
467 let padding_needed = feature_size - (audio.len() % feature_size);
468 padded_audio.extend(vec![0.0; padding_needed]);
469
470 let new_time_steps = padded_audio.len() / feature_size;
471 Tensor::from_vec(
472 padded_audio,
473 (batch_size, new_time_steps, feature_size),
474 &self.device,
475 )
476 } else {
477 Tensor::from_vec(
478 audio.to_vec(),
479 (batch_size, time_steps, feature_size),
480 &self.device,
481 )
482 }
483 }
484 _ => {
485 return Err(Error::model(format!(
486 "Unsupported input shape dimensionality: {}",
487 input_shape.len()
488 )));
489 }
490 }
491 .map_err(|e| Error::model(format!("Failed to create input tensor: {e}")))
492 }
493
494 fn tensor_to_audio(&self, tensor: &Tensor) -> Result<Vec<f32>> {
496 match tensor.shape().dims().len() {
497 1 => {
498 tensor
500 .to_vec1::<f32>()
501 .map_err(|e| Error::model(format!("Failed to convert tensor to audio: {e}")))
502 }
503 2 => {
504 let squeezed = tensor
506 .squeeze(0)
507 .map_err(|e| Error::model(format!("Failed to squeeze tensor: {e}")))?;
508 squeezed
509 .to_vec1::<f32>()
510 .map_err(|e| Error::model(format!("Failed to convert tensor to audio: {e}")))
511 }
512 _ => Err(Error::model(format!(
513 "Unsupported tensor shape for audio conversion: {:?}",
514 tensor.shape()
515 ))),
516 }
517 }
518
519 fn create_network(
521 model_type: ModelType,
522 config: &ModelConfig,
523 device: &Device,
524 ) -> Result<Box<dyn NeuralNetwork>> {
525 match model_type {
526 ModelType::NeuralVC => Ok(Box::new(NeuralVCNetwork::new(config, device)?)),
527 ModelType::AutoVC => Ok(Box::new(AutoVCNetwork::new(config, device)?)),
528 ModelType::Transformer => Ok(Box::new(TransformerNetwork::new(config, device)?)),
529 _ => {
530 warn!(
532 "Model type {:?} not fully implemented, using simple feedforward network",
533 model_type
534 );
535 Ok(Box::new(SimpleNetwork::new(config, device)?))
536 }
537 }
538 }
539
540 fn find_weights_file(base_path: &Path) -> Option<std::path::PathBuf> {
543 let weights_path = base_path.with_extension("weights");
544 if weights_path.exists() {
545 Some(weights_path)
546 } else {
547 None
548 }
549 }
550
551 fn find_metadata_file(base_path: &Path) -> Option<std::path::PathBuf> {
552 let metadata_path = base_path.with_extension("json");
553 if metadata_path.exists() {
554 Some(metadata_path)
555 } else {
556 None
557 }
558 }
559
560 async fn load_weights_file(&mut self, path: &Path) -> Result<()> {
561 let weights = std::fs::read(path)?;
562 self.network.load_weights(&weights)?;
563 self.parameters_loaded = true;
564 Ok(())
565 }
566
567 async fn load_metadata_file(&mut self, path: &Path) -> Result<()> {
568 let metadata_json = std::fs::read_to_string(path)?;
569 self.metadata = serde_json::from_str(&metadata_json)?;
570 Ok(())
571 }
572}
573
574impl Default for ConversionModel {
575 fn default() -> Self {
576 Self::new(ModelType::NeuralVC)
577 }
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
582pub struct ModelInfo {
583 pub model_type: ModelType,
585 pub config: ModelConfig,
587 pub metadata: ModelMetadata,
589 pub device: String,
591 pub parameters_loaded: bool,
593 pub parameter_count: u64,
595 pub supports_realtime: bool,
597}
598
599#[derive(Debug)]
603struct SimpleNetwork {
604 layers: Vec<Linear>,
605 config: ModelConfig,
606 input_shape: Vec<usize>,
607 output_shape: Vec<usize>,
608 training: bool,
609}
610
611impl SimpleNetwork {
612 fn new(config: &ModelConfig, device: &Device) -> Result<Self> {
613 let varmap = candle_nn::VarMap::new();
614 let vs = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, device);
615
616 let mut layers = Vec::new();
617 let mut current_dim = config.input_dim;
618
619 for i in 0..config.num_layers - 1 {
621 let layer = linear(current_dim, config.hidden_dim, vs.pp(format!("layer_{i}")))
622 .map_err(|e| Error::model(format!("Failed to create layer {i}: {e}")))?;
623 layers.push(layer);
624 current_dim = config.hidden_dim;
625 }
626
627 let output_layer = linear(current_dim, config.output_dim, vs.pp("output"))
629 .map_err(|e| Error::model(format!("Failed to create output layer: {e}")))?;
630 layers.push(output_layer);
631
632 Ok(Self {
633 layers,
634 config: config.clone(),
635 input_shape: vec![config.input_dim],
636 output_shape: vec![config.output_dim],
637 training: false,
638 })
639 }
640}
641
642impl NeuralNetwork for SimpleNetwork {
643 fn forward(&self, input: &Tensor) -> Result<Tensor> {
644 let mut x = input.clone();
645
646 for (i, layer) in self.layers.iter().enumerate() {
647 x = layer
648 .forward(&x)
649 .map_err(|e| Error::model(format!("Forward pass failed at layer {i}: {e}")))?;
650
651 if i < self.layers.len() - 1 {
653 x = match self.config.activation {
654 ActivationType::ReLU => x.relu()?,
655 ActivationType::LeakyReLU => {
656 let scaled = (x.clone() * 0.01)?;
657 x.maximum(&scaled)?
658 }
659 ActivationType::Tanh => x.tanh()?,
660 ActivationType::Sigmoid => {
661 let neg_x = x.neg()?;
663 let exp_neg_x = neg_x.exp()?;
664 let one_plus_exp = (exp_neg_x + 1.0)?;
665 one_plus_exp.recip()?
666 }
667 ActivationType::GELU => x.gelu()?,
668 ActivationType::Swish => x.silu()?,
669 };
670 }
671 }
672
673 Ok(x)
674 }
675
676 fn input_shape(&self) -> &[usize] {
677 &self.input_shape
678 }
679
680 fn output_shape(&self) -> &[usize] {
681 &self.output_shape
682 }
683
684 fn load_weights(&mut self, _weights: &[u8]) -> Result<()> {
685 Ok(())
687 }
688
689 fn save_weights(&self) -> Result<Vec<u8>> {
690 Ok(vec![0; 1024])
692 }
693
694 fn parameter_count(&self) -> u64 {
695 let mut count = 0;
696 let mut current_dim = self.config.input_dim;
697
698 for _ in 0..self.config.num_layers - 1 {
699 count += (current_dim * self.config.hidden_dim + self.config.hidden_dim) as u64;
700 current_dim = self.config.hidden_dim;
701 }
702
703 count += (current_dim * self.config.output_dim + self.config.output_dim) as u64;
705
706 count
707 }
708
709 fn set_training(&mut self, training: bool) {
710 self.training = training;
711 }
712
713 fn clone_network(&self) -> Box<dyn NeuralNetwork> {
714 Box::new(SimpleNetwork {
715 layers: Vec::new(), config: self.config.clone(),
717 input_shape: self.input_shape.clone(),
718 output_shape: self.output_shape.clone(),
719 training: self.training,
720 })
721 }
722}
723
724#[derive(Debug)]
726struct NeuralVCNetwork {
727 encoder: Vec<Linear>,
728 decoder: Vec<Linear>,
729 config: ModelConfig,
730 input_shape: Vec<usize>,
731 output_shape: Vec<usize>,
732 training: bool,
733}
734
735impl NeuralVCNetwork {
736 fn new(config: &ModelConfig, device: &Device) -> Result<Self> {
737 let varmap = candle_nn::VarMap::new();
738 let vs = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, device);
739
740 let mut encoder = Vec::new();
742 let mut current_dim = config.input_dim;
743
744 for i in 0..config.num_layers / 2 {
745 let layer = linear(
746 current_dim,
747 config.hidden_dim,
748 vs.pp(format!("encoder_{i}")),
749 )
750 .map_err(|e| Error::model(format!("Failed to create encoder layer {i}: {e}")))?;
751 encoder.push(layer);
752 current_dim = config.hidden_dim;
753 }
754
755 let mut decoder = Vec::new();
757 for i in 0..config.num_layers / 2 {
758 let output_dim = if i == config.num_layers / 2 - 1 {
759 config.output_dim
760 } else {
761 config.hidden_dim
762 };
763
764 let layer = linear(current_dim, output_dim, vs.pp(format!("decoder_{i}")))
765 .map_err(|e| Error::model(format!("Failed to create decoder layer {i}: {e}")))?;
766 decoder.push(layer);
767 current_dim = output_dim;
768 }
769
770 Ok(Self {
771 encoder,
772 decoder,
773 config: config.clone(),
774 input_shape: vec![config.input_dim],
775 output_shape: vec![config.output_dim],
776 training: false,
777 })
778 }
779}
780
781impl NeuralNetwork for NeuralVCNetwork {
782 fn forward(&self, input: &Tensor) -> Result<Tensor> {
783 let mut x = input.clone();
784
785 for (i, layer) in self.encoder.iter().enumerate() {
787 x = layer.forward(&x).map_err(|e| {
788 Error::model(format!("Encoder forward pass failed at layer {i}: {e}"))
789 })?;
790
791 x = match self.config.activation {
792 ActivationType::ReLU => x.relu()?,
793 ActivationType::LeakyReLU => {
794 let scaled = (x.clone() * 0.01)?;
795 x.maximum(&scaled)?
796 }
797 ActivationType::Tanh => x.tanh()?,
798 ActivationType::Sigmoid => {
799 let neg_x = x.neg()?;
801 let exp_neg_x = neg_x.exp()?;
802 let one_plus_exp = (exp_neg_x + 1.0)?;
803 one_plus_exp.recip()?
804 }
805 ActivationType::GELU => x.gelu()?,
806 ActivationType::Swish => x.silu()?,
807 };
808 }
809
810 for (i, layer) in self.decoder.iter().enumerate() {
812 x = layer.forward(&x).map_err(|e| {
813 Error::model(format!("Decoder forward pass failed at layer {i}: {e}"))
814 })?;
815
816 if i < self.decoder.len() - 1 {
818 x = match self.config.activation {
819 ActivationType::ReLU => x.relu()?,
820 ActivationType::LeakyReLU => {
821 let scaled = (x.clone() * 0.01)?;
822 x.maximum(&scaled)?
823 }
824 ActivationType::Tanh => x.tanh()?,
825 ActivationType::Sigmoid => {
826 let neg_x = x.neg()?;
828 let exp_neg_x = neg_x.exp()?;
829 let one_plus_exp = (exp_neg_x + 1.0)?;
830 one_plus_exp.recip()?
831 }
832 ActivationType::GELU => x.gelu()?,
833 ActivationType::Swish => x.silu()?,
834 };
835 }
836 }
837
838 Ok(x)
839 }
840
841 fn input_shape(&self) -> &[usize] {
842 &self.input_shape
843 }
844
845 fn output_shape(&self) -> &[usize] {
846 &self.output_shape
847 }
848
849 fn load_weights(&mut self, _weights: &[u8]) -> Result<()> {
850 Ok(())
851 }
852
853 fn save_weights(&self) -> Result<Vec<u8>> {
854 Ok(vec![0; 2048])
855 }
856
857 fn parameter_count(&self) -> u64 {
858 let encoder_params = (self.config.input_dim * self.config.hidden_dim
860 + self.config.hidden_dim) as u64
861 * (self.config.num_layers / 2) as u64;
862 let decoder_params = (self.config.hidden_dim * self.config.output_dim
863 + self.config.output_dim) as u64
864 * (self.config.num_layers / 2) as u64;
865 encoder_params + decoder_params
866 }
867
868 fn set_training(&mut self, training: bool) {
869 self.training = training;
870 }
871
872 fn clone_network(&self) -> Box<dyn NeuralNetwork> {
873 Box::new(NeuralVCNetwork {
874 encoder: Vec::new(),
875 decoder: Vec::new(),
876 config: self.config.clone(),
877 input_shape: self.input_shape.clone(),
878 output_shape: self.output_shape.clone(),
879 training: self.training,
880 })
881 }
882}
883
884type AutoVCNetwork = NeuralVCNetwork; type TransformerNetwork = SimpleNetwork; #[cfg(test)]
891mod tests {
892 use super::*;
893
894 #[test]
895 fn test_model_type_properties() {
896 assert!(ModelType::NeuralVC.supports_realtime());
897 assert!(!ModelType::CycleGAN.supports_realtime());
898 assert!(ModelType::Transformer.supports_realtime());
899 }
900
901 #[test]
902 fn test_model_config_creation() {
903 let config = ModelType::NeuralVC.default_config();
904 assert_eq!(config.input_dim, 80);
905 assert_eq!(config.hidden_dim, 256);
906 assert_eq!(config.output_dim, 80);
907 }
908
909 #[test]
910 fn test_model_creation() {
911 let model = ConversionModel::new(ModelType::NeuralVC);
912 assert_eq!(model.model_type, ModelType::NeuralVC);
913 assert!(!model.parameters_loaded);
914 }
915
916 #[tokio::test]
917 async fn test_model_processing() {
918 let model = ConversionModel::new(ModelType::NeuralVC);
919 let input = vec![0.1; 80];
921
922 let result = model.process(&input).await;
923 match &result {
924 Ok(output) => {
925 println!("Test passed successfully, output length: {}", output.len());
926 assert_eq!(output.len(), 80, "Output should have same length as input");
927 }
928 Err(e) => {
929 println!("Test failed with error: {e:?}");
930 }
931 }
932 assert!(
933 result.is_ok(),
934 "Model processing should succeed: {:?}",
935 result.err()
936 );
937 }
938
939 #[test]
940 fn test_model_info() {
941 let model = ConversionModel::new(ModelType::AutoVC);
942 let info = model.info();
943
944 assert_eq!(info.model_type, ModelType::AutoVC);
945 assert!(!info.parameters_loaded);
946 assert!(info.parameter_count > 0);
947 }
948}