Skip to main content

libpetri_core/
context.rs

1use std::any::Any;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use crate::action::ActionError;
6use crate::token::ErasedToken;
7
8/// Callback for emitting log messages from transition actions.
9pub type LogFn = Arc<dyn Fn(&str, &str) + Send + Sync>;
10
11/// An output entry: place name + erased token.
12#[derive(Debug, Clone)]
13pub struct OutputEntry {
14    pub place_name: Arc<str>,
15    pub token: ErasedToken,
16}
17
18/// Context provided to transition actions.
19///
20/// Provides filtered access based on structure:
21/// - Input places (consumed tokens)
22/// - Read places (context tokens, not consumed)
23/// - Output places (where to produce tokens)
24///
25/// Enforces the structure contract — actions can only access places
26/// declared in the transition's structure.
27pub struct TransitionContext {
28    transition_name: Arc<str>,
29    inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
30    reads: HashMap<Arc<str>, Vec<ErasedToken>>,
31    allowed_outputs: HashSet<Arc<str>>,
32    outputs: Vec<OutputEntry>,
33    execution_ctx: HashMap<String, Box<dyn Any + Send + Sync>>,
34    log_fn: Option<LogFn>,
35}
36
37impl TransitionContext {
38    pub fn new(
39        transition_name: Arc<str>,
40        inputs: HashMap<Arc<str>, Vec<ErasedToken>>,
41        reads: HashMap<Arc<str>, Vec<ErasedToken>>,
42        allowed_outputs: HashSet<Arc<str>>,
43        log_fn: Option<LogFn>,
44    ) -> Self {
45        Self {
46            transition_name,
47            inputs,
48            reads,
49            allowed_outputs,
50            outputs: Vec::new(),
51            execution_ctx: HashMap::new(),
52            log_fn,
53        }
54    }
55
56    // ==================== Input Access (consumed) ====================
57
58    /// Get single consumed input value. Returns error if place not declared or wrong type.
59    pub fn input<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
60        let tokens = self.inputs.get(place_name).ok_or_else(|| {
61            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
62        })?;
63        if tokens.len() != 1 {
64            return Err(ActionError::new(format!(
65                "Place '{place_name}' consumed {} tokens, use inputs() for batched access",
66                tokens.len()
67            )));
68        }
69        self.downcast_value::<T>(&tokens[0], place_name)
70    }
71
72    /// Get all consumed input values for a place.
73    pub fn inputs<T: Send + Sync + 'static>(
74        &self,
75        place_name: &str,
76    ) -> Result<Vec<Arc<T>>, ActionError> {
77        let tokens = self.inputs.get(place_name).ok_or_else(|| {
78            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
79        })?;
80        tokens
81            .iter()
82            .map(|t| self.downcast_value::<T>(t, place_name))
83            .collect()
84    }
85
86    /// Get the raw (type-erased) value of the first input token.
87    pub fn input_raw(&self, place_name: &str) -> Result<Arc<dyn Any + Send + Sync>, ActionError> {
88        let tokens = self.inputs.get(place_name).ok_or_else(|| {
89            ActionError::new(format!("Place '{place_name}' not in declared inputs"))
90        })?;
91        if tokens.is_empty() {
92            return Err(ActionError::new(format!(
93                "No tokens for place '{place_name}'"
94            )));
95        }
96        Ok(Arc::clone(&tokens[0].value))
97    }
98
99    /// Returns the names of all declared input places.
100    pub fn input_place_names(&self) -> Vec<Arc<str>> {
101        self.inputs.keys().cloned().collect()
102    }
103
104    // ==================== Read Access (not consumed) ====================
105
106    /// Get read-only context value. Returns error if place not declared.
107    pub fn read<T: Send + Sync + 'static>(&self, place_name: &str) -> Result<Arc<T>, ActionError> {
108        let tokens = self.reads.get(place_name).ok_or_else(|| {
109            ActionError::new(format!("Place '{place_name}' not in declared reads"))
110        })?;
111        if tokens.is_empty() {
112            return Err(ActionError::new(format!(
113                "No tokens for read place '{place_name}'"
114            )));
115        }
116        self.downcast_value::<T>(&tokens[0], place_name)
117    }
118
119    /// Get all read-only context values for a place.
120    pub fn reads<T: Send + Sync + 'static>(
121        &self,
122        place_name: &str,
123    ) -> Result<Vec<Arc<T>>, ActionError> {
124        let tokens = self.reads.get(place_name).ok_or_else(|| {
125            ActionError::new(format!("Place '{place_name}' not in declared reads"))
126        })?;
127        tokens
128            .iter()
129            .map(|t| self.downcast_value::<T>(t, place_name))
130            .collect()
131    }
132
133    /// Returns the names of all declared read places.
134    pub fn read_place_names(&self) -> Vec<Arc<str>> {
135        self.reads.keys().cloned().collect()
136    }
137
138    // ==================== Output Access ====================
139
140    /// Add output value. Returns error if place not declared as output.
141    pub fn output<T: Send + Sync + 'static>(
142        &mut self,
143        place_name: &str,
144        value: T,
145    ) -> Result<(), ActionError> {
146        let name = self.require_output(place_name)?;
147        self.outputs.push(OutputEntry {
148            place_name: name,
149            token: ErasedToken {
150                value: Arc::new(value),
151                created_at: crate::token::now_millis(),
152            },
153        });
154        Ok(())
155    }
156
157    /// Add a raw (type-erased) output value.
158    pub fn output_raw(
159        &mut self,
160        place_name: &str,
161        value: Arc<dyn Any + Send + Sync>,
162    ) -> Result<(), ActionError> {
163        let name = self.require_output(place_name)?;
164        self.outputs.push(OutputEntry {
165            place_name: name,
166            token: ErasedToken {
167                value,
168                created_at: crate::token::now_millis(),
169            },
170        });
171        Ok(())
172    }
173
174    fn require_output(&self, place_name: &str) -> Result<Arc<str>, ActionError> {
175        self.allowed_outputs
176            .get(place_name)
177            .cloned()
178            .ok_or_else(|| {
179                ActionError::new(format!(
180                    "Place '{}' not in declared outputs: {:?}",
181                    place_name,
182                    self.allowed_outputs.iter().collect::<Vec<_>>()
183                ))
184            })
185    }
186
187    /// Returns the names of all declared output places.
188    pub fn output_place_names(&self) -> Vec<Arc<str>> {
189        self.allowed_outputs.iter().cloned().collect()
190    }
191
192    // ==================== Structure Info ====================
193
194    /// Returns the transition name.
195    pub fn transition_name(&self) -> &str {
196        &self.transition_name
197    }
198
199    // ==================== Execution Context ====================
200
201    /// Store an execution context value.
202    pub fn set_execution_context<T: Send + Sync + 'static>(&mut self, key: &str, value: T) {
203        self.execution_ctx.insert(key.to_string(), Box::new(value));
204    }
205
206    /// Retrieve an execution context value.
207    pub fn execution_context<T: 'static>(&self, key: &str) -> Option<&T> {
208        self.execution_ctx
209            .get(key)
210            .and_then(|v| v.downcast_ref::<T>())
211    }
212
213    /// Check if an execution context key exists.
214    pub fn has_execution_context(&self, key: &str) -> bool {
215        self.execution_ctx.contains_key(key)
216    }
217
218    // ==================== Logging ====================
219
220    /// Emits a structured log message.
221    pub fn log(&self, level: &str, message: &str) {
222        if let Some(ref log_fn) = self.log_fn {
223            log_fn(level, message);
224        }
225    }
226
227    // ==================== Internal ====================
228
229    /// Collects all output entries (used by executor).
230    pub fn take_outputs(&mut self) -> Vec<OutputEntry> {
231        std::mem::take(&mut self.outputs)
232    }
233
234    /// Reclaims the inputs HashMap for reuse (used by executor to avoid per-firing allocation).
235    pub fn take_inputs(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
236        std::mem::take(&mut self.inputs)
237    }
238
239    /// Reclaims the reads HashMap for reuse (used by executor to avoid per-firing allocation).
240    pub fn take_reads(&mut self) -> HashMap<Arc<str>, Vec<ErasedToken>> {
241        std::mem::take(&mut self.reads)
242    }
243
244    /// Returns a reference to the output entries.
245    pub fn outputs(&self) -> &[OutputEntry] {
246        &self.outputs
247    }
248
249    fn downcast_value<T: Send + Sync + 'static>(
250        &self,
251        token: &ErasedToken,
252        place_name: &str,
253    ) -> Result<Arc<T>, ActionError> {
254        // Try to downcast the inner Arc
255        let any_arc = Arc::clone(&token.value);
256        // First check if the type matches
257        if any_arc.downcast_ref::<T>().is_none() {
258            return Err(ActionError::new(format!(
259                "Type mismatch for place '{place_name}': expected {}",
260                std::any::type_name::<T>()
261            )));
262        }
263        // Safety: we just verified the type
264        let raw = Arc::into_raw(any_arc);
265        let typed = unsafe { Arc::from_raw(raw.cast::<T>()) };
266        Ok(typed)
267    }
268}
269
270impl std::fmt::Debug for TransitionContext {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        f.debug_struct("TransitionContext")
273            .field("transition_name", &self.transition_name)
274            .field("input_count", &self.inputs.len())
275            .field("read_count", &self.reads.len())
276            .field("output_count", &self.outputs.len())
277            .finish()
278    }
279}