1use std::{
15 path::{Path, PathBuf},
16 time::{SystemTime, UNIX_EPOCH},
17};
18
19use grpc::heddle::v1::{
20 AbortTransactionRequest, AbortTransactionResponse, BeginTransactionRequest,
21 BeginTransactionResponse, CommitTransactionRequest, CommitTransactionResponse,
22 GetTransactionStatusRequest, TransactionStatus, transaction_service_server::TransactionService,
23};
24use objects::{
25 fs_atomic::write_file_atomic,
26 object::{ChangeId, OperationId, ThreadName},
27};
28use oplog::OpRecord;
29use prost::Message;
30use serde::{Deserialize, Serialize};
31use tonic::{Request, Response, Status};
32
33use super::{GrpcLocalService, to_status, with_idempotency};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
41struct TransactionSentinel {
42 transaction_id: String,
43 repo_path: String,
44 thread: String,
45 message: String,
46 state: String,
48 started_at_secs: i64,
49 started_by_email: String,
50 base_state: String,
52 buffered_ops: Vec<String>,
56 aborted_reason: Option<String>,
58}
59
60const STATE_ACTIVE: &str = "active";
61const STATE_COMMITTED: &str = "committed";
62const STATE_ABORTED: &str = "aborted";
63
64fn parse_transaction_id(raw: &str) -> Result<OperationId, Status> {
65 let transaction_id: OperationId = raw
66 .parse()
67 .map_err(|err| Status::invalid_argument(format!("invalid transaction_id: {err}")))?;
68 if transaction_id.to_string() != raw {
69 return Err(Status::invalid_argument(
70 "invalid transaction_id: expected canonical UUID",
71 ));
72 }
73 Ok(transaction_id)
74}
75
76fn sentinel_path(repo: &repo::Repository, transaction_id: &OperationId) -> PathBuf {
81 repo.heddle_dir()
82 .join("state")
83 .join("transactions")
84 .join(format!("{transaction_id}.toml"))
85}
86
87fn load_sentinel(path: &Path) -> Result<TransactionSentinel, Status> {
90 let bytes = std::fs::read(path).map_err(|err| {
91 if err.kind() == std::io::ErrorKind::NotFound {
92 Status::not_found(format!(
93 "transaction sentinel not found at {}",
94 path.display()
95 ))
96 } else {
97 Status::internal(format!("read sentinel failed: {err}"))
98 }
99 })?;
100 let text = std::str::from_utf8(&bytes)
101 .map_err(|err| Status::internal(format!("sentinel not utf8: {err}")))?;
102 toml::from_str(text).map_err(|err| Status::internal(format!("sentinel parse failed: {err}")))
103}
104
105fn save_sentinel(path: &Path, sentinel: &TransactionSentinel) -> Result<(), Status> {
107 let serialized = toml::to_string_pretty(sentinel)
108 .map_err(|err| Status::internal(format!("sentinel serialize failed: {err}")))?;
109 write_file_atomic(path, serialized.as_bytes())
110 .map_err(|err| Status::internal(format!("sentinel write failed: {err}")))
111}
112
113fn now_secs() -> i64 {
115 SystemTime::now()
116 .duration_since(UNIX_EPOCH)
117 .map(|d| d.as_secs() as i64)
118 .unwrap_or(0)
119}
120
121#[derive(Clone)]
126pub struct LocalTransactionService {
127 inner: GrpcLocalService,
128}
129
130impl LocalTransactionService {
131 pub fn new(inner: GrpcLocalService) -> Self {
132 Self { inner }
133 }
134}
135
136#[tonic::async_trait]
137impl TransactionService for LocalTransactionService {
138 async fn begin_transaction(
139 &self,
140 request: Request<BeginTransactionRequest>,
141 ) -> Result<Response<BeginTransactionResponse>, Status> {
142 let req = request.into_inner();
143 let request_body = req.encode_to_vec();
144 let client_op = req.client_operation_id.clone();
145 let inner = self.inner.clone();
146
147 let response = with_idempotency(
148 &self.inner,
149 &client_op,
150 "TransactionService.BeginTransaction",
151 &request_body,
152 move || async move {
153 let repo = inner.repo();
154
155 let base_change_id = if !req.thread.is_empty() {
160 repo.refs()
161 .get_thread(&ThreadName::from(req.thread.as_str()))
162 .map_err(to_status)?
163 } else {
164 repo.head().map_err(to_status)?
165 };
166 let base_state = base_change_id
167 .ok_or_else(|| {
168 Status::failed_precondition(
169 "cannot begin transaction: no base state (repository has no snapshots)",
170 )
171 })?
172 .to_string_full();
173
174 let started_by_email = repo.get_principal().map(|p| p.email).unwrap_or_default();
175 let started_at_secs = now_secs();
176 let transaction_id = OperationId::new();
177 let transaction_id_wire = transaction_id.to_string();
178
179 let sentinel = TransactionSentinel {
180 transaction_id: transaction_id_wire.clone(),
181 repo_path: req.repo_path.clone(),
182 thread: req.thread.clone(),
183 message: req.message.clone(),
184 state: STATE_ACTIVE.to_string(),
185 started_at_secs,
186 started_by_email,
187 base_state,
188 buffered_ops: Vec::new(),
189 aborted_reason: None,
190 };
191 let path = sentinel_path(repo, &transaction_id);
192 save_sentinel(&path, &sentinel)?;
193
194 Ok(BeginTransactionResponse {
195 transaction_id: transaction_id_wire,
196 started_at: Some(prost_types::Timestamp {
197 seconds: started_at_secs,
198 nanos: 0,
199 }),
200 })
201 },
202 )
203 .await?;
204
205 Ok(Response::new(response))
206 }
207
208 async fn commit_transaction(
209 &self,
210 request: Request<CommitTransactionRequest>,
211 ) -> Result<Response<CommitTransactionResponse>, Status> {
212 let req = request.into_inner();
213 let transaction_id = parse_transaction_id(&req.transaction_id)?;
214 let request_body = req.encode_to_vec();
215 let client_op = req.client_operation_id.clone();
216 let inner = self.inner.clone();
217
218 let response = with_idempotency(
219 &self.inner,
220 &client_op,
221 "TransactionService.CommitTransaction",
222 &request_body,
223 move || async move {
224 let repo = inner.repo();
225 let path = sentinel_path(repo, &transaction_id);
226 let mut sentinel = load_sentinel(&path)?;
227
228 if sentinel.state != STATE_ACTIVE {
229 return Err(Status::failed_precondition(format!(
230 "transaction already {}",
231 sentinel.state
232 )));
233 }
234
235 let op_count = sentinel.buffered_ops.len() as u32;
246 let transaction_id = sentinel.transaction_id.clone();
247 sentinel.state = STATE_COMMITTED.to_string();
248 sentinel.buffered_ops.clear();
249 save_sentinel(&path, &sentinel)?;
250
251 if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionCommit {
252 transaction_id,
253 op_count,
254 }]) {
255 tracing::warn!(error = %err, txn = %sentinel.transaction_id,
256 "transaction-service: failed to record TransactionCommit");
257 }
258
259 Ok(CommitTransactionResponse {
260 state_id: ChangeId::parse(&sentinel.base_state)
263 .map(|id| id.as_bytes().to_vec())
264 .unwrap_or_default(),
265 op_count,
266 })
267 },
268 )
269 .await?;
270
271 Ok(Response::new(response))
272 }
273
274 async fn abort_transaction(
275 &self,
276 request: Request<AbortTransactionRequest>,
277 ) -> Result<Response<AbortTransactionResponse>, Status> {
278 let req = request.into_inner();
279 let transaction_id = parse_transaction_id(&req.transaction_id)?;
280 let request_body = req.encode_to_vec();
281 let client_op = req.client_operation_id.clone();
282 let inner = self.inner.clone();
283
284 let response = with_idempotency(
285 &self.inner,
286 &client_op,
287 "TransactionService.AbortTransaction",
288 &request_body,
289 move || async move {
290 let repo = inner.repo();
291 let path = sentinel_path(repo, &transaction_id);
292 let mut sentinel = load_sentinel(&path)?;
293
294 if sentinel.state != STATE_ACTIVE {
295 return Err(Status::failed_precondition(format!(
296 "transaction already {}",
297 sentinel.state
298 )));
299 }
300
301 let reason = if req.reason.is_empty() {
302 None
303 } else {
304 Some(req.reason.clone())
305 };
306 let transaction_id = sentinel.transaction_id.clone();
307 sentinel.state = STATE_ABORTED.to_string();
308 sentinel.aborted_reason = reason.clone();
309 sentinel.buffered_ops.clear();
313 save_sentinel(&path, &sentinel)?;
314
315 if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionAbort {
320 transaction_id,
321 reason: reason.unwrap_or_default(),
322 }]) {
323 tracing::warn!(error = %err, txn = %sentinel.transaction_id,
324 "transaction-service: failed to record TransactionAbort");
325 }
326
327 Ok(AbortTransactionResponse { aborted: true })
328 },
329 )
330 .await?;
331
332 Ok(Response::new(response))
333 }
334
335 async fn get_transaction_status(
336 &self,
337 request: Request<GetTransactionStatusRequest>,
338 ) -> Result<Response<TransactionStatus>, Status> {
339 let req = request.into_inner();
340 let transaction_id = parse_transaction_id(&req.transaction_id)?;
341 let repo = self.inner.repo();
342 let path = sentinel_path(repo, &transaction_id);
343 let sentinel = load_sentinel(&path)?;
344
345 Ok(Response::new(TransactionStatus {
346 transaction_id: sentinel.transaction_id,
347 state: sentinel.state,
348 started_at: Some(prost_types::Timestamp {
349 seconds: sentinel.started_at_secs,
350 nanos: 0,
351 }),
352 buffered_ops: sentinel.buffered_ops.len() as u32,
353 }))
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use std::{fs, path::Path, sync::Arc};
360
361 use oplog::OpLogBackend;
362 use repo::{Repository, operation_dedup::OperationDedupStore};
363 use tempfile::TempDir;
364
365 use super::*;
366
367 fn build_service() -> (TempDir, LocalTransactionService) {
370 let tmp = TempDir::new().expect("tempdir");
371 let repo = Repository::init_default(tmp.path()).expect("init repo");
372 assert!(repo.head().expect("head").is_some(), "head should be set");
375 let dedup = OperationDedupStore::open(repo.heddle_dir()).expect("dedup open");
376 let service = GrpcLocalService::new(Arc::new(repo), Arc::new(dedup));
377 (tmp, LocalTransactionService::new(service))
378 }
379
380 fn begin_req() -> BeginTransactionRequest {
381 BeginTransactionRequest {
382 repo_path: String::new(),
383 thread: String::new(),
384 message: "test txn".to_string(),
385 client_operation_id: String::new(),
386 }
387 }
388
389 fn parse_begin_id(raw: &str) -> OperationId {
390 raw.parse()
391 .expect("begin_transaction should return an OperationId")
392 }
393
394 fn legacy_string_sentinel_path(repo: &Repository, transaction_id: &str) -> PathBuf {
395 repo.heddle_dir()
396 .join("state")
397 .join("transactions")
398 .join(format!("{transaction_id}.toml"))
399 }
400
401 fn write_trap_sentinel(path: &Path) -> Vec<u8> {
402 if let Some(parent) = path.parent() {
403 fs::create_dir_all(parent).expect("trap parent");
404 }
405 let sentinel = TransactionSentinel {
406 transaction_id: OperationId::new().to_string(),
407 repo_path: "trap".to_string(),
408 thread: String::new(),
409 message: "must not be touched".to_string(),
410 state: STATE_ACTIVE.to_string(),
411 started_at_secs: 1,
412 started_by_email: "trap@example.com".to_string(),
413 base_state: ChangeId::from_bytes([0; 16]).to_string_full(),
414 buffered_ops: vec!["trap-op".to_string()],
415 aborted_reason: None,
416 };
417 let body = toml::to_string_pretty(&sentinel).expect("serialize trap");
418 fs::write(path, body.as_bytes()).expect("write trap");
419 body.into_bytes()
420 }
421
422 async fn assert_invalid_transaction_id_rejected(svc: &LocalTransactionService, raw: &str) {
423 let commit_err = svc
424 .commit_transaction(Request::new(CommitTransactionRequest {
425 repo_path: String::new(),
426 transaction_id: raw.to_string(),
427 client_operation_id: String::new(),
428 }))
429 .await
430 .expect_err("commit must reject invalid transaction_id");
431 assert_eq!(commit_err.code(), tonic::Code::InvalidArgument);
432 assert!(
433 commit_err.message().contains("invalid transaction_id"),
434 "unexpected commit error: {commit_err}"
435 );
436
437 let abort_err = svc
438 .abort_transaction(Request::new(AbortTransactionRequest {
439 repo_path: String::new(),
440 transaction_id: raw.to_string(),
441 reason: "nope".to_string(),
442 client_operation_id: String::new(),
443 }))
444 .await
445 .expect_err("abort must reject invalid transaction_id");
446 assert_eq!(abort_err.code(), tonic::Code::InvalidArgument);
447 assert!(
448 abort_err.message().contains("invalid transaction_id"),
449 "unexpected abort error: {abort_err}"
450 );
451
452 let status_err = svc
453 .get_transaction_status(Request::new(GetTransactionStatusRequest {
454 repo_path: String::new(),
455 transaction_id: raw.to_string(),
456 }))
457 .await
458 .expect_err("status must reject invalid transaction_id");
459 assert_eq!(status_err.code(), tonic::Code::InvalidArgument);
460 assert!(
461 status_err.message().contains("invalid transaction_id"),
462 "unexpected status error: {status_err}"
463 );
464 }
465
466 #[tokio::test]
467 #[serial_test::serial(process_global)]
468 async fn begin_creates_active_sentinel() {
469 let (_tmp, svc) = build_service();
470 let resp = svc
471 .begin_transaction(Request::new(begin_req()))
472 .await
473 .expect("begin")
474 .into_inner();
475 assert!(!resp.transaction_id.is_empty());
476 assert!(resp.started_at.as_ref().map(|t| t.seconds).unwrap_or(0) > 0);
477
478 let status = svc
479 .get_transaction_status(Request::new(GetTransactionStatusRequest {
480 repo_path: String::new(),
481 transaction_id: resp.transaction_id.clone(),
482 }))
483 .await
484 .expect("status")
485 .into_inner();
486 assert_eq!(status.state, STATE_ACTIVE);
487 assert_eq!(status.buffered_ops, 0);
488 }
489
490 #[tokio::test]
491 #[serial_test::serial(process_global)]
492 async fn commit_flips_state_to_committed() {
493 let (_tmp, svc) = build_service();
494 let begin = svc
495 .begin_transaction(Request::new(begin_req()))
496 .await
497 .expect("begin")
498 .into_inner();
499
500 let commit = svc
501 .commit_transaction(Request::new(CommitTransactionRequest {
502 repo_path: String::new(),
503 transaction_id: begin.transaction_id.clone(),
504 client_operation_id: String::new(),
505 }))
506 .await
507 .expect("commit")
508 .into_inner();
509 assert!(!commit.state_id.is_empty());
510 assert_eq!(commit.op_count, 0);
511
512 let status = svc
513 .get_transaction_status(Request::new(GetTransactionStatusRequest {
514 repo_path: String::new(),
515 transaction_id: begin.transaction_id,
516 }))
517 .await
518 .expect("status")
519 .into_inner();
520 assert_eq!(status.state, STATE_COMMITTED);
521 }
522
523 #[tokio::test]
524 #[serial_test::serial(process_global)]
525 async fn abort_records_reason() {
526 let (_tmp, svc) = build_service();
527 let begin = svc
528 .begin_transaction(Request::new(begin_req()))
529 .await
530 .expect("begin")
531 .into_inner();
532
533 let abort = svc
534 .abort_transaction(Request::new(AbortTransactionRequest {
535 repo_path: String::new(),
536 transaction_id: begin.transaction_id.clone(),
537 reason: "user cancelled".to_string(),
538 client_operation_id: String::new(),
539 }))
540 .await
541 .expect("abort")
542 .into_inner();
543 assert!(abort.aborted);
544
545 let transaction_id = parse_begin_id(&begin.transaction_id);
548 let path = sentinel_path(svc.inner.repo(), &transaction_id);
549 let sentinel = load_sentinel(&path).expect("load");
550 assert_eq!(sentinel.state, STATE_ABORTED);
551 assert_eq!(sentinel.aborted_reason.as_deref(), Some("user cancelled"));
552 }
553
554 #[tokio::test]
555 #[serial_test::serial(process_global)]
556 async fn commit_after_abort_returns_failed_precondition() {
557 let (_tmp, svc) = build_service();
558 let begin = svc
559 .begin_transaction(Request::new(begin_req()))
560 .await
561 .expect("begin")
562 .into_inner();
563 svc.abort_transaction(Request::new(AbortTransactionRequest {
564 repo_path: String::new(),
565 transaction_id: begin.transaction_id.clone(),
566 reason: String::new(),
567 client_operation_id: String::new(),
568 }))
569 .await
570 .expect("abort");
571
572 let err = svc
573 .commit_transaction(Request::new(CommitTransactionRequest {
574 repo_path: String::new(),
575 transaction_id: begin.transaction_id,
576 client_operation_id: String::new(),
577 }))
578 .await
579 .expect_err("commit must fail");
580 assert_eq!(err.code(), tonic::Code::FailedPrecondition);
581 }
582
583 #[tokio::test]
584 #[serial_test::serial(process_global)]
585 async fn get_status_returns_current_state() {
586 let (_tmp, svc) = build_service();
587 let begin = svc
588 .begin_transaction(Request::new(begin_req()))
589 .await
590 .expect("begin")
591 .into_inner();
592
593 let status = svc
594 .get_transaction_status(Request::new(GetTransactionStatusRequest {
595 repo_path: String::new(),
596 transaction_id: begin.transaction_id.clone(),
597 }))
598 .await
599 .expect("status")
600 .into_inner();
601 assert_eq!(status.transaction_id, begin.transaction_id);
602 assert_eq!(status.state, STATE_ACTIVE);
603 assert_eq!(status.started_at, begin.started_at);
604 }
605
606 #[tokio::test]
607 #[serial_test::serial(process_global)]
608 async fn commit_clears_buffered_ops_and_records_oplog_entry() {
609 let (_tmp, svc) = build_service();
610 let begin = svc
611 .begin_transaction(Request::new(begin_req()))
612 .await
613 .expect("begin")
614 .into_inner();
615
616 let transaction_id = parse_begin_id(&begin.transaction_id);
620 let path = sentinel_path(svc.inner.repo(), &transaction_id);
621 let mut sentinel = load_sentinel(&path).expect("load");
622 sentinel.buffered_ops = vec!["capture".into(), "merge".into()];
623 save_sentinel(&path, &sentinel).expect("save");
624
625 let before_tail_len = svc
628 .inner
629 .repo()
630 .oplog()
631 .recent(64)
632 .expect("oplog recent")
633 .len();
634
635 let commit = svc
636 .commit_transaction(Request::new(CommitTransactionRequest {
637 repo_path: String::new(),
638 transaction_id: begin.transaction_id.clone(),
639 client_operation_id: String::new(),
640 }))
641 .await
642 .expect("commit")
643 .into_inner();
644 assert_eq!(commit.op_count, 2, "wire response carries the count");
645
646 let after = load_sentinel(&path).expect("load after commit");
648 assert_eq!(after.state, STATE_COMMITTED);
649 assert!(
650 after.buffered_ops.is_empty(),
651 "commit must drain buffered_ops so a re-run cannot double-replay"
652 );
653
654 let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
657 assert!(
658 tail.len() > before_tail_len,
659 "commit should have appended at least one oplog entry"
660 );
661 let last = tail.last().expect("non-empty tail");
662 match &last.operation {
663 oplog::OpRecord::TransactionCommit {
664 transaction_id,
665 op_count,
666 } => {
667 assert_eq!(transaction_id, &begin.transaction_id);
668 assert_eq!(*op_count, 2);
669 }
670 other => panic!("expected TransactionCommit at oplog tail, got {other:?}"),
671 }
672 }
673
674 #[tokio::test]
675 #[serial_test::serial(process_global)]
676 async fn abort_records_oplog_entry_with_reason() {
677 let (_tmp, svc) = build_service();
678 let begin = svc
679 .begin_transaction(Request::new(begin_req()))
680 .await
681 .expect("begin")
682 .into_inner();
683
684 let before_tail_len = svc
685 .inner
686 .repo()
687 .oplog()
688 .recent(64)
689 .expect("oplog recent")
690 .len();
691
692 svc.abort_transaction(Request::new(AbortTransactionRequest {
693 repo_path: String::new(),
694 transaction_id: begin.transaction_id.clone(),
695 reason: "rollback please".to_string(),
696 client_operation_id: String::new(),
697 }))
698 .await
699 .expect("abort");
700
701 let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
702 assert!(
703 tail.len() > before_tail_len,
704 "abort should have appended at least one oplog entry"
705 );
706 let last = tail.last().expect("non-empty tail");
707 match &last.operation {
708 oplog::OpRecord::TransactionAbort {
709 transaction_id,
710 reason,
711 } => {
712 assert_eq!(transaction_id, &begin.transaction_id);
713 assert_eq!(reason, "rollback please");
714 }
715 other => panic!("expected TransactionAbort at oplog tail, got {other:?}"),
716 }
717 }
718
719 #[tokio::test]
720 #[serial_test::serial(process_global)]
721 async fn begin_idempotent_returns_same_transaction_id() {
722 let (_tmp, svc) = build_service();
723 let client_op = OperationId::new().to_string();
724
725 let mut req = begin_req();
726 req.client_operation_id = client_op.clone();
727
728 let first = svc
729 .begin_transaction(Request::new(req.clone()))
730 .await
731 .expect("begin1")
732 .into_inner();
733 let second = svc
734 .begin_transaction(Request::new(req))
735 .await
736 .expect("begin2")
737 .into_inner();
738 assert_eq!(first.transaction_id, second.transaction_id);
739 assert_eq!(first.started_at, second.started_at);
740 }
741
742 #[test]
743 #[serial_test::serial(process_global)]
744 fn sentinel_path_is_derived_from_operation_id() {
745 let (_tmp, svc) = build_service();
746 let transaction_id = OperationId::new();
747 let path = sentinel_path(svc.inner.repo(), &transaction_id);
748 let expected_file_name = format!("{transaction_id}.toml");
749
750 assert_eq!(
751 path.file_name().and_then(|name| name.to_str()),
752 Some(expected_file_name.as_str())
753 );
754 assert!(
755 path.starts_with(
756 svc.inner
757 .repo()
758 .heddle_dir()
759 .join("state")
760 .join("transactions")
761 )
762 );
763 }
764
765 #[tokio::test]
766 #[serial_test::serial(process_global)]
767 async fn invalid_transaction_ids_are_rejected_before_sentinel_path_io() {
768 let (tmp, svc) = build_service();
769 let absolute = tmp.path().join("outside-absolute").display().to_string();
770 let invalid_ids = [
771 "../../x".to_string(),
772 "a/b".to_string(),
773 "..".to_string(),
774 absolute,
775 String::new(),
776 ];
777
778 for raw in invalid_ids {
779 let trap_path = legacy_string_sentinel_path(svc.inner.repo(), &raw);
780 let before = write_trap_sentinel(&trap_path);
781
782 assert_invalid_transaction_id_rejected(&svc, &raw).await;
783
784 let after = fs::read(&trap_path).expect("trap should still exist");
785 assert_eq!(
786 after,
787 before,
788 "invalid transaction_id {raw:?} must not touch {}",
789 trap_path.display()
790 );
791 }
792 }
793}