use alaz_core::{AlazError, Result};
use qdrant_client::Qdrant;
use qdrant_client::qdrant::{
Condition, Filter, PointStruct, SearchPointsBuilder, UpsertPointsBuilder,
};
use tracing::debug;
use crate::client::{COLLECTION_COLBERT, point_id};
pub(crate) fn average_embedding(tokens: &[Vec<f32>]) -> Vec<f32> {
if tokens.is_empty() {
return vec![];
}
let dim = tokens[0].len();
let mut avg = vec![0.0_f32; dim];
let count = tokens.len() as f32;
for token in tokens {
for (i, &val) in token.iter().enumerate() {
if i < dim {
avg[i] += val;
}
}
}
for v in &mut avg {
*v /= count;
}
avg
}
pub(crate) use alaz_core::cosine_similarity;
pub(crate) fn max_sim(query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
query_tokens
.iter()
.map(|qt| {
doc_tokens
.iter()
.map(|dt| cosine_similarity(qt, dt))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
pub struct ColbertOps;
impl ColbertOps {
pub async fn upsert_colbert(
client: &Qdrant,
entity_type: &str,
entity_id: &str,
project_id: Option<&str>,
token_embeddings: Vec<Vec<f32>>,
) -> Result<()> {
if token_embeddings.is_empty() {
return Err(AlazError::Validation(
"token_embeddings cannot be empty".into(),
));
}
let pid = point_id(entity_type, entity_id);
let avg = average_embedding(&token_embeddings);
let tokens_json = serde_json::to_string(&token_embeddings)
.map_err(|e| AlazError::Qdrant(format!("failed to serialize token embeddings: {e}")))?;
let mut payload = qdrant_client::Payload::new();
payload.insert("entity_type", entity_type);
payload.insert("entity_id", entity_id);
payload.insert("token_vectors", tokens_json.as_str());
if let Some(project) = project_id {
payload.insert("project_id", project);
}
let point = PointStruct::new(pid, avg, payload);
client
.upsert_points(UpsertPointsBuilder::new(COLLECTION_COLBERT, vec![point]).wait(true))
.await
.map_err(|e| {
AlazError::Qdrant(format!(
"failed to upsert ColBERT {entity_type}:{entity_id}: {e}"
))
})?;
debug!(
entity_type,
entity_id,
num_tokens = token_embeddings.len(),
"upserted ColBERT embeddings"
);
Ok(())
}
pub async fn search_colbert(
client: &Qdrant,
query_tokens: Vec<Vec<f32>>,
project: Option<&str>,
limit: usize,
) -> Result<Vec<(String, String, f32)>> {
if query_tokens.is_empty() {
return Ok(vec![]);
}
let avg_query = average_embedding(&query_tokens);
let candidate_limit = (limit * 4).max(20) as u64;
let mut builder = SearchPointsBuilder::new(COLLECTION_COLBERT, avg_query, candidate_limit)
.with_payload(true);
if let Some(project_id) = project {
let filter = Filter::must([Condition::matches("project_id", project_id.to_string())]);
builder = builder.filter(filter);
}
let response = client
.search_points(builder)
.await
.map_err(|e| AlazError::Qdrant(format!("failed to search ColBERT collection: {e}")))?;
let mut scored: Vec<(String, String, f32)> = response
.result
.into_iter()
.filter_map(|point| {
let entity_type = point
.payload
.get("entity_type")
.and_then(|v| v.as_str())
.map(|s| s.to_string())?;
let entity_id = point
.payload
.get("entity_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())?;
let token_vectors_json = point
.payload
.get("token_vectors")
.and_then(|v| v.as_str())?;
let doc_tokens: Vec<Vec<f32>> = serde_json::from_str(token_vectors_json).ok()?;
let score = max_sim(&query_tokens, &doc_tokens);
Some((entity_type, entity_id, score))
})
.collect();
scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn average_embedding_empty_input() {
let result = average_embedding(&[]);
assert!(result.is_empty());
}
#[test]
fn average_embedding_single_vector() {
let tokens = vec![vec![1.0, 2.0, 3.0]];
let avg = average_embedding(&tokens);
assert_eq!(avg, vec![1.0, 2.0, 3.0]);
}
#[test]
fn average_embedding_multiple_vectors() {
let tokens = vec![vec![1.0, 0.0, 3.0], vec![3.0, 4.0, 1.0]];
let avg = average_embedding(&tokens);
assert!((avg[0] - 2.0).abs() < 0.001);
assert!((avg[1] - 2.0).abs() < 0.001);
assert!((avg[2] - 2.0).abs() < 0.001);
}
#[test]
fn cosine_identical_vectors_returns_one() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 0.001, "expected ~1.0, got {sim}");
}
#[test]
fn cosine_orthogonal_vectors_returns_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 0.001, "expected ~0.0, got {sim}");
}
#[test]
fn cosine_opposite_vectors_returns_neg_one() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - (-1.0)).abs() < 0.001, "expected ~-1.0, got {sim}");
}
#[test]
fn cosine_zero_vector_returns_zero() {
let a = vec![1.0, 2.0, 3.0];
let zero = vec![0.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &zero)).abs() < 0.001);
assert!((cosine_similarity(&zero, &a)).abs() < 0.001);
assert!((cosine_similarity(&zero, &zero)).abs() < 0.001);
}
#[test]
fn max_sim_single_tokens() {
let query = vec![vec![1.0, 0.0]];
let doc = vec![vec![1.0, 0.0]];
let score = max_sim(&query, &doc);
assert!((score - 1.0).abs() < 0.001, "expected ~1.0, got {score}");
}
#[test]
fn max_sim_multiple_tokens() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let score = max_sim(&query, &doc);
assert!((score - 2.0).abs() < 0.001, "expected ~2.0, got {score}");
}
#[test]
fn point_id_deterministic() {
let id1 = point_id("knowledge_item", "abc123");
let id2 = point_id("knowledge_item", "abc123");
assert_eq!(id1, id2);
}
#[test]
fn point_id_different_inputs_differ() {
let id1 = point_id("knowledge_item", "abc123");
let id2 = point_id("episode", "abc123");
let id3 = point_id("knowledge_item", "xyz789");
assert_ne!(id1, id2);
assert_ne!(id1, id3);
}
}