1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use epics_base_rs::client::CaChannel;
6use tokio::sync::{Mutex as TokioMutex, Notify};
7
8use crate::channel::Channel;
9use crate::channel_store::ChannelStore;
10use crate::error::{PvOpResult, PvStat, SeqError, SeqResult};
11use crate::event_flag::EventFlagSet;
12use crate::variables::ProgramVars;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CompType {
17 Default,
19 Sync,
21 Async,
23}
24
25struct AsyncOpSlot {
27 pending: bool,
28 completed: bool,
29 result: Option<PvOpResult>,
30}
31
32impl AsyncOpSlot {
33 fn new() -> Self {
34 Self {
35 pending: false,
36 completed: false,
37 result: None,
38 }
39 }
40
41 fn reset(&mut self) {
42 self.pending = false;
43 self.completed = false;
44 self.result = None;
45 }
46}
47
48pub struct StateSetContext<V: ProgramVars> {
54 pub local_vars: V,
56 pub ss_id: usize,
58 current_state: usize,
60 next_state: Option<usize>,
62 prev_state: Option<usize>,
64 time_entered: Instant,
66 next_wakeup: Option<Duration>,
68 dirty: Arc<Vec<AtomicBool>>,
70 wakeup: Arc<Notify>,
72 store: Arc<ChannelStore>,
74 channels: Arc<Vec<Channel>>,
76 event_flags: Arc<EventFlagSet>,
78 shutdown: Arc<AtomicBool>,
80 get_slots: Vec<Arc<TokioMutex<AsyncOpSlot>>>,
82 put_slots: Vec<Arc<TokioMutex<AsyncOpSlot>>>,
84 last_op_result: Vec<PvOpResult>,
86}
87
88impl<V: ProgramVars> StateSetContext<V> {
89 pub fn new(
90 initial_vars: V,
91 ss_id: usize,
92 num_channels: usize,
93 wakeup: Arc<Notify>,
94 store: Arc<ChannelStore>,
95 channels: Arc<Vec<Channel>>,
96 event_flags: Arc<EventFlagSet>,
97 shutdown: Arc<AtomicBool>,
98 ) -> Self {
99 let dirty: Vec<AtomicBool> = (0..num_channels).map(|_| AtomicBool::new(false)).collect();
100 let get_slots = (0..num_channels)
101 .map(|_| Arc::new(TokioMutex::new(AsyncOpSlot::new())))
102 .collect();
103 let put_slots = (0..num_channels)
104 .map(|_| Arc::new(TokioMutex::new(AsyncOpSlot::new())))
105 .collect();
106 let last_op_result = (0..num_channels).map(|_| PvOpResult::default()).collect();
107
108 Self {
109 local_vars: initial_vars,
110 ss_id,
111 current_state: 0,
112 next_state: None,
113 prev_state: None,
114 time_entered: Instant::now(),
115 next_wakeup: None,
116 dirty: Arc::new(dirty),
117 wakeup,
118 store,
119 channels,
120 event_flags,
121 shutdown,
122 get_slots,
123 put_slots,
124 last_op_result,
125 }
126 }
127
128 pub fn dirty_flags(&self) -> Arc<Vec<AtomicBool>> {
130 self.dirty.clone()
131 }
132
133 pub fn current_state(&self) -> usize {
137 self.current_state
138 }
139
140 pub fn transition_to(&mut self, state: usize) {
143 self.next_state = Some(state);
144 }
145
146 pub fn has_transition(&self) -> bool {
148 self.next_state.is_some()
149 }
150
151 pub fn is_shutdown(&self) -> bool {
153 self.shutdown.load(Ordering::Acquire)
154 }
155
156 pub fn delay(&mut self, seconds: f64) -> bool {
164 let target = Duration::from_secs_f64(seconds);
165 let elapsed = self.time_entered.elapsed();
166 if elapsed >= target {
167 true
168 } else {
169 let remaining = target - elapsed;
170 self.next_wakeup = Some(match self.next_wakeup {
171 Some(current) => current.min(remaining),
172 None => remaining,
173 });
174 false
175 }
176 }
177
178 pub async fn pv_get(&mut self, ch_id: usize, comp: CompType) -> PvStat {
185 match comp {
186 CompType::Async => self.pv_get_async(ch_id).await,
187 CompType::Default | CompType::Sync => self.pv_get_sync(ch_id).await,
188 }
189 }
190
191 async fn pv_get_sync(&mut self, ch_id: usize) -> PvStat {
192 let ca_ch = match self.get_ca_channel(ch_id) {
193 Ok(ch) => ch,
194 Err(_) => {
195 let result = PvOpResult {
196 stat: PvStat::Disconnected,
197 severity: 3,
198 message: Some("channel not connected".into()),
199 };
200 self.update_last_op_result(ch_id, result);
201 return PvStat::Disconnected;
202 }
203 };
204
205 let timeout = tokio::time::timeout(Duration::from_secs(5), ca_ch.get()).await;
206 match timeout {
207 Ok(Ok((_dbr, value))) => {
208 self.store.set(ch_id, value.clone());
209 self.local_vars.set_channel_value(ch_id, &value);
210 let result = PvOpResult::default();
211 self.update_last_op_result(ch_id, result);
212 PvStat::Ok
213 }
214 Ok(Err(e)) => {
215 let result = PvOpResult {
216 stat: PvStat::Error,
217 severity: 3,
218 message: Some(format!("{e}")),
219 };
220 self.update_last_op_result(ch_id, result);
221 PvStat::Error
222 }
223 Err(_) => {
224 let result = PvOpResult {
225 stat: PvStat::Timeout,
226 severity: 3,
227 message: Some("pvGet timeout (5s)".into()),
228 };
229 self.update_last_op_result(ch_id, result);
230 PvStat::Timeout
231 }
232 }
233 }
234
235 async fn pv_get_async(&mut self, ch_id: usize) -> PvStat {
236 let slot = match self.get_slots.get(ch_id) {
237 Some(s) => s.clone(),
238 None => return PvStat::Error,
239 };
240
241 {
242 let mut s = slot.lock().await;
243 if s.pending {
244 return PvStat::Error; }
246 s.pending = true;
247 s.completed = false;
248 s.result = None;
249 }
250
251 let ca_ch = match self.get_ca_channel(ch_id) {
252 Ok(ch) => ch.clone(),
253 Err(_) => {
254 let mut s = slot.lock().await;
255 s.pending = false;
256 s.completed = true;
257 s.result = Some(PvOpResult {
258 stat: PvStat::Disconnected,
259 severity: 3,
260 message: Some("channel not connected".into()),
261 });
262 return PvStat::Disconnected;
263 }
264 };
265
266 let store = self.store.clone();
267 let wakeup = self.wakeup.clone();
268
269 tokio::spawn(async move {
270 let result = ca_ch.get().await;
271 let mut s = slot.lock().await;
272 if !s.pending {
273 return; }
275 match result {
276 Ok((_dbr, value)) => {
277 store.set(ch_id, value);
278 s.result = Some(PvOpResult::default());
279 }
280 Err(e) => {
281 s.result = Some(PvOpResult {
282 stat: PvStat::Error,
283 severity: 3,
284 message: Some(format!("{e}")),
285 });
286 }
287 }
288 s.pending = false;
289 s.completed = true;
290 wakeup.notify_one();
291 });
292
293 PvStat::Ok
294 }
295
296 pub async fn pv_put(&mut self, ch_id: usize, comp: CompType) -> PvStat {
302 match comp {
303 CompType::Async => self.pv_put_async(ch_id).await,
304 CompType::Sync => self.pv_put_sync(ch_id).await,
305 CompType::Default => self.pv_put_default(ch_id).await,
306 }
307 }
308
309 async fn pv_put_default(&mut self, ch_id: usize) -> PvStat {
310 let value = self.local_vars.get_channel_value(ch_id);
311 self.store.set(ch_id, value.clone());
312
313 let ca_ch = match self.get_ca_channel(ch_id) {
314 Ok(ch) => ch,
315 Err(_) => {
316 let result = PvOpResult {
317 stat: PvStat::Disconnected,
318 severity: 3,
319 message: Some("channel not connected".into()),
320 };
321 self.update_last_op_result(ch_id, result);
322 return PvStat::Disconnected;
323 }
324 };
325
326 match ca_ch.put(&value).await {
327 Ok(()) => {
328 let result = PvOpResult::default();
329 self.update_last_op_result(ch_id, result);
330 PvStat::Ok
331 }
332 Err(e) => {
333 let result = PvOpResult {
334 stat: PvStat::Error,
335 severity: 3,
336 message: Some(format!("{e}")),
337 };
338 self.update_last_op_result(ch_id, result);
339 PvStat::Error
340 }
341 }
342 }
343
344 async fn pv_put_sync(&mut self, ch_id: usize) -> PvStat {
345 let value = self.local_vars.get_channel_value(ch_id);
346 self.store.set(ch_id, value.clone());
347
348 let ca_ch = match self.get_ca_channel(ch_id) {
349 Ok(ch) => ch,
350 Err(_) => {
351 let result = PvOpResult {
352 stat: PvStat::Disconnected,
353 severity: 3,
354 message: Some("channel not connected".into()),
355 };
356 self.update_last_op_result(ch_id, result);
357 return PvStat::Disconnected;
358 }
359 };
360
361 let timeout = tokio::time::timeout(Duration::from_secs(5), ca_ch.put(&value)).await;
362 match timeout {
363 Ok(Ok(())) => {
364 let result = PvOpResult::default();
365 self.update_last_op_result(ch_id, result);
366 PvStat::Ok
367 }
368 Ok(Err(e)) => {
369 let result = PvOpResult {
370 stat: PvStat::Error,
371 severity: 3,
372 message: Some(format!("{e}")),
373 };
374 self.update_last_op_result(ch_id, result);
375 PvStat::Error
376 }
377 Err(_) => {
378 let result = PvOpResult {
379 stat: PvStat::Timeout,
380 severity: 3,
381 message: Some("pvPut timeout (5s)".into()),
382 };
383 self.update_last_op_result(ch_id, result);
384 PvStat::Timeout
385 }
386 }
387 }
388
389 async fn pv_put_async(&mut self, ch_id: usize) -> PvStat {
390 let slot = match self.put_slots.get(ch_id) {
391 Some(s) => s.clone(),
392 None => return PvStat::Error,
393 };
394
395 let value = self.local_vars.get_channel_value(ch_id);
396 self.store.set(ch_id, value.clone());
397
398 {
399 let mut s = slot.lock().await;
400 if s.pending {
401 return PvStat::Error; }
403 s.pending = true;
404 s.completed = false;
405 s.result = None;
406 }
407
408 let ca_ch = match self.get_ca_channel(ch_id) {
409 Ok(ch) => ch.clone(),
410 Err(_) => {
411 let mut s = slot.lock().await;
412 s.pending = false;
413 s.completed = true;
414 s.result = Some(PvOpResult {
415 stat: PvStat::Disconnected,
416 severity: 3,
417 message: Some("channel not connected".into()),
418 });
419 return PvStat::Disconnected;
420 }
421 };
422
423 let wakeup = self.wakeup.clone();
424
425 tokio::spawn(async move {
426 let result = ca_ch.put(&value).await;
427 let mut s = slot.lock().await;
428 if !s.pending {
429 return; }
431 match result {
432 Ok(()) => {
433 s.result = Some(PvOpResult::default());
434 }
435 Err(e) => {
436 s.result = Some(PvOpResult {
437 stat: PvStat::Error,
438 severity: 3,
439 message: Some(format!("{e}")),
440 });
441 }
442 }
443 s.pending = false;
444 s.completed = true;
445 wakeup.notify_one();
446 });
447
448 PvStat::Ok
449 }
450
451 pub async fn pv_get_complete(&mut self, ch_id: usize) -> bool {
454 let slot = match self.get_slots.get(ch_id) {
455 Some(s) => s.clone(),
456 None => return true, };
458
459 let s = slot.lock().await;
460 if s.pending {
461 return false;
462 }
463 if s.completed {
464 if let Some(ref r) = s.result {
466 self.update_last_op_result(ch_id, r.clone());
467 }
468 if let Some(value) = self.store.get(ch_id) {
470 self.local_vars.set_channel_value(ch_id, &value);
471 }
472 return true;
473 }
474 true }
476
477 pub async fn pv_put_complete(&mut self, ch_id: usize) -> bool {
479 let slot = match self.put_slots.get(ch_id) {
480 Some(s) => s.clone(),
481 None => return true,
482 };
483
484 let s = slot.lock().await;
485 if s.pending {
486 return false;
487 }
488 if s.completed {
489 if let Some(ref r) = s.result {
490 self.update_last_op_result(ch_id, r.clone());
491 }
492 return true;
493 }
494 true
495 }
496
497 pub async fn pv_get_cancel(&mut self, ch_id: usize) {
499 if let Some(slot) = self.get_slots.get(ch_id) {
500 let mut s = slot.lock().await;
501 s.reset();
502 }
503 }
504
505 pub async fn pv_put_cancel(&mut self, ch_id: usize) {
507 if let Some(slot) = self.put_slots.get(ch_id) {
508 let mut s = slot.lock().await;
509 s.reset();
510 }
511 }
512
513 pub fn pv_status(&self, ch_id: usize) -> PvStat {
515 self.last_op_result
516 .get(ch_id)
517 .map_or(PvStat::Ok, |r| r.stat)
518 }
519
520 pub fn pv_severity(&self, ch_id: usize) -> i16 {
522 self.last_op_result
523 .get(ch_id)
524 .map_or(0, |r| r.severity)
525 }
526
527 pub fn pv_message(&self, ch_id: usize) -> Option<&str> {
529 self.last_op_result
530 .get(ch_id)
531 .and_then(|r| r.message.as_deref())
532 }
533
534 fn update_last_op_result(&mut self, ch_id: usize, result: PvOpResult) {
535 if ch_id < self.last_op_result.len() {
536 self.last_op_result[ch_id] = result;
537 }
538 }
539
540 fn get_ca_channel(&self, ch_id: usize) -> SeqResult<&CaChannel> {
541 let channel = self
542 .channels
543 .get(ch_id)
544 .ok_or(SeqError::InvalidChannelId(ch_id))?;
545 channel
546 .ca_channel()
547 .ok_or_else(|| SeqError::NotConnected(channel.def.pv_name.clone()))
548 }
549
550 pub fn ef_set(&self, ef_id: usize) {
554 self.event_flags.set(ef_id);
555 }
556
557 pub fn ef_test(&mut self, ef_id: usize) -> bool {
560 let result = self.event_flags.test(ef_id);
561 if result {
562 self.sync_channels_for_flag(ef_id);
563 }
564 result
565 }
566
567 pub fn ef_clear(&self, ef_id: usize) -> bool {
569 self.event_flags.clear(ef_id)
570 }
571
572 pub fn ef_test_and_clear(&mut self, ef_id: usize) -> bool {
575 let was_set = self.event_flags.test_and_clear(ef_id);
576 if was_set {
577 self.sync_channels_for_flag(ef_id);
578 }
579 was_set
580 }
581
582 fn sync_channels_for_flag(&mut self, ef_id: usize) {
584 let ch_ids = self.event_flags.synced_channels(ef_id).to_vec();
585 for ch_id in ch_ids {
586 if let Some(value) = self.store.get(ch_id) {
587 self.local_vars.set_channel_value(ch_id, &value);
588 }
589 }
590 }
591
592 pub fn pv_connected(&self, ch_id: usize) -> bool {
596 self.channels
597 .get(ch_id)
598 .map_or(false, |ch| ch.is_connected())
599 }
600
601 pub fn pv_connect_count(&self) -> usize {
603 self.channels.iter().filter(|ch| ch.is_connected()).count()
604 }
605
606 pub fn pv_channel_count(&self) -> usize {
608 self.channels.len()
609 }
610
611 pub fn sync_dirty_vars(&mut self) {
616 for ch_id in 0..self.dirty.len() {
617 if let Some(flag) = self.dirty.get(ch_id) {
618 if flag.swap(false, Ordering::AcqRel) {
619 if let Some(value) = self.store.get(ch_id) {
620 self.local_vars.set_channel_value(ch_id, &value);
621 }
622 }
623 }
624 }
625 }
626
627 pub fn reset_wakeup(&mut self) {
631 self.next_wakeup = None;
632 }
633
634 pub async fn wait_for_wakeup(&self) {
636 match self.next_wakeup {
637 Some(timeout) => {
638 tokio::select! {
639 _ = self.wakeup.notified() => {}
640 _ = tokio::time::sleep(timeout) => {}
641 }
642 }
643 None => {
644 self.wakeup.notified().await;
646 }
647 }
648 }
649
650 pub fn enter_state(&mut self, state: usize) {
652 self.prev_state = if state != self.current_state {
653 Some(self.current_state)
654 } else {
655 self.prev_state
656 };
657 self.current_state = state;
658 self.next_state = None;
659 self.time_entered = Instant::now();
660 }
661
662 pub fn should_run_entry(&self) -> bool {
664 self.prev_state.map_or(true, |prev| prev != self.current_state)
665 }
666
667 pub fn should_run_exit(&self) -> bool {
669 self.next_state.map_or(false, |next| next != self.current_state)
670 }
671
672 pub fn take_transition(&mut self) -> Option<usize> {
674 self.next_state.take()
675 }
676
677 pub fn wakeup(&self) -> &Arc<Notify> {
679 &self.wakeup
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use crate::channel::Channel;
687 use crate::channel_store::ChannelStore;
688 use crate::event_flag::EventFlagSet;
689 use crate::variables::ProgramVars;
690 use epics_base_rs::types::EpicsValue;
691
692 #[derive(Clone)]
693 struct TestVars {
694 values: Vec<f64>,
695 }
696
697 impl ProgramVars for TestVars {
698 fn get_channel_value(&self, ch_id: usize) -> EpicsValue {
699 EpicsValue::Double(self.values.get(ch_id).copied().unwrap_or(0.0))
700 }
701 fn set_channel_value(&mut self, ch_id: usize, value: &EpicsValue) {
702 if let Some(v) = value.to_f64() {
703 if ch_id < self.values.len() {
704 self.values[ch_id] = v;
705 }
706 }
707 }
708 }
709
710 fn make_ctx(num_channels: usize) -> StateSetContext<TestVars> {
711 let vars = TestVars {
712 values: vec![0.0; num_channels],
713 };
714 let wakeup = Arc::new(Notify::new());
715 let store = Arc::new(ChannelStore::new(num_channels));
716 let channels = Arc::new(Vec::<Channel>::new());
717 let efs = Arc::new(EventFlagSet::new(
718 1,
719 vec![vec![0]],
720 vec![wakeup.clone()],
721 ));
722 let shutdown = Arc::new(AtomicBool::new(false));
723 StateSetContext::new(vars, 0, num_channels, wakeup, store, channels, efs, shutdown)
724 }
725
726 #[test]
727 fn test_delay_not_elapsed() {
728 let mut ctx = make_ctx(0);
729 assert!(!ctx.delay(10.0));
730 assert!(ctx.next_wakeup.is_some());
731 }
732
733 #[test]
734 fn test_delay_elapsed() {
735 let mut ctx = make_ctx(0);
736 ctx.time_entered = Instant::now() - std::time::Duration::from_secs(5);
738 assert!(ctx.delay(3.0));
739 }
740
741 #[test]
742 fn test_state_transitions() {
743 let mut ctx = make_ctx(0);
744 ctx.enter_state(0);
745 assert_eq!(ctx.current_state(), 0);
746 assert!(ctx.should_run_entry()); ctx.transition_to(1);
749 assert!(ctx.has_transition());
750 assert!(ctx.should_run_exit()); let next = ctx.take_transition().unwrap();
753 assert_eq!(next, 1);
754 ctx.enter_state(next);
755 assert_eq!(ctx.current_state(), 1);
756 assert!(ctx.should_run_entry()); }
758
759 #[test]
760 fn test_self_transition_no_entry_exit() {
761 let mut ctx = make_ctx(0);
762 ctx.enter_state(0);
763 ctx.transition_to(0);
765 assert!(!ctx.should_run_exit()); }
767
768 #[test]
769 fn test_sync_dirty_vars() {
770 let mut ctx = make_ctx(2);
771 ctx.store.set(0, EpicsValue::Double(42.0));
773 ctx.dirty.get(0).unwrap().store(true, Ordering::Release);
774
775 assert!((ctx.local_vars.values[0] - 0.0).abs() < 1e-10);
776 ctx.sync_dirty_vars();
777 assert!((ctx.local_vars.values[0] - 42.0).abs() < 1e-10);
778 assert!((ctx.local_vars.values[1] - 0.0).abs() < 1e-10);
780 }
781
782 #[test]
783 fn test_ef_set_and_test() {
784 let mut ctx = make_ctx(1);
785 assert!(!ctx.ef_test(0));
786 ctx.ef_set(0);
787 assert!(ctx.ef_test(0));
788 }
789
790 #[test]
791 fn test_ef_test_and_clear() {
792 let mut ctx = make_ctx(1);
793 ctx.ef_set(0);
794 ctx.store.set(0, EpicsValue::Double(99.0));
796 assert!(ctx.ef_test_and_clear(0));
797 assert!(!ctx.ef_test(0));
799 assert!((ctx.local_vars.values[0] - 99.0).abs() < 1e-10);
801 }
802
803 #[test]
804 fn test_shutdown() {
805 let ctx = make_ctx(0);
806 assert!(!ctx.is_shutdown());
807 ctx.shutdown.store(true, Ordering::Release);
808 assert!(ctx.is_shutdown());
809 }
810
811 #[test]
812 fn test_pv_status_default() {
813 let ctx = make_ctx(2);
814 assert_eq!(ctx.pv_status(0), PvStat::Ok);
815 assert_eq!(ctx.pv_severity(0), 0);
816 assert_eq!(ctx.pv_message(0), None);
817 }
818
819 #[test]
820 fn test_pv_status_after_update() {
821 let mut ctx = make_ctx(2);
822 ctx.update_last_op_result(
823 0,
824 PvOpResult {
825 stat: PvStat::Timeout,
826 severity: 3,
827 message: Some("timeout".into()),
828 },
829 );
830 assert_eq!(ctx.pv_status(0), PvStat::Timeout);
831 assert_eq!(ctx.pv_severity(0), 3);
832 assert_eq!(ctx.pv_message(0), Some("timeout"));
833 assert_eq!(ctx.pv_status(1), PvStat::Ok);
835 }
836
837 #[test]
838 fn test_pv_status_invalid_channel() {
839 let ctx = make_ctx(1);
840 assert_eq!(ctx.pv_status(99), PvStat::Ok);
841 assert_eq!(ctx.pv_severity(99), 0);
842 assert_eq!(ctx.pv_message(99), None);
843 }
844
845 #[tokio::test]
846 async fn test_pv_get_disconnected() {
847 let mut ctx = make_ctx(1);
848 let stat = ctx.pv_get(0, CompType::Sync).await;
850 assert_eq!(stat, PvStat::Disconnected);
851 assert_eq!(ctx.pv_status(0), PvStat::Disconnected);
852 }
853
854 #[tokio::test]
855 async fn test_pv_put_disconnected() {
856 let mut ctx = make_ctx(1);
857 let stat = ctx.pv_put(0, CompType::Default).await;
858 assert_eq!(stat, PvStat::Disconnected);
859 }
860
861 #[tokio::test]
862 async fn test_async_get_complete_no_channel() {
863 let mut ctx = make_ctx(1);
864 let stat = ctx.pv_get(0, CompType::Async).await;
866 assert_eq!(stat, PvStat::Disconnected);
867 assert!(ctx.pv_get_complete(0).await);
869 }
870
871 #[tokio::test]
872 async fn test_async_get_cancel() {
873 let mut ctx = make_ctx(1);
874 ctx.pv_get_cancel(0).await;
876 assert!(ctx.pv_get_complete(0).await);
877 }
878
879 #[tokio::test]
880 async fn test_async_put_complete_no_channel() {
881 let mut ctx = make_ctx(1);
882 let stat = ctx.pv_put(0, CompType::Async).await;
883 assert_eq!(stat, PvStat::Disconnected);
884 assert!(ctx.pv_put_complete(0).await);
885 }
886}