posemesh_compute_node/
engine.rs

1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17    dms::client::DmsClient,
18    heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19    poller::{jittered_delay_ms, PollerConfig},
20    session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23/// Registry mapping capability strings to runner instances.
24#[derive(Default)]
25pub struct RunnerRegistry {
26    runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30    /// Create an empty registry.
31    pub fn new() -> Self {
32        Self {
33            runners: HashMap::new(),
34        }
35    }
36
37    /// Register a runner by its capability. Last registration wins on duplicates.
38    pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39        let key = runner.capability().to_string();
40        self.runners.insert(key, Arc::new(runner));
41        self
42    }
43
44    /// Retrieve a runner by capability.
45    pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46        self.runners.get(capability).cloned()
47    }
48
49    /// Snapshot of registered capability strings.
50    pub fn capabilities(&self) -> Vec<String> {
51        let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52        caps.sort();
53        caps
54    }
55
56    /// Dispatch task to the appropriate runner based on `lease.task.capability`.
57    pub async fn run_for_lease(
58        &self,
59        lease: &LeaseEnvelope,
60        input: &dyn InputSource,
61        output: &dyn ArtifactSink,
62        ctrl: &dyn ControlPlane,
63    ) -> std::result::Result<(), crate::errors::ExecutorError> {
64        let cap = lease.task.capability.as_str();
65        let runner = self
66            .get(cap)
67            .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
68        let ctx = TaskCtx {
69            lease,
70            input,
71            output,
72            ctrl,
73        };
74        runner
75            .run(ctx)
76            .await
77            .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
78    }
79}
80
81/// Run the node main loop. Networking and storage are wired in later prompts.
82pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
83    let shutdown = CancellationToken::new();
84    let signal_token = shutdown.clone();
85    let signal_task = tokio::spawn(async move {
86        if tokio::signal::ctrl_c().await.is_ok() {
87            signal_token.cancel();
88        }
89    });
90
91    let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
92
93    shutdown.cancel();
94    let _ = signal_task.await;
95
96    result
97}
98
99pub async fn run_node_with_shutdown(
100    cfg: crate::config::NodeConfig,
101    runners: RunnerRegistry,
102    shutdown: CancellationToken,
103) -> Result<()> {
104    let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
105    info!("DDS SIWE authentication configured; waiting for DDS registration callback");
106    let siwe_handle = siwe.start().await?;
107    info!("DDS SIWE token manager started");
108
109    let poll_cfg = PollerConfig {
110        backoff_ms_min: cfg.poll_backoff_ms_min,
111        backoff_ms_max: cfg.poll_backoff_ms_max,
112    };
113
114    loop {
115        if shutdown.is_cancelled() {
116            break;
117        }
118
119        let bearer = match siwe_handle.bearer().await {
120            Ok(token) => token,
121            Err(err) => {
122                warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
123                let delay_ms = jittered_delay_ms(poll_cfg);
124                tokio::select! {
125                    _ = shutdown.cancelled() => break,
126                    _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
127                }
128            }
129        };
130
131        let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132        let dms_client = match crate::dms::client::DmsClient::new(
133            cfg.dms_base_url.clone(),
134            timeout,
135            Some(bearer),
136        ) {
137            Ok(client) => client,
138            Err(err) => {
139                warn!(error = %err, "Failed to create DMS client; backing off");
140                let delay_ms = jittered_delay_ms(poll_cfg);
141                tokio::select! {
142                    _ = shutdown.cancelled() => break,
143                    _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144                }
145            }
146        };
147
148        match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149            Ok(true) => {
150                // Successful task execution; immediately attempt next poll.
151                continue;
152            }
153            Ok(false) => {
154                let delay_ms = jittered_delay_ms(poll_cfg);
155                info!(delay_ms, "No lease available; backing off before next poll");
156                tokio::select! {
157                    _ = shutdown.cancelled() => break,
158                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159                }
160            }
161            Err(err) => {
162                warn!(error = %err, "DMS cycle failed; backing off");
163                let delay_ms = jittered_delay_ms(poll_cfg);
164                tokio::select! {
165                    _ = shutdown.cancelled() => break,
166                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167                }
168            }
169        }
170    }
171
172    siwe_handle.shutdown().await;
173    info!("Shutdown signal received; exiting run_node loop");
174
175    Ok(())
176}
177
178/// Build storage ports (input/output) for a given lease by constructing a TokenRef
179/// from the lease's access token and delegating to storage::build_ports.
180pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181    let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182    crate::storage::build_ports(lease, token)
183}
184
185/// Apply heartbeat token refresh: if HeartbeatResponse carries a new access token,
186/// swap it into the provided TokenRef so subsequent storage requests use it.
187pub fn apply_heartbeat_token_update(
188    token: &crate::storage::TokenRef,
189    hb: &crate::dms::types::HeartbeatResponse,
190) {
191    if let Some(new) = hb.access_token.clone() {
192        token.swap(new);
193    }
194}
195
196/// Run a single poll→run→complete/fail cycle using DMS client and the runner registry.
197/// This is a minimal integration used by tests; `run_node` wiring remains separate.
198pub async fn run_cycle_with_dms(
199    _cfg: &crate::config::NodeConfig,
200    dms: &DmsClient,
201    reg: &RunnerRegistry,
202) -> Result<bool> {
203    use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
204    use serde_json::json;
205
206    let capabilities = reg.capabilities();
207    let capability = capabilities
208        .first()
209        .cloned()
210        .ok_or_else(|| anyhow!("no runners registered"))?;
211
212    // Lease a task from DMS
213    let mut lease = match dms.lease_by_capability(&capability).await? {
214        Some(lease) => lease,
215        None => {
216            return Ok(false);
217        }
218    };
219    if lease.access_token.is_none() {
220        tracing::warn!(
221            "Lease missing access token; storage client will fall back to legacy token flow"
222        );
223    }
224
225    // Initialise session state for heartbeats and token rotation.
226    let selector = CapabilitySelector::new(capabilities.clone());
227    let session = SessionManager::new(selector);
228    let policy = HeartbeatPolicy::default_policy();
229    let mut rng = StdRng::from_entropy();
230    let snapshot = session
231        .start_session(&lease, Instant::now(), &policy, &mut rng)
232        .await
233        .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
234    if snapshot.cancel() {
235        warn!(
236            task_id = %snapshot.task_id(),
237            "Lease already marked as cancelled; skipping execution"
238        );
239        return Ok(true);
240    }
241
242    let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
243
244    let heartbeat_initial = dms
245        .heartbeat(
246            lease.task.id,
247            &HeartbeatRequest {
248                progress: json!({}),
249                events: json!({}),
250            },
251        )
252        .await?;
253    apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
254    session
255        .apply_heartbeat(
256            &heartbeat_initial,
257            Some(json!({})),
258            Instant::now(),
259            &policy,
260            &mut rng,
261        )
262        .await
263        .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
264    lease.access_token = heartbeat_initial.access_token.clone();
265    lease.access_token_expires_at = heartbeat_initial.access_token_expires_at;
266    lease.lease_expires_at = heartbeat_initial.lease_expires_at;
267    lease.cancel = heartbeat_initial.cancel;
268
269    let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
270
271    let (progress_tx, progress_rx) = progress_channel();
272    let control_state = Arc::new(Mutex::new(ControlState::default()));
273    {
274        let mut guard = control_state.lock().await;
275        guard.progress = json!({});
276        guard.events = json!({});
277    }
278
279    let runner_cancel = CancellationToken::new();
280    let heartbeat_shutdown = CancellationToken::new();
281
282    let ctrl = EngineControlPlane::new(
283        runner_cancel.clone(),
284        progress_tx.clone(),
285        control_state.clone(),
286    );
287
288    // Trigger an immediate heartbeat once the loop starts to refresh tokens.
289    progress_tx.update(json!({}), json!({}));
290
291    let heartbeat_driver = HeartbeatDriver::new(
292        dms.clone(),
293        HeartbeatDriverArgs {
294            session: session.clone(),
295            policy,
296            rng,
297            progress_rx,
298            state: control_state.clone(),
299            token_ref: token_ref.clone(),
300            runner_cancel: runner_cancel.clone(),
301            shutdown: heartbeat_shutdown.clone(),
302            task_id: lease.task.id,
303        },
304    );
305    let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
306
307    let run_res = reg
308        .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl)
309        .await;
310
311    heartbeat_shutdown.cancel();
312    let heartbeat_result = match heartbeat_handle.await {
313        Ok(result) => result,
314        Err(err) => {
315            warn!(error = %err, "heartbeat loop task failed");
316            HeartbeatLoopResult::Completed
317        }
318    };
319
320    match heartbeat_result {
321        HeartbeatLoopResult::Completed => {}
322        HeartbeatLoopResult::Cancelled => {
323            info!(
324                task_id = %lease.task.id,
325                "Lease cancelled during execution; skipping completion"
326            );
327            runner_cancel.cancel();
328            return Ok(true);
329        }
330        HeartbeatLoopResult::LostLease(err) => {
331            warn!(
332                task_id = %lease.task.id,
333                error = %err,
334                "Lease lost during heartbeat; abandoning task"
335            );
336            runner_cancel.cancel();
337            return Ok(true);
338        }
339    }
340
341    let uploaded_artifacts = ports.uploaded_artifacts();
342    let artifacts_json: Vec<Value> = uploaded_artifacts
343        .iter()
344        .map(|artifact| {
345            json!({
346                "logical_path": artifact.logical_path,
347                "name": artifact.name,
348                "data_type": artifact.data_type,
349                "id": artifact.id,
350            })
351        })
352        .collect();
353    let job_info = json!({
354        "task_id": lease.task.id,
355        "job_id": lease.task.job_id,
356        "domain_id": lease.domain_id,
357        "capability": lease.task.capability,
358    });
359
360    // Complete or fail the task depending on runner outcome.
361    match run_res {
362        Ok(()) => {
363            let body = CompleteTaskRequest {
364                outputs_index: json!({ "artifacts": artifacts_json.clone() }),
365                result: json!({
366                    "job": job_info,
367                    "artifacts": artifacts_json,
368                }),
369            };
370            dms.complete(lease.task.id, &body).await?;
371        }
372        Err(err) => {
373            error!(
374                task_id = %lease.task.id,
375                job_id = ?lease.task.job_id,
376                capability = %lease.task.capability,
377                error = %err,
378                debug = ?err,
379                "Runner execution failed; reporting failure to DMS"
380            );
381            let body = FailTaskRequest {
382                reason: err.to_string(),
383                details: json!({
384                    "job": job_info,
385                    "artifacts": artifacts_json,
386                }),
387            };
388            dms.fail(lease.task.id, &body)
389                .await
390                .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
391        }
392    }
393
394    Ok(true)
395}
396
397#[derive(Default)]
398pub struct ControlState {
399    progress: Value,
400    events: Value,
401}
402
403struct EngineControlPlane {
404    cancel: CancellationToken,
405    progress_tx: ProgressSender,
406    state: Arc<Mutex<ControlState>>,
407}
408
409impl EngineControlPlane {
410    pub fn new(
411        cancel: CancellationToken,
412        progress_tx: ProgressSender,
413        state: Arc<Mutex<ControlState>>,
414    ) -> Self {
415        Self {
416            cancel,
417            progress_tx,
418            state,
419        }
420    }
421}
422
423#[async_trait]
424impl ControlPlane for EngineControlPlane {
425    async fn is_cancelled(&self) -> bool {
426        self.cancel.is_cancelled()
427    }
428
429    async fn progress(&self, value: Value) -> Result<()> {
430        let events = {
431            let mut state = self.state.lock().await;
432            state.progress = value.clone();
433            state.events.clone()
434        };
435        self.progress_tx.update(value, events);
436        Ok(())
437    }
438
439    async fn log_event(&self, fields: Value) -> Result<()> {
440        let progress = {
441            let mut state = self.state.lock().await;
442            state.events = fields.clone();
443            state.progress.clone()
444        };
445        self.progress_tx.update(progress, fields);
446        Ok(())
447    }
448}
449
450pub enum HeartbeatLoopResult {
451    Completed,
452    Cancelled,
453    LostLease(anyhow::Error),
454}
455
456#[async_trait]
457pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
458    async fn post_heartbeat(
459        &self,
460        task_id: Uuid,
461        body: &crate::dms::types::HeartbeatRequest,
462    ) -> Result<crate::dms::types::HeartbeatResponse>;
463}
464
465#[async_trait]
466impl HeartbeatTransport for DmsClient {
467    async fn post_heartbeat(
468        &self,
469        task_id: Uuid,
470        body: &crate::dms::types::HeartbeatRequest,
471    ) -> Result<crate::dms::types::HeartbeatResponse> {
472        self.heartbeat(task_id, body).await
473    }
474}
475
476pub struct HeartbeatDriverArgs {
477    pub session: SessionManager,
478    pub policy: HeartbeatPolicy,
479    pub rng: StdRng,
480    pub progress_rx: ProgressReceiver,
481    pub state: Arc<Mutex<ControlState>>,
482    pub token_ref: crate::storage::TokenRef,
483    pub runner_cancel: CancellationToken,
484    pub shutdown: CancellationToken,
485    pub task_id: Uuid,
486}
487
488pub struct HeartbeatDriver<T>
489where
490    T: HeartbeatTransport,
491{
492    transport: T,
493    session: SessionManager,
494    policy: HeartbeatPolicy,
495    rng: StdRng,
496    progress_rx: ProgressReceiver,
497    state: Arc<Mutex<ControlState>>,
498    token_ref: crate::storage::TokenRef,
499    runner_cancel: CancellationToken,
500    shutdown: CancellationToken,
501    task_id: Uuid,
502    last_progress: Value,
503}
504
505impl<T> HeartbeatDriver<T>
506where
507    T: HeartbeatTransport,
508{
509    pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
510        Self {
511            transport,
512            session: args.session,
513            policy: args.policy,
514            rng: args.rng,
515            progress_rx: args.progress_rx,
516            state: args.state,
517            token_ref: args.token_ref,
518            runner_cancel: args.runner_cancel,
519            shutdown: args.shutdown,
520            task_id: args.task_id,
521            last_progress: Value::default(),
522        }
523    }
524
525    pub async fn run(mut self) -> HeartbeatLoopResult {
526        loop {
527            if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
528                return HeartbeatLoopResult::Completed;
529            }
530
531            let snapshot = match self.session.snapshot().await {
532                Some(s) => s,
533                None => return HeartbeatLoopResult::Completed,
534            };
535
536            let ttl_delay = snapshot
537                .next_heartbeat_due()
538                .map(|due| due.saturating_duration_since(Instant::now()));
539
540            if let Some(delay) = ttl_delay {
541                tokio::select! {
542                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
543                    progress = self.progress_rx.recv() => {
544                        if let Some(data) = progress {
545                            if let Some(outcome) = self.handle_progress(data).await {
546                                return outcome;
547                            }
548                        } else {
549                            return HeartbeatLoopResult::Completed;
550                        }
551                    }
552                    _ = tokio::time::sleep(delay) => {
553                        if let Some(outcome) = self.handle_ttl().await {
554                            return outcome;
555                        }
556                    }
557                }
558            } else {
559                tokio::select! {
560                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
561                    progress = self.progress_rx.recv() => {
562                        if let Some(data) = progress {
563                            if let Some(outcome) = self.handle_progress(data).await {
564                                return outcome;
565                            }
566                        } else {
567                            return HeartbeatLoopResult::Completed;
568                        }
569                    }
570                }
571            }
572        }
573    }
574
575    async fn handle_progress(
576        &mut self,
577        data: crate::heartbeat::HeartbeatData,
578    ) -> Option<HeartbeatLoopResult> {
579        self.last_progress = data.progress.clone();
580        self.send_and_update(data.progress, data.events).await
581    }
582
583    async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
584        let (progress, events) = self.snapshot_state().await;
585        self.send_and_update(progress, events).await
586    }
587
588    async fn snapshot_state(&self) -> (Value, Value) {
589        let state = self.state.lock().await;
590        (state.progress.clone(), state.events.clone())
591    }
592
593    async fn send_and_update(
594        &mut self,
595        progress: Value,
596        events: Value,
597    ) -> Option<HeartbeatLoopResult> {
598        let request = crate::dms::types::HeartbeatRequest {
599            progress: progress.clone(),
600            events: events.clone(),
601        };
602
603        match self.transport.post_heartbeat(self.task_id, &request).await {
604            Ok(lease) => {
605                apply_heartbeat_token_update(&self.token_ref, &lease);
606                if let Err(err) = self
607                    .session
608                    .apply_heartbeat(
609                        &lease,
610                        Some(progress.clone()),
611                        Instant::now(),
612                        &self.policy,
613                        &mut self.rng,
614                    )
615                    .await
616                {
617                    return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
618                }
619                if lease.cancel {
620                    self.runner_cancel.cancel();
621                    return Some(HeartbeatLoopResult::Cancelled);
622                }
623                None
624            }
625            Err(err) => {
626                self.runner_cancel.cancel();
627                Some(HeartbeatLoopResult::LostLease(err))
628            }
629        }
630    }
631}