use chrono::{DateTime, Utc};
use enact_core::kernel::StepId;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContextSegmentType {
System,
History,
WorkingMemory,
ToolResults,
RagContext,
UserInput,
AgentScratchpad,
ChildSummary,
Guidance,
}
impl ContextSegmentType {
pub fn is_compressible(&self) -> bool {
match self {
Self::System => false,
Self::History => true,
Self::WorkingMemory => true,
Self::ToolResults => true,
Self::RagContext => true,
Self::UserInput => false,
Self::AgentScratchpad => true,
Self::ChildSummary => false,
Self::Guidance => false,
}
}
pub fn default_priority(&self) -> ContextPriority {
match self {
Self::System => ContextPriority::Critical,
Self::UserInput => ContextPriority::Critical,
Self::Guidance => ContextPriority::High,
Self::ChildSummary => ContextPriority::High,
Self::History => ContextPriority::Medium,
Self::WorkingMemory => ContextPriority::Medium,
Self::ToolResults => ContextPriority::Medium,
Self::RagContext => ContextPriority::Low,
Self::AgentScratchpad => ContextPriority::Low,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContextPriority {
Low = 0,
Medium = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContextSegment {
pub id: String,
#[serde(rename = "type")]
pub segment_type: ContextSegmentType,
pub content: String,
pub token_count: usize,
pub priority: ContextPriority,
pub compressible: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_step_id: Option<StepId>,
pub added_at: DateTime<Utc>,
pub sequence: u64,
}
impl ContextSegment {
pub fn new(
segment_type: ContextSegmentType,
content: String,
token_count: usize,
sequence: u64,
) -> Self {
Self {
id: format!("seg_{}", uuid::Uuid::new_v4()),
segment_type,
content,
token_count,
priority: segment_type.default_priority(),
compressible: segment_type.is_compressible(),
source_step_id: None,
added_at: Utc::now(),
sequence,
}
}
pub fn system(content: impl Into<String>, token_count: usize) -> Self {
Self::new(ContextSegmentType::System, content.into(), token_count, 0)
}
pub fn user_input(content: impl Into<String>, token_count: usize, sequence: u64) -> Self {
Self::new(
ContextSegmentType::UserInput,
content.into(),
token_count,
sequence,
)
}
pub fn history(content: impl Into<String>, token_count: usize, sequence: u64) -> Self {
Self::new(
ContextSegmentType::History,
content.into(),
token_count,
sequence,
)
}
pub fn tool_results(
content: impl Into<String>,
token_count: usize,
sequence: u64,
step_id: StepId,
) -> Self {
let mut segment = Self::new(
ContextSegmentType::ToolResults,
content.into(),
token_count,
sequence,
);
segment.source_step_id = Some(step_id);
segment
}
pub fn rag_context(content: impl Into<String>, token_count: usize, sequence: u64) -> Self {
Self::new(
ContextSegmentType::RagContext,
content.into(),
token_count,
sequence,
)
}
pub fn child_summary(
content: impl Into<String>,
token_count: usize,
sequence: u64,
step_id: StepId,
) -> Self {
let mut segment = Self::new(
ContextSegmentType::ChildSummary,
content.into(),
token_count,
sequence,
);
segment.source_step_id = Some(step_id);
segment
}
pub fn guidance(content: impl Into<String>, token_count: usize, sequence: u64) -> Self {
Self::new(
ContextSegmentType::Guidance,
content.into(),
token_count,
sequence,
)
}
pub fn with_priority(mut self, priority: ContextPriority) -> Self {
self.priority = priority;
self
}
pub fn non_compressible(mut self) -> Self {
self.compressible = false;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_segment_type_compressibility() {
assert!(!ContextSegmentType::System.is_compressible());
assert!(ContextSegmentType::History.is_compressible());
assert!(!ContextSegmentType::UserInput.is_compressible());
assert!(ContextSegmentType::ToolResults.is_compressible());
}
#[test]
fn test_segment_type_priority() {
assert_eq!(
ContextSegmentType::System.default_priority(),
ContextPriority::Critical
);
assert_eq!(
ContextSegmentType::History.default_priority(),
ContextPriority::Medium
);
assert_eq!(
ContextSegmentType::RagContext.default_priority(),
ContextPriority::Low
);
}
#[test]
fn test_priority_ordering() {
assert!(ContextPriority::Critical > ContextPriority::High);
assert!(ContextPriority::High > ContextPriority::Medium);
assert!(ContextPriority::Medium > ContextPriority::Low);
}
#[test]
fn test_create_segment() {
let segment = ContextSegment::system("You are helpful", 10);
assert_eq!(segment.segment_type, ContextSegmentType::System);
assert_eq!(segment.priority, ContextPriority::Critical);
assert!(!segment.compressible);
}
}