use std::collections::HashMap;
use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ContextItemId(pub String);
impl ContextItemId {
pub fn from_file(path: &str) -> Self {
Self(format!("file:{path}"))
}
pub fn from_shell(command: &str) -> Self {
let hash = crate::core::project_hash::hash_project_root(command);
Self(format!("shell:{hash}"))
}
pub fn from_knowledge(category: &str, key: &str) -> Self {
Self(format!("knowledge:{category}:{key}"))
}
pub fn from_memory(key: &str) -> Self {
Self(format!("memory:{key}"))
}
pub fn from_provider(provider: &str, key: &str) -> Self {
Self(format!("provider:{provider}:{key}"))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for ContextItemId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContextKind {
File,
Shell,
Knowledge,
Memory,
Provider,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum ContextState {
#[default]
Candidate,
Included,
Excluded,
Pinned,
Stale,
Shadowed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ViewKind {
Full,
Signatures,
Map,
Diff,
Aggressive,
Entropy,
Lines,
Reference,
Handle,
}
impl ViewKind {
pub fn as_str(&self) -> &'static str {
match self {
Self::Full => "full",
Self::Signatures => "signatures",
Self::Map => "map",
Self::Diff => "diff",
Self::Aggressive => "aggressive",
Self::Entropy => "entropy",
Self::Lines => "lines",
Self::Reference => "reference",
Self::Handle => "handle",
}
}
pub fn parse(s: &str) -> Self {
match s.trim().to_lowercase().as_str() {
"signatures" => Self::Signatures,
"map" => Self::Map,
"diff" => Self::Diff,
"aggressive" => Self::Aggressive,
"entropy" => Self::Entropy,
"lines" => Self::Lines,
"reference" => Self::Reference,
"handle" => Self::Handle,
_ => Self::Full,
}
}
pub fn density_rank(&self) -> u8 {
match self {
Self::Full => 0,
Self::Aggressive => 1,
Self::Diff => 2,
Self::Lines => 3,
Self::Entropy => 4,
Self::Signatures => 5,
Self::Map => 6,
Self::Reference => 7,
Self::Handle => 8,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ViewCosts {
pub estimates: HashMap<ViewKind, usize>,
}
impl ViewCosts {
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, view: ViewKind, tokens: usize) {
self.estimates.insert(view, tokens);
}
pub fn get(&self, view: &ViewKind) -> usize {
self.estimates.get(view).copied().unwrap_or(0)
}
pub fn cheapest_content_view(&self) -> Option<(ViewKind, usize)> {
self.estimates
.iter()
.filter(|(v, _)| **v != ViewKind::Handle)
.min_by_key(|(_, &tokens)| tokens)
.map(|(&v, &t)| (v, t))
}
pub fn from_full_tokens(full_tokens: usize) -> Self {
let mut vc = Self::new();
vc.set(ViewKind::Full, full_tokens);
vc.set(ViewKind::Signatures, full_tokens / 5);
vc.set(ViewKind::Map, full_tokens / 8);
vc.set(ViewKind::Reference, full_tokens / 20);
vc.set(ViewKind::Handle, 25);
vc
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Provenance {
pub tool: Option<String>,
pub agent_id: Option<String>,
pub client_name: Option<String>,
pub timestamp: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldWeights {
pub w_relevance: f64,
pub w_surprise: f64,
pub w_graph: f64,
pub w_history: f64,
pub w_cost: f64,
pub w_redundancy: f64,
}
impl Default for FieldWeights {
fn default() -> Self {
Self {
w_relevance: 0.35,
w_surprise: 0.15,
w_graph: 0.20,
w_history: 0.10,
w_cost: 0.10,
w_redundancy: 0.10,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FieldSignals {
pub relevance: f64,
pub surprise: f64,
pub graph_proximity: f64,
pub history_signal: f64,
pub token_cost_norm: f64,
pub redundancy: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldPotential {
pub signals: FieldSignals,
pub phi: f64,
pub view_costs: ViewCosts,
pub best_view: ViewKind,
}
#[derive(Debug, Clone, Copy)]
pub struct TokenBudget {
pub total: usize,
pub used: usize,
}
impl TokenBudget {
pub fn remaining(&self) -> usize {
self.total.saturating_sub(self.used)
}
pub fn utilization(&self) -> f64 {
if self.total == 0 {
return 1.0;
}
self.used as f64 / self.total as f64
}
pub fn temperature(&self) -> f64 {
let u = self.utilization();
(0.1 + u * 1.9).clamp(0.1, 2.0)
}
}
pub struct ContextField {
weights: FieldWeights,
}
impl Default for ContextField {
fn default() -> Self {
Self::new()
}
}
impl ContextField {
pub fn new() -> Self {
Self {
weights: FieldWeights::default(),
}
}
pub fn with_weights(weights: FieldWeights) -> Self {
Self { weights }
}
pub fn compute_phi(&self, signals: &FieldSignals) -> f64 {
let w = &self.weights;
let phi = w.w_relevance * signals.relevance
+ w.w_surprise * signals.surprise
+ w.w_graph * signals.graph_proximity
+ w.w_history * signals.history_signal
- w.w_cost * signals.token_cost_norm
- w.w_redundancy * signals.redundancy;
phi.clamp(0.0, 1.0)
}
pub fn select_view(&self, costs: &ViewCosts, temperature: f64) -> ViewKind {
if costs.estimates.is_empty() {
return ViewKind::Full;
}
let t = temperature.max(0.01);
let max_cost = costs.estimates.values().copied().max().unwrap_or(1).max(1) as f64;
let mut best_view = ViewKind::Full;
let mut best_score = f64::NEG_INFINITY;
for (&view, &tokens) in &costs.estimates {
let normalized_cost = tokens as f64 / max_cost;
let density_bonus = 1.0 - (view.density_rank() as f64 / 8.0);
let score = density_bonus * (2.0 - t) - normalized_cost * t;
if score > best_score {
best_score = score;
best_view = view;
}
}
best_view
}
pub fn compute_batch(
&self,
items: &[(ContextItemId, FieldSignals, ViewCosts)],
budget: TokenBudget,
) -> HashMap<ContextItemId, FieldPotential> {
let temperature = budget.temperature();
let mut result = HashMap::new();
for (id, signals, costs) in items {
let phi = self.compute_phi(signals);
let best_view = self.select_view(costs, temperature);
result.insert(
id.clone(),
FieldPotential {
signals: signals.clone(),
phi,
view_costs: costs.clone(),
best_view,
},
);
}
result
}
}
pub fn normalize_relevance(score: f64, max_score: f64) -> f64 {
if max_score <= 0.0 {
return 0.0;
}
(score / max_score).clamp(0.0, 1.0)
}
pub fn normalize_surprise(surprise: f64) -> f64 {
((surprise - 5.0) / 12.0).clamp(0.0, 1.0)
}
pub fn normalize_graph_proximity(distance: usize) -> f64 {
1.0 / (1.0 + distance as f64)
}
pub fn normalize_token_cost(tokens: usize, budget_total: usize) -> f64 {
if budget_total == 0 {
return 1.0;
}
(tokens as f64 / budget_total as f64).clamp(0.0, 1.0)
}
pub fn efficiency(phi: f64, tokens: usize) -> f64 {
if tokens == 0 {
return phi;
}
phi / tokens as f64
}
pub fn compute_signals_for_path(
path: &str,
task: Option<&str>,
file_content: Option<&str>,
budget_total: usize,
full_tokens: usize,
) -> (FieldSignals, ViewCosts) {
let mut signals = FieldSignals::default();
let heatmap = super::heatmap::HeatMap::load();
let heat_entry = heatmap.entries.get(path);
if let Some(task_desc) = task {
let (_, keywords) = super::task_relevance::parse_task_hints(task_desc);
let path_lower = path.to_lowercase();
let keyword_hits = keywords
.iter()
.filter(|kw| path_lower.contains(&kw.to_lowercase()))
.count();
let keyword_score = (keyword_hits as f64 * 0.3).min(1.0);
let freq_score = heat_entry.map_or(0.0, |e| (e.access_count as f64 / 10.0).min(1.0));
signals.relevance = normalize_relevance(keyword_score + freq_score, 2.0);
} else {
let freq = heat_entry.map_or(0.0, |e| e.access_count as f64);
signals.relevance = normalize_relevance(freq, 10.0);
}
if let Some(content) = file_content {
let surprise_val = super::surprise::line_surprise(content);
signals.surprise = normalize_surprise(surprise_val);
}
let depth = path.matches('/').count();
signals.graph_proximity = normalize_graph_proximity(depth);
let access_count = heat_entry.map_or(0, |e| e.access_count);
signals.history_signal = (access_count as f64 / 20.0).min(1.0);
signals.token_cost_norm = normalize_token_cost(full_tokens, budget_total);
signals.redundancy = 0.0;
let view_costs = ViewCosts::from_full_tokens(full_tokens);
(signals, view_costs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn phi_increases_with_relevance() {
let field = ContextField::new();
let low = field.compute_phi(&FieldSignals {
relevance: 0.2,
..Default::default()
});
let high = field.compute_phi(&FieldSignals {
relevance: 0.9,
..Default::default()
});
assert!(high > low, "higher relevance should yield higher phi");
}
#[test]
fn phi_decreases_with_cost() {
let field = ContextField::new();
let cheap = field.compute_phi(&FieldSignals {
relevance: 0.5,
token_cost_norm: 0.1,
..Default::default()
});
let expensive = field.compute_phi(&FieldSignals {
relevance: 0.5,
token_cost_norm: 0.9,
..Default::default()
});
assert!(cheap > expensive, "higher cost should reduce phi");
}
#[test]
fn phi_decreases_with_redundancy() {
let field = ContextField::new();
let unique = field.compute_phi(&FieldSignals {
relevance: 0.5,
redundancy: 0.0,
..Default::default()
});
let redundant = field.compute_phi(&FieldSignals {
relevance: 0.5,
redundancy: 0.9,
..Default::default()
});
assert!(unique > redundant, "redundancy should reduce phi");
}
#[test]
fn phi_is_clamped_to_unit_interval() {
let field = ContextField::new();
let phi = field.compute_phi(&FieldSignals {
relevance: 1.0,
surprise: 1.0,
graph_proximity: 1.0,
history_signal: 1.0,
token_cost_norm: 0.0,
redundancy: 0.0,
});
assert!(phi <= 1.0);
assert!(phi >= 0.0);
}
#[test]
fn view_selection_prefers_dense_at_low_temperature() {
let field = ContextField::new();
let costs = ViewCosts::from_full_tokens(5000);
let view = field.select_view(&costs, 0.1);
assert_eq!(
view,
ViewKind::Full,
"low temperature (relaxed budget) should prefer full view"
);
}
#[test]
fn view_selection_prefers_sparse_at_high_temperature() {
let field = ContextField::new();
let costs = ViewCosts::from_full_tokens(5000);
let view = field.select_view(&costs, 2.0);
assert_ne!(
view,
ViewKind::Full,
"high temperature (tight budget) should prefer sparser view"
);
}
#[test]
fn budget_temperature_scales_with_utilization() {
let low = TokenBudget {
total: 10000,
used: 1000,
};
let high = TokenBudget {
total: 10000,
used: 9000,
};
assert!(
high.temperature() > low.temperature(),
"higher utilization should increase temperature"
);
}
#[test]
fn normalize_surprise_maps_range() {
assert!((normalize_surprise(5.0) - 0.0).abs() < 0.01);
assert!((normalize_surprise(17.0) - 1.0).abs() < 0.01);
assert!((normalize_surprise(11.0) - 0.5).abs() < 0.01);
}
#[test]
fn normalize_graph_proximity_inverse_distance() {
assert!((normalize_graph_proximity(0) - 1.0).abs() < f64::EPSILON);
assert!((normalize_graph_proximity(1) - 0.5).abs() < f64::EPSILON);
assert!(normalize_graph_proximity(10) < 0.15);
}
#[test]
fn efficiency_ratio_is_phi_per_token() {
let e = efficiency(0.8, 400);
assert!((e - 0.002).abs() < 0.0001);
}
#[test]
fn context_item_id_stable() {
let a = ContextItemId::from_file("src/main.rs");
let b = ContextItemId::from_file("src/main.rs");
assert_eq!(a, b);
}
#[test]
fn view_costs_from_full() {
let vc = ViewCosts::from_full_tokens(5000);
assert_eq!(vc.get(&ViewKind::Full), 5000);
assert_eq!(vc.get(&ViewKind::Signatures), 1000);
assert_eq!(vc.get(&ViewKind::Map), 625);
assert_eq!(vc.get(&ViewKind::Handle), 25);
}
#[test]
fn batch_compute_produces_results_for_all_items() {
let field = ContextField::new();
let items = vec![
(
ContextItemId::from_file("a.rs"),
FieldSignals {
relevance: 0.8,
..Default::default()
},
ViewCosts::from_full_tokens(2000),
),
(
ContextItemId::from_file("b.rs"),
FieldSignals {
relevance: 0.3,
..Default::default()
},
ViewCosts::from_full_tokens(500),
),
];
let budget = TokenBudget {
total: 10000,
used: 2000,
};
let results = field.compute_batch(&items, budget);
assert_eq!(results.len(), 2);
assert!(results.contains_key(&ContextItemId::from_file("a.rs")));
assert!(results.contains_key(&ContextItemId::from_file("b.rs")));
}
}