1use crate::segment::{ContextPriority, ContextSegment};
9use crate::token_counter::TokenCounter;
10use chrono::{DateTime, Utc};
11use enact_core::kernel::{ExecutionId, StepId};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16static STEP_SEQUENCE: AtomicU64 = AtomicU64::new(2000);
18
19fn next_sequence() -> u64 {
20 STEP_SEQUENCE.fetch_add(1, Ordering::SeqCst)
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "camelCase")]
26pub struct StepContextConfig {
27 pub max_tokens: usize,
29
30 pub include_tool_results: bool,
32
33 pub include_reasoning: bool,
35
36 pub include_errors: bool,
38
39 pub max_tool_results: usize,
41
42 pub truncate_long_content: bool,
44
45 pub max_content_length: usize,
47}
48
49impl Default for StepContextConfig {
50 fn default() -> Self {
51 Self {
52 max_tokens: 2000,
53 include_tool_results: true,
54 include_reasoning: true,
55 include_errors: true,
56 max_tool_results: 5,
57 truncate_long_content: true,
58 max_content_length: 1000,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub struct StepLearning {
67 pub id: String,
69
70 pub step_id: StepId,
72
73 pub execution_id: ExecutionId,
75
76 pub learning_type: LearningType,
78
79 pub content: String,
81
82 pub confidence: f64,
84
85 pub relevance: f64,
87
88 pub tags: Vec<String>,
90
91 pub created_at: DateTime<Utc>,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum LearningType {
99 SuccessPattern,
101 ErrorRecovery,
103 ToolInsight,
105 DecisionRationale,
107 DomainKnowledge,
109 ConstraintDiscovered,
111 UserPreference,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct StepContextResult {
119 pub execution_id: ExecutionId,
121
122 pub step_id: StepId,
124
125 pub segments: Vec<ContextSegment>,
127
128 pub learnings: Vec<StepLearning>,
130
131 pub total_tokens: usize,
133
134 pub processed_at: DateTime<Utc>,
136}
137
138pub struct StepContextBuilder {
140 token_counter: TokenCounter,
141 config: StepContextConfig,
142}
143
144impl StepContextBuilder {
145 pub fn new() -> Self {
147 Self {
148 token_counter: TokenCounter::default(),
149 config: StepContextConfig::default(),
150 }
151 }
152
153 pub fn with_config(config: StepContextConfig) -> Self {
155 Self {
156 token_counter: TokenCounter::default(),
157 config,
158 }
159 }
160
161 #[allow(clippy::too_many_arguments)]
163 pub fn build_context(
164 &self,
165 execution_id: ExecutionId,
166 step_id: StepId,
167 step_type: &str,
168 input: &str,
169 output: Option<&str>,
170 tool_calls: &[ToolCallInfo],
171 error: Option<&str>,
172 metadata: &HashMap<String, String>,
173 ) -> StepContextResult {
174 let mut segments = Vec::new();
175 let mut learnings = Vec::new();
176 let mut total_tokens = 0;
177
178 let step_summary = self.build_step_summary(step_type, input, output);
180 let summary_tokens = self.token_counter.count(&step_summary);
181
182 if total_tokens + summary_tokens <= self.config.max_tokens {
183 segments.push(ContextSegment::history(
184 step_summary.clone(),
185 summary_tokens,
186 next_sequence(),
187 ));
188 total_tokens += summary_tokens;
189 }
190
191 if self.config.include_tool_results {
193 let tool_context = self.extract_tool_context(tool_calls, step_id.clone());
194 for segment in tool_context {
195 let tokens = segment.token_count;
196 if total_tokens + tokens <= self.config.max_tokens {
197 total_tokens += tokens;
198 segments.push(segment);
199 }
200 }
201 }
202
203 if self.config.include_errors {
205 if let Some(err) = error {
206 let error_learning =
207 self.extract_error_learning(execution_id.clone(), step_id.clone(), err);
208 learnings.push(error_learning);
209
210 let error_content = format!("Error encountered: {}", self.truncate_content(err));
211 let error_tokens = self.token_counter.count(&error_content);
212 let error_segment = ContextSegment::tool_results(
213 error_content,
214 error_tokens,
215 next_sequence(),
216 step_id.clone(),
217 )
218 .with_priority(ContextPriority::High);
219
220 if total_tokens + error_tokens <= self.config.max_tokens {
221 total_tokens += error_tokens;
222 segments.push(error_segment);
223 }
224 }
225 }
226
227 if error.is_none() && output.is_some() {
229 let success_learnings = self.extract_success_learnings(
230 execution_id.clone(),
231 step_id.clone(),
232 step_type,
233 tool_calls,
234 metadata,
235 );
236 learnings.extend(success_learnings);
237 }
238
239 StepContextResult {
240 execution_id,
241 step_id,
242 segments,
243 learnings,
244 total_tokens,
245 processed_at: Utc::now(),
246 }
247 }
248
249 fn build_step_summary(&self, step_type: &str, input: &str, output: Option<&str>) -> String {
251 let truncated_input = self.truncate_content(input);
252 let truncated_output = output
253 .map(|o| self.truncate_content(o))
254 .unwrap_or_else(|| "(pending)".to_string());
255
256 format!(
257 "[Step: {}]\nInput: {}\nOutput: {}",
258 step_type, truncated_input, truncated_output
259 )
260 }
261
262 fn extract_tool_context(
264 &self,
265 tool_calls: &[ToolCallInfo],
266 step_id: StepId,
267 ) -> Vec<ContextSegment> {
268 tool_calls
269 .iter()
270 .take(self.config.max_tool_results)
271 .map(|tc| {
272 let content = format!(
273 "Tool: {}\nArgs: {}\nResult: {}",
274 tc.tool_name,
275 self.truncate_content(&tc.arguments),
276 tc.result
277 .as_ref()
278 .map(|r| self.truncate_content(r))
279 .unwrap_or_else(|| "(pending)".to_string())
280 );
281 let tokens = self.token_counter.count(&content);
282
283 ContextSegment::tool_results(content, tokens, next_sequence(), step_id.clone())
284 .with_priority(if tc.success {
285 ContextPriority::Medium
286 } else {
287 ContextPriority::High
288 })
289 })
290 .collect()
291 }
292
293 fn extract_error_learning(
295 &self,
296 execution_id: ExecutionId,
297 step_id: StepId,
298 error: &str,
299 ) -> StepLearning {
300 StepLearning {
301 id: format!("learn_{}", uuid::Uuid::new_v4()),
302 step_id,
303 execution_id,
304 learning_type: LearningType::ErrorRecovery,
305 content: format!(
306 "Error encountered: {}. Consider alternative approaches.",
307 error
308 ),
309 confidence: 0.7,
310 relevance: 0.8,
311 tags: vec!["error".to_string(), "recovery".to_string()],
312 created_at: Utc::now(),
313 }
314 }
315
316 fn extract_success_learnings(
318 &self,
319 execution_id: ExecutionId,
320 step_id: StepId,
321 step_type: &str,
322 tool_calls: &[ToolCallInfo],
323 metadata: &HashMap<String, String>,
324 ) -> Vec<StepLearning> {
325 let mut learnings = Vec::new();
326
327 for tc in tool_calls.iter().filter(|tc| tc.success) {
329 learnings.push(StepLearning {
330 id: format!("learn_{}", uuid::Uuid::new_v4()),
331 step_id: step_id.clone(),
332 execution_id: execution_id.clone(),
333 learning_type: LearningType::ToolInsight,
334 content: format!(
335 "Tool '{}' succeeded with pattern: {}",
336 tc.tool_name,
337 self.truncate_content(&tc.arguments)
338 ),
339 confidence: 0.8,
340 relevance: 0.6,
341 tags: vec!["tool".to_string(), tc.tool_name.clone()],
342 created_at: Utc::now(),
343 });
344 }
345
346 if let Some(pattern) = metadata.get("success_pattern") {
348 learnings.push(StepLearning {
349 id: format!("learn_{}", uuid::Uuid::new_v4()),
350 step_id: step_id.clone(),
351 execution_id: execution_id.clone(),
352 learning_type: LearningType::SuccessPattern,
353 content: format!("Step '{}' success pattern: {}", step_type, pattern),
354 confidence: 0.9,
355 relevance: 0.7,
356 tags: vec!["pattern".to_string(), step_type.to_string()],
357 created_at: Utc::now(),
358 });
359 }
360
361 learnings
362 }
363
364 fn truncate_content(&self, content: &str) -> String {
366 if self.config.truncate_long_content && content.len() > self.config.max_content_length {
367 format!(
368 "{}... [truncated, {} chars total]",
369 &content[..self.config.max_content_length],
370 content.len()
371 )
372 } else {
373 content.to_string()
374 }
375 }
376
377 pub fn build_child_context(
379 &self,
380 parent_execution_id: ExecutionId,
381 parent_step_id: StepId,
382 child_step_id: StepId,
383 task: &str,
384 parent_context: &[ContextSegment],
385 ) -> StepContextResult {
386 let mut segments = Vec::new();
387 let mut total_tokens = 0;
388
389 let task_content = format!(
391 "Sub-task spawned from parent step.\nTask: {}\nParent step: {}",
392 task,
393 parent_step_id.as_str()
394 );
395 let task_tokens = self.token_counter.count(&task_content);
396 let task_segment = ContextSegment::system(task_content, task_tokens);
397 total_tokens += task_tokens;
398 segments.push(task_segment);
399
400 for segment in parent_context {
402 if segment.priority >= ContextPriority::Medium {
403 let tokens = segment.token_count;
404 if total_tokens + tokens <= self.config.max_tokens {
405 total_tokens += tokens;
406 segments.push(segment.clone());
407 }
408 }
409 }
410
411 StepContextResult {
412 execution_id: parent_execution_id,
413 step_id: child_step_id,
414 segments,
415 learnings: Vec::new(),
416 total_tokens,
417 processed_at: Utc::now(),
418 }
419 }
420}
421
422impl Default for StepContextBuilder {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct ToolCallInfo {
431 pub tool_name: String,
433
434 pub arguments: String,
436
437 pub result: Option<String>,
439
440 pub success: bool,
442
443 pub duration_ms: Option<u64>,
445}
446
447impl ToolCallInfo {
448 pub fn success(
450 tool_name: impl Into<String>,
451 arguments: impl Into<String>,
452 result: impl Into<String>,
453 ) -> Self {
454 Self {
455 tool_name: tool_name.into(),
456 arguments: arguments.into(),
457 result: Some(result.into()),
458 success: true,
459 duration_ms: None,
460 }
461 }
462
463 pub fn failed(
465 tool_name: impl Into<String>,
466 arguments: impl Into<String>,
467 error: impl Into<String>,
468 ) -> Self {
469 Self {
470 tool_name: tool_name.into(),
471 arguments: arguments.into(),
472 result: Some(error.into()),
473 success: false,
474 duration_ms: None,
475 }
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 fn test_execution_id() -> ExecutionId {
484 ExecutionId::new()
485 }
486
487 fn test_step_id() -> StepId {
488 StepId::new()
489 }
490
491 #[test]
492 fn test_step_context_config_defaults() {
493 let config = StepContextConfig::default();
494 assert_eq!(config.max_tokens, 2000);
495 assert!(config.include_tool_results);
496 assert!(config.include_errors);
497 }
498
499 #[test]
500 fn test_build_context_basic() {
501 let builder = StepContextBuilder::new();
502 let result = builder.build_context(
503 test_execution_id(),
504 test_step_id(),
505 "llm_call",
506 "What is 2+2?",
507 Some("4"),
508 &[],
509 None,
510 &HashMap::new(),
511 );
512
513 assert!(!result.segments.is_empty());
514 assert!(result.total_tokens > 0);
515 }
516
517 #[test]
518 fn test_build_context_with_error() {
519 let builder = StepContextBuilder::new();
520 let result = builder.build_context(
521 test_execution_id(),
522 test_step_id(),
523 "tool_call",
524 "fetch data",
525 None,
526 &[],
527 Some("Connection timeout"),
528 &HashMap::new(),
529 );
530
531 assert!(!result.learnings.is_empty());
532 assert_eq!(
533 result.learnings[0].learning_type,
534 LearningType::ErrorRecovery
535 );
536 }
537
538 #[test]
539 fn test_build_context_with_tool_calls() {
540 let builder = StepContextBuilder::new();
541 let tool_calls = vec![
542 ToolCallInfo::success("search", r#"{"query": "test"}"#, "Found 5 results"),
543 ToolCallInfo::failed("fetch", r#"{"url": "..."}"#, "404 Not Found"),
544 ];
545
546 let result = builder.build_context(
547 test_execution_id(),
548 test_step_id(),
549 "multi_tool",
550 "search and fetch",
551 Some("partial results"),
552 &tool_calls,
553 None,
554 &HashMap::new(),
555 );
556
557 assert!(result.segments.len() >= 2);
559 assert!(result
561 .learnings
562 .iter()
563 .any(|l| l.learning_type == LearningType::ToolInsight));
564 }
565
566 #[test]
567 fn test_truncate_long_content() {
568 let config = StepContextConfig {
569 max_content_length: 50,
570 ..Default::default()
571 };
572 let builder = StepContextBuilder::with_config(config);
573
574 let long_content = "a".repeat(100);
575 let result = builder.build_context(
576 test_execution_id(),
577 test_step_id(),
578 "test",
579 &long_content,
580 None,
581 &[],
582 None,
583 &HashMap::new(),
584 );
585
586 assert!(result.segments[0].content.contains("truncated"));
588 }
589
590 #[test]
591 fn test_build_child_context() {
592 let builder = StepContextBuilder::new();
593 let token_counter = TokenCounter::default();
594
595 let system_content = "Parent system context";
596 let system_tokens = token_counter.count(system_content);
597
598 let history_content = "Some history";
599 let history_tokens = token_counter.count(history_content);
600
601 let parent_context = vec![
602 ContextSegment::system(system_content, system_tokens),
603 ContextSegment::new(
604 crate::segment::ContextSegmentType::History,
605 history_content.to_string(),
606 history_tokens,
607 1,
608 )
609 .with_priority(ContextPriority::Low),
610 ];
611
612 let result = builder.build_child_context(
613 test_execution_id(),
614 test_step_id(),
615 StepId::new(),
616 "Analyze the data",
617 &parent_context,
618 );
619
620 assert!(result
622 .segments
623 .iter()
624 .any(|s| s.content.contains("Sub-task")));
625 assert!(result
627 .segments
628 .iter()
629 .any(|s| s.content.contains("Parent system")));
630 }
631
632 #[test]
633 fn test_learning_types() {
634 let builder = StepContextBuilder::new();
635 let mut metadata = HashMap::new();
636 metadata.insert(
637 "success_pattern".to_string(),
638 "retry with backoff".to_string(),
639 );
640
641 let result = builder.build_context(
642 test_execution_id(),
643 test_step_id(),
644 "api_call",
645 "fetch user",
646 Some("user data"),
647 &[ToolCallInfo::success("http", "{}", "200 OK")],
648 None,
649 &metadata,
650 );
651
652 assert!(result
654 .learnings
655 .iter()
656 .any(|l| l.learning_type == LearningType::ToolInsight));
657 assert!(result
658 .learnings
659 .iter()
660 .any(|l| l.learning_type == LearningType::SuccessPattern));
661 }
662}