use std::collections::HashMap;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use crate::types::RankingConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterPoint {
pub pagerank_alpha: f64,
pub pagerank_chat_multiplier: f64,
pub depth_weight_root: f64,
pub depth_weight_moderate: f64,
pub depth_weight_deep: f64,
pub depth_weight_vendor: f64,
pub boost_mentioned_ident: f64,
pub boost_mentioned_file: f64,
pub boost_chat_file: f64,
pub boost_temporal_coupling: f64,
pub boost_focus_expansion: f64,
pub git_recency_decay_days: f64,
pub git_recency_max_boost: f64,
pub git_churn_threshold: f64,
pub git_churn_max_boost: f64,
pub focus_decay: f64,
pub focus_max_hops: f64,
}
impl ParameterPoint {
pub fn to_ranking_config(&self) -> RankingConfig {
let mut config = RankingConfig::default();
config.pagerank_alpha = self.pagerank_alpha;
config.pagerank_chat_multiplier = self.pagerank_chat_multiplier;
config.depth_weight_root = self.depth_weight_root;
config.depth_weight_moderate = self.depth_weight_moderate;
config.depth_weight_deep = self.depth_weight_deep;
config.depth_weight_vendor = self.depth_weight_vendor;
config.boost_mentioned_ident = self.boost_mentioned_ident;
config.boost_mentioned_file = self.boost_mentioned_file;
config.boost_chat_file = self.boost_chat_file;
config.boost_temporal_coupling = self.boost_temporal_coupling;
config.boost_focus_expansion = self.boost_focus_expansion;
config.git_recency_decay_days = self.git_recency_decay_days;
config.git_recency_max_boost = self.git_recency_max_boost;
config.git_churn_threshold = self.git_churn_threshold.round() as usize;
config.git_churn_max_boost = self.git_churn_max_boost;
config
}
pub fn focus_params(&self) -> (f64, usize) {
(self.focus_decay, self.focus_max_hops.round() as usize)
}
}
impl Default for ParameterPoint {
fn default() -> Self {
let config = RankingConfig::default();
Self {
pagerank_alpha: config.pagerank_alpha,
pagerank_chat_multiplier: config.pagerank_chat_multiplier,
depth_weight_root: config.depth_weight_root,
depth_weight_moderate: config.depth_weight_moderate,
depth_weight_deep: config.depth_weight_deep,
depth_weight_vendor: config.depth_weight_vendor,
boost_mentioned_ident: config.boost_mentioned_ident,
boost_mentioned_file: config.boost_mentioned_file,
boost_chat_file: config.boost_chat_file,
boost_temporal_coupling: config.boost_temporal_coupling,
boost_focus_expansion: config.boost_focus_expansion,
git_recency_decay_days: config.git_recency_decay_days,
git_recency_max_boost: config.git_recency_max_boost,
git_churn_threshold: config.git_churn_threshold as f64,
git_churn_max_boost: config.git_churn_max_boost,
focus_decay: 0.5,
focus_max_hops: 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct ParamRange {
pub min: f64,
pub max: f64,
pub log_scale: bool,
}
impl ParamRange {
pub fn linear(min: f64, max: f64) -> Self {
Self {
min,
max,
log_scale: false,
}
}
pub fn log(min: f64, max: f64) -> Self {
Self {
min,
max,
log_scale: true,
}
}
pub fn decode(&self, normalized: f64) -> f64 {
let t = normalized.clamp(0.0, 1.0);
if self.log_scale {
let log_min = self.min.ln();
let log_max = self.max.ln();
(log_min + t * (log_max - log_min)).exp()
} else {
self.min + t * (self.max - self.min)
}
}
pub fn encode(&self, value: f64) -> f64 {
if self.log_scale {
let log_min = self.min.ln();
let log_max = self.max.ln();
let log_val = value.clamp(self.min, self.max).ln();
(log_val - log_min) / (log_max - log_min)
} else {
(value - self.min) / (self.max - self.min)
}
}
}
#[derive(Debug, Clone)]
pub struct ParameterGrid {
pub ranges: HashMap<String, ParamRange>,
}
impl Default for ParameterGrid {
fn default() -> Self {
let mut ranges = HashMap::new();
ranges.insert("pagerank_alpha".into(), ParamRange::linear(0.70, 0.95));
ranges.insert(
"pagerank_chat_multiplier".into(),
ParamRange::log(10.0, 200.0),
);
ranges.insert("depth_weight_root".into(), ParamRange::linear(0.5, 2.0));
ranges.insert("depth_weight_moderate".into(), ParamRange::linear(0.2, 1.0));
ranges.insert("depth_weight_deep".into(), ParamRange::linear(0.05, 0.5));
ranges.insert("depth_weight_vendor".into(), ParamRange::log(0.001, 0.1));
ranges.insert("boost_mentioned_ident".into(), ParamRange::log(2.0, 50.0));
ranges.insert("boost_mentioned_file".into(), ParamRange::log(2.0, 20.0));
ranges.insert("boost_chat_file".into(), ParamRange::log(5.0, 100.0));
ranges.insert("boost_temporal_coupling".into(), ParamRange::log(1.0, 10.0));
ranges.insert("boost_focus_expansion".into(), ParamRange::log(1.0, 20.0));
ranges.insert(
"git_recency_decay_days".into(),
ParamRange::linear(7.0, 90.0),
);
ranges.insert("git_recency_max_boost".into(), ParamRange::log(2.0, 20.0));
ranges.insert("git_churn_threshold".into(), ParamRange::linear(3.0, 15.0));
ranges.insert("git_churn_max_boost".into(), ParamRange::log(2.0, 15.0));
ranges.insert("focus_decay".into(), ParamRange::linear(0.2, 0.8));
ranges.insert("focus_max_hops".into(), ParamRange::linear(1.0, 3.0));
Self { ranges }
}
}
impl ParameterGrid {
pub fn decode(&self, normalized: &[f64]) -> ParameterPoint {
let names = self.param_names();
assert_eq!(normalized.len(), names.len(), "Dimension mismatch");
let values: HashMap<_, _> = names
.iter()
.zip(normalized.iter())
.map(|(name, &n)| {
let range = &self.ranges[name];
(name.as_str(), range.decode(n))
})
.collect();
ParameterPoint {
pagerank_alpha: values["pagerank_alpha"],
pagerank_chat_multiplier: values["pagerank_chat_multiplier"],
depth_weight_root: values["depth_weight_root"],
depth_weight_moderate: values["depth_weight_moderate"],
depth_weight_deep: values["depth_weight_deep"],
depth_weight_vendor: values["depth_weight_vendor"],
boost_mentioned_ident: values["boost_mentioned_ident"],
boost_mentioned_file: values["boost_mentioned_file"],
boost_chat_file: values["boost_chat_file"],
boost_temporal_coupling: values["boost_temporal_coupling"],
boost_focus_expansion: values["boost_focus_expansion"],
git_recency_decay_days: values["git_recency_decay_days"],
git_recency_max_boost: values["git_recency_max_boost"],
git_churn_threshold: values["git_churn_threshold"],
git_churn_max_boost: values["git_churn_max_boost"],
focus_decay: values["focus_decay"],
focus_max_hops: values["focus_max_hops"],
}
}
pub fn param_names(&self) -> Vec<String> {
let mut names: Vec<_> = self.ranges.keys().cloned().collect();
names.sort();
names
}
pub fn ndim(&self) -> usize {
self.ranges.len()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SearchStrategy {
Grid { points_per_dim: usize },
LatinHypercube,
Random,
Bayesian,
}
pub fn sample_points(
grid: &ParameterGrid,
strategy: SearchStrategy,
n_samples: usize,
seed: u64,
) -> Vec<ParameterPoint> {
let mut rng = StdRng::seed_from_u64(seed);
match strategy {
SearchStrategy::Grid { points_per_dim } => sample_grid(grid, points_per_dim),
SearchStrategy::LatinHypercube => sample_lhs(grid, n_samples, &mut rng),
SearchStrategy::Random => sample_random(grid, n_samples, &mut rng),
SearchStrategy::Bayesian => {
sample_lhs(grid, n_samples.min(20), &mut rng)
}
}
}
fn sample_grid(grid: &ParameterGrid, points_per_dim: usize) -> Vec<ParameterPoint> {
let names = grid.param_names();
let ndim = names.len();
let total = points_per_dim.pow(ndim as u32);
(0..total)
.map(|idx| {
let mut normalized = Vec::with_capacity(ndim);
let mut remaining = idx;
for _ in 0..ndim {
let dim_idx = remaining % points_per_dim;
remaining /= points_per_dim;
let t = if points_per_dim > 1 {
dim_idx as f64 / (points_per_dim - 1) as f64
} else {
0.5
};
normalized.push(t);
}
grid.decode(&normalized)
})
.collect()
}
fn sample_lhs<R: Rng>(grid: &ParameterGrid, n_samples: usize, rng: &mut R) -> Vec<ParameterPoint> {
let ndim = grid.ndim();
let mut strata: Vec<Vec<usize>> = (0..ndim)
.map(|_| {
let mut perm: Vec<usize> = (0..n_samples).collect();
perm.shuffle(rng);
perm
})
.collect();
(0..n_samples)
.map(|i| {
let normalized: Vec<f64> = (0..ndim)
.map(|d| {
let stratum = strata[d][i];
let lower = stratum as f64 / n_samples as f64;
let upper = (stratum + 1) as f64 / n_samples as f64;
lower + rng.r#gen::<f64>() * (upper - lower)
})
.collect();
grid.decode(&normalized)
})
.collect()
}
fn sample_random<R: Rng>(
grid: &ParameterGrid,
n_samples: usize,
rng: &mut R,
) -> Vec<ParameterPoint> {
let ndim = grid.ndim();
(0..n_samples)
.map(|_| {
let normalized: Vec<f64> = (0..ndim).map(|_| rng.r#gen()).collect();
grid.decode(&normalized)
})
.collect()
}
pub fn bayesian_next_sample<R: Rng>(
grid: &ParameterGrid,
history: &[(ParameterPoint, f64)], rng: &mut R,
) -> ParameterPoint {
if history.is_empty() {
let normalized: Vec<f64> = (0..grid.ndim()).map(|_| rng.r#gen()).collect();
return grid.decode(&normalized);
}
let best_score = history
.iter()
.map(|(_, s)| *s)
.fold(f64::NEG_INFINITY, f64::max);
let n_candidates = 1000;
let candidates: Vec<ParameterPoint> = sample_random(grid, n_candidates, rng);
candidates
.into_iter()
.max_by(|a, b| {
let dist_a = min_distance_to_history(a, history);
let dist_b = min_distance_to_history(b, history);
let score_a = dist_a + 0.3 * similarity_to_best(a, history, best_score);
let score_b = dist_b + 0.3 * similarity_to_best(b, history, best_score);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or_else(|| grid.decode(&vec![0.5; grid.ndim()]))
}
fn min_distance_to_history(point: &ParameterPoint, history: &[(ParameterPoint, f64)]) -> f64 {
history
.iter()
.map(|(h, _)| normalized_distance(point, h))
.fold(f64::INFINITY, f64::min)
}
fn similarity_to_best(
point: &ParameterPoint,
history: &[(ParameterPoint, f64)],
best_score: f64,
) -> f64 {
let mut total = 0.0;
let mut weight_sum = 0.0;
for (h, score) in history {
let weight = (*score / best_score).max(0.0);
let sim = 1.0 / (1.0 + normalized_distance(point, h));
total += weight * sim;
weight_sum += weight;
}
if weight_sum > 0.0 {
total / weight_sum
} else {
0.0
}
}
fn normalized_distance(a: &ParameterPoint, b: &ParameterPoint) -> f64 {
let grid = ParameterGrid::default();
let names = grid.param_names();
let a_vals = point_to_vec(a, &grid);
let b_vals = point_to_vec(b, &grid);
let sum_sq: f64 = a_vals
.iter()
.zip(b_vals.iter())
.map(|(av, bv)| (av - bv).powi(2))
.sum();
(sum_sq / names.len() as f64).sqrt()
}
fn point_to_vec(point: &ParameterPoint, grid: &ParameterGrid) -> Vec<f64> {
let names = grid.param_names();
names
.iter()
.map(|name| {
let range = &grid.ranges[name];
let value = match name.as_str() {
"pagerank_alpha" => point.pagerank_alpha,
"pagerank_chat_multiplier" => point.pagerank_chat_multiplier,
"depth_weight_root" => point.depth_weight_root,
"depth_weight_moderate" => point.depth_weight_moderate,
"depth_weight_deep" => point.depth_weight_deep,
"depth_weight_vendor" => point.depth_weight_vendor,
"boost_mentioned_ident" => point.boost_mentioned_ident,
"boost_mentioned_file" => point.boost_mentioned_file,
"boost_chat_file" => point.boost_chat_file,
"boost_temporal_coupling" => point.boost_temporal_coupling,
"boost_focus_expansion" => point.boost_focus_expansion,
"git_recency_decay_days" => point.git_recency_decay_days,
"git_recency_max_boost" => point.git_recency_max_boost,
"git_churn_threshold" => point.git_churn_threshold,
"git_churn_max_boost" => point.git_churn_max_boost,
"focus_decay" => point.focus_decay,
"focus_max_hops" => point.focus_max_hops,
_ => 0.5,
};
range.encode(value)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_range_linear() {
let range = ParamRange::linear(0.0, 10.0);
assert!((range.decode(0.0) - 0.0).abs() < 1e-6);
assert!((range.decode(0.5) - 5.0).abs() < 1e-6);
assert!((range.decode(1.0) - 10.0).abs() < 1e-6);
}
#[test]
fn test_param_range_log() {
let range = ParamRange::log(1.0, 100.0);
assert!((range.decode(0.0) - 1.0).abs() < 1e-6);
assert!((range.decode(1.0) - 100.0).abs() < 1e-6);
assert!((range.decode(0.5) - 10.0).abs() < 1e-6);
}
#[test]
fn test_param_range_roundtrip() {
let range = ParamRange::log(2.0, 50.0);
for v in [2.0, 10.0, 25.0, 50.0] {
let encoded = range.encode(v);
let decoded = range.decode(encoded);
assert!((decoded - v).abs() < 1e-6, "Roundtrip failed for {}", v);
}
}
#[test]
fn test_grid_decode() {
let grid = ParameterGrid::default();
let ndim = grid.ndim();
let min_point = grid.decode(&vec![0.0; ndim]);
assert!(min_point.pagerank_alpha >= 0.69);
let max_point = grid.decode(&vec![1.0; ndim]);
assert!(max_point.pagerank_alpha <= 0.96);
}
#[test]
fn test_lhs_coverage() {
let grid = ParameterGrid::default();
let mut rng = StdRng::seed_from_u64(42);
let samples = sample_lhs(&grid, 10, &mut rng);
assert_eq!(samples.len(), 10);
let alphas: Vec<_> = samples.iter().map(|s| s.pagerank_alpha).collect();
let min_alpha = alphas.iter().cloned().fold(f64::INFINITY, f64::min);
let max_alpha = alphas.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(max_alpha - min_alpha > 0.1, "LHS should cover the range");
}
#[test]
fn test_default_point_to_config() {
let point = ParameterPoint::default();
let config = point.to_ranking_config();
assert!((config.pagerank_alpha - 0.85).abs() < 1e-6);
assert!((config.boost_mentioned_ident - 10.0).abs() < 1e-6);
}
#[test]
fn test_normalized_distance_same() {
let p = ParameterPoint::default();
let dist = normalized_distance(&p, &p);
assert!(dist < 1e-6, "Distance to self should be 0");
}
#[test]
fn test_normalized_distance_different() {
let mut p1 = ParameterPoint::default();
let mut p2 = ParameterPoint::default();
p1.pagerank_alpha = 0.7;
p2.pagerank_alpha = 0.95;
let dist = normalized_distance(&p1, &p2);
assert!(dist > 0.0, "Different points should have distance > 0");
assert!(dist < 1.0, "Normalized distance should be < 1");
}
}