1use crate::budget::ContextBudget;
11use crate::segment::{ContextPriority, ContextSegment, ContextSegmentType};
12use crate::token_counter::TokenCounter;
13use crate::window::ContextWindow;
14use chrono::{DateTime, Utc};
15use enact_core::kernel::ExecutionId;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19
20#[allow(dead_code)]
22static SEGMENT_SEQUENCE: AtomicU64 = AtomicU64::new(1000);
23
24#[allow(dead_code)]
25fn next_sequence() -> u64 {
26 SEGMENT_SEQUENCE.fetch_add(1, Ordering::SeqCst)
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct CalibrationConfig {
33 pub max_tokens: usize,
35
36 pub response_reserve: usize,
38
39 pub min_priority: ContextPriority,
41
42 pub include_system: bool,
44
45 pub include_history: bool,
47
48 pub max_history_messages: usize,
50
51 pub include_working_memory: bool,
53
54 pub include_rag: bool,
56
57 pub max_rag_chunks: usize,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub segment_filters: Option<HashMap<String, bool>>,
63}
64
65impl Default for CalibrationConfig {
66 fn default() -> Self {
67 Self {
68 max_tokens: 8000,
69 response_reserve: 2000,
70 min_priority: ContextPriority::Low,
71 include_system: true,
72 include_history: true,
73 max_history_messages: 20,
74 include_working_memory: true,
75 include_rag: true,
76 max_rag_chunks: 5,
77 segment_filters: None,
78 }
79 }
80}
81
82impl CalibrationConfig {
83 pub fn minimal() -> Self {
85 Self {
86 max_tokens: 4000,
87 response_reserve: 1000,
88 min_priority: ContextPriority::High,
89 include_system: true,
90 include_history: false,
91 max_history_messages: 0,
92 include_working_memory: false,
93 include_rag: false,
94 max_rag_chunks: 0,
95 segment_filters: None,
96 }
97 }
98
99 pub fn full_context() -> Self {
101 Self {
102 max_tokens: 32000,
103 response_reserve: 4000,
104 min_priority: ContextPriority::Low,
105 include_system: true,
106 include_history: true,
107 max_history_messages: 50,
108 include_working_memory: true,
109 include_rag: true,
110 max_rag_chunks: 10,
111 segment_filters: None,
112 }
113 }
114
115 pub fn available_tokens(&self) -> usize {
117 self.max_tokens.saturating_sub(self.response_reserve)
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123#[serde(rename_all = "camelCase")]
124pub struct CalibratedPrompt {
125 pub execution_id: ExecutionId,
127
128 pub segments: Vec<ContextSegment>,
130
131 pub total_tokens: usize,
133
134 pub response_tokens: usize,
136
137 pub excluded_count: usize,
139
140 pub calibrated_at: DateTime<Utc>,
142
143 pub config: CalibrationConfig,
145}
146
147impl CalibratedPrompt {
148 pub fn as_text(&self) -> String {
150 self.segments
151 .iter()
152 .map(|s| s.content.clone())
153 .collect::<Vec<_>>()
154 .join("\n\n")
155 }
156
157 pub fn segments_by_type(&self, segment_type: ContextSegmentType) -> Vec<&ContextSegment> {
159 self.segments
160 .iter()
161 .filter(|s| s.segment_type == segment_type)
162 .collect()
163 }
164
165 pub fn has_system(&self) -> bool {
167 self.segments
168 .iter()
169 .any(|s| s.segment_type == ContextSegmentType::System)
170 }
171
172 pub fn has_history(&self) -> bool {
174 self.segments
175 .iter()
176 .any(|s| s.segment_type == ContextSegmentType::History)
177 }
178}
179
180pub struct PromptCalibrator {
182 token_counter: TokenCounter,
183}
184
185impl PromptCalibrator {
186 pub fn new() -> Self {
188 Self {
189 token_counter: TokenCounter::default(),
190 }
191 }
192
193 pub fn calibrate(
195 &self,
196 window: &ContextWindow,
197 config: &CalibrationConfig,
198 ) -> CalibratedPrompt {
199 let execution_id = window.budget().execution_id.clone();
200 let available = config.available_tokens();
201
202 let mut segments = window.segments().to_vec();
204 segments.sort_by(|a, b| b.priority.cmp(&a.priority));
205
206 let mut selected: Vec<ContextSegment> = Vec::new();
208 let mut total_tokens = 0;
209 let mut excluded_count = 0;
210 let mut history_count = 0;
211 let mut rag_count = 0;
212
213 for segment in segments {
214 if segment.priority < config.min_priority {
216 excluded_count += 1;
217 continue;
218 }
219
220 match segment.segment_type {
222 ContextSegmentType::System if !config.include_system => {
223 excluded_count += 1;
224 continue;
225 }
226 ContextSegmentType::History if !config.include_history => {
227 excluded_count += 1;
228 continue;
229 }
230 ContextSegmentType::History if history_count >= config.max_history_messages => {
231 excluded_count += 1;
232 continue;
233 }
234 ContextSegmentType::WorkingMemory if !config.include_working_memory => {
235 excluded_count += 1;
236 continue;
237 }
238 ContextSegmentType::RagContext if !config.include_rag => {
239 excluded_count += 1;
240 continue;
241 }
242 ContextSegmentType::RagContext if rag_count >= config.max_rag_chunks => {
243 excluded_count += 1;
244 continue;
245 }
246 _ => {}
247 }
248
249 let segment_tokens = segment.token_count;
251 if total_tokens + segment_tokens > available {
252 excluded_count += 1;
253 continue;
254 }
255
256 total_tokens += segment_tokens;
258 if segment.segment_type == ContextSegmentType::History {
259 history_count += 1;
260 }
261 if segment.segment_type == ContextSegmentType::RagContext {
262 rag_count += 1;
263 }
264 selected.push(segment);
265 }
266
267 selected.sort_by(|a, b| {
269 if a.segment_type == ContextSegmentType::System
271 && b.segment_type != ContextSegmentType::System
272 {
273 return std::cmp::Ordering::Less;
274 }
275 if b.segment_type == ContextSegmentType::System
276 && a.segment_type != ContextSegmentType::System
277 {
278 return std::cmp::Ordering::Greater;
279 }
280 a.sequence.cmp(&b.sequence)
282 });
283
284 CalibratedPrompt {
285 execution_id,
286 segments: selected,
287 total_tokens,
288 response_tokens: config.max_tokens.saturating_sub(total_tokens),
289 excluded_count,
290 calibrated_at: Utc::now(),
291 config: config.clone(),
292 }
293 }
294
295 pub fn calibrate_segments(
297 &self,
298 execution_id: ExecutionId,
299 segments: Vec<ContextSegment>,
300 config: &CalibrationConfig,
301 ) -> CalibratedPrompt {
302 let budget = ContextBudget::new(
304 execution_id.clone(),
305 config.max_tokens,
306 config.response_reserve,
307 );
308 let mut window = ContextWindow::new(budget).expect("valid budget");
309
310 for segment in segments {
311 let _ = window.add_segment(segment);
312 }
313
314 self.calibrate(&window, config)
315 }
316
317 pub fn calibrate_for_child(
322 &self,
323 parent_window: &ContextWindow,
324 child_execution_id: ExecutionId,
325 task_description: &str,
326 config: &CalibrationConfig,
327 ) -> CalibratedPrompt {
328 let available = config.available_tokens();
329
330 let mut selected: Vec<ContextSegment> = Vec::new();
332 let mut total_tokens = 0;
333
334 let task_content = format!(
336 "You are executing a sub-task. Task: {}\n\nParent context follows:",
337 task_description
338 );
339 let task_tokens = self.token_counter.count(&task_content);
340 if task_tokens <= available {
341 let task_segment = ContextSegment::system(task_content, task_tokens);
342 total_tokens += task_tokens;
343 selected.push(task_segment);
344 }
345
346 let mut parent_segments = parent_window.segments().to_vec();
348 parent_segments.sort_by(|a, b| b.priority.cmp(&a.priority));
349
350 let mut excluded_count = 0;
351
352 for segment in parent_segments {
354 if segment.priority < ContextPriority::Medium {
356 excluded_count += 1;
357 continue;
358 }
359
360 let segment_tokens = segment.token_count;
362 if total_tokens + segment_tokens > available {
363 excluded_count += 1;
364 continue;
365 }
366
367 total_tokens += segment_tokens;
368 selected.push(segment);
369 }
370
371 selected.sort_by(|a, b| {
373 if a.segment_type == ContextSegmentType::System
374 && b.segment_type != ContextSegmentType::System
375 {
376 return std::cmp::Ordering::Less;
377 }
378 if b.segment_type == ContextSegmentType::System
379 && a.segment_type != ContextSegmentType::System
380 {
381 return std::cmp::Ordering::Greater;
382 }
383 a.sequence.cmp(&b.sequence)
384 });
385
386 CalibratedPrompt {
387 execution_id: child_execution_id,
388 segments: selected,
389 total_tokens,
390 response_tokens: config.max_tokens.saturating_sub(total_tokens),
391 excluded_count,
392 calibrated_at: Utc::now(),
393 config: config.clone(),
394 }
395 }
396}
397
398impl Default for PromptCalibrator {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 fn test_execution_id() -> ExecutionId {
409 ExecutionId::new()
410 }
411
412 #[test]
413 fn test_calibration_config_defaults() {
414 let config = CalibrationConfig::default();
415 assert_eq!(config.max_tokens, 8000);
416 assert_eq!(config.response_reserve, 2000);
417 assert_eq!(config.available_tokens(), 6000);
418 }
419
420 #[test]
421 fn test_calibration_config_minimal() {
422 let config = CalibrationConfig::minimal();
423 assert!(!config.include_history);
424 assert!(!config.include_working_memory);
425 assert_eq!(config.min_priority, ContextPriority::High);
426 }
427
428 #[test]
429 fn test_calibrate_empty_window() {
430 let calibrator = PromptCalibrator::new();
431 let budget = ContextBudget::preset_default(test_execution_id());
432 let window = ContextWindow::new(budget).unwrap();
433 let config = CalibrationConfig::default();
434
435 let result = calibrator.calibrate(&window, &config);
436 assert_eq!(result.segments.len(), 0);
437 assert_eq!(result.total_tokens, 0);
438 }
439
440 #[test]
441 fn test_calibrate_with_segments() {
442 let calibrator = PromptCalibrator::new();
443 let budget = ContextBudget::preset_default(test_execution_id());
444 let mut window = ContextWindow::new(budget).unwrap();
445
446 window
447 .add_segment(ContextSegment::system("You are a helpful assistant.", 10))
448 .unwrap();
449 window
450 .add_segment(ContextSegment::user_input("Hello!", 5, 1))
451 .unwrap();
452
453 let config = CalibrationConfig::default();
454 let result = calibrator.calibrate(&window, &config);
455
456 assert_eq!(result.segments.len(), 2);
457 assert!(result.total_tokens > 0);
458 assert!(result.has_system());
459 }
460
461 #[test]
462 fn test_calibrate_respects_priority() {
463 let calibrator = PromptCalibrator::new();
464 let budget = ContextBudget::preset_default(test_execution_id());
465 let mut window = ContextWindow::new(budget).unwrap();
466
467 window
468 .add_segment(ContextSegment::system("System prompt", 10))
469 .unwrap();
470 window
471 .add_segment(
472 ContextSegment::new(
473 ContextSegmentType::History,
474 "Low priority history".to_string(),
475 20,
476 1,
477 )
478 .with_priority(ContextPriority::Low),
479 )
480 .unwrap();
481
482 let config = CalibrationConfig {
484 min_priority: ContextPriority::High,
485 ..Default::default()
486 };
487 let result = calibrator.calibrate(&window, &config);
488
489 assert_eq!(result.segments.len(), 1);
491 assert!(result.has_system());
492 assert!(!result.has_history());
493 }
494
495 #[test]
496 fn test_calibrate_for_child() {
497 let calibrator = PromptCalibrator::new();
498 let parent_budget = ContextBudget::preset_default(test_execution_id());
499 let mut parent_window = ContextWindow::new(parent_budget).unwrap();
500
501 parent_window
502 .add_segment(ContextSegment::system("Parent system prompt", 15))
503 .unwrap();
504 parent_window
505 .add_segment(ContextSegment::user_input("Parent user input", 10, 1))
506 .unwrap();
507
508 let child_id = ExecutionId::new();
509 let config = CalibrationConfig::default();
510 let result =
511 calibrator.calibrate_for_child(&parent_window, child_id, "Analyze data", &config);
512
513 assert!(result.total_tokens > 0);
515 assert!(result
516 .segments
517 .iter()
518 .any(|s| s.content.contains("sub-task")));
519 }
520
521 #[test]
522 fn test_calibrated_prompt_as_text() {
523 let calibrator = PromptCalibrator::new();
524 let budget = ContextBudget::preset_default(test_execution_id());
525 let mut window = ContextWindow::new(budget).unwrap();
526
527 window
528 .add_segment(ContextSegment::system("System", 5))
529 .unwrap();
530 window
531 .add_segment(ContextSegment::user_input("User", 5, 1))
532 .unwrap();
533
534 let config = CalibrationConfig::default();
535 let result = calibrator.calibrate(&window, &config);
536
537 let text = result.as_text();
538 assert!(text.contains("System"));
539 assert!(text.contains("User"));
540 }
541}