posemesh_compute_node/
engine.rs

1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17    dms::client::DmsClient,
18    heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19    poller::{jittered_delay_ms, PollerConfig},
20    session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23/// Registry mapping capability strings to runner instances.
24#[derive(Default)]
25pub struct RunnerRegistry {
26    runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30    /// Create an empty registry.
31    pub fn new() -> Self {
32        Self {
33            runners: HashMap::new(),
34        }
35    }
36
37    /// Register a runner by its capability. Last registration wins on duplicates.
38    pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39        let key = runner.capability().to_string();
40        self.runners.insert(key, Arc::new(runner));
41        self
42    }
43
44    /// Retrieve a runner by capability.
45    pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46        self.runners.get(capability).cloned()
47    }
48
49    /// Snapshot of registered capability strings.
50    pub fn capabilities(&self) -> Vec<String> {
51        let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52        caps.sort();
53        caps
54    }
55
56    /// Dispatch task to the appropriate runner based on `lease.task.capability`.
57    pub async fn run_for_lease(
58        &self,
59        lease: &LeaseEnvelope,
60        input: &dyn InputSource,
61        output: &dyn ArtifactSink,
62        ctrl: &dyn ControlPlane,
63    ) -> std::result::Result<(), crate::errors::ExecutorError> {
64        let cap = lease.task.capability.as_str();
65        let runner = self
66            .get(cap)
67            .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
68        let ctx = TaskCtx {
69            lease,
70            input,
71            output,
72            ctrl,
73        };
74        runner
75            .run(ctx)
76            .await
77            .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
78    }
79}
80
81/// Run the node main loop. Networking and storage are wired in later prompts.
82pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
83    let shutdown = CancellationToken::new();
84    let signal_token = shutdown.clone();
85    let signal_task = tokio::spawn(async move {
86        if tokio::signal::ctrl_c().await.is_ok() {
87            signal_token.cancel();
88        }
89    });
90
91    let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
92
93    shutdown.cancel();
94    let _ = signal_task.await;
95
96    result
97}
98
99pub async fn run_node_with_shutdown(
100    cfg: crate::config::NodeConfig,
101    runners: RunnerRegistry,
102    shutdown: CancellationToken,
103) -> Result<()> {
104    let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
105    info!("DDS SIWE authentication configured; waiting for DDS registration callback");
106    let siwe_handle = siwe.start().await?;
107    info!("DDS SIWE token manager started");
108
109    let poll_cfg = PollerConfig {
110        backoff_ms_min: cfg.poll_backoff_ms_min,
111        backoff_ms_max: cfg.poll_backoff_ms_max,
112    };
113
114    loop {
115        if shutdown.is_cancelled() {
116            break;
117        }
118
119        // Ensure SIWE token is available before attempting DMS operations
120        if let Err(err) = siwe_handle.bearer().await {
121            warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
122            let delay_ms = jittered_delay_ms(poll_cfg);
123            tokio::select! {
124                _ = shutdown.cancelled() => break,
125                _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
126            }
127        }
128
129        let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
130        let dms_client = match crate::dms::client::DmsClient::new(
131            cfg.dms_base_url.clone(),
132            timeout,
133            std::sync::Arc::new(siwe_handle.clone()),
134        ) {
135            Ok(client) => client,
136            Err(err) => {
137                warn!(error = %err, "Failed to create DMS client; backing off");
138                let delay_ms = jittered_delay_ms(poll_cfg);
139                tokio::select! {
140                    _ = shutdown.cancelled() => break,
141                    _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
142                }
143            }
144        };
145
146        match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
147            Ok(true) => {
148                // Successful task execution; immediately attempt next poll.
149                continue;
150            }
151            Ok(false) => {
152                let delay_ms = jittered_delay_ms(poll_cfg);
153                info!(delay_ms, "No lease available; backing off before next poll");
154                tokio::select! {
155                    _ = shutdown.cancelled() => break,
156                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
157                }
158            }
159            Err(err) => {
160                warn!(error = %err, "DMS cycle failed; backing off");
161                let delay_ms = jittered_delay_ms(poll_cfg);
162                tokio::select! {
163                    _ = shutdown.cancelled() => break,
164                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
165                }
166            }
167        }
168    }
169
170    siwe_handle.shutdown().await;
171    info!("Shutdown signal received; exiting run_node loop");
172
173    Ok(())
174}
175
176/// Build storage ports (input/output) for a given lease by constructing a TokenRef
177/// from the lease's access token and delegating to storage::build_ports.
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
179    let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
180    crate::storage::build_ports(lease, token)
181}
182
183/// Apply heartbeat token refresh: if HeartbeatResponse carries a new access token,
184/// swap it into the provided TokenRef so subsequent storage requests use it.
185pub fn apply_heartbeat_token_update(
186    token: &crate::storage::TokenRef,
187    hb: &crate::dms::types::HeartbeatResponse,
188) {
189    if let Some(new) = hb.access_token.clone() {
190        token.swap(new);
191    }
192}
193
194/// Merge fields from a heartbeat response into the cached lease.
195pub fn merge_heartbeat_into_lease(
196    lease: &mut LeaseEnvelope,
197    hb: &crate::dms::types::HeartbeatResponse,
198) {
199    if let Some(token) = hb.access_token.clone() {
200        lease.access_token = Some(token);
201    }
202    if let Some(expiry) = hb.access_token_expires_at {
203        lease.access_token_expires_at = Some(expiry);
204    }
205    if let Some(expiry) = hb.lease_expires_at {
206        lease.lease_expires_at = Some(expiry);
207    }
208    if let Some(cancel) = hb.cancel {
209        lease.cancel = cancel;
210    }
211    if let Some(status) = hb.status.clone() {
212        lease.status = Some(status);
213    }
214    if let Some(domain_id) = hb.domain_id {
215        lease.domain_id = Some(domain_id);
216    }
217    if let Some(url) = hb.domain_server_url.clone() {
218        lease.domain_server_url = Some(url);
219    }
220    if let Some(task) = hb.task.clone() {
221        lease.task = task;
222    } else {
223        if let Some(task_id) = hb.task_id {
224            lease.task.id = task_id;
225        }
226        if let Some(job_id) = hb.job_id {
227            lease.task.job_id = Some(job_id);
228        }
229        if let Some(attempts) = hb.attempts {
230            lease.task.attempts = Some(attempts);
231        }
232        if let Some(max_attempts) = hb.max_attempts {
233            lease.task.max_attempts = Some(max_attempts);
234        }
235        if let Some(deps_remaining) = hb.deps_remaining {
236            lease.task.deps_remaining = Some(deps_remaining);
237        }
238    }
239}
240
241/// Run a single poll→run→complete/fail cycle using DMS client and the runner registry.
242/// This is a minimal integration used by tests; `run_node` wiring remains separate.
243pub async fn run_cycle_with_dms(
244    _cfg: &crate::config::NodeConfig,
245    dms: &DmsClient,
246    reg: &RunnerRegistry,
247) -> Result<bool> {
248    use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
249    use serde_json::json;
250
251    let capabilities = reg.capabilities();
252    let capability = capabilities
253        .first()
254        .cloned()
255        .ok_or_else(|| anyhow!("no runners registered"))?;
256
257    // Lease a task from DMS
258    let mut lease = match dms.lease_by_capability(&capability).await? {
259        Some(lease) => lease,
260        None => {
261            return Ok(false);
262        }
263    };
264    if lease.access_token.is_none() {
265        tracing::warn!(
266            "Lease missing access token; storage client will fall back to legacy token flow"
267        );
268    }
269
270    // Initialise session state for heartbeats and token rotation.
271    let selector = CapabilitySelector::new(capabilities.clone());
272    let session = SessionManager::new(selector);
273    let policy = HeartbeatPolicy::default_policy();
274    let mut rng = StdRng::from_entropy();
275    let snapshot = session
276        .start_session(&lease, Instant::now(), &policy, &mut rng)
277        .await
278        .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
279    if snapshot.cancel() {
280        warn!(
281            task_id = %snapshot.task_id(),
282            "Lease already marked as cancelled; skipping execution"
283        );
284        return Ok(true);
285    }
286
287    let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
288
289    let heartbeat_initial = dms
290        .heartbeat(
291            lease.task.id,
292            &HeartbeatRequest {
293                progress: json!({}),
294                events: json!({}),
295            },
296        )
297        .await?;
298    apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
299    merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
300    session
301        .apply_heartbeat(
302            &heartbeat_initial,
303            Some(json!({})),
304            Instant::now(),
305            &policy,
306            &mut rng,
307        )
308        .await
309        .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
310
311    let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
312
313    let (progress_tx, progress_rx) = progress_channel();
314    let control_state = Arc::new(Mutex::new(ControlState::default()));
315    {
316        let mut guard = control_state.lock().await;
317        guard.progress = json!({});
318        guard.events = json!({});
319    }
320
321    let runner_cancel = CancellationToken::new();
322    let heartbeat_shutdown = CancellationToken::new();
323
324    let ctrl = EngineControlPlane::new(
325        runner_cancel.clone(),
326        progress_tx.clone(),
327        control_state.clone(),
328    );
329
330    // Trigger an immediate heartbeat once the loop starts to refresh tokens.
331    progress_tx.update(json!({}), json!({}));
332
333    let heartbeat_driver = HeartbeatDriver::new(
334        dms.clone(),
335        HeartbeatDriverArgs {
336            session: session.clone(),
337            policy,
338            rng,
339            progress_rx,
340            state: control_state.clone(),
341            token_ref: token_ref.clone(),
342            runner_cancel: runner_cancel.clone(),
343            shutdown: heartbeat_shutdown.clone(),
344            task_id: lease.task.id,
345        },
346    );
347    let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
348
349    let run_res = reg
350        .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl)
351        .await;
352
353    heartbeat_shutdown.cancel();
354    let heartbeat_result = match heartbeat_handle.await {
355        Ok(result) => result,
356        Err(err) => {
357            warn!(error = %err, "heartbeat loop task failed");
358            HeartbeatLoopResult::Completed
359        }
360    };
361
362    match heartbeat_result {
363        HeartbeatLoopResult::Completed => {}
364        HeartbeatLoopResult::Cancelled => {
365            info!(
366                task_id = %lease.task.id,
367                "Lease cancelled during execution; skipping completion"
368            );
369            runner_cancel.cancel();
370            return Ok(true);
371        }
372        HeartbeatLoopResult::LostLease(err) => {
373            warn!(
374                task_id = %lease.task.id,
375                error = %err,
376                "Lease lost during heartbeat; abandoning task"
377            );
378            runner_cancel.cancel();
379            return Ok(true);
380        }
381    }
382
383    let uploaded_artifacts = ports.uploaded_artifacts();
384    let artifacts_json: Vec<Value> = uploaded_artifacts
385        .iter()
386        .map(|artifact| {
387            json!({
388                "logical_path": artifact.logical_path,
389                "name": artifact.name,
390                "data_type": artifact.data_type,
391                "id": artifact.id,
392            })
393        })
394        .collect();
395    let job_info = json!({
396        "task_id": lease.task.id,
397        "job_id": lease.task.job_id,
398        "domain_id": lease.domain_id,
399        "capability": lease.task.capability,
400    });
401
402    // Complete or fail the task depending on runner outcome.
403    match run_res {
404        Ok(()) => {
405            let body = CompleteTaskRequest {
406                outputs_index: json!({ "artifacts": artifacts_json.clone() }),
407                result: json!({
408                    "job": job_info,
409                    "artifacts": artifacts_json,
410                }),
411            };
412            dms.complete(lease.task.id, &body).await?;
413        }
414        Err(err) => {
415            error!(
416                task_id = %lease.task.id,
417                job_id = ?lease.task.job_id,
418                capability = %lease.task.capability,
419                error = %err,
420                debug = ?err,
421                "Runner execution failed; reporting failure to DMS"
422            );
423            let body = FailTaskRequest {
424                reason: err.to_string(),
425                details: json!({
426                    "job": job_info,
427                    "artifacts": artifacts_json,
428                }),
429            };
430            dms.fail(lease.task.id, &body)
431                .await
432                .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
433        }
434    }
435
436    Ok(true)
437}
438
439#[derive(Default)]
440pub struct ControlState {
441    progress: Value,
442    events: Value,
443}
444
445struct EngineControlPlane {
446    cancel: CancellationToken,
447    progress_tx: ProgressSender,
448    state: Arc<Mutex<ControlState>>,
449}
450
451impl EngineControlPlane {
452    pub fn new(
453        cancel: CancellationToken,
454        progress_tx: ProgressSender,
455        state: Arc<Mutex<ControlState>>,
456    ) -> Self {
457        Self {
458            cancel,
459            progress_tx,
460            state,
461        }
462    }
463}
464
465#[async_trait]
466impl ControlPlane for EngineControlPlane {
467    async fn is_cancelled(&self) -> bool {
468        self.cancel.is_cancelled()
469    }
470
471    async fn progress(&self, value: Value) -> Result<()> {
472        let events = {
473            let mut state = self.state.lock().await;
474            state.progress = value.clone();
475            state.events.clone()
476        };
477        self.progress_tx.update(value, events);
478        Ok(())
479    }
480
481    async fn log_event(&self, fields: Value) -> Result<()> {
482        let progress = {
483            let mut state = self.state.lock().await;
484            state.events = fields.clone();
485            state.progress.clone()
486        };
487        self.progress_tx.update(progress, fields);
488        Ok(())
489    }
490}
491
492pub enum HeartbeatLoopResult {
493    Completed,
494    Cancelled,
495    LostLease(anyhow::Error),
496}
497
498#[async_trait]
499pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
500    async fn post_heartbeat(
501        &self,
502        task_id: Uuid,
503        body: &crate::dms::types::HeartbeatRequest,
504    ) -> Result<crate::dms::types::HeartbeatResponse>;
505}
506
507#[async_trait]
508impl HeartbeatTransport for DmsClient {
509    async fn post_heartbeat(
510        &self,
511        task_id: Uuid,
512        body: &crate::dms::types::HeartbeatRequest,
513    ) -> Result<crate::dms::types::HeartbeatResponse> {
514        self.heartbeat(task_id, body).await
515    }
516}
517
518pub struct HeartbeatDriverArgs {
519    pub session: SessionManager,
520    pub policy: HeartbeatPolicy,
521    pub rng: StdRng,
522    pub progress_rx: ProgressReceiver,
523    pub state: Arc<Mutex<ControlState>>,
524    pub token_ref: crate::storage::TokenRef,
525    pub runner_cancel: CancellationToken,
526    pub shutdown: CancellationToken,
527    pub task_id: Uuid,
528}
529
530pub struct HeartbeatDriver<T>
531where
532    T: HeartbeatTransport,
533{
534    transport: T,
535    session: SessionManager,
536    policy: HeartbeatPolicy,
537    rng: StdRng,
538    progress_rx: ProgressReceiver,
539    state: Arc<Mutex<ControlState>>,
540    token_ref: crate::storage::TokenRef,
541    runner_cancel: CancellationToken,
542    shutdown: CancellationToken,
543    task_id: Uuid,
544    last_progress: Value,
545}
546
547impl<T> HeartbeatDriver<T>
548where
549    T: HeartbeatTransport,
550{
551    pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
552        Self {
553            transport,
554            session: args.session,
555            policy: args.policy,
556            rng: args.rng,
557            progress_rx: args.progress_rx,
558            state: args.state,
559            token_ref: args.token_ref,
560            runner_cancel: args.runner_cancel,
561            shutdown: args.shutdown,
562            task_id: args.task_id,
563            last_progress: Value::default(),
564        }
565    }
566
567    pub async fn run(mut self) -> HeartbeatLoopResult {
568        loop {
569            if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
570                return HeartbeatLoopResult::Completed;
571            }
572
573            let snapshot = match self.session.snapshot().await {
574                Some(s) => s,
575                None => return HeartbeatLoopResult::Completed,
576            };
577
578            let ttl_delay = snapshot
579                .next_heartbeat_due()
580                .map(|due| due.saturating_duration_since(Instant::now()));
581
582            if let Some(delay) = ttl_delay {
583                tokio::select! {
584                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
585                    progress = self.progress_rx.recv() => {
586                        if let Some(data) = progress {
587                            if let Some(outcome) = self.handle_progress(data).await {
588                                return outcome;
589                            }
590                        } else {
591                            return HeartbeatLoopResult::Completed;
592                        }
593                    }
594                    _ = tokio::time::sleep(delay) => {
595                        if let Some(outcome) = self.handle_ttl().await {
596                            return outcome;
597                        }
598                    }
599                }
600            } else {
601                tokio::select! {
602                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
603                    progress = self.progress_rx.recv() => {
604                        if let Some(data) = progress {
605                            if let Some(outcome) = self.handle_progress(data).await {
606                                return outcome;
607                            }
608                        } else {
609                            return HeartbeatLoopResult::Completed;
610                        }
611                    }
612                }
613            }
614        }
615    }
616
617    async fn handle_progress(
618        &mut self,
619        data: crate::heartbeat::HeartbeatData,
620    ) -> Option<HeartbeatLoopResult> {
621        self.last_progress = data.progress.clone();
622        self.send_and_update(data.progress, data.events).await
623    }
624
625    async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
626        let (progress, events) = self.snapshot_state().await;
627        self.send_and_update(progress, events).await
628    }
629
630    async fn snapshot_state(&self) -> (Value, Value) {
631        let state = self.state.lock().await;
632        (state.progress.clone(), state.events.clone())
633    }
634
635    async fn send_and_update(
636        &mut self,
637        progress: Value,
638        events: Value,
639    ) -> Option<HeartbeatLoopResult> {
640        let request = crate::dms::types::HeartbeatRequest {
641            progress: progress.clone(),
642            events: events.clone(),
643        };
644
645        match self.transport.post_heartbeat(self.task_id, &request).await {
646            Ok(update) => {
647                apply_heartbeat_token_update(&self.token_ref, &update);
648                if let Some(task) = &update.task {
649                    self.task_id = task.id;
650                } else if let Some(task_id) = update.task_id {
651                    self.task_id = task_id;
652                }
653                if let Err(err) = self
654                    .session
655                    .apply_heartbeat(
656                        &update,
657                        Some(progress.clone()),
658                        Instant::now(),
659                        &self.policy,
660                        &mut self.rng,
661                    )
662                    .await
663                {
664                    return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
665                }
666                if update.cancel.unwrap_or(false) {
667                    self.runner_cancel.cancel();
668                    return Some(HeartbeatLoopResult::Cancelled);
669                }
670                None
671            }
672            Err(err) => {
673                self.runner_cancel.cancel();
674                Some(HeartbeatLoopResult::LostLease(err))
675            }
676        }
677    }
678}