use crate::segment::{ContextPriority, ContextSegment};
use crate::token_counter::TokenCounter;
use chrono::{DateTime, Utc};
use enact_core::kernel::{ExecutionId, StepId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
static STEP_SEQUENCE: AtomicU64 = AtomicU64::new(2000);
fn next_sequence() -> u64 {
STEP_SEQUENCE.fetch_add(1, Ordering::SeqCst)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StepContextConfig {
pub max_tokens: usize,
pub include_tool_results: bool,
pub include_reasoning: bool,
pub include_errors: bool,
pub max_tool_results: usize,
pub truncate_long_content: bool,
pub max_content_length: usize,
}
impl Default for StepContextConfig {
fn default() -> Self {
Self {
max_tokens: 2000,
include_tool_results: true,
include_reasoning: true,
include_errors: true,
max_tool_results: 5,
truncate_long_content: true,
max_content_length: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StepLearning {
pub id: String,
pub step_id: StepId,
pub execution_id: ExecutionId,
pub learning_type: LearningType,
pub content: String,
pub confidence: f64,
pub relevance: f64,
pub tags: Vec<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LearningType {
SuccessPattern,
ErrorRecovery,
ToolInsight,
DecisionRationale,
DomainKnowledge,
ConstraintDiscovered,
UserPreference,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StepContextResult {
pub execution_id: ExecutionId,
pub step_id: StepId,
pub segments: Vec<ContextSegment>,
pub learnings: Vec<StepLearning>,
pub total_tokens: usize,
pub processed_at: DateTime<Utc>,
}
pub struct StepContextBuilder {
token_counter: TokenCounter,
config: StepContextConfig,
}
impl StepContextBuilder {
pub fn new() -> Self {
Self {
token_counter: TokenCounter::default(),
config: StepContextConfig::default(),
}
}
pub fn with_config(config: StepContextConfig) -> Self {
Self {
token_counter: TokenCounter::default(),
config,
}
}
#[allow(clippy::too_many_arguments)]
pub fn build_context(
&self,
execution_id: ExecutionId,
step_id: StepId,
step_type: &str,
input: &str,
output: Option<&str>,
tool_calls: &[ToolCallInfo],
error: Option<&str>,
metadata: &HashMap<String, String>,
) -> StepContextResult {
let mut segments = Vec::new();
let mut learnings = Vec::new();
let mut total_tokens = 0;
let step_summary = self.build_step_summary(step_type, input, output);
let summary_tokens = self.token_counter.count(&step_summary);
if total_tokens + summary_tokens <= self.config.max_tokens {
segments.push(ContextSegment::history(
step_summary.clone(),
summary_tokens,
next_sequence(),
));
total_tokens += summary_tokens;
}
if self.config.include_tool_results {
let tool_context = self.extract_tool_context(tool_calls, step_id.clone());
for segment in tool_context {
let tokens = segment.token_count;
if total_tokens + tokens <= self.config.max_tokens {
total_tokens += tokens;
segments.push(segment);
}
}
}
if self.config.include_errors {
if let Some(err) = error {
let error_learning =
self.extract_error_learning(execution_id.clone(), step_id.clone(), err);
learnings.push(error_learning);
let error_content = format!("Error encountered: {}", self.truncate_content(err));
let error_tokens = self.token_counter.count(&error_content);
let error_segment = ContextSegment::tool_results(
error_content,
error_tokens,
next_sequence(),
step_id.clone(),
)
.with_priority(ContextPriority::High);
if total_tokens + error_tokens <= self.config.max_tokens {
total_tokens += error_tokens;
segments.push(error_segment);
}
}
}
if error.is_none() && output.is_some() {
let success_learnings = self.extract_success_learnings(
execution_id.clone(),
step_id.clone(),
step_type,
tool_calls,
metadata,
);
learnings.extend(success_learnings);
}
StepContextResult {
execution_id,
step_id,
segments,
learnings,
total_tokens,
processed_at: Utc::now(),
}
}
fn build_step_summary(&self, step_type: &str, input: &str, output: Option<&str>) -> String {
let truncated_input = self.truncate_content(input);
let truncated_output = output
.map(|o| self.truncate_content(o))
.unwrap_or_else(|| "(pending)".to_string());
format!(
"[Step: {}]\nInput: {}\nOutput: {}",
step_type, truncated_input, truncated_output
)
}
fn extract_tool_context(
&self,
tool_calls: &[ToolCallInfo],
step_id: StepId,
) -> Vec<ContextSegment> {
tool_calls
.iter()
.take(self.config.max_tool_results)
.map(|tc| {
let content = format!(
"Tool: {}\nArgs: {}\nResult: {}",
tc.tool_name,
self.truncate_content(&tc.arguments),
tc.result
.as_ref()
.map(|r| self.truncate_content(r))
.unwrap_or_else(|| "(pending)".to_string())
);
let tokens = self.token_counter.count(&content);
ContextSegment::tool_results(content, tokens, next_sequence(), step_id.clone())
.with_priority(if tc.success {
ContextPriority::Medium
} else {
ContextPriority::High
})
})
.collect()
}
fn extract_error_learning(
&self,
execution_id: ExecutionId,
step_id: StepId,
error: &str,
) -> StepLearning {
StepLearning {
id: format!("learn_{}", uuid::Uuid::new_v4()),
step_id,
execution_id,
learning_type: LearningType::ErrorRecovery,
content: format!(
"Error encountered: {}. Consider alternative approaches.",
error
),
confidence: 0.7,
relevance: 0.8,
tags: vec!["error".to_string(), "recovery".to_string()],
created_at: Utc::now(),
}
}
fn extract_success_learnings(
&self,
execution_id: ExecutionId,
step_id: StepId,
step_type: &str,
tool_calls: &[ToolCallInfo],
metadata: &HashMap<String, String>,
) -> Vec<StepLearning> {
let mut learnings = Vec::new();
for tc in tool_calls.iter().filter(|tc| tc.success) {
learnings.push(StepLearning {
id: format!("learn_{}", uuid::Uuid::new_v4()),
step_id: step_id.clone(),
execution_id: execution_id.clone(),
learning_type: LearningType::ToolInsight,
content: format!(
"Tool '{}' succeeded with pattern: {}",
tc.tool_name,
self.truncate_content(&tc.arguments)
),
confidence: 0.8,
relevance: 0.6,
tags: vec!["tool".to_string(), tc.tool_name.clone()],
created_at: Utc::now(),
});
}
if let Some(pattern) = metadata.get("success_pattern") {
learnings.push(StepLearning {
id: format!("learn_{}", uuid::Uuid::new_v4()),
step_id: step_id.clone(),
execution_id: execution_id.clone(),
learning_type: LearningType::SuccessPattern,
content: format!("Step '{}' success pattern: {}", step_type, pattern),
confidence: 0.9,
relevance: 0.7,
tags: vec!["pattern".to_string(), step_type.to_string()],
created_at: Utc::now(),
});
}
learnings
}
fn truncate_content(&self, content: &str) -> String {
if self.config.truncate_long_content && content.len() > self.config.max_content_length {
format!(
"{}... [truncated, {} chars total]",
&content[..self.config.max_content_length],
content.len()
)
} else {
content.to_string()
}
}
pub fn build_child_context(
&self,
parent_execution_id: ExecutionId,
parent_step_id: StepId,
child_step_id: StepId,
task: &str,
parent_context: &[ContextSegment],
) -> StepContextResult {
let mut segments = Vec::new();
let mut total_tokens = 0;
let task_content = format!(
"Sub-task spawned from parent step.\nTask: {}\nParent step: {}",
task,
parent_step_id.as_str()
);
let task_tokens = self.token_counter.count(&task_content);
let task_segment = ContextSegment::system(task_content, task_tokens);
total_tokens += task_tokens;
segments.push(task_segment);
for segment in parent_context {
if segment.priority >= ContextPriority::Medium {
let tokens = segment.token_count;
if total_tokens + tokens <= self.config.max_tokens {
total_tokens += tokens;
segments.push(segment.clone());
}
}
}
StepContextResult {
execution_id: parent_execution_id,
step_id: child_step_id,
segments,
learnings: Vec::new(),
total_tokens,
processed_at: Utc::now(),
}
}
}
impl Default for StepContextBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallInfo {
pub tool_name: String,
pub arguments: String,
pub result: Option<String>,
pub success: bool,
pub duration_ms: Option<u64>,
}
impl ToolCallInfo {
pub fn success(
tool_name: impl Into<String>,
arguments: impl Into<String>,
result: impl Into<String>,
) -> Self {
Self {
tool_name: tool_name.into(),
arguments: arguments.into(),
result: Some(result.into()),
success: true,
duration_ms: None,
}
}
pub fn failed(
tool_name: impl Into<String>,
arguments: impl Into<String>,
error: impl Into<String>,
) -> Self {
Self {
tool_name: tool_name.into(),
arguments: arguments.into(),
result: Some(error.into()),
success: false,
duration_ms: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_execution_id() -> ExecutionId {
ExecutionId::new()
}
fn test_step_id() -> StepId {
StepId::new()
}
#[test]
fn test_step_context_config_defaults() {
let config = StepContextConfig::default();
assert_eq!(config.max_tokens, 2000);
assert!(config.include_tool_results);
assert!(config.include_errors);
}
#[test]
fn test_build_context_basic() {
let builder = StepContextBuilder::new();
let result = builder.build_context(
test_execution_id(),
test_step_id(),
"llm_call",
"What is 2+2?",
Some("4"),
&[],
None,
&HashMap::new(),
);
assert!(!result.segments.is_empty());
assert!(result.total_tokens > 0);
}
#[test]
fn test_build_context_with_error() {
let builder = StepContextBuilder::new();
let result = builder.build_context(
test_execution_id(),
test_step_id(),
"tool_call",
"fetch data",
None,
&[],
Some("Connection timeout"),
&HashMap::new(),
);
assert!(!result.learnings.is_empty());
assert_eq!(
result.learnings[0].learning_type,
LearningType::ErrorRecovery
);
}
#[test]
fn test_build_context_with_tool_calls() {
let builder = StepContextBuilder::new();
let tool_calls = vec![
ToolCallInfo::success("search", r#"{"query": "test"}"#, "Found 5 results"),
ToolCallInfo::failed("fetch", r#"{"url": "..."}"#, "404 Not Found"),
];
let result = builder.build_context(
test_execution_id(),
test_step_id(),
"multi_tool",
"search and fetch",
Some("partial results"),
&tool_calls,
None,
&HashMap::new(),
);
assert!(result.segments.len() >= 2);
assert!(result
.learnings
.iter()
.any(|l| l.learning_type == LearningType::ToolInsight));
}
#[test]
fn test_truncate_long_content() {
let config = StepContextConfig {
max_content_length: 50,
..Default::default()
};
let builder = StepContextBuilder::with_config(config);
let long_content = "a".repeat(100);
let result = builder.build_context(
test_execution_id(),
test_step_id(),
"test",
&long_content,
None,
&[],
None,
&HashMap::new(),
);
assert!(result.segments[0].content.contains("truncated"));
}
#[test]
fn test_build_child_context() {
let builder = StepContextBuilder::new();
let token_counter = TokenCounter::default();
let system_content = "Parent system context";
let system_tokens = token_counter.count(system_content);
let history_content = "Some history";
let history_tokens = token_counter.count(history_content);
let parent_context = vec![
ContextSegment::system(system_content, system_tokens),
ContextSegment::new(
crate::segment::ContextSegmentType::History,
history_content.to_string(),
history_tokens,
1,
)
.with_priority(ContextPriority::Low),
];
let result = builder.build_child_context(
test_execution_id(),
test_step_id(),
StepId::new(),
"Analyze the data",
&parent_context,
);
assert!(result
.segments
.iter()
.any(|s| s.content.contains("Sub-task")));
assert!(result
.segments
.iter()
.any(|s| s.content.contains("Parent system")));
}
#[test]
fn test_learning_types() {
let builder = StepContextBuilder::new();
let mut metadata = HashMap::new();
metadata.insert(
"success_pattern".to_string(),
"retry with backoff".to_string(),
);
let result = builder.build_context(
test_execution_id(),
test_step_id(),
"api_call",
"fetch user",
Some("user data"),
&[ToolCallInfo::success("http", "{}", "200 OK")],
None,
&metadata,
);
assert!(result
.learnings
.iter()
.any(|l| l.learning_type == LearningType::ToolInsight));
assert!(result
.learnings
.iter()
.any(|l| l.learning_type == LearningType::SuccessPattern));
}
}