1use std::any::Any;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::action::ActionError;
6use crate::token::ErasedToken;
7
8pub type LogFn = Arc<dyn Fn(&str, &str) + Send + Sync>;
10
11#[derive(Debug, Clone)]
13pub struct OutputEntry {
14 pub place_name: Arc<str>,
15 pub token: ErasedToken,
16}
17
18pub struct TransitionContext {
28 transition_name: Arc<str>,
29 inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
30 reads: HashMap<Arc<str>, Vec<ErasedToken>>,
31 allowed_outputs: HashSet<Arc<str>>,
32 outputs: Vec<OutputEntry>,
33 execution_ctx: HashMap<String, Box<dyn Any + Send + Sync>>,
34 log_fn: Option<LogFn>,
35}
36
37impl TransitionContext {
38 pub fn new(
39 transition_name: Arc<str>,
40 inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
41 reads: HashMap<Arc<str>, Vec<ErasedToken>>,
42 allowed_outputs: HashSet<Arc<str>>,
43 log_fn: Option<LogFn>,
44 ) -> Self {
45 Self {
46 transition_name,
47 inputs,
48 reads,
49 allowed_outputs,
50 outputs: Vec::new(),
51 execution_ctx: HashMap::new(),
52 log_fn,
53 }
54 }
55
56 pub fn input<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
60 let tokens = self.inputs.get(place_name).ok_or_else(|| {
61 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
62 })?;
63 if tokens.len() != 1 {
64 return Err(ActionError::new(format!(
65 "Place '{place_name}' consumed {} tokens, use inputs() for batched access",
66 tokens.len()
67 )));
68 }
69 self.downcast_value::<T>(&tokens[0], place_name)
70 }
71
72 pub fn inputs<T: Send + Sync + 'static>(
74 &self,
75 place_name: &str,
76 ) -> Result<Vec<Arc<T>>, ActionError> {
77 let tokens = self.inputs.get(place_name).ok_or_else(|| {
78 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
79 })?;
80 tokens
81 .iter()
82 .map(|t| self.downcast_value::<T>(t, place_name))
83 .collect()
84 }
85
86 pub fn input_raw(&self, place_name: &str) -> Result<Arc<dyn Any + Send + Sync>, ActionError> {
88 let tokens = self.inputs.get(place_name).ok_or_else(|| {
89 ActionError::new(format!("Place '{place_name}' not in declared inputs"))
90 })?;
91 if tokens.is_empty() {
92 return Err(ActionError::new(format!(
93 "No tokens for place '{place_name}'"
94 )));
95 }
96 Ok(Arc::clone(&tokens[0].value))
97 }
98
99 pub fn input_place_names(&self) -> Vec<Arc<str>> {
101 self.inputs.keys().cloned().collect()
102 }
103
104 pub fn read<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
108 let tokens = self.reads.get(place_name).ok_or_else(|| {
109 ActionError::new(format!("Place '{place_name}' not in declared reads"))
110 })?;
111 if tokens.is_empty() {
112 return Err(ActionError::new(format!(
113 "No tokens for read place '{place_name}'"
114 )));
115 }
116 self.downcast_value::<T>(&tokens[0], place_name)
117 }
118
119 pub fn reads<T: Send + Sync + 'static>(
121 &self,
122 place_name: &str,
123 ) -> Result<Vec<Arc<T>>, ActionError> {
124 let tokens = self.reads.get(place_name).ok_or_else(|| {
125 ActionError::new(format!("Place '{place_name}' not in declared reads"))
126 })?;
127 tokens
128 .iter()
129 .map(|t| self.downcast_value::<T>(t, place_name))
130 .collect()
131 }
132
133 pub fn read_place_names(&self) -> Vec<Arc<str>> {
135 self.reads.keys().cloned().collect()
136 }
137
138 pub fn output<T: Send + Sync + 'static>(
142 &mut self,
143 place_name: &str,
144 value: T,
145 ) -> Result<(), ActionError> {
146 let name = self.require_output(place_name)?;
147 self.outputs.push(OutputEntry {
148 place_name: name,
149 token: ErasedToken {
150 value: Arc::new(value),
151 created_at: crate::token::now_millis(),
152 },
153 });
154 Ok(())
155 }
156
157 pub fn output_raw(
159 &mut self,
160 place_name: &str,
161 value: Arc<dyn Any + Send + Sync>,
162 ) -> Result<(), ActionError> {
163 let name = self.require_output(place_name)?;
164 self.outputs.push(OutputEntry {
165 place_name: name,
166 token: ErasedToken {
167 value,
168 created_at: crate::token::now_millis(),
169 },
170 });
171 Ok(())
172 }
173
174 pub fn output_many<T: Send + Sync + 'static>(
194 &mut self,
195 place_name: &str,
196 values: impl IntoIterator<Item = T>,
197 ) -> Result<(), ActionError> {
198 let name = self.require_output(place_name)?;
199 let iter = values.into_iter();
200 let (lower, _) = iter.size_hint();
201 self.outputs.reserve(lower);
202 let created_at = crate::token::now_millis();
203 for value in iter {
204 self.outputs.push(OutputEntry {
205 place_name: Arc::clone(&name),
206 token: ErasedToken {
207 value: Arc::new(value),
208 created_at,
209 },
210 });
211 }
212 Ok(())
213 }
214
215 fn require_output(&self, place_name: &str) -> Result<Arc<str>, ActionError> {
216 self.allowed_outputs
217 .get(place_name)
218 .cloned()
219 .ok_or_else(|| {
220 ActionError::new(format!(
221 "Place '{}' not in declared outputs: {:?}",
222 place_name,
223 self.allowed_outputs.iter().collect::<Vec<_>>()
224 ))
225 })
226 }
227
228 pub fn output_place_names(&self) -> Vec<Arc<str>> {
230 self.allowed_outputs.iter().cloned().collect()
231 }
232
233 pub fn transition_name(&self) -> &str {
237 &self.transition_name
238 }
239
240 pub fn set_execution_context<T: Send + Sync + 'static>(&mut self, key: &str, value: T) {
244 self.execution_ctx.insert(key.to_string(), Box::new(value));
245 }
246
247 pub fn execution_context<T: 'static>(&self, key: &str) -> Option<&T> {
249 self.execution_ctx
250 .get(key)
251 .and_then(|v| v.downcast_ref::<T>())
252 }
253
254 pub fn has_execution_context(&self, key: &str) -> bool {
256 self.execution_ctx.contains_key(key)
257 }
258
259 pub fn log(&self, level: &str, message: &str) {
263 if let Some(ref log_fn) = self.log_fn {
264 log_fn(level, message);
265 }
266 }
267
268 pub fn take_outputs(&mut self) -> Vec<OutputEntry> {
272 std::mem::take(&mut self.outputs)
273 }
274
275 pub fn take_inputs(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
277 std::mem::take(&mut self.inputs)
278 }
279
280 pub fn take_reads(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
282 std::mem::take(&mut self.reads)
283 }
284
285 pub fn outputs(&self) -> &[OutputEntry] {
287 &self.outputs
288 }
289
290 fn downcast_value<T: Send + Sync + 'static>(
291 &self,
292 token: &ErasedToken,
293 place_name: &str,
294 ) -> Result<Arc<T>, ActionError> {
295 let any_arc = Arc::clone(&token.value);
297 if any_arc.downcast_ref::<T>().is_none() {
299 return Err(ActionError::new(format!(
300 "Type mismatch for place '{place_name}': expected {}",
301 std::any::type_name::<T>()
302 )));
303 }
304 let raw = Arc::into_raw(any_arc);
306 let typed = unsafe { Arc::from_raw(raw.cast::<T>()) };
307 Ok(typed)
308 }
309}
310
311impl std::fmt::Debug for TransitionContext {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("TransitionContext")
314 .field("transition_name", &self.transition_name)
315 .field("input_count", &self.inputs.len())
316 .field("read_count", &self.reads.len())
317 .field("output_count", &self.outputs.len())
318 .finish()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn ctx_with_output(place_name: &str) -> TransitionContext {
327 let mut allowed = HashSet::new();
328 allowed.insert(Arc::<str>::from(place_name));
329 TransitionContext::new(
330 Arc::from("T"),
331 HashMap::new(),
332 HashMap::new(),
333 allowed,
334 None,
335 )
336 }
337
338 fn downcast_values<T: Send + Sync + 'static + Clone>(ctx: &TransitionContext) -> Vec<T> {
339 ctx.outputs()
340 .iter()
341 .map(|e| (*e.token.value.downcast_ref::<T>().unwrap()).clone())
342 .collect()
343 }
344
345 #[test]
346 fn output_many_from_array_appends_in_order() {
347 let mut ctx = ctx_with_output("out");
348 ctx.output_many("out", [1, 2, 3]).unwrap();
349 assert_eq!(downcast_values::<i32>(&ctx), vec![1, 2, 3]);
350 assert!(ctx.outputs().iter().all(|e| &*e.place_name == "out"));
351 }
352
353 #[test]
354 fn output_many_from_vec() {
355 let mut ctx = ctx_with_output("out");
356 ctx.output_many("out", vec!["a".to_string(), "b".to_string()])
357 .unwrap();
358 assert_eq!(
359 downcast_values::<String>(&ctx),
360 vec!["a".to_string(), "b".to_string()]
361 );
362 }
363
364 #[test]
365 fn output_many_from_range_iterator() {
366 let mut ctx = ctx_with_output("out");
367 ctx.output_many("out", 0..5i32).unwrap();
368 assert_eq!(downcast_values::<i32>(&ctx), vec![0, 1, 2, 3, 4]);
369 }
370
371 #[test]
372 fn output_many_empty_is_ok_and_no_op() {
373 let mut ctx = ctx_with_output("out");
374 let empty: [i32; 0] = [];
375 ctx.output_many("out", empty).unwrap();
376 assert!(ctx.outputs().is_empty());
377 }
378
379 #[test]
380 fn output_many_undeclared_place_errors_before_appending() {
381 let mut ctx = ctx_with_output("out");
382 let err = ctx.output_many("nope", [1, 2, 3]).unwrap_err();
383 assert!(format!("{err:?}").contains("not in declared outputs"));
384 assert!(ctx.outputs().is_empty());
385 }
386
387 #[test]
388 fn output_many_shares_timestamp_across_tokens() {
389 let mut ctx = ctx_with_output("out");
390 ctx.output_many("out", [10i32, 20, 30]).unwrap();
391 let ts: Vec<_> = ctx.outputs().iter().map(|e| e.token.created_at).collect();
392 assert_eq!(ts.len(), 3);
393 assert!(ts.windows(2).all(|w| w[0] == w[1]));
394 }
395
396 #[test]
397 fn single_output_still_works_alongside_output_many() {
398 let mut ctx = ctx_with_output("out");
400 ctx.output("out", 42i32).unwrap();
401 ctx.output_many("out", [43, 44]).unwrap();
402 assert_eq!(downcast_values::<i32>(&ctx), vec![42, 43, 44]);
403 }
404}