1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use super::context_field::{ContextItemId, ContextKind, ViewCosts, ViewKind};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ContextHandle {
23 pub ref_label: String,
24 pub item_id: ContextItemId,
25 pub kind: ContextKind,
26 pub source_path: String,
27 pub summary: String,
28 pub handle_tokens: usize,
29 pub available_views: Vec<(ViewKind, usize)>,
30 pub phi: f64,
31 pub pinned: bool,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct HandleRegistry {
37 handles: Vec<ContextHandle>,
38 counters: HashMap<ContextKind, usize>,
39}
40
41fn kind_prefix(kind: &ContextKind) -> &'static str {
46 match kind {
47 ContextKind::File => "F",
48 ContextKind::Shell => "S",
49 ContextKind::Knowledge => "K",
50 ContextKind::Memory => "M",
51 ContextKind::Provider => "P",
52 }
53}
54
55impl HandleRegistry {
60 pub fn new() -> Self {
61 Self {
62 handles: Vec::new(),
63 counters: HashMap::new(),
64 }
65 }
66
67 pub fn register(
71 &mut self,
72 item_id: ContextItemId,
73 kind: ContextKind,
74 source_path: &str,
75 summary: &str,
76 view_costs: &ViewCosts,
77 phi: f64,
78 pinned: bool,
79 ) -> &ContextHandle {
80 let counter = self.counters.entry(kind).or_insert(0);
81 *counter += 1;
82 let ref_label = format!("{}{}", kind_prefix(&kind), counter);
83
84 let available_views: Vec<(ViewKind, usize)> = {
85 let mut views: Vec<_> = view_costs
86 .estimates
87 .iter()
88 .filter(|(v, _)| **v != ViewKind::Handle)
89 .map(|(&v, &tokens)| (v, tokens))
90 .collect();
91 views.sort_by_key(|(v, _)| v.density_rank());
92 views
93 };
94
95 let handle_tokens = view_costs
96 .estimates
97 .get(&ViewKind::Handle)
98 .copied()
99 .unwrap_or_else(|| estimate_handle_tokens(source_path, summary));
100
101 let handle = ContextHandle {
102 ref_label,
103 item_id,
104 kind,
105 source_path: source_path.to_string(),
106 summary: summary.to_string(),
107 handle_tokens,
108 available_views,
109 phi,
110 pinned,
111 };
112
113 self.handles.push(handle);
114 self.handles.last().expect("just pushed")
115 }
116
117 pub fn resolve(&self, ref_label: &str) -> Option<&ContextHandle> {
121 let label = ref_label.strip_prefix('@').unwrap_or(ref_label);
122 self.handles.iter().find(|h| h.ref_label == label)
123 }
124
125 pub fn resolve_by_item(&self, item_id: &ContextItemId) -> Option<&ContextHandle> {
127 self.handles.iter().find(|h| h.item_id == *item_id)
128 }
129
130 pub fn all(&self) -> &[ContextHandle] {
132 &self.handles
133 }
134
135 pub fn total_handle_tokens(&self) -> usize {
137 self.handles.iter().map(|h| h.handle_tokens).sum()
138 }
139
140 pub fn format_manifest(&self, budget_total: usize, budget_used: usize) -> String {
142 if self.handles.is_empty() {
143 return String::new();
144 }
145
146 let mut lines = Vec::with_capacity(self.handles.len() + 3);
147 lines.push("Context Handles (expand with ctx_expand @ref):".to_string());
148
149 for h in &self.handles {
150 let best = h
151 .available_views
152 .first()
153 .map_or("full", |(v, _)| v.as_str());
154
155 let cheapest_tokens = h.available_views.iter().map(|(_, t)| *t).min().unwrap_or(0);
156
157 let pinned_tag = if h.pinned { " [pinned]" } else { "" };
158
159 lines.push(format!(
160 "@{} {} {} {}t phi={:.2}{}",
161 h.ref_label, h.source_path, best, cheapest_tokens, h.phi, pinned_tag,
162 ));
163 }
164
165 let remaining_pct = if budget_total > 0 {
166 ((budget_total.saturating_sub(budget_used)) as f64 / budget_total as f64) * 100.0
167 } else {
168 0.0
169 };
170
171 lines.push(String::new());
172 lines.push(format!(
173 "Budget: {budget_used}/{budget_total} tokens ({remaining_pct:.1}% remaining)",
174 ));
175
176 lines.join("\n")
177 }
178}
179
180impl Default for HandleRegistry {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186fn estimate_handle_tokens(source_path: &str, summary: &str) -> usize {
193 let chars = source_path.len() + summary.len() + 20; (chars / 4).clamp(5, 50)
195}
196
197#[cfg(test)]
202mod tests {
203 use super::*;
204
205 fn sample_view_costs(full_tokens: usize) -> ViewCosts {
206 ViewCosts::from_full_tokens(full_tokens)
207 }
208
209 #[test]
210 fn label_generation_sequential_per_kind() {
211 let mut reg = HandleRegistry::new();
212
213 let h1 = reg.register(
214 ContextItemId::from_file("a.ts"),
215 ContextKind::File,
216 "a.ts",
217 "module A",
218 &sample_view_costs(1000),
219 0.9,
220 false,
221 );
222 assert_eq!(h1.ref_label, "F1");
223
224 let h2 = reg.register(
225 ContextItemId::from_file("b.ts"),
226 ContextKind::File,
227 "b.ts",
228 "module B",
229 &sample_view_costs(500),
230 0.8,
231 false,
232 );
233 assert_eq!(h2.ref_label, "F2");
234
235 let h3 = reg.register(
236 ContextItemId::from_shell("pytest"),
237 ContextKind::Shell,
238 "pytest_latest",
239 "test run output",
240 &sample_view_costs(2000),
241 0.7,
242 false,
243 );
244 assert_eq!(h3.ref_label, "S1");
245
246 let h4 = reg.register(
247 ContextItemId::from_knowledge("domain", "billing"),
248 ContextKind::Knowledge,
249 "billing_rules",
250 "annual billing assumption",
251 &sample_view_costs(100),
252 0.95,
253 true,
254 );
255 assert_eq!(h4.ref_label, "K1");
256 }
257
258 #[test]
259 fn resolve_by_ref_label() {
260 let mut reg = HandleRegistry::new();
261 reg.register(
262 ContextItemId::from_file("x.rs"),
263 ContextKind::File,
264 "x.rs",
265 "file X",
266 &sample_view_costs(400),
267 0.85,
268 false,
269 );
270 reg.register(
271 ContextItemId::from_shell("cargo test"),
272 ContextKind::Shell,
273 "cargo_test",
274 "test output",
275 &sample_view_costs(800),
276 0.6,
277 false,
278 );
279
280 assert!(reg.resolve("F1").is_some());
281 assert_eq!(reg.resolve("F1").unwrap().source_path, "x.rs");
282
283 assert!(reg.resolve("@S1").is_some());
284 assert_eq!(reg.resolve("@S1").unwrap().source_path, "cargo_test");
285
286 assert!(reg.resolve("F99").is_none());
287 }
288
289 #[test]
290 fn resolve_by_item_id() {
291 let mut reg = HandleRegistry::new();
292 let id = ContextItemId::from_file("main.rs");
293 reg.register(
294 id.clone(),
295 ContextKind::File,
296 "main.rs",
297 "entrypoint",
298 &sample_view_costs(600),
299 0.92,
300 false,
301 );
302
303 let found = reg.resolve_by_item(&id);
304 assert!(found.is_some());
305 assert_eq!(found.unwrap().ref_label, "F1");
306
307 let missing = reg.resolve_by_item(&ContextItemId::from_file("nope.rs"));
308 assert!(missing.is_none());
309 }
310
311 #[test]
312 fn manifest_formatting() {
313 let mut reg = HandleRegistry::new();
314 reg.register(
315 ContextItemId::from_file("billing/service.ts"),
316 ContextKind::File,
317 "billing/service.ts",
318 "exports: createInvoice, calculateTax",
319 &sample_view_costs(2000),
320 0.93,
321 false,
322 );
323 reg.register(
324 ContextItemId::from_knowledge("domain", "annual"),
325 ContextKind::Knowledge,
326 "annual_billing",
327 "assumption",
328 &sample_view_costs(200),
329 0.95,
330 true,
331 );
332
333 let manifest = reg.format_manifest(12000, 2460);
334
335 assert!(manifest.contains("Context Handles"));
336 assert!(manifest.contains("@F1"));
337 assert!(manifest.contains("billing/service.ts"));
338 assert!(manifest.contains("phi=0.93"));
339 assert!(manifest.contains("@K1"));
340 assert!(manifest.contains("[pinned]"));
341 assert!(manifest.contains("Budget: 2460/12000 tokens"));
342 assert!(manifest.contains("remaining"));
343 }
344
345 #[test]
346 fn manifest_empty_registry() {
347 let reg = HandleRegistry::new();
348 let manifest = reg.format_manifest(10000, 0);
349 assert!(manifest.is_empty());
350 }
351
352 #[test]
353 fn total_handle_tokens() {
354 let mut reg = HandleRegistry::new();
355 reg.register(
356 ContextItemId::from_file("a.rs"),
357 ContextKind::File,
358 "a.rs",
359 "mod A",
360 &sample_view_costs(1000),
361 0.8,
362 false,
363 );
364 reg.register(
365 ContextItemId::from_file("b.rs"),
366 ContextKind::File,
367 "b.rs",
368 "mod B",
369 &sample_view_costs(2000),
370 0.7,
371 false,
372 );
373
374 let total = reg.total_handle_tokens();
375 assert_eq!(
376 total,
377 25 + 25,
378 "both handles should use ViewKind::Handle cost (25)"
379 );
380 }
381
382 #[test]
383 fn multiple_registrations_sequential() {
384 let mut reg = HandleRegistry::new();
385 for i in 1..=5 {
386 let path = format!("file_{i}.rs");
387 let id = ContextItemId::from_file(&path);
388 reg.register(
389 id,
390 ContextKind::File,
391 &path,
392 "some module",
393 &sample_view_costs(500),
394 0.5,
395 false,
396 );
397 }
398
399 assert_eq!(reg.all().len(), 5);
400 assert_eq!(reg.all()[0].ref_label, "F1");
401 assert_eq!(reg.all()[1].ref_label, "F2");
402 assert_eq!(reg.all()[2].ref_label, "F3");
403 assert_eq!(reg.all()[3].ref_label, "F4");
404 assert_eq!(reg.all()[4].ref_label, "F5");
405 }
406
407 #[test]
408 fn mixed_kinds_independent_counters() {
409 let mut reg = HandleRegistry::new();
410
411 reg.register(
412 ContextItemId::from_file("a.rs"),
413 ContextKind::File,
414 "a.rs",
415 "file",
416 &sample_view_costs(100),
417 0.5,
418 false,
419 );
420 reg.register(
421 ContextItemId::from_shell("ls"),
422 ContextKind::Shell,
423 "ls",
424 "listing",
425 &sample_view_costs(100),
426 0.5,
427 false,
428 );
429 reg.register(
430 ContextItemId::from_file("b.rs"),
431 ContextKind::File,
432 "b.rs",
433 "file",
434 &sample_view_costs(100),
435 0.5,
436 false,
437 );
438 reg.register(
439 ContextItemId::from_memory("session"),
440 ContextKind::Memory,
441 "session_state",
442 "memory",
443 &sample_view_costs(100),
444 0.5,
445 false,
446 );
447 reg.register(
448 ContextItemId::from_provider("github", "pr"),
449 ContextKind::Provider,
450 "github/pr/123",
451 "pull request",
452 &sample_view_costs(100),
453 0.5,
454 false,
455 );
456
457 assert_eq!(reg.resolve("F1").unwrap().source_path, "a.rs");
458 assert_eq!(reg.resolve("S1").unwrap().source_path, "ls");
459 assert_eq!(reg.resolve("F2").unwrap().source_path, "b.rs");
460 assert_eq!(reg.resolve("M1").unwrap().source_path, "session_state");
461 assert_eq!(reg.resolve("P1").unwrap().source_path, "github/pr/123");
462 }
463
464 #[test]
465 fn available_views_sorted_by_density() {
466 let mut reg = HandleRegistry::new();
467 let h = reg.register(
468 ContextItemId::from_file("c.rs"),
469 ContextKind::File,
470 "c.rs",
471 "module C",
472 &sample_view_costs(4000),
473 0.9,
474 false,
475 );
476
477 let ranks: Vec<u8> = h
478 .available_views
479 .iter()
480 .map(|(v, _)| v.density_rank())
481 .collect();
482
483 for window in ranks.windows(2) {
484 assert!(
485 window[0] <= window[1],
486 "views should be sorted by density rank (dense first)"
487 );
488 }
489 }
490
491 #[test]
492 fn handle_tokens_fallback_without_handle_view() {
493 let mut costs = ViewCosts::new();
494 costs.set(ViewKind::Full, 5000);
495 costs.set(ViewKind::Signatures, 1000);
496
497 let mut reg = HandleRegistry::new();
498 let h = reg.register(
499 ContextItemId::from_file("big.rs"),
500 ContextKind::File,
501 "src/core/big_module.rs",
502 "large module with many exports",
503 &costs,
504 0.88,
505 false,
506 );
507
508 assert!(
509 h.handle_tokens >= 5,
510 "fallback should produce at least 5 tokens"
511 );
512 assert!(
513 h.handle_tokens <= 50,
514 "fallback should produce at most 50 tokens"
515 );
516 }
517
518 #[test]
519 fn budget_remaining_percentage() {
520 let reg = {
521 let mut r = HandleRegistry::new();
522 r.register(
523 ContextItemId::from_file("x.rs"),
524 ContextKind::File,
525 "x.rs",
526 "x",
527 &sample_view_costs(100),
528 0.5,
529 false,
530 );
531 r
532 };
533
534 let manifest = reg.format_manifest(10000, 2000);
535 assert!(manifest.contains("80.0% remaining"));
536 }
537}