fn scale_tensor(x: &Tensor, scale: f32) -> Tensor {
x.mul_scalar(scale)
}
fn transpose_last_two(x: &Tensor) -> Tensor {
crate::nn::transformer::transpose_last_two(x)
}
fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
crate::nn::transformer::matmul_batched(a, b)
}
fn softmax(x: &Tensor, dim: i32) -> Tensor {
crate::nn::functional::softmax(x, dim)
}
fn dropout(x: &Tensor, p: f32) -> Tensor {
if p <= 0.0 {
return x.clone();
}
let data = x.data();
let scale = 1.0 / (1.0 - p);
let mut output = Vec::with_capacity(data.len());
for (i, &val) in data.iter().enumerate() {
if (i % 100) as f32 / 100.0 < p {
output.push(0.0);
} else {
output.push(val * scale);
}
}
Tensor::new(&output, x.shape())
}
fn concat_heads(x: &Tensor, batch_size: usize, seq_len: usize) -> Tensor {
let embed_dim = x.shape()[1] * x.shape()[3]; crate::nn::transformer::reshape_from_attention(x, batch_size, seq_len, embed_dim)
}
fn gelu(x: &Tensor) -> Tensor {
crate::nn::functional::gelu(x)
}
fn layer_norm(x: &Tensor, weight: &Tensor, bias: &Tensor, eps: f32) -> Tensor {
crate::nn::functional::layer_norm(x, weight, bias, eps)
}
#[derive(Debug)]
pub struct ContrastiveLoss {
temperature: f32,
}
impl ContrastiveLoss {
#[must_use]
pub fn new() -> Self {
Self { temperature: 0.07 }
}
#[must_use]
pub fn with_temperature(temperature: f32) -> Self {
Self { temperature }
}
#[must_use]
pub fn forward(
&self,
anchor: &Tensor,
positive: &Tensor,
negatives: Option<&Tensor>,
) -> Tensor {
let pos_sim = cosine_similarity_batch(anchor, positive);
let pos_sim = div_scalar(&pos_sim, self.temperature);
let neg_sims = if let Some(negs) = negatives {
let sims = cosine_similarity_many(anchor, negs);
div_scalar(&sims, self.temperature)
} else {
let all_sims = cosine_similarity_matrix(anchor, positive);
div_scalar(&all_sims, self.temperature)
};
info_nce_loss(&pos_sim, &neg_sims)
}
}
impl Default for ContrastiveLoss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TripletLoss {
margin: f32,
distance: TripletDistance,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TripletDistance {
Euclidean,
SquaredEuclidean,
Cosine,
}