Skip to main content

epics_seq/
program.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::AtomicBool;
5use std::sync::Arc;
6
7use epics_base_rs::client::CaClient;
8use tokio::sync::Notify;
9
10use crate::channel::Channel;
11use crate::channel_store::ChannelStore;
12use crate::error::SeqResult;
13use crate::event_flag::EventFlagSet;
14use crate::state_set::StateSetContext;
15use crate::variables::{ProgramMeta, ProgramVars};
16
17/// Shared state across all state sets in a program.
18pub struct ProgramShared<V: ProgramVars> {
19    pub store: Arc<ChannelStore>,
20    pub channels: Arc<Vec<Channel>>,
21    pub event_flags: Arc<EventFlagSet>,
22    pub shutdown: Arc<AtomicBool>,
23    pub ss_wakeups: Vec<Arc<Notify>>,
24    pub _phantom: std::marker::PhantomData<V>,
25}
26
27/// Type alias for a state set function.
28///
29/// Each state set is an async function that takes a `StateSetContext`
30/// and runs the state machine loop until shutdown.
31pub type StateSetFn<V> =
32    Box<dyn Fn(StateSetContext<V>) -> Pin<Box<dyn Future<Output = SeqResult<()>> + Send>> + Send + Sync>;
33
34/// Builder for constructing and running a sequencer program.
35pub struct ProgramBuilder<V: ProgramVars, M: ProgramMeta> {
36    pub name: String,
37    pub initial_vars: V,
38    pub macros: HashMap<String, String>,
39    pub state_set_fns: Vec<StateSetFn<V>>,
40    _meta: std::marker::PhantomData<M>,
41}
42
43impl<V: ProgramVars, M: ProgramMeta> ProgramBuilder<V, M> {
44    pub fn new(name: &str, initial_vars: V) -> Self {
45        Self {
46            name: name.to_string(),
47            initial_vars,
48            macros: HashMap::new(),
49            state_set_fns: Vec::new(),
50            _meta: std::marker::PhantomData,
51        }
52    }
53
54    pub fn macros(mut self, macro_str: &str) -> Self {
55        self.macros = crate::macros::parse_macros(macro_str);
56        self
57    }
58
59    pub fn add_ss(mut self, f: StateSetFn<V>) -> Self {
60        self.state_set_fns.push(f);
61        self
62    }
63
64    /// Build and run the program. Blocks until all state sets finish or shutdown.
65    pub async fn run(self) -> SeqResult<()> {
66        let num_channels = M::NUM_CHANNELS;
67        let num_flags = M::NUM_EVENT_FLAGS;
68        let num_ss = self.state_set_fns.len();
69
70        tracing::info!("starting program '{}' with {} state sets, {} channels, {} event flags",
71            self.name, num_ss, num_channels, num_flags);
72
73        // Create CA client
74        let ca_client = CaClient::new()
75            .await
76            .map_err(|e| crate::error::SeqError::Other(format!("CA init failed: {e}")))?;
77
78        // Create shared state
79        let store = Arc::new(ChannelStore::new(num_channels));
80        let shutdown = Arc::new(AtomicBool::new(false));
81
82        // Create per-SS wakeup notifiers
83        let ss_wakeups: Vec<Arc<Notify>> = (0..num_ss).map(|_| Arc::new(Notify::new())).collect();
84
85        // Create event flag set
86        let sync_map = M::event_flag_sync_map();
87        let event_flags = Arc::new(EventFlagSet::new(num_flags, sync_map, ss_wakeups.clone()));
88
89        // Create per-SS dirty flags
90        let dirty_per_ss: Vec<Arc<Vec<AtomicBool>>> = (0..num_ss)
91            .map(|_| {
92                Arc::new(
93                    (0..num_channels)
94                        .map(|_| AtomicBool::new(false))
95                        .collect(),
96                )
97            })
98            .collect();
99
100        // Create and connect channels
101        let channel_defs = M::channel_defs();
102        let mut channels: Vec<Channel> = channel_defs
103            .into_iter()
104            .enumerate()
105            .map(|(id, def)| Channel::new(def, id))
106            .collect();
107
108        for ch in &mut channels {
109            ch.connect(
110                &ca_client,
111                &self.macros,
112                store.clone(),
113                dirty_per_ss.clone(),
114                ss_wakeups.clone(),
115                Some(event_flags.clone()),
116            )
117            .await;
118        }
119
120        let channels = Arc::new(channels);
121
122        // Spawn state set tasks
123        let mut handles = Vec::new();
124        for (ss_id, ss_fn) in self.state_set_fns.into_iter().enumerate() {
125            let ctx = StateSetContext::new(
126                self.initial_vars.clone(),
127                ss_id,
128                num_channels,
129                ss_wakeups[ss_id].clone(),
130                store.clone(),
131                channels.clone(),
132                event_flags.clone(),
133                shutdown.clone(),
134            );
135
136            let handle = tokio::spawn(async move {
137                if let Err(e) = ss_fn(ctx).await {
138                    tracing::error!("state set {ss_id} error: {e}");
139                }
140            });
141            handles.push(handle);
142        }
143
144        // Wait for all state sets to complete
145        for handle in handles {
146            let _ = handle.await;
147        }
148
149        tracing::info!("program '{}' finished", self.name);
150        Ok(())
151    }
152}