Skip to main content

runkon_flow/
extensions.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5/// Type-erased extension map for passing executor-specific data through `ActionParams`
6/// without polluting the generic API surface.
7///
8/// Values are stored as `Arc<dyn Any + Send + Sync>` so cloning is cheap (Arc ref-count bump)
9/// and the map can cross thread boundaries (required by `parallel.rs`'s per-thread copies).
10#[derive(Default, Clone)]
11pub struct Extensions {
12    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
13}
14
15impl std::fmt::Debug for Extensions {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("Extensions")
18            .field("len", &self.map.len())
19            .finish()
20    }
21}
22
23impl Extensions {
24    /// Insert a value of type `T`, replacing any previously inserted value of the same type.
25    pub fn insert<T: Any + Send + Sync + 'static>(&mut self, value: T) {
26        self.map.insert(TypeId::of::<T>(), Arc::new(value));
27    }
28
29    /// Retrieve a cloned `Arc<T>` for a value of type `T`, if one was inserted.
30    pub fn get<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
31        self.map
32            .get(&TypeId::of::<T>())
33            .and_then(|arc| arc.clone().downcast::<T>().ok())
34    }
35}
36
37/// Claude-specific per-step parameters, passed through `ActionParams.extensions`.
38/// Same convention as `OutputSchema`: executor-specific types live in extensions,
39/// not on the shared `ActionParams` surface.
40pub struct ClaudeActionParams {
41    pub max_turns: Option<u32>,
42}
43
44/// LLM-runtime rollup metrics, stored in the `Extensions` map on both
45/// `WorkflowRun` and `WorkflowResult`. Only present when at least one
46/// LLM-backed step ran and reported metrics via `metadata_keys`.
47///
48/// `model` is included here rather than on the harness-neutral `WorkflowRun`
49/// because every LLM has a model identifier while non-LLM executors do not.
50/// (Issue #2987 may revisit if a non-LLM "model" concept emerges.)
51pub struct LlmRunMetrics {
52    pub total_input_tokens: Option<i64>,
53    pub total_output_tokens: Option<i64>,
54    pub total_cache_read_input_tokens: Option<i64>,
55    pub total_cache_creation_input_tokens: Option<i64>,
56    pub total_turns: Option<i64>,
57    pub total_cost_usd: Option<f64>,
58    pub model: Option<String>,
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn insert_and_get_returns_value() {
67        let mut ext = Extensions::default();
68        ext.insert(42u32);
69        let v = ext.get::<u32>().expect("should find u32");
70        assert_eq!(*v, 42u32);
71    }
72
73    #[test]
74    fn get_missing_type_returns_none() {
75        let ext = Extensions::default();
76        assert!(ext.get::<u32>().is_none());
77    }
78
79    #[test]
80    fn insert_replaces_previous_value() {
81        let mut ext = Extensions::default();
82        ext.insert(1u32);
83        ext.insert(2u32);
84        let v = ext.get::<u32>().expect("should find u32");
85        assert_eq!(*v, 2u32);
86    }
87
88    #[test]
89    fn different_types_are_stored_independently() {
90        let mut ext = Extensions::default();
91        ext.insert(10u32);
92        ext.insert("hello");
93        assert_eq!(*ext.get::<u32>().unwrap(), 10u32);
94        assert_eq!(*ext.get::<&str>().unwrap(), "hello");
95    }
96
97    #[test]
98    fn clone_shares_arc_not_data() {
99        let mut ext = Extensions::default();
100        ext.insert(String::from("shared"));
101        let cloned = ext.clone();
102        let a = ext.get::<String>().unwrap();
103        let b = cloned.get::<String>().unwrap();
104        assert!(Arc::ptr_eq(&a, &b));
105    }
106
107    #[test]
108    fn claude_action_params_round_trips() {
109        let mut ext = Extensions::default();
110        ext.insert(ClaudeActionParams {
111            max_turns: Some(50),
112        });
113        let v = ext
114            .get::<ClaudeActionParams>()
115            .expect("should find ClaudeActionParams");
116        assert_eq!(v.max_turns, Some(50));
117    }
118
119    #[test]
120    fn llm_run_metrics_round_trips() {
121        let mut ext = Extensions::default();
122        ext.insert(LlmRunMetrics {
123            total_input_tokens: Some(100),
124            total_output_tokens: Some(200),
125            total_cache_read_input_tokens: Some(50),
126            total_cache_creation_input_tokens: Some(25),
127            total_turns: Some(3),
128            total_cost_usd: Some(0.05),
129            model: Some("claude-opus-4".to_string()),
130        });
131        let v = ext
132            .get::<LlmRunMetrics>()
133            .expect("should find LlmRunMetrics");
134        assert_eq!(v.total_input_tokens, Some(100));
135        assert_eq!(v.total_output_tokens, Some(200));
136        assert_eq!(v.total_cache_read_input_tokens, Some(50));
137        assert_eq!(v.total_cache_creation_input_tokens, Some(25));
138        assert_eq!(v.total_turns, Some(3));
139        assert_eq!(v.total_cost_usd, Some(0.05));
140        assert_eq!(v.model.as_deref(), Some("claude-opus-4"));
141    }
142}