use crate::errors::{TrustformersError, Result};
use crate::tensor::Tensor;
use super::config::{CFGConfig, GenerationConfig};
use super::cache::KVCache;
pub struct CFGGenerator {
cfg_config: Option<CFGConfig>,
vocab_size: usize,
config: GenerationConfig,
}
impl CFGGenerator {
pub fn new(config: GenerationConfig, vocab_size: usize) -> Result<Self> {
let cfg_config = config.guided_generation.as_ref().and_then(|g| g.cfg.clone());
Ok(Self {
base_generator: super::core::TextGenerator::new(config, vocab_size),
cfg_config,
})
}
pub fn generate_with_cfg(
&self,
input_ids: &[usize],
conditional_logits_fn: impl Fn(&[usize], Option<&KVCache>) -> Result<(Tensor, Option<KVCache>)>,
unconditional_logits_fn: impl Fn(
&[usize],
Option<&KVCache>,
) -> Result<(Tensor, Option<KVCache>)>,
) -> Result<Vec<Vec<usize>>> {
let cfg_config = match self.cfg_config.as_ref() {
Some(config) => config,
None => return self.base_generator.generate(input_ids, conditional_logits_fn),
};
let mut sequences = vec![input_ids.to_vec()];
let mut conditional_cache =
if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None };
let mut unconditional_cache =
if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None };
let max_length = self.base_generator.get_max_length(input_ids.len());
for step in 0..max_length {
let (conditional_logits, new_conditional_cache) =
conditional_logits_fn(&sequences[0], conditional_cache.as_ref())?;
conditional_cache = new_conditional_cache;
let (unconditional_logits, new_unconditional_cache) =
unconditional_logits_fn(&sequences[0], unconditional_cache.as_ref())?;
unconditional_cache = new_unconditional_cache;
let guided_logits = self.apply_cfg_guidance(
&conditional_logits,
&unconditional_logits,
cfg_config.guidance_scale,
cfg_config.dynamic_thresholding,
cfg_config.threshold_percentile,
)?;
let next_token = self.base_generator.sample_token(&guided_logits)?;
sequences[0].push(next_token);
if self.base_generator.should_stop(&sequences[0], next_token, step + 1) {
break;
}
}
Ok(sequences)
}
fn apply_cfg_guidance(
&self,
conditional_logits: &Tensor,
unconditional_logits: &Tensor,
guidance_scale: f32,
dynamic_thresholding: bool,
threshold_percentile: f32,
) -> Result<Tensor> {
match (conditional_logits, unconditional_logits) {
(Tensor::F32(cond_arr), Tensor::F32(uncond_arr)) => {
let cond_data: Vec<f32> = cond_arr.iter().cloned().collect();
let uncond_data: Vec<f32> = uncond_arr.iter().cloned().collect();
if cond_data.len() != uncond_data.len() {
return Err(TrustformersError::tensor_op_error(
"Conditional and unconditional logits must have same length",
"apply_cfg_guidance",
));
}
let mut guided_logits: Vec<f32> = uncond_data
.iter()
.zip(cond_data.iter())
.map(|(&uncond, &cond)| uncond + guidance_scale * (cond - uncond))
.collect();
if dynamic_thresholding {
guided_logits =
self.apply_dynamic_thresholding(guided_logits, threshold_percentile)?;
}
Tensor::from_vec(guided_logits, &[cond_data.len()])
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for CFG guidance",
"apply_cfg_guidance",
)),
}
}
fn apply_dynamic_thresholding(
&self,
mut logits: Vec<f32>,
percentile: f32,
) -> Result<Vec<f32>> {
let mut sorted_abs_logits: Vec<f32> = logits.iter().map(|&x| x.abs()).collect();
sorted_abs_logits.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
let threshold_idx = ((sorted_abs_logits.len() as f32 * percentile) as usize)
.min(sorted_abs_logits.len() - 1);
let threshold = sorted_abs_logits[threshold_idx];
for logit in &mut logits {
*logit = logit.clamp(-threshold, threshold);
}
Ok(logits)
}
pub fn generate_with_negative_prompt(
&self,
input_ids: &[usize],
positive_logits_fn: impl Fn(&[usize], Option<&KVCache>) -> Result<(Tensor, Option<KVCache>)>,
negative_logits_fn: impl Fn(&[usize], Option<&KVCache>) -> Result<(Tensor, Option<KVCache>)>,
negative_scale: f32,
) -> Result<Vec<Vec<usize>>> {
let mut sequences = vec![input_ids.to_vec()];
let mut positive_cache =
if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None };
let mut negative_cache =
if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None };
let max_length = self.base_generator.get_max_length(input_ids.len());
for step in 0..max_length {
let (positive_logits, new_positive_cache) =
positive_logits_fn(&sequences[0], positive_cache.as_ref())?;
positive_cache = new_positive_cache;
let (negative_logits, new_negative_cache) =
negative_logits_fn(&sequences[0], negative_cache.as_ref())?;
negative_cache = new_negative_cache;
let guided_logits =
self.apply_negative_guidance(&positive_logits, &negative_logits, negative_scale)?;
let next_token = self.base_generator.sample_token(&guided_logits)?;
sequences[0].push(next_token);
if self.base_generator.should_stop(&sequences[0], next_token, step + 1) {
break;
}
}
Ok(sequences)
}
fn apply_negative_guidance(
&self,
positive_logits: &Tensor,
negative_logits: &Tensor,
negative_scale: f32,
) -> Result<Tensor> {
match (positive_logits, negative_logits) {
(Tensor::F32(pos_arr), Tensor::F32(neg_arr)) => {
let pos_data: Vec<f32> = pos_arr.iter().cloned().collect();
let neg_data: Vec<f32> = neg_arr.iter().cloned().collect();
if pos_data.len() != neg_data.len() {
return Err(TrustformersError::tensor_op_error(
"Positive and negative logits must have same length",
"apply_negative_guidance",
));
}
let guided_logits: Vec<f32> = pos_data
.iter()
.zip(neg_data.iter())
.map(|(&pos, &neg)| pos - negative_scale * neg)
.collect();
Tensor::from_vec(guided_logits, &[pos_data.len()])
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for negative guidance",
"apply_negative_guidance",
)),
}
}
pub fn generate_with_multi_condition_cfg(
&self,
input_ids: &[usize],
condition_logits_fns: Vec<
Box<dyn Fn(&[usize], Option<&KVCache>) -> Result<(Tensor, Option<KVCache>)>>,
>,
condition_scales: Vec<f32>,
unconditional_logits_fn: impl Fn(
&[usize],
Option<&KVCache>,
) -> Result<(Tensor, Option<KVCache>)>,
) -> Result<Vec<Vec<usize>>> {
if condition_logits_fns.len() != condition_scales.len() {
return Err(TrustformersError::invalid_input(
"Number of condition functions must match number of scales".to_string(),
));
}
let mut sequences = vec![input_ids.to_vec()];
let mut condition_caches: Vec<Option<KVCache>> = (0..condition_logits_fns.len())
.map(|_| if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None })
.collect();
let mut unconditional_cache =
if self.base_generator.config.use_cache { Some(KVCache::new()) } else { None };
let max_length = self.base_generator.get_max_length(input_ids.len());
for step in 0..max_length {
let (unconditional_logits, new_unconditional_cache) =
unconditional_logits_fn(&sequences[0], unconditional_cache.as_ref())?;
unconditional_cache = new_unconditional_cache;
let mut condition_logits = Vec::new();
for (i, condition_fn) in condition_logits_fns.iter().enumerate() {
let (logits, new_cache) =
condition_fn(&sequences[0], condition_caches[i].as_ref())?;
condition_caches[i] = new_cache;
condition_logits.push(logits);
}
let guided_logits = self.apply_multi_condition_cfg(
&unconditional_logits,
&condition_logits,
&condition_scales,
)?;
let next_token = self.base_generator.sample_token(&guided_logits)?;
sequences[0].push(next_token);
if self.base_generator.should_stop(&sequences[0], next_token, step + 1) {
break;
}
}
Ok(sequences)
}
fn apply_multi_condition_cfg(
&self,
unconditional_logits: &Tensor,
condition_logits: &[Tensor],
condition_scales: &[f32],
) -> Result<Tensor> {
match unconditional_logits {
Tensor::F32(uncond_arr) => {
let uncond_data: Vec<f32> = uncond_arr.iter().cloned().collect();
let mut guided_logits = uncond_data.clone();
for (condition_tensor, &scale) in
condition_logits.iter().zip(condition_scales.iter())
{
match condition_tensor {
Tensor::F32(cond_arr) => {
let cond_data: Vec<f32> = cond_arr.iter().cloned().collect();
if cond_data.len() != uncond_data.len() {
return Err(TrustformersError::tensor_op_error(
"All logits must have same length",
"apply_multi_condition_cfg",
));
}
for (i, &cond) in cond_data.iter().enumerate() {
guided_logits[i] += scale * (cond - uncond_data[i]);
}
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for condition",
"apply_multi_condition_cfg",
))
},
}
}
Tensor::from_vec(guided_logits, &[uncond_data.len()])
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for unconditional logits",
"apply_multi_condition_cfg",
)),
}
}
}