use crate::budget::ContextBudget;
use crate::segment::{ContextPriority, ContextSegment, ContextSegmentType};
use crate::token_counter::TokenCounter;
use crate::window::ContextWindow;
use chrono::{DateTime, Utc};
use enact_core::kernel::ExecutionId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
#[allow(dead_code)]
static SEGMENT_SEQUENCE: AtomicU64 = AtomicU64::new(1000);
#[allow(dead_code)]
fn next_sequence() -> u64 {
SEGMENT_SEQUENCE.fetch_add(1, Ordering::SeqCst)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CalibrationConfig {
pub max_tokens: usize,
pub response_reserve: usize,
pub min_priority: ContextPriority,
pub include_system: bool,
pub include_history: bool,
pub max_history_messages: usize,
pub include_working_memory: bool,
pub include_rag: bool,
pub max_rag_chunks: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub segment_filters: Option<HashMap<String, bool>>,
}
impl Default for CalibrationConfig {
fn default() -> Self {
Self {
max_tokens: 8000,
response_reserve: 2000,
min_priority: ContextPriority::Low,
include_system: true,
include_history: true,
max_history_messages: 20,
include_working_memory: true,
include_rag: true,
max_rag_chunks: 5,
segment_filters: None,
}
}
}
impl CalibrationConfig {
pub fn minimal() -> Self {
Self {
max_tokens: 4000,
response_reserve: 1000,
min_priority: ContextPriority::High,
include_system: true,
include_history: false,
max_history_messages: 0,
include_working_memory: false,
include_rag: false,
max_rag_chunks: 0,
segment_filters: None,
}
}
pub fn full_context() -> Self {
Self {
max_tokens: 32000,
response_reserve: 4000,
min_priority: ContextPriority::Low,
include_system: true,
include_history: true,
max_history_messages: 50,
include_working_memory: true,
include_rag: true,
max_rag_chunks: 10,
segment_filters: None,
}
}
pub fn available_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.response_reserve)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CalibratedPrompt {
pub execution_id: ExecutionId,
pub segments: Vec<ContextSegment>,
pub total_tokens: usize,
pub response_tokens: usize,
pub excluded_count: usize,
pub calibrated_at: DateTime<Utc>,
pub config: CalibrationConfig,
}
impl CalibratedPrompt {
pub fn as_text(&self) -> String {
self.segments
.iter()
.map(|s| s.content.clone())
.collect::<Vec<_>>()
.join("\n\n")
}
pub fn segments_by_type(&self, segment_type: ContextSegmentType) -> Vec<&ContextSegment> {
self.segments
.iter()
.filter(|s| s.segment_type == segment_type)
.collect()
}
pub fn has_system(&self) -> bool {
self.segments
.iter()
.any(|s| s.segment_type == ContextSegmentType::System)
}
pub fn has_history(&self) -> bool {
self.segments
.iter()
.any(|s| s.segment_type == ContextSegmentType::History)
}
}
pub struct PromptCalibrator {
token_counter: TokenCounter,
}
impl PromptCalibrator {
pub fn new() -> Self {
Self {
token_counter: TokenCounter::default(),
}
}
pub fn calibrate(
&self,
window: &ContextWindow,
config: &CalibrationConfig,
) -> CalibratedPrompt {
let execution_id = window.budget().execution_id.clone();
let available = config.available_tokens();
let mut segments = window.segments().to_vec();
segments.sort_by(|a, b| b.priority.cmp(&a.priority));
let mut selected: Vec<ContextSegment> = Vec::new();
let mut total_tokens = 0;
let mut excluded_count = 0;
let mut history_count = 0;
let mut rag_count = 0;
for segment in segments {
if segment.priority < config.min_priority {
excluded_count += 1;
continue;
}
match segment.segment_type {
ContextSegmentType::System if !config.include_system => {
excluded_count += 1;
continue;
}
ContextSegmentType::History if !config.include_history => {
excluded_count += 1;
continue;
}
ContextSegmentType::History if history_count >= config.max_history_messages => {
excluded_count += 1;
continue;
}
ContextSegmentType::WorkingMemory if !config.include_working_memory => {
excluded_count += 1;
continue;
}
ContextSegmentType::RagContext if !config.include_rag => {
excluded_count += 1;
continue;
}
ContextSegmentType::RagContext if rag_count >= config.max_rag_chunks => {
excluded_count += 1;
continue;
}
_ => {}
}
let segment_tokens = segment.token_count;
if total_tokens + segment_tokens > available {
excluded_count += 1;
continue;
}
total_tokens += segment_tokens;
if segment.segment_type == ContextSegmentType::History {
history_count += 1;
}
if segment.segment_type == ContextSegmentType::RagContext {
rag_count += 1;
}
selected.push(segment);
}
selected.sort_by(|a, b| {
if a.segment_type == ContextSegmentType::System
&& b.segment_type != ContextSegmentType::System
{
return std::cmp::Ordering::Less;
}
if b.segment_type == ContextSegmentType::System
&& a.segment_type != ContextSegmentType::System
{
return std::cmp::Ordering::Greater;
}
a.sequence.cmp(&b.sequence)
});
CalibratedPrompt {
execution_id,
segments: selected,
total_tokens,
response_tokens: config.max_tokens.saturating_sub(total_tokens),
excluded_count,
calibrated_at: Utc::now(),
config: config.clone(),
}
}
pub fn calibrate_segments(
&self,
execution_id: ExecutionId,
segments: Vec<ContextSegment>,
config: &CalibrationConfig,
) -> CalibratedPrompt {
let budget = ContextBudget::new(
execution_id.clone(),
config.max_tokens,
config.response_reserve,
);
let mut window = ContextWindow::new(budget).expect("valid budget");
for segment in segments {
let _ = window.add_segment(segment);
}
self.calibrate(&window, config)
}
pub fn calibrate_for_child(
&self,
parent_window: &ContextWindow,
child_execution_id: ExecutionId,
task_description: &str,
config: &CalibrationConfig,
) -> CalibratedPrompt {
let available = config.available_tokens();
let mut selected: Vec<ContextSegment> = Vec::new();
let mut total_tokens = 0;
let task_content = format!(
"You are executing a sub-task. Task: {}\n\nParent context follows:",
task_description
);
let task_tokens = self.token_counter.count(&task_content);
if task_tokens <= available {
let task_segment = ContextSegment::system(task_content, task_tokens);
total_tokens += task_tokens;
selected.push(task_segment);
}
let mut parent_segments = parent_window.segments().to_vec();
parent_segments.sort_by(|a, b| b.priority.cmp(&a.priority));
let mut excluded_count = 0;
for segment in parent_segments {
if segment.priority < ContextPriority::Medium {
excluded_count += 1;
continue;
}
let segment_tokens = segment.token_count;
if total_tokens + segment_tokens > available {
excluded_count += 1;
continue;
}
total_tokens += segment_tokens;
selected.push(segment);
}
selected.sort_by(|a, b| {
if a.segment_type == ContextSegmentType::System
&& b.segment_type != ContextSegmentType::System
{
return std::cmp::Ordering::Less;
}
if b.segment_type == ContextSegmentType::System
&& a.segment_type != ContextSegmentType::System
{
return std::cmp::Ordering::Greater;
}
a.sequence.cmp(&b.sequence)
});
CalibratedPrompt {
execution_id: child_execution_id,
segments: selected,
total_tokens,
response_tokens: config.max_tokens.saturating_sub(total_tokens),
excluded_count,
calibrated_at: Utc::now(),
config: config.clone(),
}
}
}
impl Default for PromptCalibrator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_execution_id() -> ExecutionId {
ExecutionId::new()
}
#[test]
fn test_calibration_config_defaults() {
let config = CalibrationConfig::default();
assert_eq!(config.max_tokens, 8000);
assert_eq!(config.response_reserve, 2000);
assert_eq!(config.available_tokens(), 6000);
}
#[test]
fn test_calibration_config_minimal() {
let config = CalibrationConfig::minimal();
assert!(!config.include_history);
assert!(!config.include_working_memory);
assert_eq!(config.min_priority, ContextPriority::High);
}
#[test]
fn test_calibrate_empty_window() {
let calibrator = PromptCalibrator::new();
let budget = ContextBudget::preset_default(test_execution_id());
let window = ContextWindow::new(budget).unwrap();
let config = CalibrationConfig::default();
let result = calibrator.calibrate(&window, &config);
assert_eq!(result.segments.len(), 0);
assert_eq!(result.total_tokens, 0);
}
#[test]
fn test_calibrate_with_segments() {
let calibrator = PromptCalibrator::new();
let budget = ContextBudget::preset_default(test_execution_id());
let mut window = ContextWindow::new(budget).unwrap();
window
.add_segment(ContextSegment::system("You are a helpful assistant.", 10))
.unwrap();
window
.add_segment(ContextSegment::user_input("Hello!", 5, 1))
.unwrap();
let config = CalibrationConfig::default();
let result = calibrator.calibrate(&window, &config);
assert_eq!(result.segments.len(), 2);
assert!(result.total_tokens > 0);
assert!(result.has_system());
}
#[test]
fn test_calibrate_respects_priority() {
let calibrator = PromptCalibrator::new();
let budget = ContextBudget::preset_default(test_execution_id());
let mut window = ContextWindow::new(budget).unwrap();
window
.add_segment(ContextSegment::system("System prompt", 10))
.unwrap();
window
.add_segment(
ContextSegment::new(
ContextSegmentType::History,
"Low priority history".to_string(),
20,
1,
)
.with_priority(ContextPriority::Low),
)
.unwrap();
let config = CalibrationConfig {
min_priority: ContextPriority::High,
..Default::default()
};
let result = calibrator.calibrate(&window, &config);
assert_eq!(result.segments.len(), 1);
assert!(result.has_system());
assert!(!result.has_history());
}
#[test]
fn test_calibrate_for_child() {
let calibrator = PromptCalibrator::new();
let parent_budget = ContextBudget::preset_default(test_execution_id());
let mut parent_window = ContextWindow::new(parent_budget).unwrap();
parent_window
.add_segment(ContextSegment::system("Parent system prompt", 15))
.unwrap();
parent_window
.add_segment(ContextSegment::user_input("Parent user input", 10, 1))
.unwrap();
let child_id = ExecutionId::new();
let config = CalibrationConfig::default();
let result =
calibrator.calibrate_for_child(&parent_window, child_id, "Analyze data", &config);
assert!(result.total_tokens > 0);
assert!(result
.segments
.iter()
.any(|s| s.content.contains("sub-task")));
}
#[test]
fn test_calibrated_prompt_as_text() {
let calibrator = PromptCalibrator::new();
let budget = ContextBudget::preset_default(test_execution_id());
let mut window = ContextWindow::new(budget).unwrap();
window
.add_segment(ContextSegment::system("System", 5))
.unwrap();
window
.add_segment(ContextSegment::user_input("User", 5, 1))
.unwrap();
let config = CalibrationConfig::default();
let result = calibrator.calibrate(&window, &config);
let text = result.as_text();
assert!(text.contains("System"));
assert!(text.contains("User"));
}
}