Skip to main content

harn_vm/
waitpoints.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5use time::format_description::well_known::Rfc3339;
6use time::OffsetDateTime;
7
8use crate::event_log::{
9    sanitize_topic_component, AnyEventLog, EventLog, LogError, LogEvent, Topic,
10};
11
12pub const WAITPOINT_STATE_TOPIC_PREFIX: &str = "waitpoint.state.";
13pub const WAITPOINT_WAITS_TOPIC: &str = "waitpoint.waits";
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum WaitpointStatus {
18    #[default]
19    Open,
20    Completed,
21    Cancelled,
22}
23
24impl WaitpointStatus {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::Open => "open",
28            Self::Completed => "completed",
29            Self::Cancelled => "cancelled",
30        }
31    }
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum WaitpointWaitStatus {
37    Completed,
38    Cancelled,
39    TimedOut,
40    Interrupted,
41}
42
43impl WaitpointWaitStatus {
44    pub fn as_str(self) -> &'static str {
45        match self {
46            Self::Completed => "completed",
47            Self::Cancelled => "cancelled",
48            Self::TimedOut => "timed_out",
49            Self::Interrupted => "interrupted",
50        }
51    }
52
53    fn event_kind(self) -> &'static str {
54        match self {
55            Self::Completed => "waitpoint_wait_completed",
56            Self::Cancelled => "waitpoint_wait_cancelled",
57            Self::TimedOut => "waitpoint_wait_timed_out",
58            Self::Interrupted => "waitpoint_wait_interrupted",
59        }
60    }
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
64pub struct WaitpointRecord {
65    pub id: String,
66    pub status: WaitpointStatus,
67    pub created_at: String,
68    pub created_by: Option<String>,
69    pub completed_at: Option<String>,
70    pub completed_by: Option<String>,
71    pub cancelled_at: Option<String>,
72    pub cancelled_by: Option<String>,
73    pub reason: Option<String>,
74    #[serde(default)]
75    pub metadata: BTreeMap<String, serde_json::Value>,
76}
77
78impl WaitpointRecord {
79    pub fn open(
80        id: impl Into<String>,
81        created_by: Option<String>,
82        metadata: BTreeMap<String, serde_json::Value>,
83    ) -> Self {
84        Self {
85            id: id.into(),
86            status: WaitpointStatus::Open,
87            created_at: now_rfc3339(),
88            created_by,
89            completed_at: None,
90            completed_by: None,
91            cancelled_at: None,
92            cancelled_by: None,
93            reason: None,
94            metadata,
95        }
96    }
97
98    pub fn is_terminal(&self) -> bool {
99        matches!(
100            self.status,
101            WaitpointStatus::Completed | WaitpointStatus::Cancelled
102        )
103    }
104}
105
106#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
107pub struct WaitpointWaitStartRecord {
108    pub wait_id: String,
109    pub waitpoint_ids: Vec<String>,
110    pub started_at: String,
111    pub trace_id: Option<String>,
112    pub replay_of_event_id: Option<String>,
113}
114
115#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
116pub struct WaitpointWaitRecord {
117    pub wait_id: String,
118    pub waitpoint_ids: Vec<String>,
119    pub status: WaitpointWaitStatus,
120    pub started_at: String,
121    pub resolved_at: String,
122    pub waitpoints: Vec<WaitpointRecord>,
123    pub cancelled_waitpoint_id: Option<String>,
124    pub trace_id: Option<String>,
125    pub replay_of_event_id: Option<String>,
126    pub reason: Option<String>,
127}
128
129#[derive(Clone, Debug, PartialEq, Eq)]
130pub enum WaitpointResolution {
131    Pending,
132    Completed,
133    Cancelled { waitpoint_id: String },
134}
135
136pub fn dedupe_waitpoint_ids(ids: &[String]) -> Vec<String> {
137    let mut seen = BTreeSet::new();
138    let mut out = Vec::new();
139    for id in ids {
140        let trimmed = id.trim();
141        if trimmed.is_empty() {
142            continue;
143        }
144        if seen.insert(trimmed.to_string()) {
145            out.push(trimmed.to_string());
146        }
147    }
148    out
149}
150
151pub fn waitpoint_topic(id: &str) -> Result<Topic, LogError> {
152    Topic::new(format!(
153        "{WAITPOINT_STATE_TOPIC_PREFIX}{}",
154        sanitize_topic_component(id)
155    ))
156}
157
158pub fn waits_topic() -> Result<Topic, LogError> {
159    Topic::new(WAITPOINT_WAITS_TOPIC)
160}
161
162pub async fn load_waitpoint(
163    log: &Arc<AnyEventLog>,
164    id: &str,
165) -> Result<Option<WaitpointRecord>, LogError> {
166    let events = log
167        .read_range(&waitpoint_topic(id)?, None, usize::MAX)
168        .await?;
169    let mut latest = None;
170    for (_, event) in events {
171        if !matches!(
172            event.kind.as_str(),
173            "waitpoint_created" | "waitpoint_completed" | "waitpoint_cancelled"
174        ) {
175            continue;
176        }
177        let Ok(record) = serde_json::from_value::<WaitpointRecord>(event.payload) else {
178            continue;
179        };
180        latest = Some(record);
181    }
182    Ok(latest)
183}
184
185pub async fn load_waitpoints(
186    log: &Arc<AnyEventLog>,
187    ids: &[String],
188) -> Result<Vec<WaitpointRecord>, LogError> {
189    let mut out = Vec::new();
190    for id in dedupe_waitpoint_ids(ids) {
191        if let Some(record) = load_waitpoint(log, &id).await? {
192            out.push(record);
193        }
194    }
195    Ok(out)
196}
197
198pub fn resolve_waitpoints(ids: &[String], waitpoints: &[WaitpointRecord]) -> WaitpointResolution {
199    let mut by_id = BTreeMap::new();
200    for waitpoint in waitpoints {
201        by_id.insert(waitpoint.id.as_str(), waitpoint);
202    }
203    let ids = dedupe_waitpoint_ids(ids);
204    if ids.is_empty() {
205        return WaitpointResolution::Pending;
206    }
207
208    let mut all_completed = true;
209    for id in ids {
210        let Some(waitpoint) = by_id.get(id.as_str()) else {
211            all_completed = false;
212            continue;
213        };
214        match waitpoint.status {
215            WaitpointStatus::Completed => {}
216            WaitpointStatus::Cancelled => {
217                return WaitpointResolution::Cancelled {
218                    waitpoint_id: waitpoint.id.clone(),
219                };
220            }
221            WaitpointStatus::Open => {
222                all_completed = false;
223            }
224        }
225    }
226
227    if all_completed {
228        WaitpointResolution::Completed
229    } else {
230        WaitpointResolution::Pending
231    }
232}
233
234pub async fn create_waitpoint(
235    log: &Arc<AnyEventLog>,
236    id: &str,
237    created_by: Option<String>,
238    metadata: BTreeMap<String, serde_json::Value>,
239) -> Result<WaitpointRecord, LogError> {
240    if let Some(existing) = load_waitpoint(log, id).await? {
241        return Ok(existing);
242    }
243    let record = WaitpointRecord::open(id, created_by, metadata);
244    append_waitpoint_state(log, "waitpoint_created", &record).await?;
245    Ok(record)
246}
247
248pub async fn complete_waitpoint(
249    log: &Arc<AnyEventLog>,
250    id: &str,
251    completed_by: Option<String>,
252) -> Result<WaitpointRecord, LogError> {
253    let existing = load_waitpoint(log, id).await?;
254    if let Some(existing) = existing.as_ref() {
255        if existing.is_terminal() {
256            return Ok(existing.clone());
257        }
258    }
259
260    let now = now_rfc3339();
261    let mut record = existing.unwrap_or_else(|| WaitpointRecord {
262        id: id.to_string(),
263        status: WaitpointStatus::Open,
264        created_at: now.clone(),
265        created_by: completed_by.clone(),
266        completed_at: None,
267        completed_by: None,
268        cancelled_at: None,
269        cancelled_by: None,
270        reason: None,
271        metadata: BTreeMap::new(),
272    });
273    record.status = WaitpointStatus::Completed;
274    record.completed_at = Some(now);
275    record.completed_by = completed_by;
276    record.cancelled_at = None;
277    record.cancelled_by = None;
278    record.reason = None;
279    append_waitpoint_state(log, "waitpoint_completed", &record).await?;
280    Ok(record)
281}
282
283pub async fn cancel_waitpoint(
284    log: &Arc<AnyEventLog>,
285    id: &str,
286    cancelled_by: Option<String>,
287    reason: Option<String>,
288) -> Result<WaitpointRecord, LogError> {
289    let existing = load_waitpoint(log, id).await?;
290    if let Some(existing) = existing.as_ref() {
291        if existing.is_terminal() {
292            return Ok(existing.clone());
293        }
294    }
295
296    let now = now_rfc3339();
297    let mut record = existing.unwrap_or_else(|| WaitpointRecord {
298        id: id.to_string(),
299        status: WaitpointStatus::Open,
300        created_at: now.clone(),
301        created_by: cancelled_by.clone(),
302        completed_at: None,
303        completed_by: None,
304        cancelled_at: None,
305        cancelled_by: None,
306        reason: None,
307        metadata: BTreeMap::new(),
308    });
309    record.status = WaitpointStatus::Cancelled;
310    record.completed_at = None;
311    record.completed_by = None;
312    record.cancelled_at = Some(now);
313    record.cancelled_by = cancelled_by;
314    record.reason = reason;
315    append_waitpoint_state(log, "waitpoint_cancelled", &record).await?;
316    Ok(record)
317}
318
319pub async fn append_wait_started(
320    log: &Arc<AnyEventLog>,
321    record: &WaitpointWaitStartRecord,
322) -> Result<(), LogError> {
323    log.append(
324        &waits_topic()?,
325        LogEvent::new(
326            "waitpoint_wait_started",
327            serde_json::to_value(record).map_err(|error| {
328                LogError::Serde(format!("waitpoint wait encode error: {error}"))
329            })?,
330        )
331        .with_headers(wait_headers(&record.wait_id, &record.waitpoint_ids)),
332    )
333    .await
334    .map(|_| ())
335}
336
337pub async fn append_wait_terminal(
338    log: &Arc<AnyEventLog>,
339    record: &WaitpointWaitRecord,
340) -> Result<(), LogError> {
341    log.append(
342        &waits_topic()?,
343        LogEvent::new(
344            record.status.event_kind(),
345            serde_json::to_value(record).map_err(|error| {
346                LogError::Serde(format!("waitpoint wait encode error: {error}"))
347            })?,
348        )
349        .with_headers(wait_headers(&record.wait_id, &record.waitpoint_ids)),
350    )
351    .await
352    .map(|_| ())
353}
354
355pub async fn find_wait_terminal(
356    log: &Arc<AnyEventLog>,
357    wait_id: &str,
358) -> Result<Option<WaitpointWaitRecord>, LogError> {
359    let events = log.read_range(&waits_topic()?, None, usize::MAX).await?;
360    let mut latest = None;
361    for (_, event) in events {
362        if !matches!(
363            event.kind.as_str(),
364            "waitpoint_wait_completed"
365                | "waitpoint_wait_cancelled"
366                | "waitpoint_wait_timed_out"
367                | "waitpoint_wait_interrupted"
368        ) {
369            continue;
370        }
371        if event.headers.get("wait_id").map(String::as_str) != Some(wait_id) {
372            continue;
373        }
374        let Ok(record) = serde_json::from_value::<WaitpointWaitRecord>(event.payload) else {
375            continue;
376        };
377        latest = Some(record);
378    }
379    Ok(latest)
380}
381
382async fn append_waitpoint_state(
383    log: &Arc<AnyEventLog>,
384    kind: &str,
385    record: &WaitpointRecord,
386) -> Result<(), LogError> {
387    log.append(
388        &waitpoint_topic(&record.id)?,
389        LogEvent::new(
390            kind,
391            serde_json::to_value(record)
392                .map_err(|error| LogError::Serde(format!("waitpoint encode error: {error}")))?,
393        )
394        .with_headers(waitpoint_headers(record)),
395    )
396    .await
397    .map(|_| ())
398}
399
400fn wait_headers(wait_id: &str, waitpoint_ids: &[String]) -> BTreeMap<String, String> {
401    let mut headers = BTreeMap::new();
402    headers.insert("wait_id".to_string(), wait_id.to_string());
403    headers.insert("waitpoints".to_string(), waitpoint_ids.join(","));
404    headers
405}
406
407fn waitpoint_headers(record: &WaitpointRecord) -> BTreeMap<String, String> {
408    let mut headers = BTreeMap::new();
409    headers.insert("waitpoint_id".to_string(), record.id.clone());
410    headers.insert("status".to_string(), record.status.as_str().to_string());
411    if let Some(created_by) = record.created_by.as_ref() {
412        headers.insert("created_by".to_string(), created_by.clone());
413    }
414    if let Some(completed_by) = record.completed_by.as_ref() {
415        headers.insert("completed_by".to_string(), completed_by.clone());
416    }
417    if let Some(cancelled_by) = record.cancelled_by.as_ref() {
418        headers.insert("cancelled_by".to_string(), cancelled_by.clone());
419    }
420    headers
421}
422
423fn now_rfc3339() -> String {
424    OffsetDateTime::now_utc()
425        .format(&Rfc3339)
426        .unwrap_or_else(|_| OffsetDateTime::now_utc().to_string())
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::event_log::{FileEventLog, MemoryEventLog};
433
434    #[tokio::test]
435    async fn waitpoint_state_persists_across_file_reopen() {
436        let dir = tempfile::tempdir().expect("tempdir");
437        let first = Arc::new(AnyEventLog::File(
438            FileEventLog::open(dir.path().to_path_buf(), 32).expect("open file log"),
439        ));
440        create_waitpoint(&first, "demo", Some("creator".to_string()), BTreeMap::new())
441            .await
442            .expect("create waitpoint");
443        complete_waitpoint(&first, "demo", Some("completer".to_string()))
444            .await
445            .expect("complete waitpoint");
446
447        let reopened = Arc::new(AnyEventLog::File(
448            FileEventLog::open(dir.path().to_path_buf(), 32).expect("reopen file log"),
449        ));
450        let state = load_waitpoint(&reopened, "demo")
451            .await
452            .expect("load state")
453            .expect("waitpoint exists");
454        assert_eq!(state.status, WaitpointStatus::Completed);
455        assert_eq!(state.completed_by.as_deref(), Some("completer"));
456    }
457
458    #[tokio::test]
459    async fn wait_terminal_lookup_returns_latest_terminal_record() {
460        let log = Arc::new(AnyEventLog::Memory(MemoryEventLog::new(32)));
461        append_wait_started(
462            &log,
463            &WaitpointWaitStartRecord {
464                wait_id: "wait-demo".to_string(),
465                waitpoint_ids: vec!["a".to_string(), "b".to_string()],
466                started_at: "2026-01-01T00:00:00Z".to_string(),
467                trace_id: Some("trace-demo".to_string()),
468                replay_of_event_id: None,
469            },
470        )
471        .await
472        .expect("append wait start");
473        append_wait_terminal(
474            &log,
475            &WaitpointWaitRecord {
476                wait_id: "wait-demo".to_string(),
477                waitpoint_ids: vec!["a".to_string(), "b".to_string()],
478                status: WaitpointWaitStatus::TimedOut,
479                started_at: "2026-01-01T00:00:00Z".to_string(),
480                resolved_at: "2026-01-01T00:01:00Z".to_string(),
481                waitpoints: Vec::new(),
482                cancelled_waitpoint_id: None,
483                trace_id: Some("trace-demo".to_string()),
484                replay_of_event_id: None,
485                reason: Some("deadline elapsed".to_string()),
486            },
487        )
488        .await
489        .expect("append wait result");
490
491        let record = find_wait_terminal(&log, "wait-demo")
492            .await
493            .expect("lookup wait result")
494            .expect("wait result exists");
495        assert_eq!(record.status, WaitpointWaitStatus::TimedOut);
496        assert_eq!(record.reason.as_deref(), Some("deadline elapsed"));
497    }
498}