Skip to main content

harn_vm/
waitpoints.rs

1#[cfg(test)]
2use std::cell::RefCell;
3use std::collections::{BTreeMap, BTreeSet};
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7use time::format_description::well_known::Rfc3339;
8use time::OffsetDateTime;
9
10use crate::event_log::{
11    sanitize_topic_component, AnyEventLog, EventLog, LogError, LogEvent, Topic,
12};
13#[cfg(test)]
14use tokio::sync::oneshot;
15
16pub const WAITPOINT_STATE_TOPIC_PREFIX: &str = "waitpoint.state.";
17pub const WAITPOINT_WAITS_TOPIC: &str = "waitpoint.waits";
18
19#[cfg(test)]
20thread_local! {
21    static TEST_WAIT_SIGNALS: RefCell<Vec<WaitpointTestSignal>> = const { RefCell::new(Vec::new()) };
22}
23
24#[cfg(test)]
25struct WaitpointTestSignal {
26    wait_id: String,
27    kind: WaitpointTestSignalKind,
28    tx: oneshot::Sender<()>,
29}
30
31#[cfg(test)]
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub(crate) enum WaitpointTestSignalKind {
34    Started,
35    Interrupted,
36}
37
38#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum WaitpointStatus {
41    #[default]
42    Open,
43    Completed,
44    Cancelled,
45}
46
47impl WaitpointStatus {
48    pub fn as_str(self) -> &'static str {
49        match self {
50            Self::Open => "open",
51            Self::Completed => "completed",
52            Self::Cancelled => "cancelled",
53        }
54    }
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum WaitpointWaitStatus {
60    Completed,
61    Cancelled,
62    TimedOut,
63    Interrupted,
64}
65
66impl WaitpointWaitStatus {
67    pub fn as_str(self) -> &'static str {
68        match self {
69            Self::Completed => "completed",
70            Self::Cancelled => "cancelled",
71            Self::TimedOut => "timed_out",
72            Self::Interrupted => "interrupted",
73        }
74    }
75
76    fn event_kind(self) -> &'static str {
77        match self {
78            Self::Completed => "waitpoint_wait_completed",
79            Self::Cancelled => "waitpoint_wait_cancelled",
80            Self::TimedOut => "waitpoint_wait_timed_out",
81            Self::Interrupted => "waitpoint_wait_interrupted",
82        }
83    }
84}
85
86#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
87pub struct WaitpointRecord {
88    pub id: String,
89    pub status: WaitpointStatus,
90    pub created_at: String,
91    pub created_by: Option<String>,
92    pub completed_at: Option<String>,
93    pub completed_by: Option<String>,
94    pub cancelled_at: Option<String>,
95    pub cancelled_by: Option<String>,
96    pub reason: Option<String>,
97    #[serde(default)]
98    pub metadata: BTreeMap<String, serde_json::Value>,
99}
100
101impl WaitpointRecord {
102    pub fn open(
103        id: impl Into<String>,
104        created_by: Option<String>,
105        metadata: BTreeMap<String, serde_json::Value>,
106    ) -> Self {
107        Self {
108            id: id.into(),
109            status: WaitpointStatus::Open,
110            created_at: now_rfc3339(),
111            created_by,
112            completed_at: None,
113            completed_by: None,
114            cancelled_at: None,
115            cancelled_by: None,
116            reason: None,
117            metadata,
118        }
119    }
120
121    pub fn is_terminal(&self) -> bool {
122        matches!(
123            self.status,
124            WaitpointStatus::Completed | WaitpointStatus::Cancelled
125        )
126    }
127}
128
129#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
130pub struct WaitpointWaitStartRecord {
131    pub wait_id: String,
132    pub waitpoint_ids: Vec<String>,
133    pub started_at: String,
134    pub trace_id: Option<String>,
135    pub replay_of_event_id: Option<String>,
136}
137
138#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
139pub struct WaitpointWaitRecord {
140    pub wait_id: String,
141    pub waitpoint_ids: Vec<String>,
142    pub status: WaitpointWaitStatus,
143    pub started_at: String,
144    pub resolved_at: String,
145    pub waitpoints: Vec<WaitpointRecord>,
146    pub cancelled_waitpoint_id: Option<String>,
147    pub trace_id: Option<String>,
148    pub replay_of_event_id: Option<String>,
149    pub reason: Option<String>,
150}
151
152#[derive(Clone, Debug, PartialEq, Eq)]
153pub enum WaitpointResolution {
154    Pending,
155    Completed,
156    Cancelled { waitpoint_id: String },
157}
158
159pub fn dedupe_waitpoint_ids(ids: &[String]) -> Vec<String> {
160    let mut seen = BTreeSet::new();
161    let mut out = Vec::new();
162    for id in ids {
163        let trimmed = id.trim();
164        if trimmed.is_empty() {
165            continue;
166        }
167        if seen.insert(trimmed.to_string()) {
168            out.push(trimmed.to_string());
169        }
170    }
171    out
172}
173
174pub fn waitpoint_topic(id: &str) -> Result<Topic, LogError> {
175    Topic::new(format!(
176        "{WAITPOINT_STATE_TOPIC_PREFIX}{}",
177        sanitize_topic_component(id)
178    ))
179}
180
181pub fn waits_topic() -> Result<Topic, LogError> {
182    Topic::new(WAITPOINT_WAITS_TOPIC)
183}
184
185pub async fn load_waitpoint(
186    log: &Arc<AnyEventLog>,
187    id: &str,
188) -> Result<Option<WaitpointRecord>, LogError> {
189    let events = log
190        .read_range(&waitpoint_topic(id)?, None, usize::MAX)
191        .await?;
192    let mut latest = None;
193    for (_, event) in events {
194        if !matches!(
195            event.kind.as_str(),
196            "waitpoint_created" | "waitpoint_completed" | "waitpoint_cancelled"
197        ) {
198            continue;
199        }
200        let Ok(record) = serde_json::from_value::<WaitpointRecord>(event.payload) else {
201            continue;
202        };
203        latest = Some(record);
204    }
205    Ok(latest)
206}
207
208pub async fn load_waitpoints(
209    log: &Arc<AnyEventLog>,
210    ids: &[String],
211) -> Result<Vec<WaitpointRecord>, LogError> {
212    let mut out = Vec::new();
213    for id in dedupe_waitpoint_ids(ids) {
214        if let Some(record) = load_waitpoint(log, &id).await? {
215            out.push(record);
216        }
217    }
218    Ok(out)
219}
220
221pub fn resolve_waitpoints(ids: &[String], waitpoints: &[WaitpointRecord]) -> WaitpointResolution {
222    let mut by_id = BTreeMap::new();
223    for waitpoint in waitpoints {
224        by_id.insert(waitpoint.id.as_str(), waitpoint);
225    }
226    let ids = dedupe_waitpoint_ids(ids);
227    if ids.is_empty() {
228        return WaitpointResolution::Pending;
229    }
230
231    let mut all_completed = true;
232    for id in ids {
233        let Some(waitpoint) = by_id.get(id.as_str()) else {
234            all_completed = false;
235            continue;
236        };
237        match waitpoint.status {
238            WaitpointStatus::Completed => {}
239            WaitpointStatus::Cancelled => {
240                return WaitpointResolution::Cancelled {
241                    waitpoint_id: waitpoint.id.clone(),
242                };
243            }
244            WaitpointStatus::Open => {
245                all_completed = false;
246            }
247        }
248    }
249
250    if all_completed {
251        WaitpointResolution::Completed
252    } else {
253        WaitpointResolution::Pending
254    }
255}
256
257pub async fn create_waitpoint(
258    log: &Arc<AnyEventLog>,
259    id: &str,
260    created_by: Option<String>,
261    metadata: BTreeMap<String, serde_json::Value>,
262) -> Result<WaitpointRecord, LogError> {
263    if let Some(existing) = load_waitpoint(log, id).await? {
264        return Ok(existing);
265    }
266    let record = WaitpointRecord::open(id, created_by, metadata);
267    append_waitpoint_state(log, "waitpoint_created", &record).await?;
268    Ok(record)
269}
270
271pub async fn complete_waitpoint(
272    log: &Arc<AnyEventLog>,
273    id: &str,
274    completed_by: Option<String>,
275) -> Result<WaitpointRecord, LogError> {
276    let existing = load_waitpoint(log, id).await?;
277    if let Some(existing) = existing.as_ref() {
278        if existing.is_terminal() {
279            return Ok(existing.clone());
280        }
281    }
282
283    let now = now_rfc3339();
284    let mut record = existing.unwrap_or_else(|| WaitpointRecord {
285        id: id.to_string(),
286        status: WaitpointStatus::Open,
287        created_at: now.clone(),
288        created_by: completed_by.clone(),
289        completed_at: None,
290        completed_by: None,
291        cancelled_at: None,
292        cancelled_by: None,
293        reason: None,
294        metadata: BTreeMap::new(),
295    });
296    record.status = WaitpointStatus::Completed;
297    record.completed_at = Some(now);
298    record.completed_by = completed_by;
299    record.cancelled_at = None;
300    record.cancelled_by = None;
301    record.reason = None;
302    append_waitpoint_state(log, "waitpoint_completed", &record).await?;
303    Ok(record)
304}
305
306pub async fn cancel_waitpoint(
307    log: &Arc<AnyEventLog>,
308    id: &str,
309    cancelled_by: Option<String>,
310    reason: Option<String>,
311) -> Result<WaitpointRecord, LogError> {
312    let existing = load_waitpoint(log, id).await?;
313    if let Some(existing) = existing.as_ref() {
314        if existing.is_terminal() {
315            return Ok(existing.clone());
316        }
317    }
318
319    let now = now_rfc3339();
320    let mut record = existing.unwrap_or_else(|| WaitpointRecord {
321        id: id.to_string(),
322        status: WaitpointStatus::Open,
323        created_at: now.clone(),
324        created_by: cancelled_by.clone(),
325        completed_at: None,
326        completed_by: None,
327        cancelled_at: None,
328        cancelled_by: None,
329        reason: None,
330        metadata: BTreeMap::new(),
331    });
332    record.status = WaitpointStatus::Cancelled;
333    record.completed_at = None;
334    record.completed_by = None;
335    record.cancelled_at = Some(now);
336    record.cancelled_by = cancelled_by;
337    record.reason = reason;
338    append_waitpoint_state(log, "waitpoint_cancelled", &record).await?;
339    Ok(record)
340}
341
342pub async fn append_wait_started(
343    log: &Arc<AnyEventLog>,
344    record: &WaitpointWaitStartRecord,
345) -> Result<(), LogError> {
346    log.append(
347        &waits_topic()?,
348        LogEvent::new(
349            "waitpoint_wait_started",
350            serde_json::to_value(record).map_err(|error| {
351                LogError::Serde(format!("waitpoint wait encode error: {error}"))
352            })?,
353        )
354        .with_headers(wait_headers(&record.wait_id, &record.waitpoint_ids)),
355    )
356    .await
357    .map(|_| ())?;
358    notify_test_wait_started(&record.wait_id);
359    Ok(())
360}
361
362pub async fn append_wait_terminal(
363    log: &Arc<AnyEventLog>,
364    record: &WaitpointWaitRecord,
365) -> Result<(), LogError> {
366    log.append(
367        &waits_topic()?,
368        LogEvent::new(
369            record.status.event_kind(),
370            serde_json::to_value(record).map_err(|error| {
371                LogError::Serde(format!("waitpoint wait encode error: {error}"))
372            })?,
373        )
374        .with_headers(wait_headers(&record.wait_id, &record.waitpoint_ids)),
375    )
376    .await
377    .map(|_| ())?;
378    if record.status == WaitpointWaitStatus::Interrupted {
379        notify_test_wait_interrupted(&record.wait_id);
380    }
381    Ok(())
382}
383
384#[cfg(test)]
385pub(crate) fn install_test_wait_signal(
386    wait_id: impl Into<String>,
387    kind: WaitpointTestSignalKind,
388    tx: oneshot::Sender<()>,
389) {
390    TEST_WAIT_SIGNALS.with(|slot| {
391        slot.borrow_mut().push(WaitpointTestSignal {
392            wait_id: wait_id.into(),
393            kind,
394            tx,
395        });
396    });
397}
398
399#[cfg(test)]
400pub(crate) fn clear_test_wait_signals() {
401    TEST_WAIT_SIGNALS.with(|slot| slot.borrow_mut().clear());
402}
403
404#[cfg(not(test))]
405fn notify_test_wait_started(_wait_id: &str) {}
406
407#[cfg(test)]
408fn notify_test_wait_started(wait_id: &str) {
409    notify_test_wait_signal(wait_id, WaitpointTestSignalKind::Started);
410}
411
412#[cfg(not(test))]
413fn notify_test_wait_interrupted(_wait_id: &str) {}
414
415#[cfg(test)]
416fn notify_test_wait_interrupted(wait_id: &str) {
417    notify_test_wait_signal(wait_id, WaitpointTestSignalKind::Interrupted);
418}
419
420#[cfg(test)]
421fn notify_test_wait_signal(wait_id: &str, kind: WaitpointTestSignalKind) {
422    TEST_WAIT_SIGNALS.with(|slot| {
423        let mut signals = slot.borrow_mut();
424        let mut index = 0;
425        while index < signals.len() {
426            if signals[index].wait_id == wait_id && signals[index].kind == kind {
427                let signal = signals.remove(index);
428                let _ = signal.tx.send(());
429            } else {
430                index += 1;
431            }
432        }
433    });
434}
435
436pub async fn find_wait_terminal(
437    log: &Arc<AnyEventLog>,
438    wait_id: &str,
439) -> Result<Option<WaitpointWaitRecord>, LogError> {
440    let events = log.read_range(&waits_topic()?, None, usize::MAX).await?;
441    let mut latest = None;
442    for (_, event) in events {
443        if !matches!(
444            event.kind.as_str(),
445            "waitpoint_wait_completed"
446                | "waitpoint_wait_cancelled"
447                | "waitpoint_wait_timed_out"
448                | "waitpoint_wait_interrupted"
449        ) {
450            continue;
451        }
452        if event.headers.get("wait_id").map(String::as_str) != Some(wait_id) {
453            continue;
454        }
455        let Ok(record) = serde_json::from_value::<WaitpointWaitRecord>(event.payload) else {
456            continue;
457        };
458        latest = Some(record);
459    }
460    Ok(latest)
461}
462
463async fn append_waitpoint_state(
464    log: &Arc<AnyEventLog>,
465    kind: &str,
466    record: &WaitpointRecord,
467) -> Result<(), LogError> {
468    log.append(
469        &waitpoint_topic(&record.id)?,
470        LogEvent::new(
471            kind,
472            serde_json::to_value(record)
473                .map_err(|error| LogError::Serde(format!("waitpoint encode error: {error}")))?,
474        )
475        .with_headers(waitpoint_headers(record)),
476    )
477    .await
478    .map(|_| ())
479}
480
481fn wait_headers(wait_id: &str, waitpoint_ids: &[String]) -> BTreeMap<String, String> {
482    let mut headers = BTreeMap::new();
483    headers.insert("wait_id".to_string(), wait_id.to_string());
484    headers.insert("waitpoints".to_string(), waitpoint_ids.join(","));
485    headers
486}
487
488fn waitpoint_headers(record: &WaitpointRecord) -> BTreeMap<String, String> {
489    let mut headers = BTreeMap::new();
490    headers.insert("waitpoint_id".to_string(), record.id.clone());
491    headers.insert("status".to_string(), record.status.as_str().to_string());
492    if let Some(created_by) = record.created_by.as_ref() {
493        headers.insert("created_by".to_string(), created_by.clone());
494    }
495    if let Some(completed_by) = record.completed_by.as_ref() {
496        headers.insert("completed_by".to_string(), completed_by.clone());
497    }
498    if let Some(cancelled_by) = record.cancelled_by.as_ref() {
499        headers.insert("cancelled_by".to_string(), cancelled_by.clone());
500    }
501    headers
502}
503
504fn now_rfc3339() -> String {
505    OffsetDateTime::now_utc()
506        .format(&Rfc3339)
507        .unwrap_or_else(|_| OffsetDateTime::now_utc().to_string())
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use crate::event_log::{FileEventLog, MemoryEventLog};
514
515    #[tokio::test]
516    async fn waitpoint_state_persists_across_file_reopen() {
517        let dir = tempfile::tempdir().expect("tempdir");
518        let first = Arc::new(AnyEventLog::File(
519            FileEventLog::open(dir.path().to_path_buf(), 32).expect("open file log"),
520        ));
521        create_waitpoint(&first, "demo", Some("creator".to_string()), BTreeMap::new())
522            .await
523            .expect("create waitpoint");
524        complete_waitpoint(&first, "demo", Some("completer".to_string()))
525            .await
526            .expect("complete waitpoint");
527
528        let reopened = Arc::new(AnyEventLog::File(
529            FileEventLog::open(dir.path().to_path_buf(), 32).expect("reopen file log"),
530        ));
531        let state = load_waitpoint(&reopened, "demo")
532            .await
533            .expect("load state")
534            .expect("waitpoint exists");
535        assert_eq!(state.status, WaitpointStatus::Completed);
536        assert_eq!(state.completed_by.as_deref(), Some("completer"));
537    }
538
539    #[tokio::test]
540    async fn wait_terminal_lookup_returns_latest_terminal_record() {
541        let log = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(32)));
542        append_wait_started(
543            &log,
544            &WaitpointWaitStartRecord {
545                wait_id: "wait-demo".to_string(),
546                waitpoint_ids: vec!["a".to_string(), "b".to_string()],
547                started_at: "2026-01-01T00:00:00Z".to_string(),
548                trace_id: Some("trace-demo".to_string()),
549                replay_of_event_id: None,
550            },
551        )
552        .await
553        .expect("append wait start");
554        append_wait_terminal(
555            &log,
556            &WaitpointWaitRecord {
557                wait_id: "wait-demo".to_string(),
558                waitpoint_ids: vec!["a".to_string(), "b".to_string()],
559                status: WaitpointWaitStatus::TimedOut,
560                started_at: "2026-01-01T00:00:00Z".to_string(),
561                resolved_at: "2026-01-01T00:01:00Z".to_string(),
562                waitpoints: Vec::new(),
563                cancelled_waitpoint_id: None,
564                trace_id: Some("trace-demo".to_string()),
565                replay_of_event_id: None,
566                reason: Some("deadline elapsed".to_string()),
567            },
568        )
569        .await
570        .expect("append wait result");
571
572        let record = find_wait_terminal(&log, "wait-demo")
573            .await
574            .expect("lookup wait result")
575            .expect("wait result exists");
576        assert_eq!(record.status, WaitpointWaitStatus::TimedOut);
577        assert_eq!(record.reason.as_deref(), Some("deadline elapsed"));
578    }
579}