1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 ) -> std::result::Result<(), crate::errors::ExecutorError> {
64 let cap = lease.task.capability.as_str();
65 let runner = self
66 .get(cap)
67 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
68 let ctx = TaskCtx {
69 lease,
70 input,
71 output,
72 ctrl,
73 };
74 runner
75 .run(ctx)
76 .await
77 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
78 }
79}
80
81pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
83 let shutdown = CancellationToken::new();
84 let signal_token = shutdown.clone();
85 let signal_task = tokio::spawn(async move {
86 if tokio::signal::ctrl_c().await.is_ok() {
87 signal_token.cancel();
88 }
89 });
90
91 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
92
93 shutdown.cancel();
94 let _ = signal_task.await;
95
96 result
97}
98
99pub async fn run_node_with_shutdown(
100 cfg: crate::config::NodeConfig,
101 runners: RunnerRegistry,
102 shutdown: CancellationToken,
103) -> Result<()> {
104 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
105 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
106 let siwe_handle = siwe.start().await?;
107 info!("DDS SIWE token manager started");
108
109 let poll_cfg = PollerConfig {
110 backoff_ms_min: cfg.poll_backoff_ms_min,
111 backoff_ms_max: cfg.poll_backoff_ms_max,
112 };
113
114 loop {
115 if shutdown.is_cancelled() {
116 break;
117 }
118
119 let bearer = match siwe_handle.bearer().await {
120 Ok(token) => token,
121 Err(err) => {
122 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
123 let delay_ms = jittered_delay_ms(poll_cfg);
124 tokio::select! {
125 _ = shutdown.cancelled() => break,
126 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
127 }
128 }
129 };
130
131 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132 let dms_client = match crate::dms::client::DmsClient::new(
133 cfg.dms_base_url.clone(),
134 timeout,
135 Some(bearer),
136 ) {
137 Ok(client) => client,
138 Err(err) => {
139 warn!(error = %err, "Failed to create DMS client; backing off");
140 let delay_ms = jittered_delay_ms(poll_cfg);
141 tokio::select! {
142 _ = shutdown.cancelled() => break,
143 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144 }
145 }
146 };
147
148 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149 Ok(true) => {
150 continue;
152 }
153 Ok(false) => {
154 let delay_ms = jittered_delay_ms(poll_cfg);
155 info!(delay_ms, "No lease available; backing off before next poll");
156 tokio::select! {
157 _ = shutdown.cancelled() => break,
158 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159 }
160 }
161 Err(err) => {
162 warn!(error = %err, "DMS cycle failed; backing off");
163 let delay_ms = jittered_delay_ms(poll_cfg);
164 tokio::select! {
165 _ = shutdown.cancelled() => break,
166 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167 }
168 }
169 }
170 }
171
172 siwe_handle.shutdown().await;
173 info!("Shutdown signal received; exiting run_node loop");
174
175 Ok(())
176}
177
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182 crate::storage::build_ports(lease, token)
183}
184
185pub fn apply_heartbeat_token_update(
188 token: &crate::storage::TokenRef,
189 hb: &crate::dms::types::HeartbeatResponse,
190) {
191 if let Some(new) = hb.access_token.clone() {
192 token.swap(new);
193 }
194}
195
196pub fn merge_heartbeat_into_lease(
198 lease: &mut LeaseEnvelope,
199 hb: &crate::dms::types::HeartbeatResponse,
200) {
201 if let Some(token) = hb.access_token.clone() {
202 lease.access_token = Some(token);
203 }
204 if let Some(expiry) = hb.access_token_expires_at {
205 lease.access_token_expires_at = Some(expiry);
206 }
207 if let Some(expiry) = hb.lease_expires_at {
208 lease.lease_expires_at = Some(expiry);
209 }
210 if let Some(cancel) = hb.cancel {
211 lease.cancel = cancel;
212 }
213 if let Some(status) = hb.status.clone() {
214 lease.status = Some(status);
215 }
216 if let Some(domain_id) = hb.domain_id {
217 lease.domain_id = Some(domain_id);
218 }
219 if let Some(url) = hb.domain_server_url.clone() {
220 lease.domain_server_url = Some(url);
221 }
222 if let Some(task) = hb.task.clone() {
223 lease.task = task;
224 } else {
225 if let Some(task_id) = hb.task_id {
226 lease.task.id = task_id;
227 }
228 if let Some(job_id) = hb.job_id {
229 lease.task.job_id = Some(job_id);
230 }
231 if let Some(attempts) = hb.attempts {
232 lease.task.attempts = Some(attempts);
233 }
234 if let Some(max_attempts) = hb.max_attempts {
235 lease.task.max_attempts = Some(max_attempts);
236 }
237 if let Some(deps_remaining) = hb.deps_remaining {
238 lease.task.deps_remaining = Some(deps_remaining);
239 }
240 }
241}
242
243pub async fn run_cycle_with_dms(
246 _cfg: &crate::config::NodeConfig,
247 dms: &DmsClient,
248 reg: &RunnerRegistry,
249) -> Result<bool> {
250 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
251 use serde_json::json;
252
253 let capabilities = reg.capabilities();
254 let capability = capabilities
255 .first()
256 .cloned()
257 .ok_or_else(|| anyhow!("no runners registered"))?;
258
259 let mut lease = match dms.lease_by_capability(&capability).await? {
261 Some(lease) => lease,
262 None => {
263 return Ok(false);
264 }
265 };
266 if lease.access_token.is_none() {
267 tracing::warn!(
268 "Lease missing access token; storage client will fall back to legacy token flow"
269 );
270 }
271
272 let selector = CapabilitySelector::new(capabilities.clone());
274 let session = SessionManager::new(selector);
275 let policy = HeartbeatPolicy::default_policy();
276 let mut rng = StdRng::from_entropy();
277 let snapshot = session
278 .start_session(&lease, Instant::now(), &policy, &mut rng)
279 .await
280 .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
281 if snapshot.cancel() {
282 warn!(
283 task_id = %snapshot.task_id(),
284 "Lease already marked as cancelled; skipping execution"
285 );
286 return Ok(true);
287 }
288
289 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
290
291 let heartbeat_initial = dms
292 .heartbeat(
293 lease.task.id,
294 &HeartbeatRequest {
295 progress: json!({}),
296 events: json!({}),
297 },
298 )
299 .await?;
300 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
301 merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
302 session
303 .apply_heartbeat(
304 &heartbeat_initial,
305 Some(json!({})),
306 Instant::now(),
307 &policy,
308 &mut rng,
309 )
310 .await
311 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
312
313 let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
314
315 let (progress_tx, progress_rx) = progress_channel();
316 let control_state = Arc::new(Mutex::new(ControlState::default()));
317 {
318 let mut guard = control_state.lock().await;
319 guard.progress = json!({});
320 guard.events = json!({});
321 }
322
323 let runner_cancel = CancellationToken::new();
324 let heartbeat_shutdown = CancellationToken::new();
325
326 let ctrl = EngineControlPlane::new(
327 runner_cancel.clone(),
328 progress_tx.clone(),
329 control_state.clone(),
330 );
331
332 progress_tx.update(json!({}), json!({}));
334
335 let heartbeat_driver = HeartbeatDriver::new(
336 dms.clone(),
337 HeartbeatDriverArgs {
338 session: session.clone(),
339 policy,
340 rng,
341 progress_rx,
342 state: control_state.clone(),
343 token_ref: token_ref.clone(),
344 runner_cancel: runner_cancel.clone(),
345 shutdown: heartbeat_shutdown.clone(),
346 task_id: lease.task.id,
347 },
348 );
349 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
350
351 let run_res = reg
352 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl)
353 .await;
354
355 heartbeat_shutdown.cancel();
356 let heartbeat_result = match heartbeat_handle.await {
357 Ok(result) => result,
358 Err(err) => {
359 warn!(error = %err, "heartbeat loop task failed");
360 HeartbeatLoopResult::Completed
361 }
362 };
363
364 match heartbeat_result {
365 HeartbeatLoopResult::Completed => {}
366 HeartbeatLoopResult::Cancelled => {
367 info!(
368 task_id = %lease.task.id,
369 "Lease cancelled during execution; skipping completion"
370 );
371 runner_cancel.cancel();
372 return Ok(true);
373 }
374 HeartbeatLoopResult::LostLease(err) => {
375 warn!(
376 task_id = %lease.task.id,
377 error = %err,
378 "Lease lost during heartbeat; abandoning task"
379 );
380 runner_cancel.cancel();
381 return Ok(true);
382 }
383 }
384
385 let uploaded_artifacts = ports.uploaded_artifacts();
386 let artifacts_json: Vec<Value> = uploaded_artifacts
387 .iter()
388 .map(|artifact| {
389 json!({
390 "logical_path": artifact.logical_path,
391 "name": artifact.name,
392 "data_type": artifact.data_type,
393 "id": artifact.id,
394 })
395 })
396 .collect();
397 let job_info = json!({
398 "task_id": lease.task.id,
399 "job_id": lease.task.job_id,
400 "domain_id": lease.domain_id,
401 "capability": lease.task.capability,
402 });
403
404 match run_res {
406 Ok(()) => {
407 let body = CompleteTaskRequest {
408 outputs_index: json!({ "artifacts": artifacts_json.clone() }),
409 result: json!({
410 "job": job_info,
411 "artifacts": artifacts_json,
412 }),
413 };
414 dms.complete(lease.task.id, &body).await?;
415 }
416 Err(err) => {
417 error!(
418 task_id = %lease.task.id,
419 job_id = ?lease.task.job_id,
420 capability = %lease.task.capability,
421 error = %err,
422 debug = ?err,
423 "Runner execution failed; reporting failure to DMS"
424 );
425 let body = FailTaskRequest {
426 reason: err.to_string(),
427 details: json!({
428 "job": job_info,
429 "artifacts": artifacts_json,
430 }),
431 };
432 dms.fail(lease.task.id, &body)
433 .await
434 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
435 }
436 }
437
438 Ok(true)
439}
440
441#[derive(Default)]
442pub struct ControlState {
443 progress: Value,
444 events: Value,
445}
446
447struct EngineControlPlane {
448 cancel: CancellationToken,
449 progress_tx: ProgressSender,
450 state: Arc<Mutex<ControlState>>,
451}
452
453impl EngineControlPlane {
454 pub fn new(
455 cancel: CancellationToken,
456 progress_tx: ProgressSender,
457 state: Arc<Mutex<ControlState>>,
458 ) -> Self {
459 Self {
460 cancel,
461 progress_tx,
462 state,
463 }
464 }
465}
466
467#[async_trait]
468impl ControlPlane for EngineControlPlane {
469 async fn is_cancelled(&self) -> bool {
470 self.cancel.is_cancelled()
471 }
472
473 async fn progress(&self, value: Value) -> Result<()> {
474 let events = {
475 let mut state = self.state.lock().await;
476 state.progress = value.clone();
477 state.events.clone()
478 };
479 self.progress_tx.update(value, events);
480 Ok(())
481 }
482
483 async fn log_event(&self, fields: Value) -> Result<()> {
484 let progress = {
485 let mut state = self.state.lock().await;
486 state.events = fields.clone();
487 state.progress.clone()
488 };
489 self.progress_tx.update(progress, fields);
490 Ok(())
491 }
492}
493
494pub enum HeartbeatLoopResult {
495 Completed,
496 Cancelled,
497 LostLease(anyhow::Error),
498}
499
500#[async_trait]
501pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
502 async fn post_heartbeat(
503 &self,
504 task_id: Uuid,
505 body: &crate::dms::types::HeartbeatRequest,
506 ) -> Result<crate::dms::types::HeartbeatResponse>;
507}
508
509#[async_trait]
510impl HeartbeatTransport for DmsClient {
511 async fn post_heartbeat(
512 &self,
513 task_id: Uuid,
514 body: &crate::dms::types::HeartbeatRequest,
515 ) -> Result<crate::dms::types::HeartbeatResponse> {
516 self.heartbeat(task_id, body).await
517 }
518}
519
520pub struct HeartbeatDriverArgs {
521 pub session: SessionManager,
522 pub policy: HeartbeatPolicy,
523 pub rng: StdRng,
524 pub progress_rx: ProgressReceiver,
525 pub state: Arc<Mutex<ControlState>>,
526 pub token_ref: crate::storage::TokenRef,
527 pub runner_cancel: CancellationToken,
528 pub shutdown: CancellationToken,
529 pub task_id: Uuid,
530}
531
532pub struct HeartbeatDriver<T>
533where
534 T: HeartbeatTransport,
535{
536 transport: T,
537 session: SessionManager,
538 policy: HeartbeatPolicy,
539 rng: StdRng,
540 progress_rx: ProgressReceiver,
541 state: Arc<Mutex<ControlState>>,
542 token_ref: crate::storage::TokenRef,
543 runner_cancel: CancellationToken,
544 shutdown: CancellationToken,
545 task_id: Uuid,
546 last_progress: Value,
547}
548
549impl<T> HeartbeatDriver<T>
550where
551 T: HeartbeatTransport,
552{
553 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
554 Self {
555 transport,
556 session: args.session,
557 policy: args.policy,
558 rng: args.rng,
559 progress_rx: args.progress_rx,
560 state: args.state,
561 token_ref: args.token_ref,
562 runner_cancel: args.runner_cancel,
563 shutdown: args.shutdown,
564 task_id: args.task_id,
565 last_progress: Value::default(),
566 }
567 }
568
569 pub async fn run(mut self) -> HeartbeatLoopResult {
570 loop {
571 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
572 return HeartbeatLoopResult::Completed;
573 }
574
575 let snapshot = match self.session.snapshot().await {
576 Some(s) => s,
577 None => return HeartbeatLoopResult::Completed,
578 };
579
580 let ttl_delay = snapshot
581 .next_heartbeat_due()
582 .map(|due| due.saturating_duration_since(Instant::now()));
583
584 if let Some(delay) = ttl_delay {
585 tokio::select! {
586 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
587 progress = self.progress_rx.recv() => {
588 if let Some(data) = progress {
589 if let Some(outcome) = self.handle_progress(data).await {
590 return outcome;
591 }
592 } else {
593 return HeartbeatLoopResult::Completed;
594 }
595 }
596 _ = tokio::time::sleep(delay) => {
597 if let Some(outcome) = self.handle_ttl().await {
598 return outcome;
599 }
600 }
601 }
602 } else {
603 tokio::select! {
604 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
605 progress = self.progress_rx.recv() => {
606 if let Some(data) = progress {
607 if let Some(outcome) = self.handle_progress(data).await {
608 return outcome;
609 }
610 } else {
611 return HeartbeatLoopResult::Completed;
612 }
613 }
614 }
615 }
616 }
617 }
618
619 async fn handle_progress(
620 &mut self,
621 data: crate::heartbeat::HeartbeatData,
622 ) -> Option<HeartbeatLoopResult> {
623 self.last_progress = data.progress.clone();
624 self.send_and_update(data.progress, data.events).await
625 }
626
627 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
628 let (progress, events) = self.snapshot_state().await;
629 self.send_and_update(progress, events).await
630 }
631
632 async fn snapshot_state(&self) -> (Value, Value) {
633 let state = self.state.lock().await;
634 (state.progress.clone(), state.events.clone())
635 }
636
637 async fn send_and_update(
638 &mut self,
639 progress: Value,
640 events: Value,
641 ) -> Option<HeartbeatLoopResult> {
642 let request = crate::dms::types::HeartbeatRequest {
643 progress: progress.clone(),
644 events: events.clone(),
645 };
646
647 match self.transport.post_heartbeat(self.task_id, &request).await {
648 Ok(update) => {
649 apply_heartbeat_token_update(&self.token_ref, &update);
650 if let Some(task) = &update.task {
651 self.task_id = task.id;
652 } else if let Some(task_id) = update.task_id {
653 self.task_id = task_id;
654 }
655 if let Err(err) = self
656 .session
657 .apply_heartbeat(
658 &update,
659 Some(progress.clone()),
660 Instant::now(),
661 &self.policy,
662 &mut self.rng,
663 )
664 .await
665 {
666 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
667 }
668 if update.cancel.unwrap_or(false) {
669 self.runner_cancel.cancel();
670 return Some(HeartbeatLoopResult::Cancelled);
671 }
672 None
673 }
674 Err(err) => {
675 self.runner_cancel.cancel();
676 Some(HeartbeatLoopResult::LostLease(err))
677 }
678 }
679 }
680}