#![cfg(feature = "postgres")]
use rustc_hash::FxHashMap as FastMap;
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::Barrier;
use weavegraph::{
channels::{
Channel,
errors::{ErrorEvent, ErrorScope, WeaveError},
},
message::Role,
runtimes::{Checkpoint, Checkpointer, PostgresCheckpointer, checkpointer_postgres::StepQuery},
state::VersionedState,
types::NodeKind as Kind,
};
#[path = "common/mod.rs"]
mod support;
use support::state_with_user;
fn postgres_test_url() -> String {
match std::env::var("WEAVEGRAPH_POSTGRES_TEST_URL") {
Ok(url) => url,
Err(_) => "postgresql://weavegraph:weavegraph@localhost:5432/weavegraph_test".to_owned(),
}
}
async fn postgres_checkpointer() -> PostgresCheckpointer {
let url = postgres_test_url();
let connection = PostgresCheckpointer::connect(&url).await;
match connection {
Ok(store) => store,
Err(error) => panic!("failed to connect to postgres at {url}: {error}"),
}
}
fn session_id(name: &str) -> String {
format!("wg-pg-{name}-{}", uuid::Uuid::new_v4())
}
fn state_with_entries(
prompt: &str,
entries: impl IntoIterator<Item = (&'static str, Value)>,
) -> VersionedState {
entries
.into_iter()
.fold(state_with_user(prompt), |mut draft, (field, value)| {
draft.extra.get_mut().insert(field.to_owned(), value);
draft
})
}
fn checkpoint_record(session: &str, step_number: u64, state: VersionedState) -> Checkpoint {
Checkpoint {
session_id: session.to_owned(),
step: step_number,
state,
frontier: Vec::from([Kind::End]),
versions_seen: FastMap::default(),
concurrency_limit: 2,
created_at: chrono::Utc::now(),
ran_nodes: vec![],
skipped_nodes: vec![],
updated_channels: vec![],
}
}
async fn latest_checkpoint(store: &PostgresCheckpointer, session: &str) -> Checkpoint {
let Ok(found) = store.load_latest(session).await else {
panic!("load_latest failed for {session}");
};
let Some(current) = found else {
panic!("expected persisted checkpoint for {session}");
};
current
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn checkpoint_state_and_metadata_survive_postgres_roundtrip() {
let store = postgres_checkpointer().await;
let session = session_id("checkpoint-state-roundtrip");
let mut persisted = checkpoint_record(
&session,
7,
state_with_entries(
"hello from postgres roundtrip",
[("counter", json!(7)), ("status", json!("persisted"))],
),
);
persisted.frontier = vec![Kind::Custom("resume-node".to_owned()), Kind::End];
persisted.concurrency_limit = 4;
persisted.ran_nodes = vec![Kind::Start];
persisted.skipped_nodes = vec![Kind::Custom("Skipped".to_owned())];
persisted.updated_channels = vec!["messages".to_owned(), "extra".to_owned()];
let mut seen_channels = FastMap::default();
seen_channels.insert("messages".to_owned(), 3);
seen_channels.insert("extra".to_owned(), 2);
persisted
.versions_seen
.insert("ResumeNode".to_owned(), seen_channels);
store.save(persisted).await.expect("save checkpoint");
let loaded = latest_checkpoint(&store, &session).await;
let persisted_message_version = loaded
.versions_seen
.get("ResumeNode")
.and_then(|table| table.get("messages"))
.copied();
assert_eq!(loaded.step, 7);
assert_eq!(loaded.concurrency_limit, 4);
assert_eq!(
loaded.frontier,
vec![Kind::Custom("resume-node".to_owned()), Kind::End]
);
assert_eq!(persisted_message_version, Some(3));
let messages = loaded.state.messages.snapshot();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, Role::User);
assert_eq!(messages[0].content, "hello from postgres roundtrip");
let extra = loaded.state.extra.snapshot();
assert_eq!(extra.get("counter"), Some(&json!(7)));
assert_eq!(extra.get("status"), Some(&json!("persisted")));
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn saved_sessions_appear_in_list_and_missing_session_returns_none() {
let store = postgres_checkpointer().await;
let created_sessions = [
session_id("listed-a"),
session_id("listed-b"),
session_id("listed-c"),
];
for name in &created_sessions {
let snapshot = checkpoint_record(name, 1, state_with_user("session listing"));
store.save(snapshot).await.expect("save checkpoint");
}
let sessions = store.list_sessions().await.expect("list sessions");
for wanted in created_sessions {
assert!(
sessions.iter().any(|listed| listed == &wanted),
"missing {wanted}"
);
}
let absent = session_id("missing-session");
let missing = store
.load_latest(&absent)
.await
.expect("load missing session");
assert!(missing.is_none());
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn step_query_paginates_results_newest_first() {
let store = postgres_checkpointer().await;
let session = session_id("step-query-pagination");
for recorded_step in 1_u64..=5 {
let mut snapshot = checkpoint_record(
&session,
recorded_step,
state_with_entries(
&format!("state recorded at step {recorded_step}"),
[
("step", json!(recorded_step)),
("bucket", json!(recorded_step % 2)),
],
),
);
snapshot.ran_nodes = if recorded_step % 2 == 0 {
vec![Kind::Start]
} else {
vec![Kind::Custom("worker".to_owned())]
};
snapshot.skipped_nodes = vec![Kind::End];
snapshot.updated_channels = vec!["messages".to_owned()];
store.save(snapshot).await.expect("save checkpoint");
}
let newest_query = StepQuery {
offset: Some(0),
limit: Some(2),
..StepQuery::default()
};
let newest_page = store
.query_steps(&session, newest_query)
.await
.expect("query newest page");
assert_eq!(newest_page.page_info.total_count, 5);
assert_eq!(newest_page.page_info.page_size, 2);
assert_eq!(newest_page.page_info.offset, 0);
assert!(newest_page.page_info.has_next_page);
assert_eq!(
newest_page
.checkpoints
.iter()
.map(|item| item.step)
.collect::<Vec<_>>(),
vec![5, 4]
);
let next_query = StepQuery {
offset: Some(2),
limit: Some(2),
..StepQuery::default()
};
let next_page = store
.query_steps(&session, next_query)
.await
.expect("query next page");
assert_eq!(
next_page
.checkpoints
.iter()
.map(|item| item.step)
.collect::<Vec<_>>(),
vec![3, 2]
);
assert_eq!(next_page.page_info.total_count, 5);
assert!(next_page.page_info.has_next_page);
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn error_events_survive_postgres_roundtrip() {
let store = postgres_checkpointer().await;
let session = session_id("error-events-roundtrip");
let mut state = state_with_user("error persistence");
state.errors.get_mut().push(
ErrorEvent::runner(
session.clone(),
3,
WeaveError::msg("boom").with_details(json!({"code": "E_BANG"})),
)
.with_tags(vec!["postgres".to_owned(), "roundtrip".to_owned()])
.with_context(json!({"retryable": false})),
);
let mut persisted = checkpoint_record(&session, 3, state);
persisted.updated_channels = vec!["errors".to_owned()];
store.save(persisted).await.expect("save checkpoint");
let loaded = latest_checkpoint(&store, &session).await;
let errors = loaded.state.errors.snapshot();
let [event] = &errors[..] else {
panic!("expected one error event");
};
assert_eq!(event.error.message, "boom");
assert_eq!(event.error.details, json!({"code": "E_BANG"}));
assert_eq!(event.tags, vec!["postgres", "roundtrip"]);
assert_eq!(event.context, json!({"retryable": false}));
assert!(matches!(event.scope, ErrorScope::Runner { step: 3, .. }));
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn duplicate_save_is_idempotent() {
let store = postgres_checkpointer().await;
let session = session_id("duplicate-save");
let snapshot = checkpoint_record(
&session,
1,
state_with_entries("same checkpoint twice", [("dedupe", json!(true))]),
);
store.save(snapshot.clone()).await.expect("first save");
store.save(snapshot).await.expect("second save");
let history = store
.query_steps(
&session,
StepQuery {
limit: Some(10),
..StepQuery::default()
},
)
.await
.expect("query deduplicated history");
let latest = latest_checkpoint(&store, &session).await;
assert_eq!(history.page_info.total_count, 1);
assert_eq!(history.checkpoints.len(), 1);
assert_eq!(history.checkpoints[0].step, latest.step);
assert_eq!(
latest.state.extra.snapshot().get("dedupe"),
Some(&json!(true))
);
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn save_with_stale_expected_step_is_rejected() {
let store = postgres_checkpointer().await;
let session = session_id("stale-expected-step");
let first = checkpoint_record(&session, 1, state_with_user("first write"));
store.save(first).await.expect("save first checkpoint");
let second = checkpoint_record(&session, 2, state_with_user("second write"));
store
.save_with_concurrency_check(second, Some(1))
.await
.expect("save second checkpoint with matching expected step");
let stale = checkpoint_record(&session, 3, state_with_user("stale write"));
let failure = store
.save_with_concurrency_check(stale, Some(1))
.await
.expect_err("stale expected step should fail");
assert!(failure.to_string().contains("concurrency conflict"));
let latest = latest_checkpoint(&store, &session).await;
let history = store
.query_steps(
&session,
StepQuery {
limit: Some(10),
..StepQuery::default()
},
)
.await
.expect("query history after rejected write");
assert_eq!(latest.step, 2);
assert_eq!(latest.state.messages.snapshot()[0].content, "second write");
assert_eq!(history.page_info.total_count, 2);
assert_eq!(
history
.checkpoints
.iter()
.map(|item| item.step)
.collect::<Vec<_>>(),
vec![2, 1]
);
}
#[tokio::test(worker_threads = 2, flavor = "multi_thread")]
async fn out_of_order_write_does_not_overwrite_higher_step() {
let store = postgres_checkpointer().await;
let session = session_id("out-of-order-write");
let highest = checkpoint_record(
&session,
9,
state_with_entries(
"higher step wins",
[("winner", json!("high-step")), ("step", json!(9))],
),
);
store.save(highest).await.expect("save higher step");
let late_lower = checkpoint_record(
&session,
3,
state_with_entries(
"lower step arrives later",
[("winner", json!("late-low-step")), ("step", json!(3))],
),
);
store.save(late_lower).await.expect("save lower step");
let latest = latest_checkpoint(&store, &session).await;
let history = store
.query_steps(
&session,
StepQuery {
limit: Some(10),
..StepQuery::default()
},
)
.await
.expect("query out-of-order history");
assert_eq!(latest.step, 9);
assert_eq!(
latest.state.extra.snapshot().get("winner"),
Some(&json!("high-step"))
);
assert_eq!(
latest.state.messages.snapshot()[0].content,
"higher step wins"
);
assert_eq!(history.page_info.total_count, 2);
assert_eq!(
history
.checkpoints
.iter()
.map(|item| item.step)
.collect::<Vec<_>>(),
vec![9, 3]
);
}
#[tokio::test(worker_threads = 4, flavor = "multi_thread")]
async fn concurrent_writers_only_one_wins_concurrency_check() {
let store = Arc::new(postgres_checkpointer().await);
let session = session_id("concurrent-writers");
let seed = checkpoint_record(&session, 1, state_with_user("seed step"));
store.save(seed).await.expect("save seed checkpoint");
let barrier = Arc::new(Barrier::new(3));
let first_store = Arc::clone(&store);
let first_barrier = Arc::clone(&barrier);
let first_session = session.clone();
let first_writer = tokio::spawn(async move {
let candidate = checkpoint_record(
&first_session,
2,
state_with_entries(
"writer one",
[("writer", json!("writer-one")), ("step", json!(2))],
),
);
first_barrier.wait().await;
first_store
.save_with_concurrency_check(candidate, Some(1))
.await
});
let second_store = Arc::clone(&store);
let second_barrier = Arc::clone(&barrier);
let second_session = session.clone();
let second_writer = tokio::spawn(async move {
let candidate = checkpoint_record(
&second_session,
2,
state_with_entries(
"writer two",
[("writer", json!("writer-two")), ("step", json!(2))],
),
);
second_barrier.wait().await;
second_store
.save_with_concurrency_check(candidate, Some(1))
.await
});
barrier.wait().await;
let first_outcome = first_writer.await.expect("join writer one");
let second_outcome = second_writer.await.expect("join writer two");
let successes = [first_outcome.is_ok(), second_outcome.is_ok()]
.into_iter()
.filter(|won| *won)
.count();
let latest = latest_checkpoint(&store, &session).await;
let history = store
.query_steps(
&session,
StepQuery {
limit: Some(10),
..StepQuery::default()
},
)
.await
.expect("query concurrent writer history");
let elected = latest.state.extra.snapshot().get("writer").cloned();
assert_eq!(successes, 1);
assert_eq!(latest.step, 2);
assert!(elected == Some(json!("writer-one")) || elected == Some(json!("writer-two")));
assert_eq!(history.page_info.total_count, 2);
}