1use compute_runner_api::LeaseEnvelope;
2use rand::distributions::{Distribution, Uniform};
3use rand::Rng;
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use url::Url;
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SessionStatus {
14 Pending,
15 Running,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct SessionState {
20 pub task_id: Uuid,
21 pub job_id: Option<Uuid>,
22 pub capability: String,
23 pub meta: Value,
24 pub inputs_cids: Vec<String>,
25 pub domain_id: Option<Uuid>,
26 pub domain_server_url: Option<Url>,
27 pub lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
28 pub access_token: Option<String>,
29 pub access_token_expires_at: Option<chrono::DateTime<chrono::Utc>>,
30 pub last_progress: Option<Value>,
31 pub next_heartbeat_due: Option<Instant>,
32 pub status: SessionStatus,
33 pub cancel: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct SessionSnapshot(SessionState);
39
40impl SessionSnapshot {
41 pub fn task_id(&self) -> Uuid {
42 self.0.task_id
43 }
44
45 pub fn job_id(&self) -> Option<Uuid> {
46 self.0.job_id
47 }
48
49 pub fn capability(&self) -> &str {
50 &self.0.capability
51 }
52
53 pub fn meta(&self) -> &Value {
54 &self.0.meta
55 }
56
57 pub fn inputs_cids(&self) -> &[String] {
58 &self.0.inputs_cids
59 }
60
61 pub fn domain_id(&self) -> Option<Uuid> {
62 self.0.domain_id
63 }
64
65 pub fn domain_server_url(&self) -> Option<&Url> {
66 self.0.domain_server_url.as_ref()
67 }
68
69 pub fn access_token(&self) -> Option<&str> {
70 self.0.access_token.as_deref()
71 }
72
73 pub fn access_token_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
74 self.0.access_token_expires_at
75 }
76
77 pub fn lease_expires_at(&self) -> Option<chrono::DateTime<chrono::Utc>> {
78 self.0.lease_expires_at
79 }
80
81 pub fn next_heartbeat_due(&self) -> Option<Instant> {
82 self.0.next_heartbeat_due
83 }
84
85 pub fn status(&self) -> SessionStatus {
86 self.0.status
87 }
88
89 pub fn cancel(&self) -> bool {
90 self.0.cancel
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct CapabilitySelector {
97 capabilities: Vec<String>,
98}
99
100impl CapabilitySelector {
101 pub fn new(capabilities: Vec<String>) -> Self {
102 Self { capabilities }
103 }
104
105 pub fn choose(&self) -> Option<&str> {
106 self.capabilities.first().map(|s| s.as_str())
107 }
108
109 pub fn accepts(&self, capability: &str) -> bool {
110 self.capabilities.iter().any(|c| c == capability)
111 }
112
113 pub fn all(&self) -> &[String] {
114 &self.capabilities
115 }
116}
117
118#[derive(Debug, Clone, Copy)]
120pub struct HeartbeatPolicy {
121 pub min_ratio: f64,
122 pub max_ratio: f64,
123}
124
125impl HeartbeatPolicy {
126 pub const fn new(min_ratio: f64, max_ratio: f64) -> Self {
127 Self {
128 min_ratio,
129 max_ratio,
130 }
131 }
132
133 pub const fn default_policy() -> Self {
134 Self {
135 min_ratio: 0.55,
136 max_ratio: 0.65,
137 }
138 }
139
140 fn sample_ratio<R: Rng>(&self, rng: &mut R) -> f64 {
141 let min = self.min_ratio.min(self.max_ratio).max(0.0);
142 let max = self.max_ratio.max(self.min_ratio).max(min);
143 let dist = Uniform::new_inclusive(min, max);
144 dist.sample(rng)
145 }
146}
147
148#[derive(Debug, thiserror::Error, PartialEq, Eq)]
149pub enum SessionError {
150 #[error("lease did not include a task")]
151 MissingTask,
152 #[error("task capability `{got}` not in configured set {expected:?}")]
153 CapabilityMismatch { expected: Vec<String>, got: String },
154 #[error("no active session")]
155 NoActiveSession,
156}
157
158#[derive(Clone)]
159pub struct SessionManager {
160 selector: CapabilitySelector,
161 state: Arc<Mutex<Option<SessionState>>>,
162}
163
164impl SessionManager {
165 pub fn new(selector: CapabilitySelector) -> Self {
166 Self {
167 selector,
168 state: Arc::new(Mutex::new(None)),
169 }
170 }
171
172 pub fn selector(&self) -> &CapabilitySelector {
173 &self.selector
174 }
175
176 pub async fn snapshot(&self) -> Option<SessionSnapshot> {
177 let guard = self.state.lock().await;
178 guard.as_ref().cloned().map(SessionSnapshot)
179 }
180
181 pub async fn clear(&self) {
182 *self.state.lock().await = None;
183 }
184
185 pub async fn start_session<R: Rng>(
186 &self,
187 lease: &LeaseEnvelope,
188 now: Instant,
189 policy: &HeartbeatPolicy,
190 rng: &mut R,
191 ) -> Result<SessionSnapshot, SessionError> {
192 let task = lease.task.clone();
193 if !self.selector.accepts(&task.capability) {
194 return Err(SessionError::CapabilityMismatch {
195 expected: self.selector.all().to_vec(),
196 got: task.capability,
197 });
198 }
199
200 let mut state = SessionState {
201 task_id: task.id,
202 job_id: task.job_id,
203 capability: task.capability,
204 meta: task.meta,
205 inputs_cids: task.inputs_cids,
206 domain_id: lease.domain_id,
207 domain_server_url: extract_domain_server_url(lease),
208 lease_expires_at: lease.lease_expires_at,
209 access_token: lease.access_token.clone(),
210 access_token_expires_at: lease.access_token_expires_at,
211 last_progress: None,
212 next_heartbeat_due: None,
213 status: SessionStatus::Pending,
214 cancel: lease.cancel,
215 };
216 state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
217
218 *self.state.lock().await = Some(state.clone());
219 Ok(SessionSnapshot(state))
220 }
221
222 pub async fn apply_heartbeat<R: Rng>(
223 &self,
224 lease: &LeaseEnvelope,
225 progress: Option<Value>,
226 now: Instant,
227 policy: &HeartbeatPolicy,
228 rng: &mut R,
229 ) -> Result<SessionSnapshot, SessionError> {
230 let mut guard = self.state.lock().await;
231 let state = guard.as_mut().ok_or(SessionError::NoActiveSession)?;
232
233 let task = lease.task.clone();
234 state.task_id = task.id;
235 state.job_id = task.job_id;
236 state.capability = task.capability;
237 state.meta = task.meta;
238 state.inputs_cids = task.inputs_cids;
239
240 if let Some(domain_id) = lease.domain_id {
241 state.domain_id = Some(domain_id);
242 }
243 if let Some(url) = extract_domain_server_url(lease) {
244 state.domain_server_url = Some(url);
245 }
246
247 if let Some(token) = &lease.access_token {
248 state.access_token = Some(token.clone());
249 }
250 if let Some(expiry) = lease.access_token_expires_at {
251 state.access_token_expires_at = Some(expiry);
252 }
253 if let Some(lease_expiry) = lease.lease_expires_at {
254 state.lease_expires_at = Some(lease_expiry);
255 }
256
257 state.last_progress = progress;
258 state.status = SessionStatus::Running;
259 state.cancel = lease.cancel;
260 state.next_heartbeat_due = compute_next_heartbeat(now, state.lease_expires_at, policy, rng);
261
262 Ok(SessionSnapshot(state.clone()))
263 }
264}
265
266fn compute_next_heartbeat<R: Rng>(
267 now: Instant,
268 lease_expires_at: Option<chrono::DateTime<chrono::Utc>>,
269 policy: &HeartbeatPolicy,
270 rng: &mut R,
271) -> Option<Instant> {
272 let expires = lease_expires_at?;
273 let ttl = expires.signed_duration_since(chrono::Utc::now());
274 if ttl.num_milliseconds() <= 0 {
275 return Some(now);
276 }
277 let ttl = ttl.to_std().ok()?;
278 let ratio = policy.sample_ratio(rng).clamp(0.0, 1.0);
279 let mut delay = ttl.mul_f64(ratio);
280 if delay > ttl {
281 delay = ttl;
282 }
283 if delay.is_zero() {
284 delay = Duration::from_millis(100);
285 }
286 Some(now + delay.min(ttl))
287}
288
289fn extract_domain_server_url(lease: &LeaseEnvelope) -> Option<Url> {
290 if let Some(url) = &lease.domain_server_url {
291 return Some(url.clone());
292 }
293 lookup_domain_url_from_meta(&lease.task.meta)
294}
295
296fn lookup_domain_url_from_meta(meta: &Value) -> Option<Url> {
297 meta.get("domain_server_url")
298 .and_then(|value| value.as_str())
299 .and_then(|raw| Url::parse(raw).ok())
300 .or_else(|| {
301 meta.get("legacy")
302 .and_then(|legacy| legacy.get("domain_server_url"))
303 .and_then(|value| value.as_str())
304 .and_then(|raw| Url::parse(raw).ok())
305 })
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use chrono::{Duration as ChronoDuration, Utc};
312 use rand::rngs::StdRng;
313 use rand::SeedableRng;
314 use serde_json::json;
315 use uuid::Uuid;
316
317 fn selector() -> CapabilitySelector {
318 CapabilitySelector::new(vec!["cap-a".to_string(), "cap-b".to_string()])
319 }
320
321 fn policy() -> HeartbeatPolicy {
322 HeartbeatPolicy::default_policy()
323 }
324
325 fn lease_base() -> LeaseEnvelope {
326 let now = Utc::now();
327 LeaseEnvelope {
328 access_token: Some("token".into()),
329 access_token_expires_at: Some(now + ChronoDuration::minutes(5)),
330 lease_expires_at: Some(now + ChronoDuration::minutes(10)),
331 cancel: false,
332 status: None,
333 domain_id: None,
334 domain_server_url: None,
335 task: compute_runner_api::TaskSpec {
336 id: Uuid::new_v4(),
337 job_id: Some(Uuid::new_v4()),
338 capability: "cap-a".into(),
339 capability_filters: json!({}),
340 inputs_cids: vec!["cid-1".into()],
341 outputs_prefix: None,
342 label: None,
343 stage: None,
344 meta: json!({ "hello": "world" }),
345 priority: None,
346 attempts: None,
347 max_attempts: None,
348 deps_remaining: None,
349 status: None,
350 mode: None,
351 organization_filter: None,
352 billing_units: None,
353 estimated_credit_cost: None,
354 debited_amount: None,
355 debited_at: None,
356 lease_expires_at: None,
357 },
358 }
359 }
360
361 #[test]
362 fn capability_selector_choose() {
363 let selector = selector();
364 assert_eq!(selector.choose(), Some("cap-a"));
365 assert!(selector.accepts("cap-b"));
366 assert!(!selector.accepts("other"));
367 }
368
369 #[tokio::test]
370 async fn start_session_rejects_unknown_capability() {
371 let manager = SessionManager::new(selector());
372 let mut lease = lease_base();
373 lease.task.capability = "other".into();
374 let mut rng = StdRng::seed_from_u64(123);
375 let res = manager
376 .start_session(&lease, Instant::now(), &policy(), &mut rng)
377 .await;
378 assert_eq!(
379 res.unwrap_err(),
380 SessionError::CapabilityMismatch {
381 expected: vec!["cap-a".into(), "cap-b".into()],
382 got: "other".into()
383 }
384 );
385 }
386
387 #[tokio::test]
388 async fn start_session_sets_next_heartbeat() {
389 let manager = SessionManager::new(selector());
390 let lease = lease_base();
391 let mut rng = StdRng::seed_from_u64(7);
392 let snapshot = manager
393 .start_session(&lease, Instant::now(), &policy(), &mut rng)
394 .await
395 .unwrap();
396 assert!(snapshot.next_heartbeat_due().is_some());
397 assert_eq!(snapshot.status(), SessionStatus::Pending);
398 }
399
400 #[tokio::test]
401 async fn apply_heartbeat_updates_state_and_cancel_flag() {
402 let manager = SessionManager::new(selector());
403 let mut lease = lease_base();
404 let mut rng = StdRng::seed_from_u64(9);
405 manager
406 .start_session(&lease, Instant::now(), &policy(), &mut rng)
407 .await
408 .unwrap();
409
410 lease.cancel = true;
411 lease.access_token = Some("new-token".into());
412 let snapshot = manager
413 .apply_heartbeat(
414 &lease,
415 Some(json!({"pct": 42})),
416 Instant::now(),
417 &policy(),
418 &mut rng,
419 )
420 .await
421 .unwrap();
422
423 assert_eq!(snapshot.access_token(), Some("new-token"));
424 assert!(snapshot.cancel());
425 assert_eq!(snapshot.status(), SessionStatus::Running);
426 assert!(snapshot.next_heartbeat_due().is_some());
427 }
428}