1use compute_runner_api::LeaseEnvelope;
2use rand::distributions::{Distribution, Uniform};
3use rand::Rng;
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use url::Url;
9use uuid::Uuid;
10
11use crate::dms::types::HeartbeatResponse;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SessionStatus {
16 Pending,
17 Running,
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct SessionState {
22 pub task_id: Uuid,
23 pub job_id: Option<Uuid>,
24 pub capability: String,
25 pub meta: Value,
26 pub inputs_cids: Vec<String>,
27 pub domain_id: Option<Uuid>,
28 pub domain_server_url: Option<Url>,
29 pub lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
30 pub access_token: Option<String>,
31 pub access_token_expires_at: Option<chrono::DateTime<chrono::Utc>>,
32 pub last_progress: Option<Value>,
33 pub next_heartbeat_due: Option<Instant>,
34 pub status: SessionStatus,
35 pub cancel: bool,
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionSnapshot(SessionState);
41
42impl SessionSnapshot {
43 pub fn task_id(&self) -> Uuid {
44 self.0.task_id
45 }
46
47 pub fn job_id(&self) -> Option<Uuid> {
48 self.0.job_id
49 }
50
51 pub fn capability(&self) -> &str {
52 &self.0.capability
53 }
54
55 pub fn meta(&self) -> &Value {
56 &self.0.meta
57 }
58
59 pub fn inputs_cids(&self) -> &[String] {
60 &self.0.inputs_cids
61 }
62
63 pub fn domain_id(&self) -> Option<Uuid> {
64 self.0.domain_id
65 }
66
67 pub fn domain_server_url(&self) -> Option<&Url> {
68 self.0.domain_server_url.as_ref()
69 }
70
71 pub fn access_token(&self) -> Option<&str> {
72 self.0.access_token.as_deref()
73 }
74
75 pub fn access_token_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
76 self.0.access_token_expires_at
77 }
78
79 pub fn lease_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
80 self.0.lease_expires_at
81 }
82
83 pub fn next_heartbeat_due(&self) -> Option<Instant> {
84 self.0.next_heartbeat_due
85 }
86
87 pub fn status(&self) -> SessionStatus {
88 self.0.status
89 }
90
91 pub fn cancel(&self) -> bool {
92 self.0.cancel
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct CapabilitySelector {
99 capabilities: Vec<String>,
100}
101
102impl CapabilitySelector {
103 pub fn new(capabilities: Vec<String>) -> Self {
104 Self { capabilities }
105 }
106
107 pub fn choose(&self) -> Option<&str> {
108 self.capabilities.first().map(|s| s.as_str())
109 }
110
111 pub fn accepts(&self, capability: &str) -> bool {
112 self.capabilities.iter().any(|c| c == capability)
113 }
114
115 pub fn all(&self) -> &[String] {
116 &self.capabilities
117 }
118}
119
120#[derive(Debug, Clone, Copy)]
122pub struct HeartbeatPolicy {
123 pub min_ratio: f64,
124 pub max_ratio: f64,
125}
126
127impl HeartbeatPolicy {
128 pub const fn new(min_ratio: f64, max_ratio: f64) -> Self {
129 Self {
130 min_ratio,
131 max_ratio,
132 }
133 }
134
135 pub const fn default_policy() -> Self {
136 Self {
137 min_ratio: 0.55,
138 max_ratio: 0.65,
139 }
140 }
141
142 fn sample_ratio<R: Rng>(&self, rng: &mut R) -> f64 {
143 let min = self.min_ratio.min(self.max_ratio).max(0.0);
144 let max = self.max_ratio.max(self.min_ratio).max(min);
145 let dist = Uniform::new_inclusive(min, max);
146 dist.sample(rng)
147 }
148}
149
150#[derive(Debug, thiserror::Error, PartialEq, Eq)]
151pub enum SessionError {
152 #[error("lease did not include a task")]
153 MissingTask,
154 #[error("task capability `{got}` not in configured set {expected:?}")]
155 CapabilityMismatch { expected: Vec<String>, got: String },
156 #[error("no active session")]
157 NoActiveSession,
158}
159
160#[derive(Clone)]
161pub struct SessionManager {
162 selector: CapabilitySelector,
163 state: Arc<Mutex<Option<SessionState>>>,
164}
165
166impl SessionManager {
167 pub fn new(selector: CapabilitySelector) -> Self {
168 Self {
169 selector,
170 state: Arc::new(Mutex::new(None)),
171 }
172 }
173
174 pub fn selector(&self) -> &CapabilitySelector {
175 &self.selector
176 }
177
178 pub async fn snapshot(&self) -> Option<SessionSnapshot> {
179 let guard = self.state.lock().await;
180 guard.as_ref().cloned().map(SessionSnapshot)
181 }
182
183 pub async fn clear(&self) {
184 *self.state.lock().await = None;
185 }
186
187 pub async fn start_session<R: Rng>(
188 &self,
189 lease: &LeaseEnvelope,
190 now: Instant,
191 policy: &HeartbeatPolicy,
192 rng: &mut R,
193 ) -> Result<SessionSnapshot, SessionError> {
194 let task = lease.task.clone();
195 if !self.selector.accepts(&task.capability) {
196 return Err(SessionError::CapabilityMismatch {
197 expected: self.selector.all().to_vec(),
198 got: task.capability,
199 });
200 }
201
202 let mut state = SessionState {
203 task_id: task.id,
204 job_id: task.job_id,
205 capability: task.capability,
206 meta: task.meta,
207 inputs_cids: task.inputs_cids,
208 domain_id: lease.domain_id,
209 domain_server_url: extract_domain_server_url(lease),
210 lease_expires_at: lease.lease_expires_at,
211 access_token: lease.access_token.clone(),
212 access_token_expires_at: lease.access_token_expires_at,
213 last_progress: None,
214 next_heartbeat_due: None,
215 status: SessionStatus::Pending,
216 cancel: lease.cancel,
217 };
218 state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
219
220 *self.state.lock().await = Some(state.clone());
221 Ok(SessionSnapshot(state))
222 }
223
224 pub async fn apply_heartbeat<R: Rng>(
225 &self,
226 update: &HeartbeatResponse,
227 progress: Option<Value>,
228 now: Instant,
229 policy: &HeartbeatPolicy,
230 rng: &mut R,
231 ) -> Result<SessionSnapshot, SessionError> {
232 let mut guard = self.state.lock().await;
233 let state = guard.as_mut().ok_or(SessionError::NoActiveSession)?;
234
235 if let Some(task) = &update.task {
236 state.task_id = task.id;
237 state.job_id = task.job_id;
238 state.capability = task.capability.clone();
239 state.meta = task.meta.clone();
240 state.inputs_cids = task.inputs_cids.clone();
241 } else {
242 if let Some(task_id) = update.task_id {
243 state.task_id = task_id;
244 }
245 if let Some(job_id) = update.job_id {
246 state.job_id = Some(job_id);
247 }
248 }
249
250 if let Some(domain_id) = update.domain_id {
251 state.domain_id = Some(domain_id);
252 }
253
254 let mut domain_url = update.domain_server_url.clone();
255 if domain_url.is_none() {
256 if let Some(task) = &update.task {
257 domain_url = lookup_domain_url_from_meta(&task.meta);
258 }
259 }
260 if let Some(url) = domain_url {
261 state.domain_server_url = Some(url);
262 }
263
264 if let Some(token) = &update.access_token {
265 state.access_token = Some(token.clone());
266 }
267 if let Some(expiry) = update.access_token_expires_at {
268 state.access_token_expires_at = Some(expiry);
269 }
270 if let Some(lease_expiry) = update.lease_expires_at {
271 state.lease_expires_at = Some(lease_expiry);
272 }
273
274 state.last_progress = progress;
275 state.status = SessionStatus::Running;
276 if let Some(cancel) = update.cancel {
277 state.cancel = cancel;
278 }
279 state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
280
281 Ok(SessionSnapshot(state.clone()))
282 }
283}
284
285fn compute_next_heartbeat<R: Rng>(
286 now: Instant,
287 lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
288 policy: &HeartbeatPolicy,
289 rng: &mut R,
290) -> Option<Instant> {
291 let expires = lease_expires_at?;
292 let ttl = expires.signed_duration_since(chrono::Utc::now());
293 if ttl.num_milliseconds() <= 0 {
294 return Some(now);
295 }
296 let ttl = ttl.to_std().ok()?;
297 let ratio = policy.sample_ratio(rng).clamp(0.0, 1.0);
298 let mut delay = ttl.mul_f64(ratio);
299 if delay > ttl {
300 delay = ttl;
301 }
302 if delay.is_zero() {
303 delay = Duration::from_millis(100);
304 }
305 Some(now + delay.min(ttl))
306}
307
308fn extract_domain_server_url(lease: &LeaseEnvelope) -> Option<Url> {
309 if let Some(url) = &lease.domain_server_url {
310 return Some(url.clone());
311 }
312 lookup_domain_url_from_meta(&lease.task.meta)
313}
314
315fn lookup_domain_url_from_meta(meta: &Value) -> Option<Url> {
316 meta.get("domain_server_url")
317 .and_then(|value| value.as_str())
318 .and_then(|raw| Url::parse(raw).ok())
319 .or_else(|| {
320 meta.get("legacy")
321 .and_then(|legacy| legacy.get("domain_server_url"))
322 .and_then(|value| value.as_str())
323 .and_then(|raw| Url::parse(raw).ok())
324 })
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use chrono::{Duration as ChronoDuration, Utc};
331 use rand::rngs::StdRng;
332 use rand::SeedableRng;
333 use serde_json::json;
334 use uuid::Uuid;
335
336 fn selector() -> CapabilitySelector {
337 CapabilitySelector::new(vec!["cap-a".to_string(), "cap-b".to_string()])
338 }
339
340 fn policy() -> HeartbeatPolicy {
341 HeartbeatPolicy::default_policy()
342 }
343
344 fn lease_base() -> LeaseEnvelope {
345 let now = Utc::now();
346 LeaseEnvelope {
347 access_token: Some("token".into()),
348 access_token_expires_at: Some(now + ChronoDuration::minutes(5)),
349 lease_expires_at: Some(now + ChronoDuration::minutes(10)),
350 cancel: false,
351 status: None,
352 domain_id: None,
353 domain_server_url: None,
354 task: compute_runner_api::TaskSpec {
355 id: Uuid::new_v4(),
356 job_id: Some(Uuid::new_v4()),
357 capability: "cap-a".into(),
358 capability_filters: json!({}),
359 inputs_cids: vec!["cid-1".into()],
360 outputs_prefix: None,
361 label: None,
362 stage: None,
363 meta: json!({ "hello": "world" }),
364 priority: None,
365 attempts: None,
366 max_attempts: None,
367 deps_remaining: None,
368 status: None,
369 mode: None,
370 organization_filter: None,
371 billing_units: None,
372 estimated_credit_cost: None,
373 debited_amount: None,
374 debited_at: None,
375 lease_expires_at: None,
376 },
377 }
378 }
379
380 fn heartbeat_from_lease(lease: &LeaseEnvelope) -> HeartbeatResponse {
381 HeartbeatResponse {
382 access_token: lease.access_token.clone(),
383 access_token_expires_at: lease.access_token_expires_at,
384 lease_expires_at: lease.lease_expires_at,
385 cancel: Some(lease.cancel),
386 status: lease.status.clone(),
387 domain_id: lease.domain_id,
388 domain_server_url: lease.domain_server_url.clone(),
389 task: Some(lease.task.clone()),
390 task_id: Some(lease.task.id),
391 job_id: lease.task.job_id,
392 attempts: lease.task.attempts,
393 max_attempts: lease.task.max_attempts,
394 deps_remaining: lease.task.deps_remaining,
395 }
396 }
397
398 #[test]
399 fn capability_selector_choose() {
400 let selector = selector();
401 assert_eq!(selector.choose(), Some("cap-a"));
402 assert!(selector.accepts("cap-b"));
403 assert!(!selector.accepts("other"));
404 }
405
406 #[tokio::test]
407 async fn start_session_rejects_unknown_capability() {
408 let manager = SessionManager::new(selector());
409 let mut lease = lease_base();
410 lease.task.capability = "other".into();
411 let mut rng = StdRng::seed_from_u64(123);
412 let res = manager
413 .start_session(&lease, Instant::now(), &policy(), &mut rng)
414 .await;
415 assert_eq!(
416 res.unwrap_err(),
417 SessionError::CapabilityMismatch {
418 expected: vec!["cap-a".into(), "cap-b".into()],
419 got: "other".into()
420 }
421 );
422 }
423
424 #[tokio::test]
425 async fn start_session_sets_next_heartbeat() {
426 let manager = SessionManager::new(selector());
427 let lease = lease_base();
428 let mut rng = StdRng::seed_from_u64(7);
429 let snapshot = manager
430 .start_session(&lease, Instant::now(), &policy(), &mut rng)
431 .await
432 .unwrap();
433 assert!(snapshot.next_heartbeat_due().is_some());
434 assert_eq!(snapshot.status(), SessionStatus::Pending);
435 }
436
437 #[tokio::test]
438 async fn apply_heartbeat_updates_state_and_cancel_flag() {
439 let manager = SessionManager::new(selector());
440 let mut lease = lease_base();
441 let mut rng = StdRng::seed_from_u64(9);
442 manager
443 .start_session(&lease, Instant::now(), &policy(), &mut rng)
444 .await
445 .unwrap();
446
447 lease.cancel = true;
448 lease.access_token = Some("new-token".into());
449 let update = heartbeat_from_lease(&lease);
450 let snapshot = manager
451 .apply_heartbeat(
452 &update,
453 Some(json!({"pct": 42})),
454 Instant::now(),
455 &policy(),
456 &mut rng,
457 )
458 .await
459 .unwrap();
460
461 assert_eq!(snapshot.access_token(), Some("new-token"));
462 assert!(snapshot.cancel());
463 assert_eq!(snapshot.status(), SessionStatus::Running);
464 assert!(snapshot.next_heartbeat_due().is_some());
465 }
466}