use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use std::sync::Arc;
use super::event_store::{AppendOptions, CanonicalEventDraft, EventStoreError};
use super::message::Message;
use super::outbox::OutboxError;
use super::storage::{RunRecord, RuntimeCheckpointStore, StorageError};
use crate::state::PersistedState;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct TransactionScopeId(String);
impl TransactionScopeId {
pub fn new(value: impl Into<String>) -> Result<Self, CommitError> {
let value = value.into();
if value.trim().is_empty() {
return Err(CommitError::Validation(
"transaction scope id must be non-empty".to_string(),
));
}
Ok(Self(value))
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
pub trait CanonicalEventStager: Send + Sync {
fn stage(&self, draft: CanonicalEventDraft);
}
#[derive(Debug, Clone, PartialEq)]
pub struct StagedCanonicalEvent {
pub draft: CanonicalEventDraft,
pub append_options: AppendOptions,
}
impl StagedCanonicalEvent {
#[must_use]
pub fn new(draft: CanonicalEventDraft) -> Self {
Self {
draft,
append_options: AppendOptions::default(),
}
}
#[must_use]
pub fn with_options(mut self, options: AppendOptions) -> Self {
self.append_options = options;
self
}
}
#[derive(Debug, Clone)]
pub struct ThreadCommit {
pub thread_id: String,
pub message_delta: Vec<Message>,
pub expected_message_count: Option<u64>,
pub run_projection: RunRecord,
pub thread_state_snapshot: Option<PersistedState>,
}
#[deprecated(since = "0.6.0", note = "Use `ThreadCommit`.")]
pub type Checkpoint = ThreadCommit;
impl ThreadCommit {
pub fn append_messages(
thread_id: impl Into<String>,
message_delta: Vec<Message>,
expected_message_count: Option<u64>,
run_projection: RunRecord,
) -> Self {
Self {
thread_id: thread_id.into(),
message_delta,
expected_message_count,
run_projection,
thread_state_snapshot: None,
}
}
#[deprecated(since = "0.6.0", note = "Use `append_messages`.")]
pub fn append(
thread_id: impl Into<String>,
message_delta: Vec<Message>,
expected_message_count: Option<u64>,
run_projection: RunRecord,
) -> Self {
Self::append_messages(
thread_id,
message_delta,
expected_message_count,
run_projection,
)
}
#[must_use]
pub fn with_thread_state_snapshot(mut self, thread_state: PersistedState) -> Self {
self.thread_state_snapshot = Some(thread_state);
self
}
#[deprecated(since = "0.6.0", note = "Use `with_thread_state_snapshot`.")]
#[must_use]
pub fn with_thread_state(self, thread_state: PersistedState) -> Self {
self.with_thread_state_snapshot(thread_state)
}
pub fn run_projection_only(thread_id: impl Into<String>, run_projection: RunRecord) -> Self {
Self::append_messages(thread_id, Vec::new(), None, run_projection)
}
#[deprecated(since = "0.6.0", note = "Use `run_projection_only`.")]
pub fn checkpoint_only(thread_id: impl Into<String>, run_projection: RunRecord) -> Self {
Self::run_projection_only(thread_id, run_projection)
}
pub fn validate(&self) -> Result<(), CommitError> {
if self.thread_id.trim().is_empty() {
return Err(CommitError::Validation(
"thread_id must be non-empty".to_string(),
));
}
if self.run_projection.thread_id != self.thread_id {
return Err(CommitError::Validation(format!(
"run_projection.thread_id '{}' must match thread commit thread_id '{}'",
self.run_projection.thread_id, self.thread_id
)));
}
if self.run_projection.run_id.trim().is_empty() {
return Err(CommitError::Validation(
"run_projection.run_id must be non-empty".to_string(),
));
}
if self.run_projection.agent_id.trim().is_empty() {
return Err(CommitError::Validation(
"run_projection.agent_id must be non-empty".to_string(),
));
}
if !self.message_delta.is_empty() && self.expected_message_count.is_none() {
return Err(CommitError::Validation(
"append with a non-empty message delta requires an expected_message_count guard"
.to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct ThreadCommitOutcome;
#[deprecated(since = "0.6.0", note = "Use `ThreadCommitOutcome`.")]
pub type CheckpointCommitOutcome = ThreadCommitOutcome;
#[derive(Debug, Error)]
pub enum CommitError {
#[error("validation error: {0}")]
Validation(String),
#[error("thread run store write failed: {0}")]
StoreWrite(#[from] StorageError),
#[error(
"message version conflict on thread '{thread_id}': expected {expected}, actual {actual}"
)]
MessageVersionConflict {
thread_id: String,
expected: u64,
actual: u64,
},
#[error("canonical event append failed: {0}")]
EventAppend(#[from] EventStoreError),
#[error("outbox insert failed: {0}")]
OutboxInsert(#[from] OutboxError),
#[error("commit failed: {0}")]
Commit(String),
#[error("transaction scope mismatch: {0}")]
ScopeMismatch(String),
}
impl CommitError {
#[must_use]
pub fn reclassify_append_conflict(self, thread_id: &str) -> Self {
match self {
CommitError::StoreWrite(StorageError::VersionConflict { expected, actual }) => {
CommitError::MessageVersionConflict {
thread_id: thread_id.to_string(),
expected,
actual,
}
}
other => other,
}
}
}
#[async_trait]
pub trait CommitCoordinator: Send + Sync {
fn scope(&self) -> TransactionScopeId;
fn thread_run_storage_identity(&self) -> Option<String> {
None
}
fn reader(&self) -> Arc<dyn RuntimeCheckpointStore>;
async fn commit_checkpoint(
&self,
plan: ThreadCommit,
) -> Result<ThreadCommitOutcome, CommitError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::contract::event_store::{
CanonicalEventDraft, CanonicalEventKind, EventScope, EventVisibility,
};
use serde_json::json;
fn sample_draft(kind: &str) -> CanonicalEventDraft {
let mut draft = CanonicalEventDraft::new(
vec![EventScope::thread("t-1"), EventScope::run("run-1")],
CanonicalEventKind::new(kind).unwrap(),
json!({"kind": kind}),
"test",
)
.unwrap();
draft.visibility = EventVisibility::Public;
draft
}
fn sample_run_record() -> crate::contract::storage::RunRecord {
crate::contract::storage::RunRecord {
run_id: "run-1".to_string(),
thread_id: "t-1".to_string(),
agent_id: "agent-1".to_string(),
resolution_id: None,
activation: None,
..Default::default()
}
}
#[test]
fn transaction_scope_id_rejects_blank() {
assert!(TransactionScopeId::new("").is_err());
assert!(TransactionScopeId::new(" ").is_err());
assert!(TransactionScopeId::new("pg::main").is_ok());
}
#[test]
fn staged_canonical_event_with_options_round_trip() {
let draft = sample_draft("RunStarted");
let opts = AppendOptions {
writer_id: Some("runtime".to_string()),
idempotency_key: Some("k-1".to_string()),
..Default::default()
};
let staged = StagedCanonicalEvent::new(draft.clone()).with_options(opts.clone());
assert_eq!(staged.draft, draft);
assert_eq!(staged.append_options, opts);
}
#[test]
fn plan_checkpoint_only_validates() {
let plan = ThreadCommit::run_projection_only("t-1", sample_run_record());
plan.validate().unwrap();
}
#[test]
fn plan_rejects_blank_thread_id() {
let mut run = sample_run_record();
run.thread_id = String::new();
let plan = ThreadCommit::run_projection_only("", run);
let err = plan.validate().unwrap_err();
assert!(matches!(err, CommitError::Validation(_)));
}
#[test]
fn plan_rejects_thread_run_mismatch() {
let mut run = sample_run_record();
run.thread_id = "other-thread".to_string();
let plan = ThreadCommit::run_projection_only("t-1", run);
let err = plan.validate().unwrap_err();
assert!(matches!(
err,
CommitError::Validation(message) if message.contains("run_projection.thread_id")
));
}
#[test]
fn plan_rejects_blank_run_id() {
let mut run = sample_run_record();
run.run_id = " ".to_string();
let plan = ThreadCommit::run_projection_only("t-1", run);
let err = plan.validate().unwrap_err();
assert!(matches!(
err,
CommitError::Validation(message) if message.contains("run_projection.run_id")
));
}
#[test]
fn plan_rejects_blank_agent_id() {
let mut run = sample_run_record();
run.agent_id.clear();
let plan = ThreadCommit::run_projection_only("t-1", run);
let err = plan.validate().unwrap_err();
assert!(matches!(
err,
CommitError::Validation(message) if message.contains("run_projection.agent_id")
));
}
#[test]
fn checkpoint_only_allows_empty_message_state_write() {
let plan = ThreadCommit::run_projection_only("t-1", sample_run_record());
assert_eq!(plan.expected_message_count, None);
assert!(plan.message_delta.is_empty());
plan.validate().unwrap();
}
#[test]
fn unguarded_append_of_non_empty_messages_is_rejected() {
let plan = ThreadCommit::append_messages(
"t-1",
vec![Message::user("a")],
None,
sample_run_record(),
);
let err = plan.validate().unwrap_err();
assert!(
matches!(&err, CommitError::Validation(message) if message.contains("expected_message_count")),
"expected message-count guard validation error, got {err:?}"
);
}
#[test]
fn append_plan_carries_delta_and_expected_version() {
let plan = ThreadCommit::append_messages(
"t-1",
vec![Message::user("hi")],
Some(3),
sample_run_record(),
);
assert_eq!(plan.expected_message_count, Some(3));
assert_eq!(plan.message_delta.len(), 1);
plan.validate().unwrap();
}
#[test]
fn state_only_checkpoint_accepts_none_version() {
let plan = ThreadCommit::append_messages("t-1", Vec::new(), None, sample_run_record());
assert_eq!(plan.expected_message_count, None);
plan.validate().unwrap();
}
#[test]
fn append_plan_still_validates_run_thread_match() {
let mut run = sample_run_record();
run.thread_id = "other-thread".to_string();
let plan = ThreadCommit::append_messages("t-1", Vec::new(), Some(0), run);
let err = plan.validate().unwrap_err();
assert!(matches!(
err,
CommitError::Validation(message) if message.contains("run_projection.thread_id")
));
}
#[test]
fn message_version_conflict_displays_thread_expected_actual() {
let err = CommitError::MessageVersionConflict {
thread_id: "t-1".to_string(),
expected: 2,
actual: 5,
};
let msg = err.to_string();
assert!(msg.contains("t-1"), "missing thread_id: {msg}");
assert!(msg.contains('2'), "missing expected: {msg}");
assert!(msg.contains('5'), "missing actual: {msg}");
}
}