rgrow/python.rs
1use std::collections::HashMap;
2use std::fs::File;
3use std::ops::DerefMut;
4use std::time::Duration;
5
6use crate::base::{NumEvents, NumTiles, RgrowError, RustAny, Tile};
7use crate::canvas::{Canvas, PointSafe2, PointSafeHere};
8use crate::ffs::{FFSRunConfig, FFSRunResult, FFSStateRef};
9use crate::models::atam::ATAM;
10use crate::models::kblock::KBlock;
11use crate::models::ktam::KTAM;
12use crate::models::oldktam::OldKTAM;
13use crate::models::sdc1d::SDC;
14use crate::models::sdc1d_bindreplace::SDC1DBindReplace;
15use crate::ratestore::RateStore;
16use crate::state::{StateEnum, StateStatus, TileCounts, TrackerData};
17use crate::system::{CriticalStateConfig, CriticalStateResult};
18use crate::system::{
19 DimerInfo, DynSystem, EvolveBounds, EvolveOutcome, NeededUpdate, System, TileBondInfo,
20};
21use crate::units::Second;
22use ndarray::Array2;
23use numpy::{
24 IntoPyArray, PyArray1, PyArray2, PyArray3, PyArrayMethods, PyReadonlyArray2, ToPyArray,
25};
26use pyo3::exceptions::PyValueError;
27use pyo3::prelude::*;
28use pyo3::types::PyDict;
29use pyo3::IntoPyObjectExt;
30use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
31
32/// A State object.
33#[cfg_attr(feature = "python", pyclass(name = "State", module = "rgrow.rgrow"))]
34#[repr(transparent)]
35pub struct PyState(pub(crate) StateEnum);
36
37/// A single 'assembly', or 'state', containing a canvas with tiles at locations.
38/// Generally does not store concentration or temperature information, but does store time simulated.
39#[cfg(feature = "python")]
40#[pymethods]
41impl PyState {
42 #[new]
43 #[pyo3(signature = (shape, kind="Square", tracking="None", n_tile_types=None))]
44 pub fn empty(
45 shape: (usize, usize),
46 kind: &str,
47 tracking: &str,
48 n_tile_types: Option<usize>,
49 ) -> PyResult<Self> {
50 Ok(PyState(StateEnum::empty(
51 shape,
52 kind.try_into()?,
53 tracking.try_into()?,
54 n_tile_types.unwrap_or(1),
55 )?))
56 }
57
58 #[staticmethod]
59 #[pyo3(signature = (array, kind="Square", tracking="None", n_tile_types=None))]
60 pub fn from_array(
61 array: PyReadonlyArray2<crate::base::Tile>,
62 kind: &str,
63 tracking: &str,
64 n_tile_types: Option<usize>,
65 ) -> PyResult<Self> {
66 Ok(PyState(StateEnum::from_array(
67 array.as_array(),
68 kind.try_into()?,
69 tracking.try_into()?,
70 n_tile_types.unwrap_or(1),
71 )?))
72 }
73
74 /// Return a cloned copy of an array with the total possible next event rate for each point in the canvas.
75 /// This is the deepest level of the quadtree for tree-based states.
76 ///
77 /// Returns
78 /// -------
79 /// NDArray[np.uint]
80 pub fn rate_array<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f64>> {
81 self.0.rate_array().mapv(|x| x.into()).to_pyarray(py)
82 }
83
84 #[getter]
85 /// float: the total rate of possible next events for the state.
86 pub fn total_rate(&self) -> f64 {
87 RateStore::total_rate(&self.0).into()
88 }
89
90 #[getter]
91 /// NDArray[np.uint]: a direct, mutable view of the state's canvas. This is potentially unsafe.
92 pub fn canvas_view<'py>(
93 this: Bound<'py, Self>,
94 _py: Python<'py>,
95 ) -> PyResult<Bound<'py, PyArray2<crate::base::Tile>>> {
96 let t = this.borrow();
97 let ra = t.0.raw_array();
98
99 unsafe { Ok(PyArray2::borrow_from_array(&ra, this.into_any())) }
100 }
101
102 /// Return a copy of the state's canvas. This is safe, but can't be modified and is slower than `canvas_view`.
103 ///
104 /// Returns
105 /// -------
106 /// NDArray[np.uint]
107 /// A cloned copy of the state's canvas, in raw form.
108 pub fn canvas_copy<'py>(
109 this: &Bound<'py, Self>,
110 py: Python<'py>,
111 ) -> PyResult<Bound<'py, PyArray2<crate::base::Tile>>> {
112 let t = this.borrow();
113 let ra = t.0.raw_array();
114
115 Ok(PyArray2::from_array(py, &ra))
116 }
117
118 /// Return the total possible next event rate at a specific canvas point.
119 ///
120 /// Parameters
121 /// ----------
122 /// point: tuple[int, int]
123 /// The canvas point.
124 ///
125 /// Returns
126 /// -------
127 /// f64
128 ///
129 /// Raises
130 /// ------
131 /// ValueError
132 /// if `point` is out of bounds for the canvas.
133 pub fn rate_at_point(&self, point: (usize, usize)) -> PyResult<f64> {
134 if self.0.inbounds(point) {
135 Ok(self.0.rate_at_point(PointSafeHere(point)).into())
136 } else {
137 Err(PyValueError::new_err(format!(
138 "Point {point:?} is out of bounds."
139 )))
140 }
141 }
142
143 /// Return a copy of the tracker's tracking data.
144 ///
145 /// Returns
146 /// -------
147 /// Any
148 pub fn tracking_copy(this: &Bound<Self>) -> PyResult<RustAny> {
149 let t = this.borrow();
150 let ra = t.0.get_tracker_data();
151
152 Ok(ra)
153 }
154
155 /// int: the number of tiles in the state.
156 #[getter]
157 pub fn n_tiles(&self) -> NumTiles {
158 self.0.n_tiles()
159 }
160
161 /// int: the number of tiles in the state (deprecated, use `n_tiles` instead).
162 #[getter]
163 pub fn ntiles(&self) -> NumTiles {
164 self.0.n_tiles()
165 }
166
167 /// int: the total number of events that have occurred in the state.
168 #[getter]
169 pub fn total_events(&self) -> NumEvents {
170 self.0.total_events()
171 }
172
173 /// float: the total time the state has simulated, in seconds.
174 #[getter]
175 pub fn time(&self) -> f64 {
176 self.0.time().into()
177 }
178
179 #[getter]
180 pub fn tile_counts<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<u32>> {
181 self.0.tile_counts().to_pyarray(py)
182 }
183
184 pub fn __repr__(&self) -> String {
185 format!(
186 "State(n_tiles={}, time={} s, events={}, size=({}, {}), total_rate={})",
187 self.n_tiles(),
188 self.0.time(),
189 self.total_events(),
190 self.0.ncols(),
191 self.0.nrows(),
192 self.0.total_rate()
193 )
194 }
195
196 pub fn print_debug(&self) {
197 println!("{:?}", self.0);
198 }
199
200 /// Write the state to a JSON file. This is inefficient, and is likely
201 /// useful primarily for debugging.
202 pub fn write_json(&self, filename: &str) -> Result<(), RgrowError> {
203 serde_json::to_writer(File::create(filename)?, &self.0).unwrap();
204 Ok(())
205 }
206
207 #[staticmethod]
208 pub fn read_json(filename: &str) -> Result<Self, RgrowError> {
209 Ok(PyState(
210 serde_json::from_reader(File::open(filename)?).unwrap(),
211 ))
212 }
213
214 /// Create a copy of the state.
215 ///
216 /// This creates a complete clone of the state, including all canvas data,
217 /// tracking information, and simulation state (time, events, etc.).
218 ///
219 /// Returns
220 /// -------
221 /// State
222 /// A new State object that is a copy of this state.
223 ///
224 /// Examples
225 /// --------
226 /// >>> original_state = State((10, 10))
227 /// >>> copied_state = original_state.copy()
228 /// >>> # The copied state is independent of the original
229 /// >>> assert copied_state.time == original_state.time
230 /// >>> assert copied_state.total_events == original_state.total_events
231 pub fn copy(&self) -> Self {
232 PyState(self.0.clone())
233 }
234
235 /// Serialize state for pickling.
236 fn __getstate__(&self) -> PyResult<Vec<u8>> {
237 bincode::serialize(&self.0).map_err(|e| {
238 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
239 "Failed to serialize state: {e}"
240 ))
241 })
242 }
243
244 /// Deserialize state from pickle data.
245 fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
246 self.0 = bincode::deserialize(&state).map_err(|e| {
247 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
248 "Failed to deserialize state: {e}"
249 ))
250 })?;
251 Ok(())
252 }
253
254 /// Return arguments for __new__ during unpickling.
255 fn __getnewargs__(&self) -> ((usize, usize),) {
256 ((1, 1),)
257 }
258
259 /// Replay the events from a MovieTracker up to a given event ID.
260 ///
261 /// This reconstructs the state by replaying all events from the MovieTracker.
262 /// The state must have been created with Movie tracking enabled.
263 ///
264 /// Parameters
265 /// ----------
266 /// up_to_event : int, optional
267 /// The event ID up to which to replay (inclusive). If not provided,
268 /// all events are replayed.
269 ///
270 /// Returns
271 /// -------
272 /// State
273 /// A new State with the events replayed. The returned state has no
274 /// tracker and no rates calculated.
275 ///
276 /// Raises
277 /// ------
278 /// ValueError
279 /// If the state does not have a MovieTracker.
280 ///
281 /// Examples
282 /// --------
283 /// >>> # Create a state with movie tracking and evolve it
284 /// >>> state = ts.create_state(tracking="Movie")
285 /// >>> sys.evolve(state, for_events=100)
286 /// >>> # Replay to get state at event 50
287 /// >>> replayed = state.replay(up_to_event=50)
288 #[pyo3(signature = (up_to_event=None))]
289 pub fn replay(&self, up_to_event: Option<u64>) -> PyResult<Self> {
290 self.0
291 .replay(up_to_event)
292 .map(PyState)
293 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
294 }
295
296 /// Replay events in-place on this state from external event data.
297 ///
298 /// This modifies the state's canvas by applying the events from the provided
299 /// coordinate and tile arrays. Unlike `replay()`, this method takes external
300 /// event data rather than using a MovieTracker.
301 ///
302 /// Parameters
303 /// ----------
304 /// coords : list[tuple[int, int]]
305 /// List of (row, col) coordinates for each event.
306 /// new_tiles : list[int]
307 /// List of tile values for each event.
308 /// event_ids : list[int]
309 /// List of event IDs for each event.
310 /// up_to_event_id : int
311 /// The event ID up to which to replay (inclusive).
312 ///
313 /// Raises
314 /// ------
315 /// ValueError
316 /// If there is an error during replay.
317 ///
318 /// Examples
319 /// --------
320 /// >>> state = State((10, 10))
321 /// >>> coords = [(1, 1), (2, 2)]
322 /// >>> new_tiles = [1, 2]
323 /// >>> event_ids = [0, 1]
324 /// >>> state.replay_inplace(coords, new_tiles, event_ids, 1)
325 pub fn replay_inplace(
326 &mut self,
327 coords: Vec<(usize, usize)>,
328 new_tiles: Vec<Tile>,
329 event_ids: Vec<u64>,
330 up_to_event_id: u64,
331 n_tiles: Option<Vec<u32>>,
332 total_time: Option<Vec<f64>>,
333 energy: Option<Vec<f64>>,
334 ) -> PyResult<()> {
335 let total_time_seconds: Option<Vec<Second>> =
336 total_time.map(|v| v.into_iter().map(Second::new).collect());
337 self.0
338 .replay_inplace(
339 &coords,
340 &new_tiles,
341 &event_ids,
342 up_to_event_id,
343 n_tiles.as_ref().map(|v| v.as_slice()),
344 total_time_seconds.as_ref().map(|v| v.as_slice()),
345 energy.as_ref().map(|v| v.as_slice()),
346 )
347 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
348 }
349}
350
351#[cfg(feature = "python")]
352#[derive(FromPyObject)]
353pub enum PyStateOrStates<'py> {
354 #[pyo3(transparent)]
355 State(Bound<'py, PyState>),
356 #[pyo3(transparent)]
357 States(Vec<Bound<'py, PyState>>),
358}
359
360#[cfg(feature = "python")]
361#[derive(FromPyObject)]
362pub enum PyStateOrRef<'py> {
363 State(Bound<'py, PyState>),
364 Ref(Bound<'py, FFSStateRef>),
365}
366
367#[cfg(feature = "python")]
368#[derive(FromPyObject)]
369pub enum PyStateOrCanvasRef<'py> {
370 State(Bound<'py, PyState>),
371 Ref(Bound<'py, FFSStateRef>),
372 Array(Bound<'py, PyArray2<Tile>>),
373}
374
375impl From<FFSStateRef> for PyState {
376 fn from(state: FFSStateRef) -> Self {
377 state.clone_state()
378 }
379}
380
381macro_rules! create_py_system {
382 ($name: ident) => {
383 create_py_system!($name, |tile: u32| tile as usize);
384 };
385 ($name: ident, $tile_index_fn: expr) => {
386 #[cfg(feature = "python")]
387 #[pymethods]
388 impl $name {
389
390
391 #[allow(clippy::too_many_arguments)]
392 #[pyo3(
393 name = "evolve",
394 signature = (state,
395 for_events=None,
396 total_events=None,
397 for_time=None,
398 total_time=None,
399 size_min=None,
400 size_max=None,
401 for_wall_time=None,
402 require_strong_bound=true,
403 show_window=false,
404 start_window_paused=true,
405 parallel=true,
406 initial_timescale=None,
407 initial_max_events_per_sec=None)
408 )]
409 /// Evolve a state (or states), with some bounds on the simulation.
410 ///
411 /// If evolving multiple states, the bounds are applied per-state.
412 ///
413 /// Parameters
414 /// ----------
415 /// state : State or Sequence[State]
416 /// The state or states to evolve.
417 /// for_events : int, optional
418 /// Stop evolving each state after this many events.
419 /// total_events : int, optional
420 /// Stop evelving each state when the state's total number of events (including
421 /// previous events) reaches this.
422 /// for_time : float, optional
423 /// Stop evolving each state after this many seconds of simulated time.
424 /// total_time : float, optional
425 /// Stop evolving each state when the state's total time (including previous steps)
426 /// reaches this.
427 /// size_min : int, optional
428 /// Stop evolving each state when the state's number of tiles is less than or equal to this.
429 /// size_max : int, optional
430 /// Stop evolving each state when the state's number of tiles is greater than or equal to this.
431 /// for_wall_time : float, optional
432 /// Stop evolving each state after this many seconds of wall time.
433 /// require_strong_bound : bool
434 /// Require that the stopping conditions are strong, i.e., they are guaranteed to be eventually
435 /// satisfied under normal conditions.
436 /// show_window : bool
437 /// Show a graphical UI window while evolving (requires rgrow-gui to be installed, and a single state).
438 /// start_window_paused : bool
439 /// If show_window is True, start the GUI window in a paused state. Defaults to True.
440 /// parallel : bool
441 /// Use multiple threads.
442 /// initial_timescale : float, optional
443 /// If show_window is True, set the initial timescale (sim_time/real_time) in the GUI. None means unlimited.
444 /// initial_max_events_per_sec : int, optional
445 /// If show_window is True, set the initial max events per second limit in the GUI. None means unlimited.
446 ///
447 /// Returns
448 /// -------
449 /// EvolveOutcome or List[EvolveOutcome]
450 /// The outcome (stopping condition) of the evolution. If evolving a single state, returns a single outcome.
451 pub fn py_evolve<'py>(
452 &mut self,
453 state: PyStateOrStates<'py>,
454 for_events: Option<u64>,
455 total_events: Option<u64>,
456 for_time: Option<f64>,
457 total_time: Option<f64>,
458 size_min: Option<u32>,
459 size_max: Option<u32>,
460 for_wall_time: Option<f64>,
461 require_strong_bound: bool,
462 show_window: bool,
463 start_window_paused: bool,
464 parallel: bool,
465 initial_timescale: Option<f64>,
466 initial_max_events_per_sec: Option<u64>,
467 py: Python<'py>,
468 ) -> PyResult<Py<PyAny>> {
469 let bounds = EvolveBounds {
470 for_events,
471 for_time,
472 total_events,
473 total_time,
474 size_min,
475 size_max,
476 for_wall_time: for_wall_time.map(Duration::from_secs_f64),
477 };
478
479 if require_strong_bound && !show_window && !bounds.is_strongly_bounded() {
480 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
481 "No strong bounds specified.",
482 ));
483 }
484
485 if !show_window && !bounds.is_weakly_bounded() {
486 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
487 "No weak bounds specified.",
488 ));
489 }
490
491 match state {
492 PyStateOrStates::State(pystate) => {
493 let state = &mut pystate.borrow_mut().0;
494 if show_window {
495 py
496 .detach(|| {
497 System::evolve_in_window(self, state, None, start_window_paused, bounds, initial_timescale, initial_max_events_per_sec)
498 })?
499 .into_py_any(py)
500 } else {
501 py
502 .detach(|| System::evolve(self, state, bounds))?
503 .into_py_any(py)
504 }
505 }
506 PyStateOrStates::States(pystates) => {
507 if show_window {
508 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
509 "Cannot show window with multiple states.",
510 ));
511 }
512 let mut refs = pystates
513 .into_iter()
514 .map(|x| x.borrow_mut())
515 .collect::<Vec<_>>();
516 let mut states = refs.iter_mut().map(|x| x.deref_mut()).collect::<Vec<_>>();
517 let out = py.detach(|| {
518 if parallel {
519 states
520 .par_iter_mut()
521 .map(|state| System::evolve(self, &mut state.0, bounds))
522 .collect::<Vec<_>>()
523 } else {
524 states
525 .iter_mut()
526 .map(|state| System::evolve(self, &mut state.0, bounds))
527 .collect::<Vec<_>>()
528 }});
529 let o: Result<Vec<EvolveOutcome>, PyErr> = out
530 .into_iter()
531 .map(|x| {
532 x.map_err(|y| {
533 pyo3::exceptions::PyValueError::new_err(y.to_string())
534 })
535 })
536 .collect();
537 o.map(|x| x.into_py_any(py).unwrap())
538 }
539 }
540 }
541
542 /// Calculate the number of mismatches in a state.
543 ///
544 /// Parameters
545 /// ----------
546 /// state : State or FFSStateRef
547 /// The state to calculate mismatches for.
548 ///
549 /// Returns
550 /// -------
551 /// int
552 /// The number of mismatches.
553 ///
554 /// See also
555 /// --------
556 /// calc_mismatch_locations
557 /// Calculate the location and direction of mismatches, not jus the number.
558 fn calc_mismatches(&self, state: PyStateOrRef) -> usize {
559 match state {
560 PyStateOrRef::State(s) => System::calc_mismatches(self, &s.borrow().0),
561 PyStateOrRef::Ref(s) => {
562 System::calc_mismatches(self, &s.borrow().clone_state().0)
563 }
564 }
565 }
566
567 /// Calculate information about the dimers the system is able to form.
568 ///
569 /// Returns
570 /// -------
571 /// List[DimerInfo]
572 ///
573 /// Raises
574 /// ------
575 /// ValueError
576 /// If the system doesn't support dimer calculation
577 fn calc_dimers(&self) -> PyResult<Vec<DimerInfo>> {
578 System::calc_dimers(self).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
579 }
580
581 /// Calculate the locations of mismatches in the state.
582 ///
583 /// This returns a copy of the canvas, with the values set to 0 if there is no mismatch
584 /// in the location, and > 0, in a model defined way, if there is at least one mismatch.
585 /// Most models use v = 8*N + 4*E + 2*S + W, where N, E, S, W are the four directions.
586 /// Thus, a tile with mismatches to the E and W would have v = 4+2 = 6.
587 ///
588 /// Parameters
589 /// ----------
590 /// state : State or FFSStateRef
591 /// The state to calculate mismatches for.
592 ///
593 /// Returns
594 /// -------
595 /// ndarray
596 /// An array of the same shape as the state's canvas, with the values set as described above.
597 fn calc_mismatch_locations<'py>(
598 &mut self,
599 state: PyStateOrRef,
600 py: Python<'py>,
601 ) -> PyResult<Bound<'py, PyArray2<usize>>> {
602 let ra = match state {
603 PyStateOrRef::State(s) => {
604 System::calc_mismatch_locations(self, &s.borrow().0)
605 }
606 PyStateOrRef::Ref(s) => {
607 System::calc_mismatch_locations(self, &s.borrow().clone_state().0)
608 }
609 };
610 Ok(PyArray2::from_array(py, &ra))
611 }
612
613 /// Set a system parameter.
614 ///
615 /// Parameters
616 /// ----------
617 /// param_name : str
618 /// The name of the parameter to set.
619 /// value : Any
620 /// The value to set the parameter to.
621 ///
622 /// Returns
623 /// -------
624 /// NeededUpdate
625 /// The type of state update needed. This can be passed to
626 /// `update_state` to update the state.
627 fn set_param(&mut self, param_name: &str, value: RustAny) -> PyResult<NeededUpdate> {
628 Ok(System::set_param(self, param_name, value.0)?)
629 }
630
631 /// Names of tiles, by tile number.
632 #[getter]
633 fn tile_names(&self) -> Vec<String> {
634 TileBondInfo::tile_names(self)
635 .iter()
636 .map(|x| x.to_string())
637 .collect()
638 }
639
640 #[getter]
641 fn bond_names(&self) -> Vec<String> {
642 TileBondInfo::bond_names(self)
643 .iter()
644 .map(|x| x.to_string())
645 .collect()
646 }
647
648 /// Given a tile name, return the tile number.
649 ///
650 /// Parameters
651 /// ----------
652 /// tile_name : str
653 /// The name of the tile.
654 ///
655 /// Returns
656 /// -------
657 /// int
658 /// The tile number.
659 fn tile_number_from_name(&self, tile_name: &str) -> Option<Tile> {
660 TileBondInfo::tile_names(self)
661 .iter()
662 .position(|x| *x == tile_name)
663 .map(|x| x as Tile)
664 }
665
666 /// Given a tile number, return the color of the tile.
667 ///
668 /// Parameters
669 /// ----------
670 /// tile_number : int
671 /// The tile number.
672 ///
673 /// Returns
674 /// -------
675 /// list[int]
676 /// The color of the tile, as a list of 4 integers (RGBA).
677 fn tile_color(&self, tile_number: Tile) -> [u8; 4] {
678 TileBondInfo::tile_color(self, tile_number)
679 }
680
681 #[getter]
682 fn tile_colors<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<u8>> {
683 let colors = TileBondInfo::tile_colors(self);
684 let mut arr = Array2::zeros((colors.len(), 4));
685 for (i, c) in colors.iter().enumerate() {
686 arr[[i, 0]] = c[0];
687 arr[[i, 1]] = c[1];
688 arr[[i, 2]] = c[2];
689 arr[[i, 3]] = c[3];
690 }
691 arr.into_pyarray(py)
692 }
693
694 /// Returns the current canvas for state as an array of tile names.
695 /// 'empty' indicates empty locations.
696 ///
697 /// Parameters
698 /// ----------
699 /// state : State or FFSStateRef
700 /// The state to return.
701 ///
702 /// Returns
703 /// -------
704 /// NDArray[str]
705 /// The current canvas for the state, as an array of tile names.
706 fn name_canvas<'py>(&self, state: PyStateOrRef<'py>, py: Python<'py>) -> PyResult<Py<PyArray2<Py<PyAny>>>> {
707 let tile_names = TileBondInfo::tile_names(self);
708 let canvas = match &state {
709 PyStateOrRef::State(s) => s.borrow().0.raw_array().to_owned(),
710 PyStateOrRef::Ref(s) => s.borrow().clone_state().0.raw_array().to_owned(),
711 };
712 let tile_index_fn = $tile_index_fn;
713 let name_array = canvas.mapv(|tile| {
714 let tile_index: usize = tile_index_fn(tile);
715 tile_names[tile_index].clone().into_pyobject(py).unwrap().unbind().into()
716 });
717 Ok(name_array.into_pyarray(py).unbind())
718 }
719
720 /// Returns the current canvas for state as an array of tile colors.
721 ///
722 /// Parameters
723 /// ----------
724 /// state : State, FFSStateRef, or NDArray
725 /// The state or canvas array to colorize.
726 ///
727 /// Returns
728 /// -------
729 /// NDArray[uint8]
730 /// The current canvas for the state, as an array of RGBA colors with shape (rows, cols, 4).
731 fn color_canvas<'py>(
732 &self,
733 state: PyStateOrCanvasRef<'py>,
734 py: Python<'py>,
735 ) -> PyResult<Bound<'py, PyArray3<u8>>> {
736 let colors = TileBondInfo::tile_colors(self);
737 let canvas = match &state {
738 PyStateOrCanvasRef::State(s) => s.borrow().0.raw_array().to_owned(),
739 PyStateOrCanvasRef::Ref(s) => s.borrow().clone_state().0.raw_array().to_owned(),
740 PyStateOrCanvasRef::Array(arr) => arr.readonly().as_array().to_owned(),
741 };
742 let tile_index_fn = $tile_index_fn;
743 let (rows, cols) = canvas.dim();
744 let mut color_array = ndarray::Array3::<u8>::zeros((rows, cols, 4));
745 for ((i, j), &tile) in canvas.indexed_iter() {
746 let tile_index: usize = tile_index_fn(tile);
747 let c = colors[tile_index];
748 color_array[[i, j, 0]] = c[0];
749 color_array[[i, j, 1]] = c[1];
750 color_array[[i, j, 2]] = c[2];
751 color_array[[i, j, 3]] = c[3];
752 }
753 Ok(color_array.into_pyarray(py))
754 }
755
756 fn get_param(&mut self, param_name: &str) -> PyResult<RustAny> {
757 Ok(RustAny(System::get_param(self, param_name)?))
758 }
759
760 #[pyo3(signature = (state, needed = &NeededUpdate::All))]
761 fn update_all(&self, state: &mut PyState, needed: &NeededUpdate) {
762 System::update_state(self, &mut state.0, needed)
763 }
764
765 /// Recalculate a state's rates.
766 ///
767 /// This is usually needed when a parameter of the system has
768 /// been changed.
769 ///
770 /// Parameters
771 /// ----------
772 /// state : State
773 /// The state to update.
774 /// needed : NeededUpdate, optional
775 /// The type of update needed. If not provided, all locations
776 /// will be recalculated.
777 #[pyo3(signature = (state, needed = &NeededUpdate::All))]
778 fn update_state(&self, state: &mut PyState, needed: &NeededUpdate) {
779 System::update_state(self, &mut state.0, needed)
780 }
781
782 #[pyo3(name = "setup_state")]
783 fn py_setup_state(&self, state: &mut PyState) -> PyResult<()> {
784 self.setup_state(&mut state.0).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
785 Ok(())
786 }
787
788 /// Calculate the committor function for a state: the probability that when a simulation
789 /// is started from that state, the assembly will grow to a larger size (cutoff_size)
790 /// rather than melting to zero tiles.
791 ///
792 /// Parameters
793 /// ----------
794 /// state : State
795 /// The state to analyze
796 /// cutoff_size : int
797 /// Size threshold for commitment
798 /// num_trials : int
799 /// Number of trials to run
800 /// max_time : float, optional
801 /// Maximum simulation time per trial
802 /// max_events : int, optional
803 /// Maximum events per trial
804 ///
805 /// Returns
806 /// -------
807 /// float
808 /// Probability of reaching cutoff_size (between 0.0 and 1.0)
809 #[pyo3(name = "calc_committor", signature = (state, cutoff_size, num_trials, max_time=None, max_events=None))]
810 fn py_calc_committor(
811 &mut self,
812 state: &PyState,
813 cutoff_size: NumTiles,
814 num_trials: usize,
815 max_time: Option<f64>,
816 max_events: Option<NumEvents>,
817 py: Python<'_>,
818 ) -> PyResult<f64> {
819
820 let state = &state.0;
821
822 let out = py.detach(|| {
823 self.calc_committor(
824 &state,
825 cutoff_size,
826 max_time,
827 max_events,
828 num_trials,
829 )});
830 out.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
831 }
832
833 /// Calculate the committor function for a state using adaptive sampling: the probability
834 /// that when a simulation is started from that state, the assembly will grow to a larger
835 /// size (cutoff_size) rather than melting to zero tiles. Automatically determines the
836 /// number of trials needed to achieve a specified confidence interval margin.
837 ///
838 /// Parameters
839 /// ----------
840 /// state : State
841 /// The state to analyze
842 /// cutoff_size : int
843 /// Size threshold for commitment
844 /// conf_interval_margin : float
845 /// Confidence interval margin (e.g., 0.05 for 5%)
846 /// max_time : float, optional
847 /// Maximum simulation time per trial
848 /// max_events : int, optional
849 /// Maximum events per trial
850 ///
851 /// Returns
852 /// -------
853 /// tuple[float, int]
854 /// Tuple of (probability of reaching cutoff_size, number of trials run)
855 #[pyo3(name = "calc_committor_adaptive", signature = (state, cutoff_size, conf_interval_margin, max_time=None, max_events=None))]
856 fn py_calc_committor_adaptive(
857 &self,
858 state: &PyState,
859 cutoff_size: NumTiles,
860 conf_interval_margin: f64,
861 max_time: Option<f64>,
862 max_events: Option<NumEvents>,
863 py: Python<'_>,
864 ) -> PyResult<(f64, usize)> {
865 py.detach(|| {
866 self.calc_committor_adaptive(
867 &state.0,
868 cutoff_size,
869 max_time,
870 max_events,
871 conf_interval_margin,
872 )
873 })
874 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
875 }
876
877 /// Calculate the committor function for multiple states using adaptive sampling.
878 ///
879 /// Parameters
880 /// ----------
881 /// states : List[State]
882 /// The states to analyze
883 /// cutoff_size : int
884 /// Size threshold for commitment
885 /// conf_interval_margin : float
886 /// Confidence interval margin (e.g., 0.05 for 5%)
887 /// max_time : float, optional
888 /// Maximum simulation time per trial
889 /// max_events : int, optional
890 /// Maximum events per trial
891 ///
892 /// Returns
893 /// -------
894 /// tuple[NDArray[float64], NDArray[usize]]
895 /// Tuple of (committor probabilities, number of trials for each state)
896 #[pyo3(name = "calc_committors_adaptive", signature = (states, cutoff_size, conf_interval_margin, max_time=None, max_events=None))]
897 fn py_calc_committors_adaptive<'py>(
898 &self,
899 states: Vec<Bound<'py, PyState>>,
900 cutoff_size: NumTiles,
901 conf_interval_margin: f64,
902 max_time: Option<f64>,
903 max_events: Option<NumEvents>,
904 py: Python<'py>,
905 ) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<usize>>)> {
906
907 let refs = states.iter().map(|x| x.borrow()).collect::<Vec<_>>();
908 let states = refs.iter().map(|x| &x.0).collect::<Vec<_>>();
909 let (committors, trials) = py.detach(|| {
910 self.calc_committors_adaptive(&states, cutoff_size, max_time, max_events, conf_interval_margin)
911 }).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
912
913 Ok((committors.into_pyarray(py), trials.into_pyarray(py)))
914 }
915
916 /// Determine whether the committor probability for a state is above or below a threshold
917 /// with a specified confidence level using adaptive sampling.
918 ///
919 /// This function uses adaptive sampling to determine with the desired confidence whether
920 /// the true committor probability is above or below the given threshold. It continues
921 /// sampling until the confidence interval is narrow enough to make a definitive determination.
922 ///
923 /// Parameters
924 /// ----------
925 /// state : State
926 /// The state to analyze
927 /// cutoff_size : int
928 /// Size threshold for commitment
929 /// threshold : float
930 /// The probability threshold to compare against (e.g., 0.5)
931 /// confidence_level : float
932 /// Confidence level for the threshold test (e.g., 0.95 for 95% confidence)
933 /// max_time : float, optional
934 /// Maximum simulation time per trial
935 /// max_events : int, optional
936 /// Maximum events per trial
937 /// max_trials : int, optional
938 /// Maximum number of trials to run (default: 100000)
939 /// return_on_max_trials : bool, optional
940 /// If True, return results even when max_trials is exceeded (default: False)
941 /// ci_confidence_level : float, optional
942 /// Confidence level for the returned confidence interval (default: None, no CI returned)
943 /// Can be different from confidence_level (e.g., test at 95%, show 99% CI)
944 ///
945 /// Returns
946 /// -------
947 /// tuple[bool, float, tuple[float, float] | None, int, bool]
948 /// Tuple of (is_above_threshold, probability_estimate, confidence_interval, num_trials, exceeded_max_trials) where:
949 /// - is_above_threshold: True if probability is above threshold with given confidence
950 /// - probability_estimate: The estimated probability
951 /// - confidence_interval: Tuple of (lower_bound, upper_bound) or None if ci_confidence_level not provided
952 /// - num_trials: Number of trials performed
953 /// - exceeded_max_trials: True if max_trials was exceeded (warning flag)
954 #[allow(clippy::too_many_arguments)]
955 #[pyo3(name = "calc_committor_threshold_test", signature = (state, cutoff_size, threshold, confidence_level, max_time=None, max_events=None, max_trials=None, return_on_max_trials=false))]
956 fn py_calc_committor_threshold_test(
957 &mut self,
958 state: &mut PyState,
959 cutoff_size: NumTiles,
960 threshold: f64,
961 confidence_level: f64,
962 max_time: Option<f64>,
963 max_events: Option<NumEvents>,
964 max_trials: Option<usize>,
965 return_on_max_trials: bool,
966 py: Python<'_>,
967 ) -> PyResult<(bool, f64, usize, bool)> {
968 py.detach(|| {
969 self.calc_committor_threshold_test(
970 &state.0,
971 cutoff_size,
972 threshold,
973 confidence_level,
974 max_time,
975 max_events,
976 max_trials,
977 return_on_max_trials,
978 )
979 })
980 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
981 }
982
983 /// Calculate forward probability for a given state.
984 ///
985 /// This function calculates the probability that a state will grow by at least
986 /// `forward_step` tiles before shrinking to size 0. Unlike calc_committor which
987 /// uses a fixed cutoff size, this uses a dynamic cutoff based on the current
988 /// state size plus the forward_step parameter.
989 ///
990 /// Parameters
991 /// ----------
992 /// state : State
993 /// The initial state to analyze
994 /// forward_step : int, optional
995 /// Number of tiles to grow beyond current size (default: 1)
996 /// num_trials : int
997 /// Number of simulation trials to run
998 /// max_time : float, optional
999 /// Maximum simulation time per trial
1000 /// max_events : int, optional
1001 /// Maximum number of events per trial
1002 ///
1003 /// Returns
1004 /// -------
1005 /// float
1006 /// Probability of reaching forward_step additional tiles (between 0.0 and 1.0)
1007 #[pyo3(name = "calc_forward_probability", signature = (state, num_trials, forward_step=1, max_time=None, max_events=None))]
1008 fn py_calc_forward_probability(
1009 &mut self,
1010 state: &PyState,
1011 num_trials: usize,
1012 forward_step: NumTiles,
1013 max_time: Option<f64>,
1014 max_events: Option<NumEvents>,
1015 ) -> PyResult<f64> {
1016 let result = self.calc_forward_probability(&state.0, forward_step, max_time, max_events, num_trials);
1017 match result {
1018 Ok(probability) => Ok(probability),
1019 Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())),
1020 }
1021 }
1022
1023 /// Calculate forward probability adaptively for a given state.
1024 ///
1025 /// Uses adaptive sampling to determine the number of trials needed based on a
1026 /// confidence interval margin. Runs until the confidence interval is narrow enough.
1027 ///
1028 /// Parameters
1029 /// ----------
1030 /// state : State
1031 /// The initial state to analyze
1032 /// forward_step : int, optional
1033 /// Number of tiles to grow beyond current size (default: 1)
1034 /// conf_interval_margin : float
1035 /// Desired confidence interval margin (e.g., 0.05 for 5%)
1036 /// max_time : float, optional
1037 /// Maximum simulation time per trial
1038 /// max_events : int, optional
1039 /// Maximum number of events per trial
1040 ///
1041 /// Returns
1042 /// -------
1043 /// tuple[float, int]
1044 /// Tuple of (forward probability, number of trials run)
1045 #[pyo3(name = "calc_forward_probability_adaptive", signature = (state, conf_interval_margin, forward_step=1, max_time=None, max_events=None))]
1046 fn py_calc_forward_probability_adaptive(
1047 &self,
1048 state: &PyState,
1049 conf_interval_margin: f64,
1050 forward_step: NumTiles,
1051 max_time: Option<f64>,
1052 max_events: Option<NumEvents>,
1053 py: Python<'_>,
1054 ) -> PyResult<(f64, usize)> {
1055 let (probability, trials) = py.detach(|| {
1056 self.calc_forward_probability_adaptive(&state.0, forward_step, max_time, max_events, conf_interval_margin)
1057 }).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
1058
1059 Ok((probability, trials))
1060 }
1061
1062 /// Calculate forward probabilities adaptively for multiple states.
1063 ///
1064 /// Uses adaptive sampling for each state in parallel to determine forward
1065 /// probabilities with specified confidence intervals.
1066 ///
1067 /// Parameters
1068 /// ----------
1069 /// states : list[State]
1070 /// List of initial states to analyze
1071 /// forward_step : int, optional
1072 /// Number of tiles to grow beyond current size for each state (default: 1)
1073 /// conf_interval_margin : float
1074 /// Desired confidence interval margin (e.g., 0.05 for 5%)
1075 /// max_time : float, optional
1076 /// Maximum simulation time per trial
1077 /// max_events : int, optional
1078 /// Maximum number of events per trial
1079 ///
1080 /// Returns
1081 /// -------
1082 /// tuple[NDArray[float64], NDArray[usize]]
1083 /// Tuple of (forward probabilities, number of trials for each state)
1084 #[pyo3(name = "calc_forward_probabilities_adaptive", signature = (states, conf_interval_margin, forward_step=1, max_time=None, max_events=None))]
1085 fn py_calc_forward_probabilities_adaptive<'py>(
1086 &self,
1087 states: Vec<Bound<'py, PyState>>,
1088 conf_interval_margin: f64,
1089 forward_step: NumTiles,
1090 max_time: Option<f64>,
1091 max_events: Option<NumEvents>,
1092 py: Python<'py>,
1093 ) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<usize>>)> {
1094
1095 let refs = states.iter().map(|x| x.borrow()).collect::<Vec<_>>();
1096 let states = refs.iter().map(|x| &x.0).collect::<Vec<_>>();
1097 let (probabilities, trials) = py.detach(|| {
1098 self.calc_forward_probabilities_adaptive(&states, forward_step, max_time, max_events, conf_interval_margin)
1099 }).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
1100
1101 Ok((probabilities.into_pyarray(py), trials.into_pyarray(py)))
1102 }
1103
1104 /// Run FFS.
1105 ///
1106 /// Parameters
1107 /// ----------
1108 /// config : FFSRunConfig
1109 /// The configuration for the FFS run.
1110 /// **kwargs
1111 /// FFSRunConfig parameters as keyword arguments.
1112 ///
1113 /// Returns
1114 /// -------
1115 /// FFSRunResult
1116 /// The result of the FFS run.
1117 #[pyo3(name = "run_ffs", signature = (config = FFSRunConfig::default(), **kwargs))]
1118 fn py_run_ffs(
1119 &mut self,
1120 config: FFSRunConfig,
1121 kwargs: Option<Bound<PyDict>>,
1122 py: Python<'_>,
1123 ) -> PyResult<FFSRunResult> {
1124 let mut c = config;
1125
1126 if let Some(dict) = kwargs {
1127 for (k, v) in dict.iter() {
1128 c._py_set(&k.extract::<String>()?, v)?;
1129 }
1130 }
1131
1132 let res = py.detach(|| self.run_ffs(&c));
1133 match res {
1134 Ok(res) => Ok(res),
1135 Err(err) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
1136 err.to_string(),
1137 )),
1138 }
1139 }
1140
1141 fn __repr__(&self) -> String {
1142 format!("System({})", System::system_info(self))
1143 }
1144
1145 pub fn print_debug(&self) {
1146 println!("{:?}", self);
1147 }
1148
1149 /// Write the system to a JSON file.
1150 ///
1151 /// Parameters
1152 /// ----------
1153 /// filename : str
1154 /// The name of the file to write to.
1155 pub fn write_json(&self, filename: &str) -> Result<(), RgrowError> {
1156 serde_json::to_writer(File::create(filename)?, self).unwrap();
1157 Ok(())
1158 }
1159
1160
1161
1162 /// Read a system from a JSON file.
1163 ///
1164 /// Parameters
1165 /// ----------
1166 /// filename : str
1167 /// The name of the file to read from.
1168 ///
1169 /// Returns
1170 /// -------
1171 /// Self
1172 #[staticmethod]
1173 pub fn read_json(filename: &str) -> Result<Self, RgrowError> {
1174 Ok(serde_json::from_reader(File::open(filename)?).unwrap())
1175 }
1176
1177 /// Place a tile at a point in the given state.
1178 ///
1179 /// This updates tile counts and rates but does not increment the
1180 /// event counter or record events in the state tracker.
1181 ///
1182 /// Parameters
1183 /// ----------
1184 /// state : PyState
1185 /// The state to modify.
1186 /// point : tuple of int
1187 /// The coordinates at which to place the tile (i, j).
1188 /// tile : int
1189 /// The tile number to place.
1190 /// replace : bool, optional
1191 /// If True (default), any existing tile at the target site is removed
1192 /// first. If False, raises an error if the site is occupied.
1193 ///
1194 /// Returns
1195 /// -------
1196 /// float
1197 /// The energy change from placing the tile.
1198 #[pyo3(name = "place_tile")]
1199 #[pyo3(signature = (state, point, tile, replace=true))]
1200 pub fn py_place_tile(
1201 &self,
1202 state: &mut PyState,
1203 point: (usize, usize),
1204 tile: u32,
1205 replace: bool,
1206 ) -> Result<f64, RgrowError> {
1207 let pt = PointSafe2(point);
1208 let energy_change = self.place_tile(&mut state.0, pt, tile.into(), replace)?;
1209 Ok(energy_change)
1210 }
1211
1212 // /// Find the first state in a trajectory above the critical threshold.
1213 // ///
1214 // /// Iterates through the trajectory (after filtering redundant events),
1215 // /// reconstructing the state at each point and testing if the committor
1216 // /// probability is above the threshold with the specified confidence.
1217 // ///
1218 // /// Parameters
1219 // /// ----------
1220 // /// trajectory : pl.DataFrame
1221 // /// DataFrame with columns: row, col, new_tile, energy
1222 // /// config : CriticalStateConfig, optional
1223 // /// Configuration for the search (uses defaults if not provided)
1224 // ///
1225 // /// Returns
1226 // /// -------
1227 // /// CriticalStateResult | None
1228 // /// The first critical state found, or None if no state is above threshold.
1229 #[pyo3(name = "find_first_critical_state", signature = (end_state, config=CriticalStateConfig::default()))]
1230 pub fn py_find_first_critical_state(
1231 &mut self,
1232 end_state: &PyState,
1233 config: CriticalStateConfig,
1234 py: Python<'_>,
1235 ) -> PyResult<Option<CriticalStateResult>> {
1236 py.detach(|| {
1237 self.find_first_critical_state(&end_state.0, &config)
1238 })
1239 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
1240 }
1241
1242 // TODO: Uncomment when find_last_critical_state is implemented on System trait
1243 // /// Find the last state not above threshold, return the next state.
1244 // ///
1245 // /// Iterates backwards through the trajectory to find the last state that is
1246 // /// NOT above the critical threshold, then returns the next state (which should
1247 // /// be above threshold). This is useful for finding the "critical nucleus".
1248 // ///
1249 // /// Parameters
1250 // /// ----------
1251 // /// trajectory : pl.DataFrame
1252 // /// DataFrame with columns: row, col, new_tile, energy
1253 // /// config : CriticalStateConfig, optional
1254 // /// Configuration for the search (uses defaults if not provided)
1255 // ///
1256 // /// Returns
1257 // /// -------
1258 // /// CriticalStateResult | None
1259 // /// The first state above threshold (following the last subcritical state),
1260 // /// or None if no transition is found.
1261 #[pyo3(name = "find_last_critical_state", signature = (end_state, config=CriticalStateConfig::default()))]
1262 pub fn py_find_last_critical_state(
1263 &mut self,
1264 end_state: &PyState,
1265 config: CriticalStateConfig,
1266 py: Python<'_>,
1267 ) -> PyResult<Option<CriticalStateResult>> {
1268 py.detach(|| {
1269 self.find_last_critical_state(&end_state.0, &config)
1270 })
1271 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
1272 }
1273 }
1274 };
1275}
1276
1277create_py_system!(KTAM);
1278create_py_system!(ATAM);
1279create_py_system!(OldKTAM);
1280create_py_system!(SDC);
1281create_py_system!(KBlock, |tile: u32| (tile >> 4) as usize);
1282create_py_system!(SDC1DBindReplace);
1283
1284#[pymethods]
1285impl KBlock {
1286 #[getter]
1287 fn get_seed(&self) -> HashMap<(usize, usize), u32> {
1288 self.seed
1289 .clone()
1290 .into_iter()
1291 .map(|(k, v)| (k.0, v.into()))
1292 .collect()
1293 }
1294
1295 #[setter]
1296 fn set_seed(&mut self, seed: HashMap<(usize, usize), u32>) {
1297 self.seed = seed
1298 .into_iter()
1299 .map(|(k, v)| {
1300 (
1301 PointSafe2(k),
1302 crate::models::kblock::TileType(v as usize).unblocked(),
1303 )
1304 })
1305 .collect();
1306 }
1307
1308 #[getter]
1309 fn get_glue_links<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f64>> {
1310 self.glue_links.mapv(|x| x.into()).to_pyarray(py)
1311 }
1312
1313 #[setter]
1314 fn set_glue_links(&mut self, glue_links: &Bound<PyArray2<f64>>) {
1315 self.glue_links = glue_links.to_owned_array().mapv(|x| x.into());
1316 self.update();
1317 }
1318
1319 fn py_get_tile_raw_glues(&self, tile: u32) -> Vec<usize> {
1320 self.get_tile_raw_glues(tile.into())
1321 }
1322
1323 fn py_get_tile_uncovered_glues(&self, tile: u32) -> Vec<usize> {
1324 self.get_tile_unblocked_glues(tile.into())
1325 }
1326
1327 #[getter]
1328 fn get_cover_concentrations(&self) -> Vec<f64> {
1329 self.blocker_concentrations
1330 .clone()
1331 .into_iter()
1332 .map(|x| x.into())
1333 .collect()
1334 }
1335
1336 #[setter]
1337 fn set_cover_concentrations(&mut self, cover_concentrations: Vec<f64>) {
1338 self.blocker_concentrations = cover_concentrations.into_iter().map(|x| x.into()).collect();
1339 self.update();
1340 }
1341}