Skip to main content

libpetri_core/
context.rs

1use std::any::Any;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::action::ActionError;
6use crate::token::ErasedToken;
7
8/// Callback for emitting log messages from transition actions.
9pub type LogFn = Arc<dyn Fn(&str, &str) + Send + Sync>;
10
11/// An output entry: place name + erased token.
12#[derive(Debug, Clone)]
13pub struct OutputEntry {
14    pub place_name: Arc<str>,
15    pub token: ErasedToken,
16}
17
18/// Context provided to transition actions.
19///
20/// Provides filtered access based on structure:
21/// - Input places (consumed tokens)
22/// - Read places (context tokens, not consumed)
23/// - Output places (where to produce tokens)
24///
25/// Enforces the structure contract — actions can only access places
26/// declared in the transition's structure.
27pub struct TransitionContext {
28    transition_name: Arc<str>,
29    inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
30    reads: HashMap<Arc<str>, Vec<ErasedToken>>,
31    allowed_outputs: HashSet<Arc<str>>,
32    outputs: Vec<OutputEntry>,
33    execution_ctx: HashMap<String, Box<dyn Any + Send + Sync>>,
34    log_fn: Option<LogFn>,
35}
36
37impl TransitionContext {
38    pub fn new(
39        transition_name: Arc<str>,
40        inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
41        reads: HashMap<Arc<str>, Vec<ErasedToken>>,
42        allowed_outputs: HashSet<Arc<str>>,
43        log_fn: Option<LogFn>,
44    ) -> Self {
45        Self {
46            transition_name,
47            inputs,
48            reads,
49            allowed_outputs,
50            outputs: Vec::new(),
51            execution_ctx: HashMap::new(),
52            log_fn,
53        }
54    }
55
56    // ==================== Input Access (consumed) ====================
57
58    /// Get single consumed input value. Returns error if place not declared or wrong type.
59    pub fn input<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
60        let tokens = self.inputs.get(place_name).ok_or_else(|| {
61            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
62        })?;
63        if tokens.len() != 1 {
64            return Err(ActionError::new(format!(
65                "Place '{place_name}' consumed {} tokens, use inputs() for batched access",
66                tokens.len()
67            )));
68        }
69        self.downcast_value::<T>(&tokens[0], place_name)
70    }
71
72    /// Get all consumed input values for a place.
73    pub fn inputs<T: Send + Sync + 'static>(
74        &self,
75        place_name: &str,
76    ) -> Result<Vec<Arc<T>>, ActionError> {
77        let tokens = self.inputs.get(place_name).ok_or_else(|| {
78            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
79        })?;
80        tokens
81            .iter()
82            .map(|t| self.downcast_value::<T>(t, place_name))
83            .collect()
84    }
85
86    /// Get the raw (type-erased) value of the first input token.
87    pub fn input_raw(&self, place_name: &str) -> Result<Arc<dyn Any + Send + Sync>, ActionError> {
88        let tokens = self.inputs.get(place_name).ok_or_else(|| {
89            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
90        })?;
91        if tokens.is_empty() {
92            return Err(ActionError::new(format!(
93                "No tokens for place '{place_name}'"
94            )));
95        }
96        Ok(Arc::clone(&tokens[0].value))
97    }
98
99    /// Returns the names of all declared input places.
100    pub fn input_place_names(&self) -> Vec<Arc<str>> {
101        self.inputs.keys().cloned().collect()
102    }
103
104    // ==================== Read Access (not consumed) ====================
105
106    /// Get read-only context value. Returns error if place not declared.
107    pub fn read<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
108        let tokens = self.reads.get(place_name).ok_or_else(|| {
109            ActionError::new(format!("Place '{place_name}' not in declared reads"))
110        })?;
111        if tokens.is_empty() {
112            return Err(ActionError::new(format!(
113                "No tokens for read place '{place_name}'"
114            )));
115        }
116        self.downcast_value::<T>(&tokens[0], place_name)
117    }
118
119    /// Get all read-only context values for a place.
120    pub fn reads<T: Send + Sync + 'static>(
121        &self,
122        place_name: &str,
123    ) -> Result<Vec<Arc<T>>, ActionError> {
124        let tokens = self.reads.get(place_name).ok_or_else(|| {
125            ActionError::new(format!("Place '{place_name}' not in declared reads"))
126        })?;
127        tokens
128            .iter()
129            .map(|t| self.downcast_value::<T>(t, place_name))
130            .collect()
131    }
132
133    /// Returns the names of all declared read places.
134    pub fn read_place_names(&self) -> Vec<Arc<str>> {
135        self.reads.keys().cloned().collect()
136    }
137
138    // ==================== Output Access ====================
139
140    /// Add output value. Returns error if place not declared as output.
141    pub fn output<T: Send + Sync + 'static>(
142        &mut self,
143        place_name: &str,
144        value: T,
145    ) -> Result<(), ActionError> {
146        let name = self.require_output(place_name)?;
147        self.outputs.push(OutputEntry {
148            place_name: name,
149            token: ErasedToken {
150                value: Arc::new(value),
151                created_at: crate::token::now_millis(),
152            },
153        });
154        Ok(())
155    }
156
157    /// Add a raw (type-erased) output value.
158    pub fn output_raw(
159        &mut self,
160        place_name: &str,
161        value: Arc<dyn Any + Send + Sync>,
162    ) -> Result<(), ActionError> {
163        let name = self.require_output(place_name)?;
164        self.outputs.push(OutputEntry {
165            place_name: name,
166            token: ErasedToken {
167                value,
168                created_at: crate::token::now_millis(),
169            },
170        });
171        Ok(())
172    }
173
174    /// Add multiple output values to the same place in a single call.
175    ///
176    /// Equivalent to calling [`output`](Self::output) once per element, but:
177    /// - Validates the declared-output set **once** before iterating, so
178    ///   an undeclared place returns `Err` before any element is appended.
179    /// - Pre-reserves capacity on the internal output collector from the
180    ///   iterator's `size_hint().0`.
181    /// - Shares a single `created_at` timestamp across all produced tokens,
182    ///   matching the "fired at time T" semantics of a single action firing.
183    ///
184    /// Accepts anything that implements [`IntoIterator`], including arrays,
185    /// `Vec`, slice iterators, and iterator adaptors.
186    ///
187    /// # Example
188    /// ```ignore
189    /// ctx.output_many("out", [1, 2, 3])?;
190    /// ctx.output_many("out", vec!["a", "b"])?;
191    /// ctx.output_many("out", (0..5))?;
192    /// ```
193    pub fn output_many<T: Send + Sync + 'static>(
194        &mut self,
195        place_name: &str,
196        values: impl IntoIterator<Item = T>,
197    ) -> Result<(), ActionError> {
198        let name = self.require_output(place_name)?;
199        let iter = values.into_iter();
200        let (lower, _) = iter.size_hint();
201        self.outputs.reserve(lower);
202        let created_at = crate::token::now_millis();
203        for value in iter {
204            self.outputs.push(OutputEntry {
205                place_name: Arc::clone(&name),
206                token: ErasedToken {
207                    value: Arc::new(value),
208                    created_at,
209                },
210            });
211        }
212        Ok(())
213    }
214
215    fn require_output(&self, place_name: &str) -> Result<Arc<str>, ActionError> {
216        self.allowed_outputs
217            .get(place_name)
218            .cloned()
219            .ok_or_else(|| {
220                ActionError::new(format!(
221                    "Place '{}' not in declared outputs: {:?}",
222                    place_name,
223                    self.allowed_outputs.iter().collect::<Vec<_>>()
224                ))
225            })
226    }
227
228    /// Returns the names of all declared output places.
229    pub fn output_place_names(&self) -> Vec<Arc<str>> {
230        self.allowed_outputs.iter().cloned().collect()
231    }
232
233    // ==================== Structure Info ====================
234
235    /// Returns the transition name.
236    pub fn transition_name(&self) -> &str {
237        &self.transition_name
238    }
239
240    // ==================== Execution Context ====================
241
242    /// Store an execution context value.
243    pub fn set_execution_context<T: Send + Sync + 'static>(&mut self, key: &str, value: T) {
244        self.execution_ctx.insert(key.to_string(), Box::new(value));
245    }
246
247    /// Retrieve an execution context value.
248    pub fn execution_context<T: 'static>(&self, key: &str) -> Option<&T> {
249        self.execution_ctx
250            .get(key)
251            .and_then(|v| v.downcast_ref::<T>())
252    }
253
254    /// Check if an execution context key exists.
255    pub fn has_execution_context(&self, key: &str) -> bool {
256        self.execution_ctx.contains_key(key)
257    }
258
259    // ==================== Logging ====================
260
261    /// Emits a structured log message.
262    pub fn log(&self, level: &str, message: &str) {
263        if let Some(ref log_fn) = self.log_fn {
264            log_fn(level, message);
265        }
266    }
267
268    // ==================== Internal ====================
269
270    /// Collects all output entries (used by executor).
271    pub fn take_outputs(&mut self) -> Vec<OutputEntry> {
272        std::mem::take(&mut self.outputs)
273    }
274
275    /// Reclaims the inputs HashMap for reuse (used by executor to avoid per-firing allocation).
276    pub fn take_inputs(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
277        std::mem::take(&mut self.inputs)
278    }
279
280    /// Reclaims the reads HashMap for reuse (used by executor to avoid per-firing allocation).
281    pub fn take_reads(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
282        std::mem::take(&mut self.reads)
283    }
284
285    /// Returns a reference to the output entries.
286    pub fn outputs(&self) -> &[OutputEntry] {
287        &self.outputs
288    }
289
290    fn downcast_value<T: Send + Sync + 'static>(
291        &self,
292        token: &ErasedToken,
293        place_name: &str,
294    ) -> Result<Arc<T>, ActionError> {
295        // Try to downcast the inner Arc
296        let any_arc = Arc::clone(&token.value);
297        // First check if the type matches
298        if any_arc.downcast_ref::<T>().is_none() {
299            return Err(ActionError::new(format!(
300                "Type mismatch for place '{place_name}': expected {}",
301                std::any::type_name::<T>()
302            )));
303        }
304        // Safety: we just verified the type
305        let raw = Arc::into_raw(any_arc);
306        let typed = unsafe { Arc::from_raw(raw.cast::<T>()) };
307        Ok(typed)
308    }
309}
310
311impl std::fmt::Debug for TransitionContext {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("TransitionContext")
314            .field("transition_name", &self.transition_name)
315            .field("input_count", &self.inputs.len())
316            .field("read_count", &self.reads.len())
317            .field("output_count", &self.outputs.len())
318            .finish()
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn ctx_with_output(place_name: &str) -> TransitionContext {
327        let mut allowed = HashSet::new();
328        allowed.insert(Arc::<str>::from(place_name));
329        TransitionContext::new(
330            Arc::from("T"),
331            HashMap::new(),
332            HashMap::new(),
333            allowed,
334            None,
335        )
336    }
337
338    fn downcast_values<T: Send + Sync + 'static + Clone>(ctx: &TransitionContext) -> Vec<T> {
339        ctx.outputs()
340            .iter()
341            .map(|e| (*e.token.value.downcast_ref::<T>().unwrap()).clone())
342            .collect()
343    }
344
345    #[test]
346    fn output_many_from_array_appends_in_order() {
347        let mut ctx = ctx_with_output("out");
348        ctx.output_many("out", [1, 2, 3]).unwrap();
349        assert_eq!(downcast_values::<i32>(&ctx), vec![1, 2, 3]);
350        assert!(ctx.outputs().iter().all(|e| &*e.place_name == "out"));
351    }
352
353    #[test]
354    fn output_many_from_vec() {
355        let mut ctx = ctx_with_output("out");
356        ctx.output_many("out", vec!["a".to_string(), "b".to_string()])
357            .unwrap();
358        assert_eq!(
359            downcast_values::<String>(&ctx),
360            vec!["a".to_string(), "b".to_string()]
361        );
362    }
363
364    #[test]
365    fn output_many_from_range_iterator() {
366        let mut ctx = ctx_with_output("out");
367        ctx.output_many("out", 0..5i32).unwrap();
368        assert_eq!(downcast_values::<i32>(&ctx), vec![0, 1, 2, 3, 4]);
369    }
370
371    #[test]
372    fn output_many_empty_is_ok_and_no_op() {
373        let mut ctx = ctx_with_output("out");
374        let empty: [i32; 0] = [];
375        ctx.output_many("out", empty).unwrap();
376        assert!(ctx.outputs().is_empty());
377    }
378
379    #[test]
380    fn output_many_undeclared_place_errors_before_appending() {
381        let mut ctx = ctx_with_output("out");
382        let err = ctx.output_many("nope", [1, 2, 3]).unwrap_err();
383        assert!(format!("{err:?}").contains("not in declared outputs"));
384        assert!(ctx.outputs().is_empty());
385    }
386
387    #[test]
388    fn output_many_shares_timestamp_across_tokens() {
389        let mut ctx = ctx_with_output("out");
390        ctx.output_many("out", [10i32, 20, 30]).unwrap();
391        let ts: Vec<_> = ctx.outputs().iter().map(|e| e.token.created_at).collect();
392        assert_eq!(ts.len(), 3);
393        assert!(ts.windows(2).all(|w| w[0] == w[1]));
394    }
395
396    #[test]
397    fn single_output_still_works_alongside_output_many() {
398        // Smoke test: the existing output() path is untouched by the bulk addition.
399        let mut ctx = ctx_with_output("out");
400        ctx.output("out", 42i32).unwrap();
401        ctx.output_many("out", [43, 44]).unwrap();
402        assert_eq!(downcast_values::<i32>(&ctx), vec![42, 43, 44]);
403    }
404}