posemesh_compute_node/
session.rs

1use compute_runner_api::LeaseEnvelope;
2use rand::distributions::{Distribution, Uniform};
3use rand::Rng;
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use url::Url;
9use uuid::Uuid;
10
11use crate::dms::types::HeartbeatResponse;
12
13/// Session lifecycle status.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SessionStatus {
16    Pending,
17    Running,
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct SessionState {
22    pub task_id: Uuid,
23    pub job_id: Option<Uuid>,
24    pub capability: String,
25    pub meta: Value,
26    pub inputs_cids: Vec<String>,
27    pub domain_id: Option<Uuid>,
28    pub domain_server_url: Option<Url>,
29    pub lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
30    pub access_token: Option<String>,
31    pub access_token_expires_at: Option<chrono::DateTime<chrono::Utc>>,
32    pub last_progress: Option<Value>,
33    pub next_heartbeat_due: Option<Instant>,
34    pub status: SessionStatus,
35    pub cancel: bool,
36}
37
38/// Immutable snapshot of session state.
39#[derive(Debug, Clone)]
40pub struct SessionSnapshot(SessionState);
41
42impl SessionSnapshot {
43    pub fn task_id(&self) -> Uuid {
44        self.0.task_id
45    }
46
47    pub fn job_id(&self) -> Option<Uuid> {
48        self.0.job_id
49    }
50
51    pub fn capability(&self) -> &str {
52        &self.0.capability
53    }
54
55    pub fn meta(&self) -> &Value {
56        &self.0.meta
57    }
58
59    pub fn inputs_cids(&self) -> &[String] {
60        &self.0.inputs_cids
61    }
62
63    pub fn domain_id(&self) -> Option<Uuid> {
64        self.0.domain_id
65    }
66
67    pub fn domain_server_url(&self) -> Option<&Url> {
68        self.0.domain_server_url.as_ref()
69    }
70
71    pub fn access_token(&self) -> Option<&str> {
72        self.0.access_token.as_deref()
73    }
74
75    pub fn access_token_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
76        self.0.access_token_expires_at
77    }
78
79    pub fn lease_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
80        self.0.lease_expires_at
81    }
82
83    pub fn next_heartbeat_due(&self) -> Option<Instant> {
84        self.0.next_heartbeat_due
85    }
86
87    pub fn status(&self) -> SessionStatus {
88        self.0.status
89    }
90
91    pub fn cancel(&self) -> bool {
92        self.0.cancel
93    }
94}
95
96/// Capabilities configured for the node.
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct CapabilitySelector {
99    capabilities: Vec<String>,
100}
101
102impl CapabilitySelector {
103    pub fn new(capabilities: Vec<String>) -> Self {
104        Self { capabilities }
105    }
106
107    pub fn choose(&self) -> Option<&str> {
108        self.capabilities.first().map(|s| s.as_str())
109    }
110
111    pub fn accepts(&self, capability: &str) -> bool {
112        self.capabilities.iter().any(|c| c == capability)
113    }
114
115    pub fn all(&self) -> &[String] {
116        &self.capabilities
117    }
118}
119
120/// Distribution for randomized TTL heartbeats.
121#[derive(Debug, Clone, Copy)]
122pub struct HeartbeatPolicy {
123    pub min_ratio: f64,
124    pub max_ratio: f64,
125}
126
127impl HeartbeatPolicy {
128    pub const fn new(min_ratio: f64, max_ratio: f64) -> Self {
129        Self {
130            min_ratio,
131            max_ratio,
132        }
133    }
134
135    pub const fn default_policy() -> Self {
136        Self {
137            min_ratio: 0.55,
138            max_ratio: 0.65,
139        }
140    }
141
142    fn sample_ratio<R: Rng>(&self, rng: &mut R) -> f64 {
143        let min = self.min_ratio.min(self.max_ratio).max(0.0);
144        let max = self.max_ratio.max(self.min_ratio).max(min);
145        let dist = Uniform::new_inclusive(min, max);
146        dist.sample(rng)
147    }
148}
149
150#[derive(Debug, thiserror::Error, PartialEq, Eq)]
151pub enum SessionError {
152    #[error("lease did not include a task")]
153    MissingTask,
154    #[error("task capability `{got}` not in configured set {expected:?}")]
155    CapabilityMismatch { expected: Vec<String>, got: String },
156    #[error("no active session")]
157    NoActiveSession,
158}
159
160#[derive(Clone)]
161pub struct SessionManager {
162    selector: CapabilitySelector,
163    state: Arc<Mutex<Option<SessionState>>>,
164}
165
166impl SessionManager {
167    pub fn new(selector: CapabilitySelector) -> Self {
168        Self {
169            selector,
170            state: Arc::new(Mutex::new(None)),
171        }
172    }
173
174    pub fn selector(&self) -> &CapabilitySelector {
175        &self.selector
176    }
177
178    pub async fn snapshot(&self) -> Option<SessionSnapshot> {
179        let guard = self.state.lock().await;
180        guard.as_ref().cloned().map(SessionSnapshot)
181    }
182
183    pub async fn clear(&self) {
184        *self.state.lock().await = None;
185    }
186
187    pub async fn start_session<R: Rng>(
188        &self,
189        lease: &LeaseEnvelope,
190        now: Instant,
191        policy: &HeartbeatPolicy,
192        rng: &mut R,
193    ) -> Result<SessionSnapshot, SessionError> {
194        let task = lease.task.clone();
195        if !self.selector.accepts(&task.capability) {
196            return Err(SessionError::CapabilityMismatch {
197                expected: self.selector.all().to_vec(),
198                got: task.capability,
199            });
200        }
201
202        let mut state = SessionState {
203            task_id: task.id,
204            job_id: task.job_id,
205            capability: task.capability,
206            meta: task.meta,
207            inputs_cids: task.inputs_cids,
208            domain_id: lease.domain_id,
209            domain_server_url: extract_domain_server_url(lease),
210            lease_expires_at: lease.lease_expires_at,
211            access_token: lease.access_token.clone(),
212            access_token_expires_at: lease.access_token_expires_at,
213            last_progress: None,
214            next_heartbeat_due: None,
215            status: SessionStatus::Pending,
216            cancel: lease.cancel,
217        };
218        state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
219
220        *self.state.lock().await = Some(state.clone());
221        Ok(SessionSnapshot(state))
222    }
223
224    pub async fn apply_heartbeat<R: Rng>(
225        &self,
226        update: &HeartbeatResponse,
227        progress: Option<Value>,
228        now: Instant,
229        policy: &HeartbeatPolicy,
230        rng: &mut R,
231    ) -> Result<SessionSnapshot, SessionError> {
232        let mut guard = self.state.lock().await;
233        let state = guard.as_mut().ok_or(SessionError::NoActiveSession)?;
234
235        if let Some(task) = &update.task {
236            state.task_id = task.id;
237            state.job_id = task.job_id;
238            state.capability = task.capability.clone();
239            state.meta = task.meta.clone();
240            state.inputs_cids = task.inputs_cids.clone();
241        } else {
242            if let Some(task_id) = update.task_id {
243                state.task_id = task_id;
244            }
245            if let Some(job_id) = update.job_id {
246                state.job_id = Some(job_id);
247            }
248        }
249
250        if let Some(domain_id) = update.domain_id {
251            state.domain_id = Some(domain_id);
252        }
253
254        let mut domain_url = update.domain_server_url.clone();
255        if domain_url.is_none() {
256            if let Some(task) = &update.task {
257                domain_url = lookup_domain_url_from_meta(&task.meta);
258            }
259        }
260        if let Some(url) = domain_url {
261            state.domain_server_url = Some(url);
262        }
263
264        if let Some(token) = &update.access_token {
265            state.access_token = Some(token.clone());
266        }
267        if let Some(expiry) = update.access_token_expires_at {
268            state.access_token_expires_at = Some(expiry);
269        }
270        if let Some(lease_expiry) = update.lease_expires_at {
271            state.lease_expires_at = Some(lease_expiry);
272        }
273
274        state.last_progress = progress;
275        state.status = SessionStatus::Running;
276        if let Some(cancel) = update.cancel {
277            state.cancel = cancel;
278        }
279        state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
280
281        Ok(SessionSnapshot(state.clone()))
282    }
283}
284
285fn compute_next_heartbeat<R: Rng>(
286    now: Instant,
287    lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
288    policy: &HeartbeatPolicy,
289    rng: &mut R,
290) -> Option<Instant> {
291    let expires = lease_expires_at?;
292    let ttl = expires.signed_duration_since(chrono::Utc::now());
293    if ttl.num_milliseconds() <= 0 {
294        return Some(now);
295    }
296    let ttl = ttl.to_std().ok()?;
297    let ratio = policy.sample_ratio(rng).clamp(0.0, 1.0);
298    let mut delay = ttl.mul_f64(ratio);
299    if delay > ttl {
300        delay = ttl;
301    }
302    if delay.is_zero() {
303        delay = Duration::from_millis(100);
304    }
305    Some(now + delay.min(ttl))
306}
307
308fn extract_domain_server_url(lease: &LeaseEnvelope) -> Option<Url> {
309    if let Some(url) = &lease.domain_server_url {
310        return Some(url.clone());
311    }
312    lookup_domain_url_from_meta(&lease.task.meta)
313}
314
315fn lookup_domain_url_from_meta(meta: &Value) -> Option<Url> {
316    meta.get("domain_server_url")
317        .and_then(|value| value.as_str())
318        .and_then(|raw| Url::parse(raw).ok())
319        .or_else(|| {
320            meta.get("legacy")
321                .and_then(|legacy| legacy.get("domain_server_url"))
322                .and_then(|value| value.as_str())
323                .and_then(|raw| Url::parse(raw).ok())
324        })
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use chrono::{Duration as ChronoDuration, Utc};
331    use rand::rngs::StdRng;
332    use rand::SeedableRng;
333    use serde_json::json;
334    use uuid::Uuid;
335
336    fn selector() -> CapabilitySelector {
337        CapabilitySelector::new(vec!["cap-a".to_string(), "cap-b".to_string()])
338    }
339
340    fn policy() -> HeartbeatPolicy {
341        HeartbeatPolicy::default_policy()
342    }
343
344    fn lease_base() -> LeaseEnvelope {
345        let now = Utc::now();
346        LeaseEnvelope {
347            access_token: Some("token".into()),
348            access_token_expires_at: Some(now + ChronoDuration::minutes(5)),
349            lease_expires_at: Some(now + ChronoDuration::minutes(10)),
350            cancel: false,
351            status: None,
352            domain_id: None,
353            domain_server_url: None,
354            task: compute_runner_api::TaskSpec {
355                id: Uuid::new_v4(),
356                job_id: Some(Uuid::new_v4()),
357                capability: "cap-a".into(),
358                capability_filters: json!({}),
359                inputs_cids: vec!["cid-1".into()],
360                outputs_prefix: None,
361                label: None,
362                stage: None,
363                meta: json!({ "hello": "world" }),
364                priority: None,
365                attempts: None,
366                max_attempts: None,
367                deps_remaining: None,
368                status: None,
369                mode: None,
370                organization_filter: None,
371                billing_units: None,
372                estimated_credit_cost: None,
373                debited_amount: None,
374                debited_at: None,
375                lease_expires_at: None,
376            },
377        }
378    }
379
380    fn heartbeat_from_lease(lease: &LeaseEnvelope) -> HeartbeatResponse {
381        HeartbeatResponse {
382            access_token: lease.access_token.clone(),
383            access_token_expires_at: lease.access_token_expires_at,
384            lease_expires_at: lease.lease_expires_at,
385            cancel: Some(lease.cancel),
386            status: lease.status.clone(),
387            domain_id: lease.domain_id,
388            domain_server_url: lease.domain_server_url.clone(),
389            task: Some(lease.task.clone()),
390            task_id: Some(lease.task.id),
391            job_id: lease.task.job_id,
392            attempts: lease.task.attempts,
393            max_attempts: lease.task.max_attempts,
394            deps_remaining: lease.task.deps_remaining,
395        }
396    }
397
398    #[test]
399    fn capability_selector_choose() {
400        let selector = selector();
401        assert_eq!(selector.choose(), Some("cap-a"));
402        assert!(selector.accepts("cap-b"));
403        assert!(!selector.accepts("other"));
404    }
405
406    #[tokio::test]
407    async fn start_session_rejects_unknown_capability() {
408        let manager = SessionManager::new(selector());
409        let mut lease = lease_base();
410        lease.task.capability = "other".into();
411        let mut rng = StdRng::seed_from_u64(123);
412        let res = manager
413            .start_session(&lease, Instant::now(), &policy(), &mut rng)
414            .await;
415        assert_eq!(
416            res.unwrap_err(),
417            SessionError::CapabilityMismatch {
418                expected: vec!["cap-a".into(), "cap-b".into()],
419                got: "other".into()
420            }
421        );
422    }
423
424    #[tokio::test]
425    async fn start_session_sets_next_heartbeat() {
426        let manager = SessionManager::new(selector());
427        let lease = lease_base();
428        let mut rng = StdRng::seed_from_u64(7);
429        let snapshot = manager
430            .start_session(&lease, Instant::now(), &policy(), &mut rng)
431            .await
432            .unwrap();
433        assert!(snapshot.next_heartbeat_due().is_some());
434        assert_eq!(snapshot.status(), SessionStatus::Pending);
435    }
436
437    #[tokio::test]
438    async fn apply_heartbeat_updates_state_and_cancel_flag() {
439        let manager = SessionManager::new(selector());
440        let mut lease = lease_base();
441        let mut rng = StdRng::seed_from_u64(9);
442        manager
443            .start_session(&lease, Instant::now(), &policy(), &mut rng)
444            .await
445            .unwrap();
446
447        lease.cancel = true;
448        lease.access_token = Some("new-token".into());
449        let update = heartbeat_from_lease(&lease);
450        let snapshot = manager
451            .apply_heartbeat(
452                &update,
453                Some(json!({"pct": 42})),
454                Instant::now(),
455                &policy(),
456                &mut rng,
457            )
458            .await
459            .unwrap();
460
461        assert_eq!(snapshot.access_token(), Some("new-token"));
462        assert!(snapshot.cancel());
463        assert_eq!(snapshot.status(), SessionStatus::Running);
464        assert!(snapshot.next_heartbeat_due().is_some());
465    }
466}