1use std::any::Any;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::action::ActionError;
6use crate::token::ErasedToken;
7
8pub type LogFn = Arc<dyn Fn(&str, &str) + Send + Sync>;
10
11#[derive(Debug, Clone)]
13pub struct OutputEntry {
14 pub place_name: Arc<str>,
15 pub token: ErasedToken,
16}
17
18pub struct TransitionContext {
28 transition_name: Arc<str>,
29 inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
30 reads: HashMap<Arc<str>, Vec<ErasedToken>>,
31 allowed_outputs: HashSet<Arc<str>>,
32 outputs: Vec<OutputEntry>,
33 execution_ctx: HashMap<String, Box<dyn Any + Send + Sync>>,
34 log_fn: Option<LogFn>,
35}
36
37impl TransitionContext {
38 pub fn new(
39 transition_name: Arc<str>,
40 inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
41 reads: HashMap<Arc<str>, Vec<ErasedToken>>,
42 allowed_outputs: HashSet<Arc<str>>,
43 log_fn: Option<LogFn>,
44 ) -> Self {
45 Self {
46 transition_name,
47 inputs,
48 reads,
49 allowed_outputs,
50 outputs: Vec::new(),
51 execution_ctx: HashMap::new(),
52 log_fn,
53 }
54 }
55
56 pub fn input<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
60 let tokens = self.inputs.get(place_name).ok_or_else(|| {
61 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
62 })?;
63 if tokens.len() != 1 {
64 return Err(ActionError::new(format!(
65 "Place '{place_name}' consumed {} tokens, use inputs() for batched access",
66 tokens.len()
67 )));
68 }
69 self.downcast_value::<T>(&tokens[0], place_name)
70 }
71
72 pub fn inputs<T: Send + Sync + 'static>(
74 &self,
75 place_name: &str,
76 ) -> Result<Vec<Arc<T>>, ActionError> {
77 let tokens = self.inputs.get(place_name).ok_or_else(|| {
78 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
79 })?;
80 tokens
81 .iter()
82 .map(|t| self.downcast_value::<T>(t, place_name))
83 .collect()
84 }
85
86 pub fn input_raw(&self, place_name: &str) -> Result<Arc<dyn Any + Send + Sync>, ActionError> {
88 let tokens = self.inputs.get(place_name).ok_or_else(|| {
89 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
90 })?;
91 if tokens.is_empty() {
92 return Err(ActionError::new(format!(
93 "No tokens for place '{place_name}'"
94 )));
95 }
96 Ok(Arc::clone(&tokens[0].value))
97 }
98
99 pub fn input_place_names(&self) -> Vec<Arc<str>> {
101 self.inputs.keys().cloned().collect()
102 }
103
104 pub fn read<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
108 let tokens = self.reads.get(place_name).ok_or_else(|| {
109 ActionError::new(format!("Place '{place_name}' not in declared reads"))
110 })?;
111 if tokens.is_empty() {
112 return Err(ActionError::new(format!(
113 "No tokens for read place '{place_name}'"
114 )));
115 }
116 self.downcast_value::<T>(&tokens[0], place_name)
117 }
118
119 pub fn reads<T: Send + Sync + 'static>(
121 &self,
122 place_name: &str,
123 ) -> Result<Vec<Arc<T>>, ActionError> {
124 let tokens = self.reads.get(place_name).ok_or_else(|| {
125 ActionError::new(format!("Place '{place_name}' not in declared reads"))
126 })?;
127 tokens
128 .iter()
129 .map(|t| self.downcast_value::<T>(t, place_name))
130 .collect()
131 }
132
133 pub fn read_place_names(&self) -> Vec<Arc<str>> {
135 self.reads.keys().cloned().collect()
136 }
137
138 pub fn output<T: Send + Sync + 'static>(
142 &mut self,
143 place_name: &str,
144 value: T,
145 ) -> Result<(), ActionError> {
146 let name = self.require_output(place_name)?;
147 self.outputs.push(OutputEntry {
148 place_name: name,
149 token: ErasedToken {
150 value: Arc::new(value),
151 created_at: crate::token::now_millis(),
152 },
153 });
154 Ok(())
155 }
156
157 pub fn output_raw(
159 &mut self,
160 place_name: &str,
161 value: Arc<dyn Any + Send + Sync>,
162 ) -> Result<(), ActionError> {
163 let name = self.require_output(place_name)?;
164 self.outputs.push(OutputEntry {
165 place_name: name,
166 token: ErasedToken {
167 value,
168 created_at: crate::token::now_millis(),
169 },
170 });
171 Ok(())
172 }
173
174 fn require_output(&self, place_name: &str) -> Result<Arc<str>, ActionError> {
175 self.allowed_outputs
176 .get(place_name)
177 .cloned()
178 .ok_or_else(|| {
179 ActionError::new(format!(
180 "Place '{}' not in declared outputs: {:?}",
181 place_name,
182 self.allowed_outputs.iter().collect::<Vec<_>>()
183 ))
184 })
185 }
186
187 pub fn output_place_names(&self) -> Vec<Arc<str>> {
189 self.allowed_outputs.iter().cloned().collect()
190 }
191
192 pub fn transition_name(&self) -> &str {
196 &self.transition_name
197 }
198
199 pub fn set_execution_context<T: Send + Sync + 'static>(&mut self, key: &str, value: T) {
203 self.execution_ctx.insert(key.to_string(), Box::new(value));
204 }
205
206 pub fn execution_context<T: 'static>(&self, key: &str) -> Option<&T> {
208 self.execution_ctx
209 .get(key)
210 .and_then(|v| v.downcast_ref::<T>())
211 }
212
213 pub fn has_execution_context(&self, key: &str) -> bool {
215 self.execution_ctx.contains_key(key)
216 }
217
218 pub fn log(&self, level: &str, message: &str) {
222 if let Some(ref log_fn) = self.log_fn {
223 log_fn(level, message);
224 }
225 }
226
227 pub fn take_outputs(&mut self) -> Vec<OutputEntry> {
231 std::mem::take(&mut self.outputs)
232 }
233
234 pub fn take_inputs(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
236 std::mem::take(&mut self.inputs)
237 }
238
239 pub fn take_reads(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
241 std::mem::take(&mut self.reads)
242 }
243
244 pub fn outputs(&self) -> &[OutputEntry] {
246 &self.outputs
247 }
248
249 fn downcast_value<T: Send + Sync + 'static>(
250 &self,
251 token: &ErasedToken,
252 place_name: &str,
253 ) -> Result<Arc<T>, ActionError> {
254 let any_arc = Arc::clone(&token.value);
256 if any_arc.downcast_ref::<T>().is_none() {
258 return Err(ActionError::new(format!(
259 "Type mismatch for place '{place_name}': expected {}",
260 std::any::type_name::<T>()
261 )));
262 }
263 let raw = Arc::into_raw(any_arc);
265 let typed = unsafe { Arc::from_raw(raw.cast::<T>()) };
266 Ok(typed)
267 }
268}
269
270impl std::fmt::Debug for TransitionContext {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.debug_struct("TransitionContext")
273 .field("transition_name", &self.transition_name)
274 .field("input_count", &self.inputs.len())
275 .field("read_count", &self.reads.len())
276 .field("output_count", &self.outputs.len())
277 .finish()
278 }
279}