runkon_flow/
extensions.rs1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5#[derive(Default, Clone)]
11pub struct Extensions {
12 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
13}
14
15impl std::fmt::Debug for Extensions {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.debug_struct("Extensions")
18 .field("len", &self.map.len())
19 .finish()
20 }
21}
22
23impl Extensions {
24 pub fn insert<T: Any + Send + Sync + 'static>(&mut self, value: T) {
26 self.map.insert(TypeId::of::<T>(), Arc::new(value));
27 }
28
29 pub fn get<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
31 self.map
32 .get(&TypeId::of::<T>())
33 .and_then(|arc| arc.clone().downcast::<T>().ok())
34 }
35}
36
37pub struct ClaudeActionParams {
41 pub max_turns: Option<u32>,
42}
43
44pub struct LlmRunMetrics {
52 pub total_input_tokens: Option<i64>,
53 pub total_output_tokens: Option<i64>,
54 pub total_cache_read_input_tokens: Option<i64>,
55 pub total_cache_creation_input_tokens: Option<i64>,
56 pub total_turns: Option<i64>,
57 pub total_cost_usd: Option<f64>,
58 pub model: Option<String>,
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn insert_and_get_returns_value() {
67 let mut ext = Extensions::default();
68 ext.insert(42u32);
69 let v = ext.get::<u32>().expect("should find u32");
70 assert_eq!(*v, 42u32);
71 }
72
73 #[test]
74 fn get_missing_type_returns_none() {
75 let ext = Extensions::default();
76 assert!(ext.get::<u32>().is_none());
77 }
78
79 #[test]
80 fn insert_replaces_previous_value() {
81 let mut ext = Extensions::default();
82 ext.insert(1u32);
83 ext.insert(2u32);
84 let v = ext.get::<u32>().expect("should find u32");
85 assert_eq!(*v, 2u32);
86 }
87
88 #[test]
89 fn different_types_are_stored_independently() {
90 let mut ext = Extensions::default();
91 ext.insert(10u32);
92 ext.insert("hello");
93 assert_eq!(*ext.get::<u32>().unwrap(), 10u32);
94 assert_eq!(*ext.get::<&str>().unwrap(), "hello");
95 }
96
97 #[test]
98 fn clone_shares_arc_not_data() {
99 let mut ext = Extensions::default();
100 ext.insert(String::from("shared"));
101 let cloned = ext.clone();
102 let a = ext.get::<String>().unwrap();
103 let b = cloned.get::<String>().unwrap();
104 assert!(Arc::ptr_eq(&a, &b));
105 }
106
107 #[test]
108 fn claude_action_params_round_trips() {
109 let mut ext = Extensions::default();
110 ext.insert(ClaudeActionParams {
111 max_turns: Some(50),
112 });
113 let v = ext
114 .get::<ClaudeActionParams>()
115 .expect("should find ClaudeActionParams");
116 assert_eq!(v.max_turns, Some(50));
117 }
118
119 #[test]
120 fn llm_run_metrics_round_trips() {
121 let mut ext = Extensions::default();
122 ext.insert(LlmRunMetrics {
123 total_input_tokens: Some(100),
124 total_output_tokens: Some(200),
125 total_cache_read_input_tokens: Some(50),
126 total_cache_creation_input_tokens: Some(25),
127 total_turns: Some(3),
128 total_cost_usd: Some(0.05),
129 model: Some("claude-opus-4".to_string()),
130 });
131 let v = ext
132 .get::<LlmRunMetrics>()
133 .expect("should find LlmRunMetrics");
134 assert_eq!(v.total_input_tokens, Some(100));
135 assert_eq!(v.total_output_tokens, Some(200));
136 assert_eq!(v.total_cache_read_input_tokens, Some(50));
137 assert_eq!(v.total_cache_creation_input_tokens, Some(25));
138 assert_eq!(v.total_turns, Some(3));
139 assert_eq!(v.total_cost_usd, Some(0.05));
140 assert_eq!(v.model.as_deref(), Some("claude-opus-4"));
141 }
142}