use crate::embeddings::EmbeddingProvider;
use crate::episode::Episode;
use crate::types::TaskType;
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct ContextAwareEmbeddings {
base_embeddings: Arc<dyn EmbeddingProvider>,
task_adapters: HashMap<TaskType, TaskAdapter>,
}
#[derive(Debug, Clone)]
pub struct TaskAdapter {
pub task_type: TaskType,
pub adaptation_matrix: Vec<Vec<f32>>,
pub trained_on_count: usize,
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
impl TaskAdapter {
fn new_identity(task_type: TaskType, dimension: usize) -> Self {
let mut matrix = vec![vec![0.0; dimension]; dimension];
#[allow(clippy::needless_range_loop)]
for i in 0..dimension {
matrix[i][i] = 1.0;
}
Self {
task_type,
adaptation_matrix: matrix,
trained_on_count: 0,
}
}
#[must_use]
pub fn adapt(&self, base_embedding: Vec<f32>) -> Vec<f32> {
let dim = base_embedding.len();
let mut adapted = vec![0.0; dim];
#[allow(clippy::needless_range_loop)]
for i in 0..dim {
for j in 0..dim {
adapted[i] += base_embedding[j] * self.adaptation_matrix[j][i];
}
}
adapted
}
#[must_use]
pub fn apply(&self, base_embedding: &[f32]) -> Vec<f32> {
self.adapt(base_embedding.to_vec())
}
}
#[derive(Debug, Clone)]
pub struct ContrastivePair {
pub anchor: Episode,
pub positive: Episode,
pub negative: Episode,
}
impl ContextAwareEmbeddings {
fn update_gradient_for_triplet(
adapter: &TaskAdapter,
anchor_emb: &[f32],
positive_emb: &[f32],
negative_emb: &[f32],
gradient: &mut [Vec<f32>],
dim: usize,
margin: f32,
) {
let anchor_adapted = adapter.adapt(anchor_emb.to_vec());
let positive_adapted = adapter.adapt(positive_emb.to_vec());
let negative_adapted = adapter.adapt(negative_emb.to_vec());
let d_pos = euclidean_distance(&anchor_adapted, &positive_adapted);
let d_neg = euclidean_distance(&anchor_adapted, &negative_adapted);
let loss = (d_pos - d_neg + margin).max(0.0);
if loss > 0.0 {
for (i, grad_row) in gradient.iter_mut().enumerate().take(dim) {
for (j, grad_cell) in grad_row.iter_mut().enumerate().take(dim) {
let grad_pos = (anchor_emb[j] - positive_emb[j])
* (anchor_adapted[i] - positive_adapted[i]);
let grad_neg = (anchor_emb[j] - negative_emb[j])
* (anchor_adapted[i] - negative_adapted[i]);
*grad_cell += grad_pos - grad_neg;
}
}
}
}
fn apply_gradient_update(
adaptation_matrix: &mut [Vec<f32>],
gradient: &[Vec<f32>],
learning_rate: f32,
num_pairs: usize,
dim: usize,
) {
for (i, matrix_row) in adaptation_matrix.iter_mut().enumerate().take(dim) {
for (j, matrix_cell) in matrix_row.iter_mut().enumerate().take(dim) {
*matrix_cell -= learning_rate * gradient[i][j] / num_pairs as f32;
}
}
}
#[must_use]
pub fn new(base_embeddings: Arc<dyn EmbeddingProvider>) -> Self {
Self {
base_embeddings,
task_adapters: HashMap::new(),
}
}
pub async fn get_adapted_embedding(
&self,
text: &str,
task_type: Option<TaskType>,
) -> Result<Vec<f32>> {
let base_embedding = self.base_embeddings.embed_text(text).await?;
if let Some(task) = task_type {
if let Some(adapter) = self.task_adapters.get(&task) {
return Ok(adapter.adapt(base_embedding));
}
}
Ok(base_embedding)
}
pub async fn get_embedding(&self, text: &str) -> Result<Vec<f32>> {
self.base_embeddings.embed_text(text).await
}
pub async fn train_adapter(
&mut self,
task_type: TaskType,
contrastive_pairs: &[ContrastivePair],
) -> Result<()> {
if contrastive_pairs.is_empty() {
anyhow::bail!("Cannot train adapter with empty training set");
}
let dim = self.base_embeddings.embedding_dimension();
let mut adapter = TaskAdapter::new_identity(task_type, dim);
const LEARNING_RATE: f32 = 0.01;
const EPOCHS: usize = 100;
const MARGIN: f32 = 0.5;
let mut embedded_pairs = Vec::new();
for pair in contrastive_pairs {
let anchor_emb = self
.base_embeddings
.embed_text(&pair.anchor.task_description)
.await?;
let positive_emb = self
.base_embeddings
.embed_text(&pair.positive.task_description)
.await?;
let negative_emb = self
.base_embeddings
.embed_text(&pair.negative.task_description)
.await?;
embedded_pairs.push((anchor_emb, positive_emb, negative_emb));
}
for _epoch in 0..EPOCHS {
let mut gradient = vec![vec![0.0; dim]; dim];
for (anchor_emb, positive_emb, negative_emb) in &embedded_pairs {
Self::update_gradient_for_triplet(
&adapter,
anchor_emb,
positive_emb,
negative_emb,
&mut gradient,
dim,
MARGIN,
);
}
Self::apply_gradient_update(
&mut adapter.adaptation_matrix,
&gradient,
LEARNING_RATE,
contrastive_pairs.len(),
dim,
);
}
adapter.trained_on_count = contrastive_pairs.len();
self.task_adapters.insert(task_type, adapter);
Ok(())
}
#[must_use]
pub fn has_adapter(&self, task_type: TaskType) -> bool {
self.task_adapters.contains_key(&task_type)
}
#[must_use]
pub fn adapter_count(&self) -> usize {
self.task_adapters.len()
}
#[must_use]
pub fn get_adapter(&self, task_type: TaskType) -> Option<&TaskAdapter> {
self.task_adapters.get(&task_type)
}
#[must_use]
pub fn base_provider(&self) -> &Arc<dyn EmbeddingProvider> {
&self.base_embeddings
}
}
#[cfg(test)]
pub mod tests;