Skip to main content

action_core/client_server/
client.rs

1use std::marker::PhantomData;
2
3use core_types::{ActionGoalId, RtDuration, Timestamp};
4
5use crate::message::{ActionLiveness, ActionResult, ActionSchema, ActionSessionHealth, GoalStatus};
6
7use super::ActionChannel;
8
9mod trait_impl;
10
11pub struct BasicActionClient<G, F, R> {
12    action_name: String,
13    schema: ActionSchema,
14    closed: bool,
15    next_goal_id: u64,
16    heartbeat_timeout: Option<RtDuration>,
17    heartbeat_stalled_after: Option<RtDuration>,
18    channel: Option<std::sync::Arc<ActionChannel<G, F, R>>>,
19    _marker: PhantomData<(G, F, R)>,
20}
21
22impl<G, F, R> BasicActionClient<G, F, R> {
23    pub fn new(action_name: impl Into<String>, schema: ActionSchema) -> Self {
24        Self {
25            action_name: action_name.into(),
26            schema,
27            closed: false,
28            next_goal_id: 0,
29            heartbeat_timeout: Some(RtDuration::from_secs(5)),
30            heartbeat_stalled_after: Some(RtDuration::from_secs(2)),
31            channel: None,
32            _marker: PhantomData,
33        }
34    }
35
36    pub fn with_channel(
37        action_name: impl Into<String>,
38        schema: ActionSchema,
39        channel: std::sync::Arc<ActionChannel<G, F, R>>,
40    ) -> Self {
41        Self {
42            action_name: action_name.into(),
43            schema,
44            closed: false,
45            next_goal_id: 0,
46            heartbeat_timeout: Some(RtDuration::from_secs(5)),
47            heartbeat_stalled_after: Some(RtDuration::from_secs(2)),
48            channel: Some(channel),
49            _marker: PhantomData,
50        }
51    }
52
53    pub fn with_heartbeat_timeout(mut self, timeout: Option<RtDuration>) -> Self {
54        self.heartbeat_timeout = timeout;
55        self
56    }
57
58    pub fn with_stalled_threshold(mut self, stalled_after: Option<RtDuration>) -> Self {
59        self.heartbeat_stalled_after = stalled_after;
60        self
61    }
62
63    fn goal_health_at(&self, goal_id: ActionGoalId, now: Timestamp) -> Option<ActionSessionHealth> {
64        let ch = self.channel.as_ref()?;
65
66        let status = ch.statuses.lock().unwrap().get(&goal_id.0).copied()?;
67        let last_heartbeat = ch.heartbeats.lock().unwrap().get(&goal_id.0).copied();
68        let last_feedback = ch
69            .feedback_timestamps
70            .lock()
71            .unwrap()
72            .get(&goal_id.0)
73            .copied();
74        let last_result = ch
75            .result_timestamps
76            .lock()
77            .unwrap()
78            .get(&goal_id.0)
79            .copied();
80
81        let timeout_nanos = self.heartbeat_timeout.map(|value| value.as_nanos());
82        let stalled_nanos = self.heartbeat_stalled_after.map(|value| value.as_nanos());
83
84        let liveness = if status.is_terminal() {
85            ActionLiveness::Completed
86        } else if let (Some(last), Some(timeout)) = (last_heartbeat, timeout_nanos) {
87            let elapsed = now.0.saturating_sub(last.0);
88            if elapsed >= timeout {
89                ActionLiveness::TimedOut
90            } else if let Some(stalled_after) = stalled_nanos {
91                if elapsed >= stalled_after {
92                    ActionLiveness::Stalled
93                } else {
94                    ActionLiveness::Active
95                }
96            } else {
97                ActionLiveness::Active
98            }
99        } else {
100            ActionLiveness::Unknown
101        };
102
103        Some(ActionSessionHealth {
104            goal_id,
105            status,
106            liveness,
107            heartbeat_timeout_nanos: timeout_nanos,
108            stalled_threshold_nanos: stalled_nanos,
109            last_heartbeat_at_unix_nanos: last_heartbeat.map(|value| value.0),
110            last_feedback_at_unix_nanos: last_feedback.map(|value| value.0),
111            last_result_at_unix_nanos: last_result.map(|value| value.0),
112        })
113    }
114
115    fn apply_heartbeat_timeouts_at(&self, now: Timestamp) -> usize {
116        let Some(timeout) = self.heartbeat_timeout else {
117            return 0;
118        };
119        let Some(ch) = &self.channel else {
120            return 0;
121        };
122
123        let timeout_nanos = timeout.as_nanos();
124        let expired_ids: Vec<u64> = {
125            let statuses = ch.statuses.lock().unwrap();
126            let heartbeats = ch.heartbeats.lock().unwrap();
127            statuses
128                .iter()
129                .filter_map(|(goal_id, status)| {
130                    if !matches!(status, GoalStatus::Executing | GoalStatus::Canceling) {
131                        return None;
132                    }
133                    let last = heartbeats.get(goal_id).copied().unwrap_or(Timestamp(0));
134                    let elapsed = now.0.saturating_sub(last.0);
135                    if elapsed >= timeout_nanos {
136                        Some(*goal_id)
137                    } else {
138                        None
139                    }
140                })
141                .collect()
142        };
143
144        if expired_ids.is_empty() {
145            return 0;
146        }
147
148        let mut statuses = ch.statuses.lock().unwrap();
149        let mut results = ch.results.lock().unwrap();
150        let mut heartbeats = ch.heartbeats.lock().unwrap();
151        let mut transitioned = 0usize;
152
153        for goal_id in expired_ids {
154            let Some(status) = statuses.get_mut(&goal_id) else {
155                continue;
156            };
157            if !matches!(*status, GoalStatus::Executing | GoalStatus::Canceling) {
158                continue;
159            }
160
161            *status = GoalStatus::Failed;
162            results.entry(goal_id).or_insert(ActionResult {
163                status: GoalStatus::Failed,
164                value: None,
165                error: Some("action heartbeat timeout".to_string()),
166            });
167            heartbeats.remove(&goal_id);
168            ch.result_timestamps.lock().unwrap().insert(goal_id, now);
169            transitioned += 1;
170        }
171
172        transitioned
173    }
174}