1use serde::{Deserialize, Serialize};
12
13use super::context_field::{ContextState, ViewKind};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ContextPolicy {
18 pub name: String,
19 #[serde(rename = "match")]
20 pub match_pattern: String,
21 pub action: PolicyAction,
22 #[serde(default)]
23 pub condition: Option<PolicyCondition>,
24 #[serde(default)]
25 pub reason: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum PolicyAction {
31 Exclude,
32 Include,
33 Pin,
34 SetView { view: String },
35 MaxTokens { limit: usize },
36 MarkOutdated,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum PolicyCondition {
42 SourceSeenBefore,
43 SourceModifiedRecently,
44 TokensAbove { threshold: usize },
45 Always,
46}
47
48#[derive(Debug, Clone, Default, Serialize, Deserialize)]
50pub struct PolicySet {
51 pub policies: Vec<ContextPolicy>,
52}
53
54impl PolicySet {
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn defaults() -> Self {
61 Self {
62 policies: vec![
63 ContextPolicy {
64 name: "never_include_secrets".to_string(),
65 match_pattern: "**/.env*".to_string(),
66 action: PolicyAction::Exclude,
67 condition: None,
68 reason: Some("secrets".to_string()),
69 },
70 ContextPolicy {
71 name: "exclude_private_keys".to_string(),
72 match_pattern: "**/*private_key*".to_string(),
73 action: PolicyAction::Exclude,
74 condition: None,
75 reason: Some("private key material".to_string()),
76 },
77 ContextPolicy {
78 name: "exclude_credentials".to_string(),
79 match_pattern: "**/credentials*".to_string(),
80 action: PolicyAction::Exclude,
81 condition: None,
82 reason: Some("credentials".to_string()),
83 },
84 ContextPolicy {
85 name: "delta_after_first_read".to_string(),
86 match_pattern: "src/**".to_string(),
87 action: PolicyAction::SetView {
88 view: "diff".to_string(),
89 },
90 condition: Some(PolicyCondition::SourceSeenBefore),
91 reason: Some("predictive coding: only send prediction errors".to_string()),
92 },
93 ContextPolicy {
94 name: "compress_large_files".to_string(),
95 match_pattern: "**/*".to_string(),
96 action: PolicyAction::SetView {
97 view: "signatures".to_string(),
98 },
99 condition: Some(PolicyCondition::TokensAbove { threshold: 8000 }),
100 reason: Some("large file budget protection".to_string()),
101 },
102 ],
103 }
104 }
105
106 pub fn evaluate(
108 &self,
109 path: &str,
110 seen_before: bool,
111 token_count: usize,
112 ) -> Vec<PolicyEvalResult> {
113 let mut results = Vec::new();
114 for policy in &self.policies {
115 if !path_matches(&policy.match_pattern, path) {
116 continue;
117 }
118 if let Some(ref condition) = policy.condition {
119 if !check_condition(condition, seen_before, token_count) {
120 continue;
121 }
122 }
123 results.push(PolicyEvalResult {
124 policy_name: policy.name.clone(),
125 action: policy.action.clone(),
126 reason: policy.reason.clone().unwrap_or_else(|| policy.name.clone()),
127 });
128 }
129 results
130 }
131
132 pub fn effective_state(
134 &self,
135 path: &str,
136 current: ContextState,
137 seen_before: bool,
138 token_count: usize,
139 ) -> ContextState {
140 let evals = self.evaluate(path, seen_before, token_count);
141 let mut state = current;
142 for eval in &evals {
143 match &eval.action {
144 PolicyAction::Exclude => state = ContextState::Excluded,
145 PolicyAction::Pin => state = ContextState::Pinned,
146 PolicyAction::Include => {
147 if state == ContextState::Candidate {
148 state = ContextState::Included;
149 }
150 }
151 PolicyAction::MarkOutdated => state = ContextState::Stale,
152 PolicyAction::MaxTokens { limit } => {
153 if token_count > *limit {
154 state = ContextState::Excluded;
155 }
156 }
157 PolicyAction::SetView { .. } => {}
158 }
159 }
160 state
161 }
162
163 pub fn recommended_view(
165 &self,
166 path: &str,
167 seen_before: bool,
168 token_count: usize,
169 ) -> Option<ViewKind> {
170 let evals = self.evaluate(path, seen_before, token_count);
171 for eval in evals.iter().rev() {
172 if let PolicyAction::SetView { view } = &eval.action {
173 return Some(ViewKind::parse(view));
174 }
175 }
176 None
177 }
178
179 pub fn load_project(project_root: &std::path::Path) -> Self {
181 let path = project_root.join(".lean-ctx").join("policies.json");
182 std::fs::read_to_string(&path)
183 .ok()
184 .and_then(|s| serde_json::from_str(&s).ok())
185 .unwrap_or_else(Self::defaults)
186 }
187
188 pub fn save_project(&self, project_root: &std::path::Path) -> Result<(), String> {
190 let dir = project_root.join(".lean-ctx");
191 std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
192 let path = dir.join("policies.json");
193 let json = serde_json::to_string_pretty(self).map_err(|e| e.to_string())?;
194 crate::config_io::write_atomic(&path, &json)
195 }
196}
197
198#[derive(Debug, Clone)]
199pub struct PolicyEvalResult {
200 pub policy_name: String,
201 pub action: PolicyAction,
202 pub reason: String,
203}
204
205fn path_matches(pattern: &str, path: &str) -> bool {
206 if pattern == "**/*" {
207 return true;
208 }
209
210 if let Some(suffix) = pattern.strip_prefix("**/") {
211 if suffix.contains('*') {
212 let inner = suffix.replace('*', "");
213 return path.contains(&inner);
214 }
215 return path.contains(suffix) || path.ends_with(suffix);
216 }
217
218 if let Some(prefix) = pattern.strip_suffix("/**") {
219 return path.starts_with(prefix);
220 }
221
222 if pattern.contains("**") {
223 let parts: Vec<&str> = pattern.split("**").collect();
224 if parts.len() == 2 {
225 return path.starts_with(parts[0]) && path.ends_with(parts[1]);
226 }
227 }
228
229 if let Some(prefix) = pattern.strip_suffix('*') {
230 return path.starts_with(prefix);
231 }
232
233 path == pattern || path.ends_with(pattern)
234}
235
236fn check_condition(condition: &PolicyCondition, seen_before: bool, token_count: usize) -> bool {
237 match condition {
238 PolicyCondition::SourceSeenBefore => seen_before,
239 PolicyCondition::TokensAbove { threshold } => token_count > *threshold,
240 PolicyCondition::SourceModifiedRecently | PolicyCondition::Always => true,
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn default_policies_exclude_env_files() {
250 let ps = PolicySet::defaults();
251 let results = ps.evaluate(".env", false, 100);
252 assert!(
253 results
254 .iter()
255 .any(|r| matches!(r.action, PolicyAction::Exclude)),
256 "should exclude .env files"
257 );
258 }
259
260 #[test]
261 fn default_policies_exclude_private_keys() {
262 let ps = PolicySet::defaults();
263 let results = ps.evaluate("secrets/private_key.pem", false, 100);
264 assert!(
265 results
266 .iter()
267 .any(|r| matches!(r.action, PolicyAction::Exclude)),
268 "should exclude private key files"
269 );
270 }
271
272 #[test]
273 fn delta_policy_only_when_seen_before() {
274 let ps = PolicySet::defaults();
275 let first = ps.evaluate("src/main.rs", false, 500);
276 let second = ps.evaluate("src/main.rs", true, 500);
277 assert!(
278 !first
279 .iter()
280 .any(|r| matches!(&r.action, PolicyAction::SetView { view } if view == "diff")),
281 "should NOT suggest diff on first read"
282 );
283 assert!(
284 second
285 .iter()
286 .any(|r| matches!(&r.action, PolicyAction::SetView { view } if view == "diff")),
287 "should suggest diff on subsequent read"
288 );
289 }
290
291 #[test]
292 fn large_file_policy_triggers_above_threshold() {
293 let ps = PolicySet::defaults();
294 let small = ps.evaluate("src/main.rs", false, 500);
295 let large = ps.evaluate("src/main.rs", false, 10000);
296 assert!(!small
297 .iter()
298 .any(|r| matches!(&r.action, PolicyAction::SetView { view } if view == "signatures")),);
299 assert!(large
300 .iter()
301 .any(|r| matches!(&r.action, PolicyAction::SetView { view } if view == "signatures")),);
302 }
303
304 #[test]
305 fn effective_state_excludes_secrets() {
306 let ps = PolicySet::defaults();
307 let state = ps.effective_state(".env.local", ContextState::Candidate, false, 100);
308 assert_eq!(state, ContextState::Excluded);
309 }
310
311 #[test]
312 fn recommended_view_for_seen_file() {
313 let ps = PolicySet::defaults();
314 let view = ps.recommended_view("src/main.rs", true, 500);
315 assert_eq!(view, Some(ViewKind::Diff));
316 }
317
318 #[test]
319 fn recommended_view_none_for_new_file() {
320 let ps = PolicySet::defaults();
321 let view = ps.recommended_view("src/main.rs", false, 500);
322 assert!(view.is_none() || view == Some(ViewKind::Diff),);
323 }
324
325 #[test]
326 fn path_matches_glob_patterns() {
327 assert!(path_matches("**/.env*", ".env"));
328 assert!(path_matches("**/.env*", ".env.local"));
329 assert!(path_matches("**/.env*", "config/.env.prod"));
330 assert!(path_matches("src/**", "src/main.rs"));
331 assert!(path_matches("src/**", "src/core/mod.rs"));
332 assert!(path_matches("**/*", "anything.txt"));
333 assert!(!path_matches("src/**", "tests/test.rs"));
334 }
335
336 #[test]
337 fn empty_policy_set_changes_nothing() {
338 let ps = PolicySet::new();
339 let state = ps.effective_state("src/main.rs", ContextState::Included, false, 100);
340 assert_eq!(state, ContextState::Included);
341 }
342
343 #[test]
344 fn custom_policy_works() {
345 let ps = PolicySet {
346 policies: vec![ContextPolicy {
347 name: "pin_readme".to_string(),
348 match_pattern: "README.md".to_string(),
349 action: PolicyAction::Pin,
350 condition: None,
351 reason: None,
352 }],
353 };
354 let state = ps.effective_state("README.md", ContextState::Candidate, false, 100);
355 assert_eq!(state, ContextState::Pinned);
356 }
357}