Skip to main content

enact_context/
step_context.rs

1//! Step Context Builder
2//!
3//! Extracts learnings and builds context when steps are discovered.
4//! Used during agentic execution when StepSource::Agent spawns new steps.
5//!
6//! @see packages/enact-schemas/src/context.schemas.ts
7
8use crate::segment::{ContextPriority, ContextSegment};
9use crate::token_counter::TokenCounter;
10use chrono::{DateTime, Utc};
11use enact_core::kernel::{ExecutionId, StepId};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16/// Global sequence counter for segments
17static STEP_SEQUENCE: AtomicU64 = AtomicU64::new(2000);
18
19fn next_sequence() -> u64 {
20    STEP_SEQUENCE.fetch_add(1, Ordering::SeqCst)
21}
22
23/// Context extraction configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "camelCase")]
26pub struct StepContextConfig {
27    /// Maximum tokens for step context
28    pub max_tokens: usize,
29
30    /// Include tool call results
31    pub include_tool_results: bool,
32
33    /// Include reasoning/chain-of-thought
34    pub include_reasoning: bool,
35
36    /// Include error context
37    pub include_errors: bool,
38
39    /// Maximum tool results to include
40    pub max_tool_results: usize,
41
42    /// Truncate long content
43    pub truncate_long_content: bool,
44
45    /// Maximum content length before truncation
46    pub max_content_length: usize,
47}
48
49impl Default for StepContextConfig {
50    fn default() -> Self {
51        Self {
52            max_tokens: 2000,
53            include_tool_results: true,
54            include_reasoning: true,
55            include_errors: true,
56            max_tool_results: 5,
57            truncate_long_content: true,
58            max_content_length: 1000,
59        }
60    }
61}
62
63/// Extracted learning from a step
64#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub struct StepLearning {
67    /// Unique ID for this learning
68    pub id: String,
69
70    /// Step that produced this learning
71    pub step_id: StepId,
72
73    /// Execution containing the step
74    pub execution_id: ExecutionId,
75
76    /// Type of learning
77    pub learning_type: LearningType,
78
79    /// The learning content
80    pub content: String,
81
82    /// Confidence in this learning (0.0 - 1.0)
83    pub confidence: f64,
84
85    /// Relevance to future steps (0.0 - 1.0)
86    pub relevance: f64,
87
88    /// Tags for categorization
89    pub tags: Vec<String>,
90
91    /// Timestamp
92    pub created_at: DateTime<Utc>,
93}
94
95/// Types of learnings that can be extracted
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum LearningType {
99    /// Successful action pattern
100    SuccessPattern,
101    /// Error and recovery
102    ErrorRecovery,
103    /// Tool usage insight
104    ToolInsight,
105    /// Decision rationale
106    DecisionRationale,
107    /// Domain knowledge
108    DomainKnowledge,
109    /// Constraint discovered
110    ConstraintDiscovered,
111    /// User preference
112    UserPreference,
113}
114
115/// Result of step context extraction
116#[derive(Debug, Clone, Serialize, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct StepContextResult {
119    /// Execution ID
120    pub execution_id: ExecutionId,
121
122    /// Step ID that was processed
123    pub step_id: StepId,
124
125    /// Extracted context segments
126    pub segments: Vec<ContextSegment>,
127
128    /// Extracted learnings
129    pub learnings: Vec<StepLearning>,
130
131    /// Total tokens in extracted context
132    pub total_tokens: usize,
133
134    /// Processing timestamp
135    pub processed_at: DateTime<Utc>,
136}
137
138/// Step Context Builder - extracts learnings when steps are discovered
139pub struct StepContextBuilder {
140    token_counter: TokenCounter,
141    config: StepContextConfig,
142}
143
144impl StepContextBuilder {
145    /// Create a new builder with default config
146    pub fn new() -> Self {
147        Self {
148            token_counter: TokenCounter::default(),
149            config: StepContextConfig::default(),
150        }
151    }
152
153    /// Create with custom config
154    pub fn with_config(config: StepContextConfig) -> Self {
155        Self {
156            token_counter: TokenCounter::default(),
157            config,
158        }
159    }
160
161    /// Build context from step data
162    #[allow(clippy::too_many_arguments)]
163    pub fn build_context(
164        &self,
165        execution_id: ExecutionId,
166        step_id: StepId,
167        step_type: &str,
168        input: &str,
169        output: Option<&str>,
170        tool_calls: &[ToolCallInfo],
171        error: Option<&str>,
172        metadata: &HashMap<String, String>,
173    ) -> StepContextResult {
174        let mut segments = Vec::new();
175        let mut learnings = Vec::new();
176        let mut total_tokens = 0;
177
178        // Extract main step context
179        let step_summary = self.build_step_summary(step_type, input, output);
180        let summary_tokens = self.token_counter.count(&step_summary);
181
182        if total_tokens + summary_tokens <= self.config.max_tokens {
183            segments.push(ContextSegment::history(
184                step_summary.clone(),
185                summary_tokens,
186                next_sequence(),
187            ));
188            total_tokens += summary_tokens;
189        }
190
191        // Extract tool results
192        if self.config.include_tool_results {
193            let tool_context = self.extract_tool_context(tool_calls, step_id.clone());
194            for segment in tool_context {
195                let tokens = segment.token_count;
196                if total_tokens + tokens <= self.config.max_tokens {
197                    total_tokens += tokens;
198                    segments.push(segment);
199                }
200            }
201        }
202
203        // Extract error context
204        if self.config.include_errors {
205            if let Some(err) = error {
206                let error_learning =
207                    self.extract_error_learning(execution_id.clone(), step_id.clone(), err);
208                learnings.push(error_learning);
209
210                let error_content = format!("Error encountered: {}", self.truncate_content(err));
211                let error_tokens = self.token_counter.count(&error_content);
212                let error_segment = ContextSegment::tool_results(
213                    error_content,
214                    error_tokens,
215                    next_sequence(),
216                    step_id.clone(),
217                )
218                .with_priority(ContextPriority::High);
219
220                if total_tokens + error_tokens <= self.config.max_tokens {
221                    total_tokens += error_tokens;
222                    segments.push(error_segment);
223                }
224            }
225        }
226
227        // Extract learnings from successful execution
228        if error.is_none() && output.is_some() {
229            let success_learnings = self.extract_success_learnings(
230                execution_id.clone(),
231                step_id.clone(),
232                step_type,
233                tool_calls,
234                metadata,
235            );
236            learnings.extend(success_learnings);
237        }
238
239        StepContextResult {
240            execution_id,
241            step_id,
242            segments,
243            learnings,
244            total_tokens,
245            processed_at: Utc::now(),
246        }
247    }
248
249    /// Build a summary of the step
250    fn build_step_summary(&self, step_type: &str, input: &str, output: Option<&str>) -> String {
251        let truncated_input = self.truncate_content(input);
252        let truncated_output = output
253            .map(|o| self.truncate_content(o))
254            .unwrap_or_else(|| "(pending)".to_string());
255
256        format!(
257            "[Step: {}]\nInput: {}\nOutput: {}",
258            step_type, truncated_input, truncated_output
259        )
260    }
261
262    /// Extract context from tool calls
263    fn extract_tool_context(
264        &self,
265        tool_calls: &[ToolCallInfo],
266        step_id: StepId,
267    ) -> Vec<ContextSegment> {
268        tool_calls
269            .iter()
270            .take(self.config.max_tool_results)
271            .map(|tc| {
272                let content = format!(
273                    "Tool: {}\nArgs: {}\nResult: {}",
274                    tc.tool_name,
275                    self.truncate_content(&tc.arguments),
276                    tc.result
277                        .as_ref()
278                        .map(|r| self.truncate_content(r))
279                        .unwrap_or_else(|| "(pending)".to_string())
280                );
281                let tokens = self.token_counter.count(&content);
282
283                ContextSegment::tool_results(content, tokens, next_sequence(), step_id.clone())
284                    .with_priority(if tc.success {
285                        ContextPriority::Medium
286                    } else {
287                        ContextPriority::High
288                    })
289            })
290            .collect()
291    }
292
293    /// Extract learning from an error
294    fn extract_error_learning(
295        &self,
296        execution_id: ExecutionId,
297        step_id: StepId,
298        error: &str,
299    ) -> StepLearning {
300        StepLearning {
301            id: format!("learn_{}", uuid::Uuid::new_v4()),
302            step_id,
303            execution_id,
304            learning_type: LearningType::ErrorRecovery,
305            content: format!(
306                "Error encountered: {}. Consider alternative approaches.",
307                error
308            ),
309            confidence: 0.7,
310            relevance: 0.8,
311            tags: vec!["error".to_string(), "recovery".to_string()],
312            created_at: Utc::now(),
313        }
314    }
315
316    /// Extract learnings from successful execution
317    fn extract_success_learnings(
318        &self,
319        execution_id: ExecutionId,
320        step_id: StepId,
321        step_type: &str,
322        tool_calls: &[ToolCallInfo],
323        metadata: &HashMap<String, String>,
324    ) -> Vec<StepLearning> {
325        let mut learnings = Vec::new();
326
327        // Learn from successful tool usage
328        for tc in tool_calls.iter().filter(|tc| tc.success) {
329            learnings.push(StepLearning {
330                id: format!("learn_{}", uuid::Uuid::new_v4()),
331                step_id: step_id.clone(),
332                execution_id: execution_id.clone(),
333                learning_type: LearningType::ToolInsight,
334                content: format!(
335                    "Tool '{}' succeeded with pattern: {}",
336                    tc.tool_name,
337                    self.truncate_content(&tc.arguments)
338                ),
339                confidence: 0.8,
340                relevance: 0.6,
341                tags: vec!["tool".to_string(), tc.tool_name.clone()],
342                created_at: Utc::now(),
343            });
344        }
345
346        // Learn from metadata hints
347        if let Some(pattern) = metadata.get("success_pattern") {
348            learnings.push(StepLearning {
349                id: format!("learn_{}", uuid::Uuid::new_v4()),
350                step_id: step_id.clone(),
351                execution_id: execution_id.clone(),
352                learning_type: LearningType::SuccessPattern,
353                content: format!("Step '{}' success pattern: {}", step_type, pattern),
354                confidence: 0.9,
355                relevance: 0.7,
356                tags: vec!["pattern".to_string(), step_type.to_string()],
357                created_at: Utc::now(),
358            });
359        }
360
361        learnings
362    }
363
364    /// Truncate content if too long
365    fn truncate_content(&self, content: &str) -> String {
366        if self.config.truncate_long_content && content.len() > self.config.max_content_length {
367            format!(
368                "{}... [truncated, {} chars total]",
369                &content[..self.config.max_content_length],
370                content.len()
371            )
372        } else {
373            content.to_string()
374        }
375    }
376
377    /// Build context for a child step being spawned
378    pub fn build_child_context(
379        &self,
380        parent_execution_id: ExecutionId,
381        parent_step_id: StepId,
382        child_step_id: StepId,
383        task: &str,
384        parent_context: &[ContextSegment],
385    ) -> StepContextResult {
386        let mut segments = Vec::new();
387        let mut total_tokens = 0;
388
389        // Add child task context
390        let task_content = format!(
391            "Sub-task spawned from parent step.\nTask: {}\nParent step: {}",
392            task,
393            parent_step_id.as_str()
394        );
395        let task_tokens = self.token_counter.count(&task_content);
396        let task_segment = ContextSegment::system(task_content, task_tokens);
397        total_tokens += task_tokens;
398        segments.push(task_segment);
399
400        // Include relevant parent context
401        for segment in parent_context {
402            if segment.priority >= ContextPriority::Medium {
403                let tokens = segment.token_count;
404                if total_tokens + tokens <= self.config.max_tokens {
405                    total_tokens += tokens;
406                    segments.push(segment.clone());
407                }
408            }
409        }
410
411        StepContextResult {
412            execution_id: parent_execution_id,
413            step_id: child_step_id,
414            segments,
415            learnings: Vec::new(),
416            total_tokens,
417            processed_at: Utc::now(),
418        }
419    }
420}
421
422impl Default for StepContextBuilder {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428/// Information about a tool call
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct ToolCallInfo {
431    /// Tool name
432    pub tool_name: String,
433
434    /// Arguments as JSON string
435    pub arguments: String,
436
437    /// Result if completed
438    pub result: Option<String>,
439
440    /// Whether the call succeeded
441    pub success: bool,
442
443    /// Duration in milliseconds
444    pub duration_ms: Option<u64>,
445}
446
447impl ToolCallInfo {
448    /// Create a successful tool call
449    pub fn success(
450        tool_name: impl Into<String>,
451        arguments: impl Into<String>,
452        result: impl Into<String>,
453    ) -> Self {
454        Self {
455            tool_name: tool_name.into(),
456            arguments: arguments.into(),
457            result: Some(result.into()),
458            success: true,
459            duration_ms: None,
460        }
461    }
462
463    /// Create a failed tool call
464    pub fn failed(
465        tool_name: impl Into<String>,
466        arguments: impl Into<String>,
467        error: impl Into<String>,
468    ) -> Self {
469        Self {
470            tool_name: tool_name.into(),
471            arguments: arguments.into(),
472            result: Some(error.into()),
473            success: false,
474            duration_ms: None,
475        }
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    fn test_execution_id() -> ExecutionId {
484        ExecutionId::new()
485    }
486
487    fn test_step_id() -> StepId {
488        StepId::new()
489    }
490
491    #[test]
492    fn test_step_context_config_defaults() {
493        let config = StepContextConfig::default();
494        assert_eq!(config.max_tokens, 2000);
495        assert!(config.include_tool_results);
496        assert!(config.include_errors);
497    }
498
499    #[test]
500    fn test_build_context_basic() {
501        let builder = StepContextBuilder::new();
502        let result = builder.build_context(
503            test_execution_id(),
504            test_step_id(),
505            "llm_call",
506            "What is 2+2?",
507            Some("4"),
508            &[],
509            None,
510            &HashMap::new(),
511        );
512
513        assert!(!result.segments.is_empty());
514        assert!(result.total_tokens > 0);
515    }
516
517    #[test]
518    fn test_build_context_with_error() {
519        let builder = StepContextBuilder::new();
520        let result = builder.build_context(
521            test_execution_id(),
522            test_step_id(),
523            "tool_call",
524            "fetch data",
525            None,
526            &[],
527            Some("Connection timeout"),
528            &HashMap::new(),
529        );
530
531        assert!(!result.learnings.is_empty());
532        assert_eq!(
533            result.learnings[0].learning_type,
534            LearningType::ErrorRecovery
535        );
536    }
537
538    #[test]
539    fn test_build_context_with_tool_calls() {
540        let builder = StepContextBuilder::new();
541        let tool_calls = vec![
542            ToolCallInfo::success("search", r#"{"query": "test"}"#, "Found 5 results"),
543            ToolCallInfo::failed("fetch", r#"{"url": "..."}"#, "404 Not Found"),
544        ];
545
546        let result = builder.build_context(
547            test_execution_id(),
548            test_step_id(),
549            "multi_tool",
550            "search and fetch",
551            Some("partial results"),
552            &tool_calls,
553            None,
554            &HashMap::new(),
555        );
556
557        // Should have tool result segments
558        assert!(result.segments.len() >= 2);
559        // Should have tool insight learning from successful call
560        assert!(result
561            .learnings
562            .iter()
563            .any(|l| l.learning_type == LearningType::ToolInsight));
564    }
565
566    #[test]
567    fn test_truncate_long_content() {
568        let config = StepContextConfig {
569            max_content_length: 50,
570            ..Default::default()
571        };
572        let builder = StepContextBuilder::with_config(config);
573
574        let long_content = "a".repeat(100);
575        let result = builder.build_context(
576            test_execution_id(),
577            test_step_id(),
578            "test",
579            &long_content,
580            None,
581            &[],
582            None,
583            &HashMap::new(),
584        );
585
586        // Content should be truncated
587        assert!(result.segments[0].content.contains("truncated"));
588    }
589
590    #[test]
591    fn test_build_child_context() {
592        let builder = StepContextBuilder::new();
593        let token_counter = TokenCounter::default();
594
595        let system_content = "Parent system context";
596        let system_tokens = token_counter.count(system_content);
597
598        let history_content = "Some history";
599        let history_tokens = token_counter.count(history_content);
600
601        let parent_context = vec![
602            ContextSegment::system(system_content, system_tokens),
603            ContextSegment::new(
604                crate::segment::ContextSegmentType::History,
605                history_content.to_string(),
606                history_tokens,
607                1,
608            )
609            .with_priority(ContextPriority::Low),
610        ];
611
612        let result = builder.build_child_context(
613            test_execution_id(),
614            test_step_id(),
615            StepId::new(),
616            "Analyze the data",
617            &parent_context,
618        );
619
620        // Should have task segment
621        assert!(result
622            .segments
623            .iter()
624            .any(|s| s.content.contains("Sub-task")));
625        // Should include high-priority parent context but not low
626        assert!(result
627            .segments
628            .iter()
629            .any(|s| s.content.contains("Parent system")));
630    }
631
632    #[test]
633    fn test_learning_types() {
634        let builder = StepContextBuilder::new();
635        let mut metadata = HashMap::new();
636        metadata.insert(
637            "success_pattern".to_string(),
638            "retry with backoff".to_string(),
639        );
640
641        let result = builder.build_context(
642            test_execution_id(),
643            test_step_id(),
644            "api_call",
645            "fetch user",
646            Some("user data"),
647            &[ToolCallInfo::success("http", "{}", "200 OK")],
648            None,
649            &metadata,
650        );
651
652        // Should have both tool insight and success pattern learnings
653        assert!(result
654            .learnings
655            .iter()
656            .any(|l| l.learning_type == LearningType::ToolInsight));
657        assert!(result
658            .learnings
659            .iter()
660            .any(|l| l.learning_type == LearningType::SuccessPattern));
661    }
662}