use crate::error::{Result, TextError};
use scirs2_core::ndarray::Array1;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum StsDatasetFormat {
StsB,
Sick,
Sts12to16,
}
#[derive(Debug, Clone)]
pub struct StsReport {
pub pearson: f32,
pub spearman: f32,
pub mse: f32,
pub predictions: Vec<f32>,
pub gold_labels: Vec<f32>,
pub n_pairs: usize,
}
type StsPairs = Vec<(Vec<String>, Vec<String>, f32)>;
pub fn load_sts_from_tsv(path: impl AsRef<Path>) -> Result<StsPairs> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(path.as_ref()).map_err(|e| TextError::IoError(e.to_string()))?;
let reader = BufReader::new(file);
let mut pairs = Vec::new();
for line in reader.lines() {
let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
let line = line.trim();
if line.is_empty() {
continue;
}
let fields: Vec<&str> = line.split('\t').collect();
let (score_str, s1, s2) = if fields.len() >= 8 {
(fields[4], fields[5], fields[6])
} else if fields.len() >= 3 {
(fields[0], fields[1], fields[2])
} else {
continue;
};
let score: f32 = match score_str.trim().parse() {
Ok(v) => v,
Err(_) => continue, };
let tokens1: Vec<String> = s1.split_whitespace().map(str::to_owned).collect();
let tokens2: Vec<String> = s2.split_whitespace().map(str::to_owned).collect();
pairs.push((tokens1, tokens2, score));
}
Ok(pairs)
}
fn pearson_correlation(x: &[f32], y: &[f32]) -> f32 {
let n = x.len() as f32;
if n == 0.0 {
return 0.0;
}
let mx = x.iter().sum::<f32>() / n;
let my = y.iter().sum::<f32>() / n;
let num: f32 = x.iter().zip(y).map(|(a, b)| (a - mx) * (b - my)).sum();
let da: f32 = x.iter().map(|a| (a - mx).powi(2)).sum::<f32>().sqrt();
let db: f32 = y.iter().map(|b| (b - my).powi(2)).sum::<f32>().sqrt();
if da == 0.0 || db == 0.0 {
0.0
} else {
num / (da * db)
}
}
fn spearman_correlation(x: &[f32], y: &[f32]) -> f32 {
fn rank(v: &[f32]) -> Vec<f32> {
let mut indexed: Vec<(usize, f32)> = v.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranks = vec![0.0f32; v.len()];
let mut i = 0;
while i < indexed.len() {
let val = indexed[i].1;
let mut j = i + 1;
while j < indexed.len() && indexed[j].1 == val {
j += 1;
}
let avg_rank = (i + j + 1) as f32 / 2.0; for item in &indexed[i..j] {
ranks[item.0] = avg_rank;
}
i = j;
}
ranks
}
let rx = rank(x);
let ry = rank(y);
pearson_correlation(&rx, &ry)
}
pub fn sts_evaluate(
embed_fn: &dyn Fn(&[String]) -> Array1<f32>,
pairs: &[(Vec<String>, Vec<String>, f32)],
) -> Result<StsReport> {
if pairs.is_empty() {
return Err(TextError::InvalidInput(
"STS dataset is empty; at least one pair is required".into(),
));
}
let mut predictions = Vec::with_capacity(pairs.len());
let mut gold_labels = Vec::with_capacity(pairs.len());
for (s1_tokens, s2_tokens, gold) in pairs {
let e1 = embed_fn(s1_tokens);
let e2 = embed_fn(s2_tokens);
let dot = e1.dot(&e2);
let n1 = e1.dot(&e1).sqrt();
let n2 = e2.dot(&e2).sqrt();
let cosine = if n1 == 0.0 || n2 == 0.0 {
0.0f32
} else {
dot / (n1 * n2)
};
predictions.push(cosine);
gold_labels.push(*gold);
}
let pearson = pearson_correlation(&predictions, &gold_labels);
let spearman = spearman_correlation(&predictions, &gold_labels);
let mse = predictions
.iter()
.zip(&gold_labels)
.map(|(p, g)| (p - g).powi(2))
.sum::<f32>()
/ predictions.len() as f32;
Ok(StsReport {
pearson,
spearman,
mse,
predictions,
gold_labels,
n_pairs: pairs.len(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn bow_embed(tokens: &[String], dim: usize) -> Array1<f32> {
let mut v = Array1::zeros(dim);
for (i, _tok) in tokens.iter().enumerate() {
let idx = i % dim;
v[idx] += 1.0;
}
v
}
#[test]
fn sts_empty_returns_error() {
let result = sts_evaluate(&|t| bow_embed(t, 4), &[]);
assert!(result.is_err());
}
#[test]
fn sts_single_pair_identical_tokens() {
let pairs = vec![(vec!["cat".to_string()], vec!["cat".to_string()], 5.0f32)];
let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
assert_eq!(report.n_pairs, 1);
assert!((report.predictions[0] - 1.0).abs() < 1e-5);
}
#[test]
fn sts_mse_is_non_negative() {
let pairs = vec![
(vec!["a".to_string()], vec!["b".to_string()], 2.5f32),
(vec!["c".to_string()], vec!["c".to_string()], 4.0f32),
];
let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
assert!(report.mse >= 0.0);
}
}