use candle_core::Tensor;
#[derive(Debug, Clone)]
pub struct VoiceCloneConfig {
pub max_new_tokens: usize,
pub min_duration: Option<f32>,
pub max_duration: Option<f32>,
}
impl Default for VoiceCloneConfig {
fn default() -> Self {
Self {
max_new_tokens: 2048,
min_duration: None,
max_duration: Some(30.0),
}
}
}
#[derive(Debug, Clone)]
pub struct VoiceClonePromptItem {
pub ref_code: Option<Tensor>,
pub ref_spk_embedding: Tensor,
pub x_vector_only_mode: bool,
pub icl_mode: bool,
pub ref_text: Option<String>,
}
impl VoiceClonePromptItem {
pub fn x_vector_only(ref_spk_embedding: Tensor) -> Self {
Self {
ref_code: None,
ref_spk_embedding,
x_vector_only_mode: true,
icl_mode: false,
ref_text: None,
}
}
pub fn icl(ref_code: Tensor, ref_spk_embedding: Tensor, ref_text: String) -> Self {
Self {
ref_code: Some(ref_code),
ref_spk_embedding,
x_vector_only_mode: false,
icl_mode: true,
ref_text: Some(ref_text),
}
}
pub fn new(
ref_code: Option<Tensor>,
ref_spk_embedding: Tensor,
x_vector_only_mode: bool,
ref_text: Option<String>,
) -> Self {
Self {
ref_code,
ref_spk_embedding,
x_vector_only_mode,
icl_mode: !x_vector_only_mode,
ref_text,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.icl_mode && !self.x_vector_only_mode {
if self.ref_code.is_none() {
return Err("ICL mode requires ref_code to be present".to_string());
}
if self.ref_text.is_none() {
return Err("ICL mode requires ref_text to be present".to_string());
}
}
Ok(())
}
pub fn is_icl(&self) -> bool {
self.icl_mode && !self.x_vector_only_mode
}
pub fn is_x_vector_only(&self) -> bool {
self.x_vector_only_mode
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
fn create_test_embedding() -> Tensor {
Tensor::zeros((1024,), DType::F32, &Device::Cpu).unwrap()
}
fn create_test_codes() -> Tensor {
Tensor::zeros((100, 32), DType::I64, &Device::Cpu).unwrap()
}
#[test]
fn test_x_vector_only_creation() {
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem::x_vector_only(embedding);
assert!(
prompt.ref_code.is_none(),
"x_vector_only should have no ref_code"
);
assert!(
prompt.x_vector_only_mode,
"x_vector_only_mode should be true"
);
assert!(!prompt.icl_mode, "icl_mode should be false");
assert!(
prompt.ref_text.is_none(),
"x_vector_only should have no ref_text"
);
}
#[test]
fn test_x_vector_only_validates() {
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem::x_vector_only(embedding);
let result = prompt.validate();
assert!(
result.is_ok(),
"x_vector_only prompt should validate successfully"
);
}
#[test]
fn test_is_x_vector_only_predicate() {
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem::x_vector_only(embedding);
assert!(
prompt.is_x_vector_only(),
"is_x_vector_only() should return true"
);
assert!(
!prompt.is_icl(),
"is_icl() should return false for x_vector_only"
);
}
#[test]
fn test_icl_creation() {
let codes = create_test_codes();
let embedding = create_test_embedding();
let ref_text = "Hello, this is a test.".to_string();
let prompt = VoiceClonePromptItem::icl(codes, embedding, ref_text.clone());
assert!(prompt.ref_code.is_some(), "icl should have ref_code");
assert!(
!prompt.x_vector_only_mode,
"x_vector_only_mode should be false for icl"
);
assert!(prompt.icl_mode, "icl_mode should be true");
assert_eq!(prompt.ref_text, Some(ref_text), "ref_text should match");
}
#[test]
fn test_icl_validates_with_all_fields() {
let codes = create_test_codes();
let embedding = create_test_embedding();
let ref_text = "Hello, this is a test.".to_string();
let prompt = VoiceClonePromptItem::icl(codes, embedding, ref_text);
let result = prompt.validate();
assert!(
result.is_ok(),
"icl prompt with all fields should validate successfully"
);
}
#[test]
fn test_is_icl_predicate() {
let codes = create_test_codes();
let embedding = create_test_embedding();
let ref_text = "Hello, this is a test.".to_string();
let prompt = VoiceClonePromptItem::icl(codes, embedding, ref_text);
assert!(
prompt.is_icl(),
"is_icl() should return true for icl prompt"
);
assert!(
!prompt.is_x_vector_only(),
"is_x_vector_only() should return false for icl"
);
}
#[test]
fn test_icl_missing_ref_code_fails_validation() {
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem {
ref_code: None,
ref_spk_embedding: embedding,
x_vector_only_mode: false,
icl_mode: true,
ref_text: Some("Some text".to_string()),
};
let result = prompt.validate();
assert!(
result.is_err(),
"ICL prompt without ref_code should fail validation"
);
let err_msg = result.unwrap_err();
assert!(
err_msg.contains("ref_code"),
"Error should mention ref_code: {}",
err_msg
);
}
#[test]
fn test_icl_missing_ref_text_fails_validation() {
let codes = create_test_codes();
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem {
ref_code: Some(codes),
ref_spk_embedding: embedding,
x_vector_only_mode: false,
icl_mode: true,
ref_text: None,
};
let result = prompt.validate();
assert!(
result.is_err(),
"ICL prompt without ref_text should fail validation"
);
let err_msg = result.unwrap_err();
assert!(
err_msg.contains("ref_text"),
"Error should mention ref_text: {}",
err_msg
);
}
#[test]
fn test_new_constructor_x_vector_mode() {
let embedding = create_test_embedding();
let prompt = VoiceClonePromptItem::new(
None, embedding, true, None,
);
assert!(prompt.x_vector_only_mode);
assert!(
!prompt.icl_mode,
"icl_mode should be opposite of x_vector_only_mode"
);
}
#[test]
fn test_new_constructor_icl_mode() {
let codes = create_test_codes();
let embedding = create_test_embedding();
let ref_text = "Test text".to_string();
let prompt = VoiceClonePromptItem::new(
Some(codes),
embedding,
false, Some(ref_text),
);
assert!(!prompt.x_vector_only_mode);
assert!(
prompt.icl_mode,
"icl_mode should be opposite of x_vector_only_mode"
);
}
#[test]
fn test_voice_clone_config_default() {
let config = VoiceCloneConfig::default();
assert_eq!(config.max_new_tokens, 2048);
assert!(config.min_duration.is_none());
assert_eq!(config.max_duration, Some(30.0));
}
}