mod context;
pub use context::*;
#[cfg(test)]
mod tests;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct FunctionFeatures {
pub has_short_prefix: bool, pub has_test_prefix: bool, pub has_handler_suffix: bool, pub has_internal_prefix: bool, pub is_capitalized: bool,
pub is_go_exported: bool, pub is_go_internal: bool, pub is_js_export: bool, pub is_js_arrow_handler: bool, pub is_python_dunder: bool, pub is_python_private: bool,
pub in_test_path: bool, pub in_util_path: bool, pub in_handler_path: bool, pub in_internal_path: bool,
pub fan_in_ratio: f64, pub fan_out_ratio: f64, pub caller_file_spread: f64,
#[allow(dead_code)] pub complexity_ratio: f64, #[allow(dead_code)] pub loc_ratio: f64, #[allow(dead_code)] pub param_count_ratio: f64,
pub address_taken: bool,
pub is_high_fan_in: bool,
pub file_context: FileContext,
}
pub struct FunctionMetrics<'a> {
pub name: &'a str,
pub file_path: &'a str,
pub fan_in: usize,
pub fan_out: usize,
pub max_fan_in: usize,
pub max_fan_out: usize,
pub caller_files: usize,
pub complexity: Option<i64>,
pub avg_complexity: f64,
pub loc: u32,
pub avg_loc: f64,
pub param_count: usize,
pub avg_params: f64,
pub address_taken: bool,
}
impl FunctionFeatures {
pub fn extract(m: &FunctionMetrics<'_>) -> Self {
let name = m.name;
let file_path = m.file_path;
let fan_in = m.fan_in;
let fan_out = m.fan_out;
let max_fan_in = m.max_fan_in;
let max_fan_out = m.max_fan_out;
let caller_files = m.caller_files;
let complexity = m.complexity;
let avg_complexity = m.avg_complexity;
let loc = m.loc;
let avg_loc = m.avg_loc;
let param_count = m.param_count;
let avg_params = m.avg_params;
let address_taken = m.address_taken;
let name_lower = name.to_lowercase();
let path_lower = file_path.to_lowercase();
let is_go = path_lower.ends_with(".go");
let is_js = path_lower.ends_with(".js")
|| path_lower.ends_with(".jsx")
|| path_lower.ends_with(".ts")
|| path_lower.ends_with(".tsx");
let is_python = path_lower.ends_with(".py");
let is_c = path_lower.ends_with(".c")
|| path_lower.ends_with(".h")
|| path_lower.ends_with(".cpp")
|| path_lower.ends_with(".hpp");
let first_char = name.chars().next();
let is_go_exported = is_go && first_char.map(|c| c.is_uppercase()).unwrap_or(false);
let is_go_internal = is_go && first_char.map(|c| c.is_lowercase()).unwrap_or(false);
let is_js_handler = is_js
&& (name_lower.starts_with("on") || name_lower.starts_with("handle") || name_lower.ends_with("handler") ||
name_lower.ends_with("callback") ||
name_lower.ends_with("listener"));
let is_python_dunder = is_python && name.starts_with("__") && name.ends_with("__");
let is_python_private = is_python && name.starts_with('_') && !name.starts_with("__");
let has_test_prefix = name_lower.starts_with("test_")
|| name_lower.starts_with("test") || name_lower.starts_with("spec_")
|| name_lower.starts_with("it_")
|| (is_go && name.starts_with("Test")) || (is_js && (name_lower.starts_with("it(") || name_lower.starts_with("describe(")));
let has_handler_suffix = name_lower.ends_with("_cb")
|| name_lower.ends_with("_callback")
|| name_lower.ends_with("_handler")
|| name_lower.ends_with("_hook")
|| name_lower.ends_with("_fn")
|| (is_go && name.ends_with("Handler")) || (is_go && name.ends_with("Func")) || is_js_handler;
let in_util_path = path_lower.contains("/util")
|| path_lower.contains("/utils")
|| path_lower.contains("/common")
|| path_lower.contains("/helper")
|| path_lower.contains("/helpers")
|| path_lower.contains("/lib/")
|| path_lower.contains("/shared")
|| path_lower.contains("/core/")
|| (is_js && path_lower.contains("/src/")) || path_lower.contains("utils.")
|| path_lower.contains("helpers.");
let in_test_path = path_lower.contains("/test")
|| path_lower.contains("/tests")
|| path_lower.contains("_test.")
|| path_lower.contains(".test.")
|| path_lower.contains(".spec.")
|| path_lower.contains("/spec")
|| path_lower.contains("/__tests__") || path_lower.contains("/__mocks__");
Self {
has_short_prefix: is_c && Self::has_short_prefix(name), has_test_prefix,
has_handler_suffix,
has_internal_prefix: name.starts_with('_') && !name.starts_with("__"),
is_capitalized: first_char.map(|c| c.is_uppercase()).unwrap_or(false),
is_go_exported,
is_go_internal,
is_js_export: is_js && in_util_path, is_js_arrow_handler: is_js_handler,
is_python_dunder,
is_python_private,
in_test_path,
in_util_path,
in_handler_path: path_lower.contains("/handler")
|| path_lower.contains("/callback")
|| path_lower.contains("/hook")
|| path_lower.contains("/hooks")
|| path_lower.contains("/events"),
in_internal_path: path_lower.contains("/internal")
|| path_lower.contains("/private")
|| path_lower.contains("/_")
|| (is_go && path_lower.contains("/pkg/")),
fan_in_ratio: if max_fan_in > 0 {
fan_in as f64 / max_fan_in as f64
} else {
0.0
},
fan_out_ratio: if max_fan_out > 0 {
fan_out as f64 / max_fan_out as f64
} else {
0.0
},
caller_file_spread: if fan_in > 0 {
caller_files as f64 / fan_in as f64
} else {
0.0
},
complexity_ratio: complexity
.map(|c| c as f64 / avg_complexity.max(1.0))
.unwrap_or(1.0),
loc_ratio: loc as f64 / avg_loc.max(1.0),
param_count_ratio: param_count as f64 / avg_params.max(1.0),
address_taken,
is_high_fan_in: fan_in > 10,
file_context: FileContext::from_path(file_path),
}
}
fn has_short_prefix(name: &str) -> bool {
if let Some(underscore_pos) = name.find('_') {
if (2..=4).contains(&underscore_pos) {
let prefix = &name[..underscore_pos];
if prefix.chars().all(|c| c.is_alphanumeric()) {
let prefix_lower = prefix.to_lowercase();
const COMMON_WORDS: &[&str] = &[
"get", "set", "is", "do", "can", "has", "new", "old", "add", "del", "pop",
"put", "run", "try", "end", "use", "for", "the", "and", "not", "dead",
"live", "test", "mock", "fake", "stub", "temp", "tmp", "foo", "bar", "baz",
"qux", "call", "read", "load", "save", "send", "recv",
];
return !COMMON_WORDS.contains(&prefix_lower.as_str());
}
}
}
false
}
pub fn to_vector(&self) -> [f64; 20] {
[
self.has_short_prefix as u8 as f64,
self.has_test_prefix as u8 as f64,
self.has_handler_suffix as u8 as f64,
self.has_internal_prefix as u8 as f64,
self.is_capitalized as u8 as f64,
self.is_go_exported as u8 as f64,
self.is_go_internal as u8 as f64,
self.is_js_export as u8 as f64,
self.is_js_arrow_handler as u8 as f64,
self.is_python_dunder as u8 as f64,
self.is_python_private as u8 as f64,
self.in_test_path as u8 as f64,
self.in_util_path as u8 as f64,
self.in_handler_path as u8 as f64,
self.in_internal_path as u8 as f64,
self.fan_in_ratio,
self.fan_out_ratio,
self.caller_file_spread,
self.address_taken as u8 as f64,
self.is_high_fan_in as u8 as f64,
]
}
#[allow(clippy::nonminimal_bool)]
pub fn looks_like_utility(&self) -> bool {
(self.has_short_prefix && self.is_high_fan_in)
|| (self.is_go_exported && self.is_high_fan_in)
|| (self.is_go_exported && self.in_util_path)
|| (self.in_util_path && self.is_high_fan_in)
|| (self.fan_in_ratio > 0.3 && self.caller_file_spread > 0.5)
|| self.fan_in_ratio > 0.2
}
pub fn looks_like_handler(&self) -> bool {
self.has_handler_suffix
|| self.is_js_arrow_handler
|| self.address_taken
|| self.in_handler_path
}
pub fn looks_like_test(&self) -> bool {
self.has_test_prefix
|| self.in_test_path
|| matches!(self.file_context, FileContext::TestFile)
}
pub fn looks_like_internal(&self) -> bool {
self.has_internal_prefix
|| self.is_go_internal
|| self.is_python_private
|| self.in_internal_path
}
}
const NUM_FEATURES: usize = 20;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextHMM {
pub initial: [f64; 5],
pub transition: [[f64; 5]; 5],
pub emission_mean: [[f64; NUM_FEATURES]; 5],
pub emission_var: [[f64; NUM_FEATURES]; 5],
}
impl Default for ContextHMM {
fn default() -> Self {
Self::new()
}
}
impl ContextHMM {
pub fn new() -> Self {
let initial = [0.15, 0.10, 0.50, 0.20, 0.05];
let transition = [
[0.60, 0.10, 0.15, 0.10, 0.05],
[0.10, 0.50, 0.20, 0.15, 0.05],
[0.10, 0.10, 0.55, 0.20, 0.05],
[0.15, 0.10, 0.25, 0.45, 0.05],
[0.05, 0.05, 0.10, 0.05, 0.75],
];
let emission_mean = [
[
0.5, 0.0, 0.1, 0.1, 0.5, 0.6, 0.2, 0.4, 0.1, 0.1, 0.1, 0.0, 0.7, 0.0, 0.1, 0.7,
0.3, 0.7, 0.2, 0.8,
],
[
0.2, 0.0, 0.8, 0.1, 0.3, 0.3, 0.3, 0.2, 0.8, 0.1, 0.1, 0.0, 0.1, 0.8, 0.1, 0.3,
0.4, 0.4, 0.8, 0.3,
],
[
0.1, 0.0, 0.1, 0.1, 0.4, 0.4, 0.4, 0.3, 0.1, 0.1, 0.1, 0.0, 0.1, 0.1, 0.1, 0.3,
0.4, 0.4, 0.1, 0.3,
],
[
0.1, 0.0, 0.1, 0.7, 0.2, 0.1, 0.7, 0.1, 0.1, 0.1, 0.6, 0.0, 0.1, 0.0, 0.7, 0.1,
0.3, 0.3, 0.1, 0.1,
],
[
0.0, 0.9, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.9, 0.0, 0.0, 0.0, 0.1,
0.5, 0.2, 0.0, 0.1,
],
];
let emission_var = [
[
0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.3, 0.2, 0.3, 0.3, 0.3, 0.3, 0.1, 0.3, 0.3, 0.1,
0.2, 0.1, 0.2, 0.1,
],
[
0.3, 0.3, 0.1, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.3, 0.3, 0.3, 0.3, 0.1, 0.3, 0.2,
0.2, 0.2, 0.1, 0.2,
],
[
0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
0.3, 0.3, 0.3, 0.3,
],
[
0.3, 0.3, 0.3, 0.1, 0.3, 0.3, 0.1, 0.3, 0.3, 0.3, 0.1, 0.3, 0.3, 0.3, 0.1, 0.2,
0.2, 0.2, 0.2, 0.2,
],
[
0.3, 0.05, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.05, 0.3, 0.3, 0.3, 0.2,
0.2, 0.2, 0.2, 0.2,
],
];
Self {
initial,
transition,
emission_mean,
emission_var,
}
}
pub fn classify(&self, features: &FunctionFeatures) -> FunctionContext {
let vec = features.to_vector();
let mut best_state = FunctionContext::Core;
let mut best_prob = f64::NEG_INFINITY;
for state in FunctionContext::ALL {
let prob = self.log_emission_prob(state, &vec) + self.initial[state.index()].ln();
if prob > best_prob {
best_prob = prob;
best_state = state;
}
}
best_state
}
#[allow(dead_code)] pub fn classify_sequence(&self, features: &[FunctionFeatures]) -> Vec<FunctionContext> {
if features.is_empty() {
return vec![];
}
let n = features.len();
let n_states = 5;
let mut viterbi = vec![[f64::NEG_INFINITY; 5]; n];
let mut backpointer = vec![[0usize; 5]; n];
let first_vec = features[0].to_vector();
for s in 0..n_states {
viterbi[0][s] = self.initial[s].ln()
+ self.log_emission_prob(FunctionContext::from_index(s), &first_vec);
}
for t in 1..n {
let vec = features[t].to_vector();
for s in 0..n_states {
let emission = self.log_emission_prob(FunctionContext::from_index(s), &vec);
for prev_s in 0..n_states {
let prob = viterbi[t - 1][prev_s] + self.transition[prev_s][s].ln() + emission;
if prob > viterbi[t][s] {
viterbi[t][s] = prob;
backpointer[t][s] = prev_s;
}
}
}
}
let mut best_last = 0;
for s in 1..n_states {
if viterbi[n - 1][s] > viterbi[n - 1][best_last] {
best_last = s;
}
}
let mut path = vec![FunctionContext::Core; n];
path[n - 1] = FunctionContext::from_index(best_last);
for t in (0..n - 1).rev() {
path[t] = FunctionContext::from_index(backpointer[t + 1][path[t + 1].index()]);
}
path
}
fn log_emission_prob(&self, state: FunctionContext, features: &[f64; NUM_FEATURES]) -> f64 {
let state_idx = state.index();
let mut log_prob = 0.0;
for feat in 0..NUM_FEATURES {
let mean = self.emission_mean[state_idx][feat];
let var = self.emission_var[state_idx][feat].max(0.01); let val = features[feat];
log_prob += -0.5 * ((val - mean).powi(2) / var + var.ln());
}
log_prob
}
pub fn update(&mut self, examples: &[(FunctionFeatures, FunctionContext)]) {
if examples.is_empty() {
return;
}
let mut state_counts = [0.0f64; 5];
let mut feature_sums = [[0.0f64; NUM_FEATURES]; 5];
let mut feature_sq_sums = [[0.0f64; NUM_FEATURES]; 5];
for (features, context) in examples {
let ctx_idx = context.index();
state_counts[ctx_idx] += 1.0;
let vec = features.to_vector();
for feat in 0..NUM_FEATURES {
feature_sums[ctx_idx][feat] += vec[feat];
feature_sq_sums[ctx_idx][feat] += vec[feat] * vec[feat];
}
}
let total: f64 = state_counts.iter().sum();
for state in 0..5 {
self.initial[state] = (state_counts[state] + 1.0) / (total + 5.0); }
for state in 0..5 {
if state_counts[state] > 0.0 {
for feat in 0..NUM_FEATURES {
let count = state_counts[state];
let mean = feature_sums[state][feat] / count;
let var = (feature_sq_sums[state][feat] / count - mean * mean).max(0.01);
self.emission_mean[state][feat] = mean;
self.emission_var[state][feat] = var;
}
}
}
}
pub fn bootstrap_from_graph(
&mut self,
function_data: &[(FunctionFeatures, usize, usize, bool)],
) {
let mut examples = Vec::new();
for (features, _fan_in, _fan_out, _address_taken) in function_data {
let context = if features.looks_like_test() {
FunctionContext::Test
} else if features.looks_like_handler() {
FunctionContext::Handler
} else if features.looks_like_utility() {
FunctionContext::Utility
} else if features.looks_like_internal() {
FunctionContext::Internal
} else {
FunctionContext::Core
};
examples.push((features.clone(), context));
}
self.update(&examples);
}
#[allow(dead_code)] pub fn em_refine(
&mut self,
function_data: &[(FunctionFeatures, usize, usize, bool)],
iterations: usize,
) {
for _iter in 0..iterations {
let mut examples = Vec::new();
for (features, _, _, _) in function_data {
let (context, confidence) = self.classify_with_confidence(features);
if confidence > 0.7 {
examples.push((features.clone(), context));
}
}
if examples.len() > function_data.len() / 4 {
self.update(&examples);
}
}
}
#[allow(dead_code)] pub fn classify_with_confidence(&self, features: &FunctionFeatures) -> (FunctionContext, f64) {
let vec = features.to_vector();
let mut log_probs = [0.0f64; 5];
for (s, log_prob) in log_probs.iter_mut().enumerate() {
*log_prob =
self.initial[s].ln() + self.log_emission_prob(FunctionContext::from_index(s), &vec);
}
let max_log = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
let mut best_state = 0;
let mut best_prob = f64::NEG_INFINITY;
for (s, &lp) in log_probs.iter().enumerate() {
if lp > best_prob {
best_prob = lp;
best_state = s;
}
}
let confidence = (best_prob - max_log).exp() / sum_exp;
(FunctionContext::from_index(best_state), confidence)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CRFWeights {
pub feature_weights: [[f64; NUM_FEATURES]; 5],
pub transition_weights: [[f64; 5]; 5],
}
impl Default for CRFWeights {
fn default() -> Self {
Self::new()
}
}
impl CRFWeights {
pub fn new() -> Self {
let mut feature_weights = [[0.0; NUM_FEATURES]; 5];
feature_weights[0][12] = 3.0; feature_weights[0][15] = 3.0; feature_weights[0][5] = 2.0; feature_weights[0][19] = 4.0; feature_weights[0][0] = 2.0; feature_weights[0][17] = 2.0;
feature_weights[1][2] = 3.0; feature_weights[1][8] = 2.5; feature_weights[1][18] = 2.0; feature_weights[1][13] = 1.5;
feature_weights[2][4] = 0.5;
feature_weights[3][3] = 2.0; feature_weights[3][6] = 2.0; feature_weights[3][10] = 2.0; feature_weights[3][14] = 1.5;
feature_weights[4][1] = 4.0; feature_weights[4][11] = 4.0;
let mut transition_weights = [[0.0; 5]; 5];
for i in 0..5 {
transition_weights[i][i] = 1.0; }
transition_weights[4][4] = 2.0;
Self {
feature_weights,
transition_weights,
}
}
pub fn score(&self, features: &FunctionFeatures, context: FunctionContext) -> f64 {
let vec = features.to_vector();
let ctx_idx = context.index();
let mut score = 0.0;
for (feat_idx, &val) in vec.iter().enumerate() {
score += self.feature_weights[ctx_idx][feat_idx] * val;
}
score
}
pub fn train(&mut self, examples: &[(FunctionFeatures, FunctionContext)], learning_rate: f64) {
for (features, true_context) in examples {
let predicted = self.predict(features);
if predicted != *true_context {
let vec = features.to_vector();
let true_idx = true_context.index();
let pred_idx = predicted.index();
for (i, &v) in vec.iter().enumerate() {
self.feature_weights[true_idx][i] += learning_rate * v;
self.feature_weights[pred_idx][i] -= learning_rate * v;
}
}
}
}
pub fn predict(&self, features: &FunctionFeatures) -> FunctionContext {
let mut best_score = f64::NEG_INFINITY;
let mut best_context = FunctionContext::Core;
for s in 0..5 {
let context = FunctionContext::from_index(s);
let score = self.score(features, context);
if score > best_score {
best_score = score;
best_context = context;
}
}
best_context
}
}
pub struct ContextClassifier {
hmm: ContextHMM,
crf: CRFWeights,
cache: HashMap<String, FunctionContext>,
hmm_weight: f64,
}
impl ContextClassifier {
pub fn new() -> Self {
let hmm_weight = std::env::var("REPOTOIRE_HMM_WEIGHT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0.9);
Self {
hmm: ContextHMM::new(),
crf: CRFWeights::new(),
cache: HashMap::new(),
hmm_weight,
}
}
pub fn for_codebase(cache_path: Option<&std::path::Path>) -> Self {
let hmm = if let Some(path) = cache_path {
if path.exists() {
std::fs::read_to_string(path)
.ok()
.and_then(|s| serde_json::from_str(&s).ok())
.unwrap_or_default()
} else {
ContextHMM::new()
}
} else {
ContextHMM::new()
};
Self {
hmm,
crf: CRFWeights::new(),
cache: HashMap::new(),
hmm_weight: 0.9,
}
}
pub fn classify(&mut self, name: &str, features: &FunctionFeatures) -> FunctionContext {
if let Some(&cached) = self.cache.get(name) {
return cached;
}
if let Some(file_bias) = features.file_context.function_bias() {
self.cache.insert(name.to_string(), file_bias);
return file_bias;
}
let context = if self.hmm_weight < 1.0 {
self.ensemble_classify(features)
} else {
self.hmm.classify(features)
};
self.cache.insert(name.to_string(), context);
context
}
pub fn classify_with_confidence(
&mut self,
name: &str,
features: &FunctionFeatures,
) -> (FunctionContext, f64) {
if let Some(file_bias) = features.file_context.function_bias() {
self.cache.insert(name.to_string(), file_bias);
return (file_bias, 1.0);
}
let (context, confidence) = if self.hmm_weight < 1.0 {
self.ensemble_classify_with_confidence(features)
} else {
self.hmm.classify_with_confidence(features)
};
self.cache.insert(name.to_string(), context);
(context, confidence)
}
fn ensemble_classify(&self, features: &FunctionFeatures) -> FunctionContext {
let vec = features.to_vector();
let mut scores = [0.0f64; 5];
let mut hmm_log_probs = [0.0f64; 5];
for s in 0..5 {
let ctx = FunctionContext::from_index(s);
hmm_log_probs[s] = self.hmm.initial[s].ln() + self.hmm.log_emission_prob(ctx, &vec);
}
let hmm_max = hmm_log_probs
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let hmm_sum: f64 = hmm_log_probs.iter().map(|&lp| (lp - hmm_max).exp()).sum();
let mut crf_scores = [0.0f64; 5];
for s in 0..5 {
crf_scores[s] = self.crf.score(features, FunctionContext::from_index(s));
}
let crf_max = crf_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let crf_sum: f64 = crf_scores.iter().map(|&sc| (sc - crf_max).exp()).sum();
for s in 0..5 {
let hmm_prob = (hmm_log_probs[s] - hmm_max).exp() / hmm_sum;
let crf_prob = (crf_scores[s] - crf_max).exp() / crf_sum;
scores[s] = self.hmm_weight * hmm_prob + (1.0 - self.hmm_weight) * crf_prob;
}
let mut best_idx = 0;
for s in 1..5 {
if scores[s] > scores[best_idx] {
best_idx = s;
}
}
FunctionContext::from_index(best_idx)
}
fn ensemble_classify_with_confidence(
&self,
features: &FunctionFeatures,
) -> (FunctionContext, f64) {
let vec = features.to_vector();
let mut scores = [0.0f64; 5];
let mut hmm_log_probs = [0.0f64; 5];
for s in 0..5 {
let ctx = FunctionContext::from_index(s);
hmm_log_probs[s] = self.hmm.initial[s].ln() + self.hmm.log_emission_prob(ctx, &vec);
}
let hmm_max = hmm_log_probs
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let hmm_sum: f64 = hmm_log_probs.iter().map(|&lp| (lp - hmm_max).exp()).sum();
let mut crf_scores = [0.0f64; 5];
for s in 0..5 {
crf_scores[s] = self.crf.score(features, FunctionContext::from_index(s));
}
let crf_max = crf_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let crf_sum: f64 = crf_scores.iter().map(|&sc| (sc - crf_max).exp()).sum();
for s in 0..5 {
let hmm_prob = (hmm_log_probs[s] - hmm_max).exp() / hmm_sum;
let crf_prob = (crf_scores[s] - crf_max).exp() / crf_sum;
scores[s] = self.hmm_weight * hmm_prob + (1.0 - self.hmm_weight) * crf_prob;
}
let mut best_idx = 0;
for s in 1..5 {
if scores[s] > scores[best_idx] {
best_idx = s;
}
}
let total: f64 = scores.iter().sum();
let confidence = if total > 0.0 {
scores[best_idx] / total
} else {
0.2
};
(FunctionContext::from_index(best_idx), confidence)
}
pub fn train(&mut self, function_data: &[(FunctionFeatures, usize, usize, bool)]) {
self.hmm.bootstrap_from_graph(function_data);
if self.hmm_weight < 1.0
&& std::env::var("REPOTOIRE_ENABLE_CRF_SELF_TRAIN")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
{
let examples: Vec<_> = function_data
.iter()
.map(|(features, _, _, _)| {
let ctx = self.hmm.classify(features);
(features.clone(), ctx)
})
.collect();
self.crf.train(&examples, 0.05);
}
self.cache.clear();
}
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
let combined = serde_json::json!({
"hmm": self.hmm,
"crf": self.crf,
"hmm_weight": self.hmm_weight,
});
let json = serde_json::to_string_pretty(&combined)?;
std::fs::write(path, json)
}
pub fn load(path: &std::path::Path) -> Option<Self> {
let content = std::fs::read_to_string(path).ok()?;
let value: serde_json::Value = serde_json::from_str(&content).ok()?;
let hmm: ContextHMM = serde_json::from_value(value.get("hmm")?.clone()).ok()?;
let crf: CRFWeights = serde_json::from_value(value.get("crf")?.clone()).ok()?;
let hmm_weight = value.get("hmm_weight")?.as_f64()?;
Some(Self {
hmm,
crf,
cache: HashMap::new(),
hmm_weight,
})
}
}
impl Default for ContextClassifier {
fn default() -> Self {
Self::new()
}
}