posemesh_compute_node/
session.rs

1use compute_runner_api::LeaseEnvelope;
2use rand::distributions::{Distribution, Uniform};
3use rand::Rng;
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use url::Url;
9use uuid::Uuid;
10
11/// Session lifecycle status.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SessionStatus {
14    Pending,
15    Running,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct SessionState {
20    pub task_id: Uuid,
21    pub job_id: Option<Uuid>,
22    pub capability: String,
23    pub meta: Value,
24    pub inputs_cids: Vec<String>,
25    pub domain_id: Option<Uuid>,
26    pub domain_server_url: Option<Url>,
27    pub lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
28    pub access_token: Option<String>,
29    pub access_token_expires_at: Option<chrono::DateTime<chrono::Utc>>,
30    pub last_progress: Option<Value>,
31    pub next_heartbeat_due: Option<Instant>,
32    pub status: SessionStatus,
33    pub cancel: bool,
34}
35
36/// Immutable snapshot of session state.
37#[derive(Debug, Clone)]
38pub struct SessionSnapshot(SessionState);
39
40impl SessionSnapshot {
41    pub fn task_id(&self) -> Uuid {
42        self.0.task_id
43    }
44
45    pub fn job_id(&self) -> Option<Uuid> {
46        self.0.job_id
47    }
48
49    pub fn capability(&self) -> &str {
50        &self.0.capability
51    }
52
53    pub fn meta(&self) -> &Value {
54        &self.0.meta
55    }
56
57    pub fn inputs_cids(&self) -> &[String] {
58        &self.0.inputs_cids
59    }
60
61    pub fn domain_id(&self) -> Option<Uuid> {
62        self.0.domain_id
63    }
64
65    pub fn domain_server_url(&self) -> Option<&Url> {
66        self.0.domain_server_url.as_ref()
67    }
68
69    pub fn access_token(&self) -> Option<&str> {
70        self.0.access_token.as_deref()
71    }
72
73    pub fn access_token_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
74        self.0.access_token_expires_at
75    }
76
77    pub fn lease_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
78        self.0.lease_expires_at
79    }
80
81    pub fn next_heartbeat_due(&self) -> Option<Instant> {
82        self.0.next_heartbeat_due
83    }
84
85    pub fn status(&self) -> SessionStatus {
86        self.0.status
87    }
88
89    pub fn cancel(&self) -> bool {
90        self.0.cancel
91    }
92}
93
94/// Capabilities configured for the node.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct CapabilitySelector {
97    capabilities: Vec<String>,
98}
99
100impl CapabilitySelector {
101    pub fn new(capabilities: Vec<String>) -> Self {
102        Self { capabilities }
103    }
104
105    pub fn choose(&self) -> Option<&str> {
106        self.capabilities.first().map(|s| s.as_str())
107    }
108
109    pub fn accepts(&self, capability: &str) -> bool {
110        self.capabilities.iter().any(|c| c == capability)
111    }
112
113    pub fn all(&self) -> &[String] {
114        &self.capabilities
115    }
116}
117
118/// Distribution for randomized TTL heartbeats.
119#[derive(Debug, Clone, Copy)]
120pub struct HeartbeatPolicy {
121    pub min_ratio: f64,
122    pub max_ratio: f64,
123}
124
125impl HeartbeatPolicy {
126    pub const fn new(min_ratio: f64, max_ratio: f64) -> Self {
127        Self {
128            min_ratio,
129            max_ratio,
130        }
131    }
132
133    pub const fn default_policy() -> Self {
134        Self {
135            min_ratio: 0.55,
136            max_ratio: 0.65,
137        }
138    }
139
140    fn sample_ratio<R: Rng>(&self, rng: &mut R) -> f64 {
141        let min = self.min_ratio.min(self.max_ratio).max(0.0);
142        let max = self.max_ratio.max(self.min_ratio).max(min);
143        let dist = Uniform::new_inclusive(min, max);
144        dist.sample(rng)
145    }
146}
147
148#[derive(Debug, thiserror::Error, PartialEq, Eq)]
149pub enum SessionError {
150    #[error("lease did not include a task")]
151    MissingTask,
152    #[error("task capability `{got}` not in configured set {expected:?}")]
153    CapabilityMismatch { expected: Vec<String>, got: String },
154    #[error("no active session")]
155    NoActiveSession,
156}
157
158#[derive(Clone)]
159pub struct SessionManager {
160    selector: CapabilitySelector,
161    state: Arc<Mutex<Option<SessionState>>>,
162}
163
164impl SessionManager {
165    pub fn new(selector: CapabilitySelector) -> Self {
166        Self {
167            selector,
168            state: Arc::new(Mutex::new(None)),
169        }
170    }
171
172    pub fn selector(&self) -> &CapabilitySelector {
173        &self.selector
174    }
175
176    pub async fn snapshot(&self) -> Option<SessionSnapshot> {
177        let guard = self.state.lock().await;
178        guard.as_ref().cloned().map(SessionSnapshot)
179    }
180
181    pub async fn clear(&self) {
182        *self.state.lock().await = None;
183    }
184
185    pub async fn start_session<R: Rng>(
186        &self,
187        lease: &LeaseEnvelope,
188        now: Instant,
189        policy: &HeartbeatPolicy,
190        rng: &mut R,
191    ) -> Result<SessionSnapshot, SessionError> {
192        let task = lease.task.clone();
193        if !self.selector.accepts(&task.capability) {
194            return Err(SessionError::CapabilityMismatch {
195                expected: self.selector.all().to_vec(),
196                got: task.capability,
197            });
198        }
199
200        let mut state = SessionState {
201            task_id: task.id,
202            job_id: task.job_id,
203            capability: task.capability,
204            meta: task.meta,
205            inputs_cids: task.inputs_cids,
206            domain_id: lease.domain_id,
207            domain_server_url: extract_domain_server_url(lease),
208            lease_expires_at: lease.lease_expires_at,
209            access_token: lease.access_token.clone(),
210            access_token_expires_at: lease.access_token_expires_at,
211            last_progress: None,
212            next_heartbeat_due: None,
213            status: SessionStatus::Pending,
214            cancel: lease.cancel,
215        };
216        state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
217
218        *self.state.lock().await = Some(state.clone());
219        Ok(SessionSnapshot(state))
220    }
221
222    pub async fn apply_heartbeat<R: Rng>(
223        &self,
224        lease: &LeaseEnvelope,
225        progress: Option<Value>,
226        now: Instant,
227        policy: &HeartbeatPolicy,
228        rng: &mut R,
229    ) -> Result<SessionSnapshot, SessionError> {
230        let mut guard = self.state.lock().await;
231        let state = guard.as_mut().ok_or(SessionError::NoActiveSession)?;
232
233        let task = lease.task.clone();
234        state.task_id = task.id;
235        state.job_id = task.job_id;
236        state.capability = task.capability;
237        state.meta = task.meta;
238        state.inputs_cids = task.inputs_cids;
239
240        if let Some(domain_id) = lease.domain_id {
241            state.domain_id = Some(domain_id);
242        }
243        if let Some(url) = extract_domain_server_url(lease) {
244            state.domain_server_url = Some(url);
245        }
246
247        if let Some(token) = &lease.access_token {
248            state.access_token = Some(token.clone());
249        }
250        if let Some(expiry) = lease.access_token_expires_at {
251            state.access_token_expires_at = Some(expiry);
252        }
253        if let Some(lease_expiry) = lease.lease_expires_at {
254            state.lease_expires_at = Some(lease_expiry);
255        }
256
257        state.last_progress = progress;
258        state.status = SessionStatus::Running;
259        state.cancel = lease.cancel;
260        state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
261
262        Ok(SessionSnapshot(state.clone()))
263    }
264}
265
266fn compute_next_heartbeat<R: Rng>(
267    now: Instant,
268    lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
269    policy: &HeartbeatPolicy,
270    rng: &mut R,
271) -> Option<Instant> {
272    let expires = lease_expires_at?;
273    let ttl = expires.signed_duration_since(chrono::Utc::now());
274    if ttl.num_milliseconds() <= 0 {
275        return Some(now);
276    }
277    let ttl = ttl.to_std().ok()?;
278    let ratio = policy.sample_ratio(rng).clamp(0.0, 1.0);
279    let mut delay = ttl.mul_f64(ratio);
280    if delay > ttl {
281        delay = ttl;
282    }
283    if delay.is_zero() {
284        delay = Duration::from_millis(100);
285    }
286    Some(now + delay.min(ttl))
287}
288
289fn extract_domain_server_url(lease: &LeaseEnvelope) -> Option<Url> {
290    if let Some(url) = &lease.domain_server_url {
291        return Some(url.clone());
292    }
293    lookup_domain_url_from_meta(&lease.task.meta)
294}
295
296fn lookup_domain_url_from_meta(meta: &Value) -> Option<Url> {
297    meta.get("domain_server_url")
298        .and_then(|value| value.as_str())
299        .and_then(|raw| Url::parse(raw).ok())
300        .or_else(|| {
301            meta.get("legacy")
302                .and_then(|legacy| legacy.get("domain_server_url"))
303                .and_then(|value| value.as_str())
304                .and_then(|raw| Url::parse(raw).ok())
305        })
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use chrono::{Duration as ChronoDuration, Utc};
312    use rand::rngs::StdRng;
313    use rand::SeedableRng;
314    use serde_json::json;
315    use uuid::Uuid;
316
317    fn selector() -> CapabilitySelector {
318        CapabilitySelector::new(vec!["cap-a".to_string(), "cap-b".to_string()])
319    }
320
321    fn policy() -> HeartbeatPolicy {
322        HeartbeatPolicy::default_policy()
323    }
324
325    fn lease_base() -> LeaseEnvelope {
326        let now = Utc::now();
327        LeaseEnvelope {
328            access_token: Some("token".into()),
329            access_token_expires_at: Some(now + ChronoDuration::minutes(5)),
330            lease_expires_at: Some(now + ChronoDuration::minutes(10)),
331            cancel: false,
332            status: None,
333            domain_id: None,
334            domain_server_url: None,
335            task: compute_runner_api::TaskSpec {
336                id: Uuid::new_v4(),
337                job_id: Some(Uuid::new_v4()),
338                capability: "cap-a".into(),
339                capability_filters: json!({}),
340                inputs_cids: vec!["cid-1".into()],
341                outputs_prefix: None,
342                label: None,
343                stage: None,
344                meta: json!({ "hello": "world" }),
345                priority: None,
346                attempts: None,
347                max_attempts: None,
348                deps_remaining: None,
349                status: None,
350                mode: None,
351                organization_filter: None,
352                billing_units: None,
353                estimated_credit_cost: None,
354                debited_amount: None,
355                debited_at: None,
356                lease_expires_at: None,
357            },
358        }
359    }
360
361    #[test]
362    fn capability_selector_choose() {
363        let selector = selector();
364        assert_eq!(selector.choose(), Some("cap-a"));
365        assert!(selector.accepts("cap-b"));
366        assert!(!selector.accepts("other"));
367    }
368
369    #[tokio::test]
370    async fn start_session_rejects_unknown_capability() {
371        let manager = SessionManager::new(selector());
372        let mut lease = lease_base();
373        lease.task.capability = "other".into();
374        let mut rng = StdRng::seed_from_u64(123);
375        let res = manager
376            .start_session(&lease, Instant::now(), &policy(), &mut rng)
377            .await;
378        assert_eq!(
379            res.unwrap_err(),
380            SessionError::CapabilityMismatch {
381                expected: vec!["cap-a".into(), "cap-b".into()],
382                got: "other".into()
383            }
384        );
385    }
386
387    #[tokio::test]
388    async fn start_session_sets_next_heartbeat() {
389        let manager = SessionManager::new(selector());
390        let lease = lease_base();
391        let mut rng = StdRng::seed_from_u64(7);
392        let snapshot = manager
393            .start_session(&lease, Instant::now(), &policy(), &mut rng)
394            .await
395            .unwrap();
396        assert!(snapshot.next_heartbeat_due().is_some());
397        assert_eq!(snapshot.status(), SessionStatus::Pending);
398    }
399
400    #[tokio::test]
401    async fn apply_heartbeat_updates_state_and_cancel_flag() {
402        let manager = SessionManager::new(selector());
403        let mut lease = lease_base();
404        let mut rng = StdRng::seed_from_u64(9);
405        manager
406            .start_session(&lease, Instant::now(), &policy(), &mut rng)
407            .await
408            .unwrap();
409
410        lease.cancel = true;
411        lease.access_token = Some("new-token".into());
412        let snapshot = manager
413            .apply_heartbeat(
414                &lease,
415                Some(json!({"pct": 42})),
416                Instant::now(),
417                &policy(),
418                &mut rng,
419            )
420            .await
421            .unwrap();
422
423        assert_eq!(snapshot.access_token(), Some("new-token"));
424        assert!(snapshot.cancel());
425        assert_eq!(snapshot.status(), SessionStatus::Running);
426        assert!(snapshot.next_heartbeat_due().is_some());
427    }
428}