1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::AtomicBool;
5use std::sync::Arc;
6
7use epics_base_rs::client::CaClient;
8use tokio::sync::Notify;
9
10use crate::channel::Channel;
11use crate::channel_store::ChannelStore;
12use crate::error::SeqResult;
13use crate::event_flag::EventFlagSet;
14use crate::state_set::StateSetContext;
15use crate::variables::{ProgramMeta, ProgramVars};
16
17pub struct ProgramShared<V: ProgramVars> {
19 pub store: Arc<ChannelStore>,
20 pub channels: Arc<Vec<Channel>>,
21 pub event_flags: Arc<EventFlagSet>,
22 pub shutdown: Arc<AtomicBool>,
23 pub ss_wakeups: Vec<Arc<Notify>>,
24 pub _phantom: std::marker::PhantomData<V>,
25}
26
27pub type StateSetFn<V> =
32 Box<dyn Fn(StateSetContext<V>) -> Pin<Box<dyn Future<Output = SeqResult<()>> + Send>> + Send + Sync>;
33
34pub struct ProgramBuilder<V: ProgramVars, M: ProgramMeta> {
36 pub name: String,
37 pub initial_vars: V,
38 pub macros: HashMap<String, String>,
39 pub state_set_fns: Vec<StateSetFn<V>>,
40 _meta: std::marker::PhantomData<M>,
41}
42
43impl<V: ProgramVars, M: ProgramMeta> ProgramBuilder<V, M> {
44 pub fn new(name: &str, initial_vars: V) -> Self {
45 Self {
46 name: name.to_string(),
47 initial_vars,
48 macros: HashMap::new(),
49 state_set_fns: Vec::new(),
50 _meta: std::marker::PhantomData,
51 }
52 }
53
54 pub fn macros(mut self, macro_str: &str) -> Self {
55 self.macros = crate::macros::parse_macros(macro_str);
56 self
57 }
58
59 pub fn add_ss(mut self, f: StateSetFn<V>) -> Self {
60 self.state_set_fns.push(f);
61 self
62 }
63
64 pub async fn run(self) -> SeqResult<()> {
66 let num_channels = M::NUM_CHANNELS;
67 let num_flags = M::NUM_EVENT_FLAGS;
68 let num_ss = self.state_set_fns.len();
69
70 tracing::info!("starting program '{}' with {} state sets, {} channels, {} event flags",
71 self.name, num_ss, num_channels, num_flags);
72
73 let ca_client = CaClient::new()
75 .await
76 .map_err(|e| crate::error::SeqError::Other(format!("CA init failed: {e}")))?;
77
78 let store = Arc::new(ChannelStore::new(num_channels));
80 let shutdown = Arc::new(AtomicBool::new(false));
81
82 let ss_wakeups: Vec<Arc<Notify>> = (0..num_ss).map(|_| Arc::new(Notify::new())).collect();
84
85 let sync_map = M::event_flag_sync_map();
87 let event_flags = Arc::new(EventFlagSet::new(num_flags, sync_map, ss_wakeups.clone()));
88
89 let dirty_per_ss: Vec<Arc<Vec<AtomicBool>>> = (0..num_ss)
91 .map(|_| {
92 Arc::new(
93 (0..num_channels)
94 .map(|_| AtomicBool::new(false))
95 .collect(),
96 )
97 })
98 .collect();
99
100 let channel_defs = M::channel_defs();
102 let mut channels: Vec<Channel> = channel_defs
103 .into_iter()
104 .enumerate()
105 .map(|(id, def)| Channel::new(def, id))
106 .collect();
107
108 for ch in &mut channels {
109 ch.connect(
110 &ca_client,
111 &self.macros,
112 store.clone(),
113 dirty_per_ss.clone(),
114 ss_wakeups.clone(),
115 Some(event_flags.clone()),
116 )
117 .await;
118 }
119
120 let channels = Arc::new(channels);
121
122 let mut handles = Vec::new();
124 for (ss_id, ss_fn) in self.state_set_fns.into_iter().enumerate() {
125 let ctx = StateSetContext::new(
126 self.initial_vars.clone(),
127 ss_id,
128 num_channels,
129 ss_wakeups[ss_id].clone(),
130 store.clone(),
131 channels.clone(),
132 event_flags.clone(),
133 shutdown.clone(),
134 );
135
136 let handle = tokio::spawn(async move {
137 if let Err(e) = ss_fn(ctx).await {
138 tracing::error!("state set {ss_id} error: {e}");
139 }
140 });
141 handles.push(handle);
142 }
143
144 for handle in handles {
146 let _ = handle.await;
147 }
148
149 tracing::info!("program '{}' finished", self.name);
150 Ok(())
151 }
152}