Skip to main content

enact_context/
calibrator.rs

1//! Prompt Calibrator
2//!
3//! Constructs calibrated prompts for spawned callables by:
4//! - Selecting relevant context segments based on priority
5//! - Applying token budgets for the target model
6//! - Formatting context for optimal model consumption
7//!
8//! @see packages/enact-schemas/src/context.schemas.ts
9
10use crate::budget::ContextBudget;
11use crate::segment::{ContextPriority, ContextSegment, ContextSegmentType};
12use crate::token_counter::TokenCounter;
13use crate::window::ContextWindow;
14use chrono::{DateTime, Utc};
15use enact_core::kernel::ExecutionId;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19
20/// Global sequence counter for segments
21#[allow(dead_code)]
22static SEGMENT_SEQUENCE: AtomicU64 = AtomicU64::new(1000);
23
24#[allow(dead_code)]
25fn next_sequence() -> u64 {
26    SEGMENT_SEQUENCE.fetch_add(1, Ordering::SeqCst)
27}
28
29/// Prompt calibration configuration
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct CalibrationConfig {
33    /// Maximum tokens for the calibrated prompt
34    pub max_tokens: usize,
35
36    /// Minimum tokens to reserve for response
37    pub response_reserve: usize,
38
39    /// Priority threshold - only include segments at or above this priority
40    pub min_priority: ContextPriority,
41
42    /// Whether to include system context
43    pub include_system: bool,
44
45    /// Whether to include conversation history
46    pub include_history: bool,
47
48    /// Maximum history messages to include
49    pub max_history_messages: usize,
50
51    /// Whether to include working memory
52    pub include_working_memory: bool,
53
54    /// Whether to include RAG context
55    pub include_rag: bool,
56
57    /// Maximum RAG chunks to include
58    pub max_rag_chunks: usize,
59
60    /// Custom segment filters by type
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub segment_filters: Option<HashMap<String, bool>>,
63}
64
65impl Default for CalibrationConfig {
66    fn default() -> Self {
67        Self {
68            max_tokens: 8000,
69            response_reserve: 2000,
70            min_priority: ContextPriority::Low,
71            include_system: true,
72            include_history: true,
73            max_history_messages: 20,
74            include_working_memory: true,
75            include_rag: true,
76            max_rag_chunks: 5,
77            segment_filters: None,
78        }
79    }
80}
81
82impl CalibrationConfig {
83    /// Create a minimal config for quick tasks
84    pub fn minimal() -> Self {
85        Self {
86            max_tokens: 4000,
87            response_reserve: 1000,
88            min_priority: ContextPriority::High,
89            include_system: true,
90            include_history: false,
91            max_history_messages: 0,
92            include_working_memory: false,
93            include_rag: false,
94            max_rag_chunks: 0,
95            segment_filters: None,
96        }
97    }
98
99    /// Create a full context config for complex tasks
100    pub fn full_context() -> Self {
101        Self {
102            max_tokens: 32000,
103            response_reserve: 4000,
104            min_priority: ContextPriority::Low,
105            include_system: true,
106            include_history: true,
107            max_history_messages: 50,
108            include_working_memory: true,
109            include_rag: true,
110            max_rag_chunks: 10,
111            segment_filters: None,
112        }
113    }
114
115    /// Available tokens after response reserve
116    pub fn available_tokens(&self) -> usize {
117        self.max_tokens.saturating_sub(self.response_reserve)
118    }
119}
120
121/// Result of prompt calibration
122#[derive(Debug, Clone, Serialize, Deserialize)]
123#[serde(rename_all = "camelCase")]
124pub struct CalibratedPrompt {
125    /// Execution ID this prompt was calibrated for
126    pub execution_id: ExecutionId,
127
128    /// The calibrated segments in order
129    pub segments: Vec<ContextSegment>,
130
131    /// Total tokens used
132    pub total_tokens: usize,
133
134    /// Tokens available for response
135    pub response_tokens: usize,
136
137    /// Segments that were excluded due to budget
138    pub excluded_count: usize,
139
140    /// Calibration timestamp
141    pub calibrated_at: DateTime<Utc>,
142
143    /// Configuration used
144    pub config: CalibrationConfig,
145}
146
147impl CalibratedPrompt {
148    /// Get the formatted prompt as a string
149    pub fn as_text(&self) -> String {
150        self.segments
151            .iter()
152            .map(|s| s.content.clone())
153            .collect::<Vec<_>>()
154            .join("\n\n")
155    }
156
157    /// Get segments by type
158    pub fn segments_by_type(&self, segment_type: ContextSegmentType) -> Vec<&ContextSegment> {
159        self.segments
160            .iter()
161            .filter(|s| s.segment_type == segment_type)
162            .collect()
163    }
164
165    /// Check if prompt has system context
166    pub fn has_system(&self) -> bool {
167        self.segments
168            .iter()
169            .any(|s| s.segment_type == ContextSegmentType::System)
170    }
171
172    /// Check if prompt has history
173    pub fn has_history(&self) -> bool {
174        self.segments
175            .iter()
176            .any(|s| s.segment_type == ContextSegmentType::History)
177    }
178}
179
180/// Prompt Calibrator - constructs calibrated prompts for spawned callables
181pub struct PromptCalibrator {
182    token_counter: TokenCounter,
183}
184
185impl PromptCalibrator {
186    /// Create a new calibrator
187    pub fn new() -> Self {
188        Self {
189            token_counter: TokenCounter::default(),
190        }
191    }
192
193    /// Calibrate a prompt from a context window
194    pub fn calibrate(
195        &self,
196        window: &ContextWindow,
197        config: &CalibrationConfig,
198    ) -> CalibratedPrompt {
199        let execution_id = window.budget().execution_id.clone();
200        let available = config.available_tokens();
201
202        // Get all segments sorted by priority (highest first)
203        let mut segments = window.segments().to_vec();
204        segments.sort_by(|a, b| b.priority.cmp(&a.priority));
205
206        // Filter and select segments within budget
207        let mut selected: Vec<ContextSegment> = Vec::new();
208        let mut total_tokens = 0;
209        let mut excluded_count = 0;
210        let mut history_count = 0;
211        let mut rag_count = 0;
212
213        for segment in segments {
214            // Check priority threshold
215            if segment.priority < config.min_priority {
216                excluded_count += 1;
217                continue;
218            }
219
220            // Apply type-specific filters
221            match segment.segment_type {
222                ContextSegmentType::System if !config.include_system => {
223                    excluded_count += 1;
224                    continue;
225                }
226                ContextSegmentType::History if !config.include_history => {
227                    excluded_count += 1;
228                    continue;
229                }
230                ContextSegmentType::History if history_count >= config.max_history_messages => {
231                    excluded_count += 1;
232                    continue;
233                }
234                ContextSegmentType::WorkingMemory if !config.include_working_memory => {
235                    excluded_count += 1;
236                    continue;
237                }
238                ContextSegmentType::RagContext if !config.include_rag => {
239                    excluded_count += 1;
240                    continue;
241                }
242                ContextSegmentType::RagContext if rag_count >= config.max_rag_chunks => {
243                    excluded_count += 1;
244                    continue;
245                }
246                _ => {}
247            }
248
249            // Check if segment fits in budget
250            let segment_tokens = segment.token_count;
251            if total_tokens + segment_tokens > available {
252                excluded_count += 1;
253                continue;
254            }
255
256            // Include segment
257            total_tokens += segment_tokens;
258            if segment.segment_type == ContextSegmentType::History {
259                history_count += 1;
260            }
261            if segment.segment_type == ContextSegmentType::RagContext {
262                rag_count += 1;
263            }
264            selected.push(segment);
265        }
266
267        // Re-sort by natural order (system first, then by sequence)
268        selected.sort_by(|a, b| {
269            // System always first
270            if a.segment_type == ContextSegmentType::System
271                && b.segment_type != ContextSegmentType::System
272            {
273                return std::cmp::Ordering::Less;
274            }
275            if b.segment_type == ContextSegmentType::System
276                && a.segment_type != ContextSegmentType::System
277            {
278                return std::cmp::Ordering::Greater;
279            }
280            // Then by sequence
281            a.sequence.cmp(&b.sequence)
282        });
283
284        CalibratedPrompt {
285            execution_id,
286            segments: selected,
287            total_tokens,
288            response_tokens: config.max_tokens.saturating_sub(total_tokens),
289            excluded_count,
290            calibrated_at: Utc::now(),
291            config: config.clone(),
292        }
293    }
294
295    /// Calibrate from raw segments (without a full window)
296    pub fn calibrate_segments(
297        &self,
298        execution_id: ExecutionId,
299        segments: Vec<ContextSegment>,
300        config: &CalibrationConfig,
301    ) -> CalibratedPrompt {
302        // Create a temporary budget and window
303        let budget = ContextBudget::new(
304            execution_id.clone(),
305            config.max_tokens,
306            config.response_reserve,
307        );
308        let mut window = ContextWindow::new(budget).expect("valid budget");
309
310        for segment in segments {
311            let _ = window.add_segment(segment);
312        }
313
314        self.calibrate(&window, config)
315    }
316
317    /// Create a child prompt from parent context
318    ///
319    /// This is used when spawning child callables to provide them with
320    /// relevant context from the parent execution.
321    pub fn calibrate_for_child(
322        &self,
323        parent_window: &ContextWindow,
324        child_execution_id: ExecutionId,
325        task_description: &str,
326        config: &CalibrationConfig,
327    ) -> CalibratedPrompt {
328        let available = config.available_tokens();
329
330        // Start with essential context for the child
331        let mut selected: Vec<ContextSegment> = Vec::new();
332        let mut total_tokens = 0;
333
334        // Add task description as system context
335        let task_content = format!(
336            "You are executing a sub-task. Task: {}\n\nParent context follows:",
337            task_description
338        );
339        let task_tokens = self.token_counter.count(&task_content);
340        if task_tokens <= available {
341            let task_segment = ContextSegment::system(task_content, task_tokens);
342            total_tokens += task_tokens;
343            selected.push(task_segment);
344        }
345
346        // Get parent segments sorted by priority
347        let mut parent_segments = parent_window.segments().to_vec();
348        parent_segments.sort_by(|a, b| b.priority.cmp(&a.priority));
349
350        let mut excluded_count = 0;
351
352        // Add relevant parent context
353        for segment in parent_segments {
354            // Skip low priority for child contexts
355            if segment.priority < ContextPriority::Medium {
356                excluded_count += 1;
357                continue;
358            }
359
360            // Check budget
361            let segment_tokens = segment.token_count;
362            if total_tokens + segment_tokens > available {
363                excluded_count += 1;
364                continue;
365            }
366
367            total_tokens += segment_tokens;
368            selected.push(segment);
369        }
370
371        // Sort for natural reading order
372        selected.sort_by(|a, b| {
373            if a.segment_type == ContextSegmentType::System
374                && b.segment_type != ContextSegmentType::System
375            {
376                return std::cmp::Ordering::Less;
377            }
378            if b.segment_type == ContextSegmentType::System
379                && a.segment_type != ContextSegmentType::System
380            {
381                return std::cmp::Ordering::Greater;
382            }
383            a.sequence.cmp(&b.sequence)
384        });
385
386        CalibratedPrompt {
387            execution_id: child_execution_id,
388            segments: selected,
389            total_tokens,
390            response_tokens: config.max_tokens.saturating_sub(total_tokens),
391            excluded_count,
392            calibrated_at: Utc::now(),
393            config: config.clone(),
394        }
395    }
396}
397
398impl Default for PromptCalibrator {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    fn test_execution_id() -> ExecutionId {
409        ExecutionId::new()
410    }
411
412    #[test]
413    fn test_calibration_config_defaults() {
414        let config = CalibrationConfig::default();
415        assert_eq!(config.max_tokens, 8000);
416        assert_eq!(config.response_reserve, 2000);
417        assert_eq!(config.available_tokens(), 6000);
418    }
419
420    #[test]
421    fn test_calibration_config_minimal() {
422        let config = CalibrationConfig::minimal();
423        assert!(!config.include_history);
424        assert!(!config.include_working_memory);
425        assert_eq!(config.min_priority, ContextPriority::High);
426    }
427
428    #[test]
429    fn test_calibrate_empty_window() {
430        let calibrator = PromptCalibrator::new();
431        let budget = ContextBudget::preset_default(test_execution_id());
432        let window = ContextWindow::new(budget).unwrap();
433        let config = CalibrationConfig::default();
434
435        let result = calibrator.calibrate(&window, &config);
436        assert_eq!(result.segments.len(), 0);
437        assert_eq!(result.total_tokens, 0);
438    }
439
440    #[test]
441    fn test_calibrate_with_segments() {
442        let calibrator = PromptCalibrator::new();
443        let budget = ContextBudget::preset_default(test_execution_id());
444        let mut window = ContextWindow::new(budget).unwrap();
445
446        window
447            .add_segment(ContextSegment::system("You are a helpful assistant.", 10))
448            .unwrap();
449        window
450            .add_segment(ContextSegment::user_input("Hello!", 5, 1))
451            .unwrap();
452
453        let config = CalibrationConfig::default();
454        let result = calibrator.calibrate(&window, &config);
455
456        assert_eq!(result.segments.len(), 2);
457        assert!(result.total_tokens > 0);
458        assert!(result.has_system());
459    }
460
461    #[test]
462    fn test_calibrate_respects_priority() {
463        let calibrator = PromptCalibrator::new();
464        let budget = ContextBudget::preset_default(test_execution_id());
465        let mut window = ContextWindow::new(budget).unwrap();
466
467        window
468            .add_segment(ContextSegment::system("System prompt", 10))
469            .unwrap();
470        window
471            .add_segment(
472                ContextSegment::new(
473                    ContextSegmentType::History,
474                    "Low priority history".to_string(),
475                    20,
476                    1,
477                )
478                .with_priority(ContextPriority::Low),
479            )
480            .unwrap();
481
482        // With high priority threshold, should exclude low priority
483        let config = CalibrationConfig {
484            min_priority: ContextPriority::High,
485            ..Default::default()
486        };
487        let result = calibrator.calibrate(&window, &config);
488
489        // Only system (critical priority) should be included
490        assert_eq!(result.segments.len(), 1);
491        assert!(result.has_system());
492        assert!(!result.has_history());
493    }
494
495    #[test]
496    fn test_calibrate_for_child() {
497        let calibrator = PromptCalibrator::new();
498        let parent_budget = ContextBudget::preset_default(test_execution_id());
499        let mut parent_window = ContextWindow::new(parent_budget).unwrap();
500
501        parent_window
502            .add_segment(ContextSegment::system("Parent system prompt", 15))
503            .unwrap();
504        parent_window
505            .add_segment(ContextSegment::user_input("Parent user input", 10, 1))
506            .unwrap();
507
508        let child_id = ExecutionId::new();
509        let config = CalibrationConfig::default();
510        let result =
511            calibrator.calibrate_for_child(&parent_window, child_id, "Analyze data", &config);
512
513        // Should have task description and parent context
514        assert!(result.total_tokens > 0);
515        assert!(result
516            .segments
517            .iter()
518            .any(|s| s.content.contains("sub-task")));
519    }
520
521    #[test]
522    fn test_calibrated_prompt_as_text() {
523        let calibrator = PromptCalibrator::new();
524        let budget = ContextBudget::preset_default(test_execution_id());
525        let mut window = ContextWindow::new(budget).unwrap();
526
527        window
528            .add_segment(ContextSegment::system("System", 5))
529            .unwrap();
530        window
531            .add_segment(ContextSegment::user_input("User", 5, 1))
532            .unwrap();
533
534        let config = CalibrationConfig::default();
535        let result = calibrator.calibrate(&window, &config);
536
537        let text = result.as_text();
538        assert!(text.contains("System"));
539        assert!(text.contains("User"));
540    }
541}