1use std::{
15 path::{Path, PathBuf},
16 time::{SystemTime, UNIX_EPOCH},
17};
18
19use grpc::heddle::v1::{
20 AbortTransactionRequest, AbortTransactionResponse, BeginTransactionRequest,
21 BeginTransactionResponse, CommitTransactionRequest, CommitTransactionResponse,
22 GetTransactionStatusRequest, TransactionStatus, transaction_service_server::TransactionService,
23};
24use objects::{
25 fs_atomic::write_file_atomic,
26 object::{ChangeId, OperationId},
27};
28use oplog::OpRecord;
29use prost::Message;
30use serde::{Deserialize, Serialize};
31use tonic::{Request, Response, Status};
32
33use super::{GrpcLocalService, to_status, with_idempotency};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
41struct TransactionSentinel {
42 transaction_id: String,
43 repo_path: String,
44 thread: String,
45 message: String,
46 state: String,
48 started_at_secs: i64,
49 started_by_email: String,
50 base_state: String,
52 buffered_ops: Vec<String>,
56 aborted_reason: Option<String>,
58}
59
60const STATE_ACTIVE: &str = "active";
61const STATE_COMMITTED: &str = "committed";
62const STATE_ABORTED: &str = "aborted";
63
64fn sentinel_path(repo: &repo::Repository, transaction_id: &str) -> PathBuf {
66 repo.heddle_dir()
67 .join("state")
68 .join("transactions")
69 .join(format!("{transaction_id}.toml"))
70}
71
72fn load_sentinel(path: &Path) -> Result<TransactionSentinel, Status> {
75 let bytes = std::fs::read(path).map_err(|err| {
76 if err.kind() == std::io::ErrorKind::NotFound {
77 Status::not_found(format!(
78 "transaction sentinel not found at {}",
79 path.display()
80 ))
81 } else {
82 Status::internal(format!("read sentinel failed: {err}"))
83 }
84 })?;
85 let text = std::str::from_utf8(&bytes)
86 .map_err(|err| Status::internal(format!("sentinel not utf8: {err}")))?;
87 toml::from_str(text).map_err(|err| Status::internal(format!("sentinel parse failed: {err}")))
88}
89
90fn save_sentinel(path: &Path, sentinel: &TransactionSentinel) -> Result<(), Status> {
92 let serialized = toml::to_string_pretty(sentinel)
93 .map_err(|err| Status::internal(format!("sentinel serialize failed: {err}")))?;
94 write_file_atomic(path, serialized.as_bytes())
95 .map_err(|err| Status::internal(format!("sentinel write failed: {err}")))
96}
97
98fn now_secs() -> i64 {
100 SystemTime::now()
101 .duration_since(UNIX_EPOCH)
102 .map(|d| d.as_secs() as i64)
103 .unwrap_or(0)
104}
105
106#[derive(Clone)]
111pub struct LocalTransactionService {
112 inner: GrpcLocalService,
113}
114
115impl LocalTransactionService {
116 pub fn new(inner: GrpcLocalService) -> Self {
117 Self { inner }
118 }
119}
120
121#[tonic::async_trait]
122impl TransactionService for LocalTransactionService {
123 async fn begin_transaction(
124 &self,
125 request: Request<BeginTransactionRequest>,
126 ) -> Result<Response<BeginTransactionResponse>, Status> {
127 let req = request.into_inner();
128 let request_body = req.encode_to_vec();
129 let client_op = req.client_operation_id.clone();
130 let inner = self.inner.clone();
131
132 let response = with_idempotency(
133 self.inner.dedup(),
134 &client_op,
135 "TransactionService.BeginTransaction",
136 &request_body,
137 |resp: &BeginTransactionResponse| resp.encode_to_vec(),
138 |bytes| {
139 BeginTransactionResponse::decode(bytes.as_slice())
140 .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
141 },
142 move || async move {
143 let repo = inner.repo();
144
145 let base_change_id = if !req.thread.is_empty() {
150 repo.refs().get_thread(&req.thread).map_err(to_status)?
151 } else {
152 repo.head().map_err(to_status)?
153 };
154 let base_state = base_change_id
155 .ok_or_else(|| {
156 Status::failed_precondition(
157 "cannot begin transaction: no base state (repository has no snapshots)",
158 )
159 })?
160 .to_string_full();
161
162 let started_by_email = repo.get_principal().map(|p| p.email).unwrap_or_default();
163 let started_at_secs = now_secs();
164 let transaction_id = OperationId::new().to_string();
165
166 let sentinel = TransactionSentinel {
167 transaction_id: transaction_id.clone(),
168 repo_path: req.repo_path.clone(),
169 thread: req.thread.clone(),
170 message: req.message.clone(),
171 state: STATE_ACTIVE.to_string(),
172 started_at_secs,
173 started_by_email,
174 base_state,
175 buffered_ops: Vec::new(),
176 aborted_reason: None,
177 };
178 let path = sentinel_path(repo, &transaction_id);
179 save_sentinel(&path, &sentinel)?;
180
181 Ok(BeginTransactionResponse {
182 transaction_id,
183 started_at: Some(prost_types::Timestamp {
184 seconds: started_at_secs,
185 nanos: 0,
186 }),
187 })
188 },
189 )
190 .await?;
191
192 Ok(Response::new(response))
193 }
194
195 async fn commit_transaction(
196 &self,
197 request: Request<CommitTransactionRequest>,
198 ) -> Result<Response<CommitTransactionResponse>, Status> {
199 let req = request.into_inner();
200 let request_body = req.encode_to_vec();
201 let client_op = req.client_operation_id.clone();
202 let inner = self.inner.clone();
203
204 let response = with_idempotency(
205 self.inner.dedup(),
206 &client_op,
207 "TransactionService.CommitTransaction",
208 &request_body,
209 |resp: &CommitTransactionResponse| resp.encode_to_vec(),
210 |bytes| {
211 CommitTransactionResponse::decode(bytes.as_slice())
212 .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
213 },
214 move || async move {
215 let repo = inner.repo();
216 let path = sentinel_path(repo, &req.transaction_id);
217 let mut sentinel = load_sentinel(&path)?;
218
219 if sentinel.state != STATE_ACTIVE {
220 return Err(Status::failed_precondition(format!(
221 "transaction already {}",
222 sentinel.state
223 )));
224 }
225
226 let op_count = sentinel.buffered_ops.len() as u32;
232 let transaction_id = sentinel.transaction_id.clone();
233 sentinel.state = STATE_COMMITTED.to_string();
234 sentinel.buffered_ops.clear();
235 save_sentinel(&path, &sentinel)?;
236
237 if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionCommit {
238 transaction_id,
239 op_count,
240 }]) {
241 tracing::warn!(error = %err, txn = %sentinel.transaction_id,
242 "transaction-service: failed to record TransactionCommit");
243 }
244
245 Ok(CommitTransactionResponse {
246 state_id: ChangeId::parse(&sentinel.base_state)
249 .map(|id| id.as_bytes().to_vec())
250 .unwrap_or_default(),
251 op_count,
252 })
253 },
254 )
255 .await?;
256
257 Ok(Response::new(response))
258 }
259
260 async fn abort_transaction(
261 &self,
262 request: Request<AbortTransactionRequest>,
263 ) -> Result<Response<AbortTransactionResponse>, Status> {
264 let req = request.into_inner();
265 let request_body = req.encode_to_vec();
266 let client_op = req.client_operation_id.clone();
267 let inner = self.inner.clone();
268
269 let response = with_idempotency(
270 self.inner.dedup(),
271 &client_op,
272 "TransactionService.AbortTransaction",
273 &request_body,
274 |resp: &AbortTransactionResponse| resp.encode_to_vec(),
275 |bytes| {
276 AbortTransactionResponse::decode(bytes.as_slice())
277 .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
278 },
279 move || async move {
280 let repo = inner.repo();
281 let path = sentinel_path(repo, &req.transaction_id);
282 let mut sentinel = load_sentinel(&path)?;
283
284 if sentinel.state != STATE_ACTIVE {
285 return Err(Status::failed_precondition(format!(
286 "transaction already {}",
287 sentinel.state
288 )));
289 }
290
291 let reason = if req.reason.is_empty() {
292 None
293 } else {
294 Some(req.reason.clone())
295 };
296 let transaction_id = sentinel.transaction_id.clone();
297 sentinel.state = STATE_ABORTED.to_string();
298 sentinel.aborted_reason = reason.clone();
299 sentinel.buffered_ops.clear();
303 save_sentinel(&path, &sentinel)?;
304
305 if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionAbort {
310 transaction_id,
311 reason: reason.unwrap_or_default(),
312 }]) {
313 tracing::warn!(error = %err, txn = %sentinel.transaction_id,
314 "transaction-service: failed to record TransactionAbort");
315 }
316
317 Ok(AbortTransactionResponse { aborted: true })
318 },
319 )
320 .await?;
321
322 Ok(Response::new(response))
323 }
324
325 async fn get_transaction_status(
326 &self,
327 request: Request<GetTransactionStatusRequest>,
328 ) -> Result<Response<TransactionStatus>, Status> {
329 let req = request.into_inner();
330 let repo = self.inner.repo();
331 let path = sentinel_path(repo, &req.transaction_id);
332 let sentinel = load_sentinel(&path)?;
333
334 Ok(Response::new(TransactionStatus {
335 transaction_id: sentinel.transaction_id,
336 state: sentinel.state,
337 started_at: Some(prost_types::Timestamp {
338 seconds: sentinel.started_at_secs,
339 nanos: 0,
340 }),
341 buffered_ops: sentinel.buffered_ops.len() as u32,
342 }))
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use std::sync::Arc;
349
350 use repo::{Repository, operation_dedup::OperationDedupStore};
351 use tempfile::TempDir;
352
353 use super::*;
354
355 fn build_service() -> (TempDir, LocalTransactionService) {
358 let tmp = TempDir::new().expect("tempdir");
359 let repo = Repository::init_default(tmp.path()).expect("init repo");
360 assert!(repo.head().expect("head").is_some(), "head should be set");
363 let dedup = OperationDedupStore::open(repo.heddle_dir()).expect("dedup open");
364 let service = GrpcLocalService::new(Arc::new(repo), Arc::new(dedup));
365 (tmp, LocalTransactionService::new(service))
366 }
367
368 fn begin_req() -> BeginTransactionRequest {
369 BeginTransactionRequest {
370 repo_path: String::new(),
371 thread: String::new(),
372 message: "test txn".to_string(),
373 client_operation_id: String::new(),
374 }
375 }
376
377 #[tokio::test]
378 async fn begin_creates_active_sentinel() {
379 let (_tmp, svc) = build_service();
380 let resp = svc
381 .begin_transaction(Request::new(begin_req()))
382 .await
383 .expect("begin")
384 .into_inner();
385 assert!(!resp.transaction_id.is_empty());
386 assert!(resp.started_at.as_ref().map(|t| t.seconds).unwrap_or(0) > 0);
387
388 let status = svc
389 .get_transaction_status(Request::new(GetTransactionStatusRequest {
390 repo_path: String::new(),
391 transaction_id: resp.transaction_id.clone(),
392 }))
393 .await
394 .expect("status")
395 .into_inner();
396 assert_eq!(status.state, STATE_ACTIVE);
397 assert_eq!(status.buffered_ops, 0);
398 }
399
400 #[tokio::test]
401 async fn commit_flips_state_to_committed() {
402 let (_tmp, svc) = build_service();
403 let begin = svc
404 .begin_transaction(Request::new(begin_req()))
405 .await
406 .expect("begin")
407 .into_inner();
408
409 let commit = svc
410 .commit_transaction(Request::new(CommitTransactionRequest {
411 repo_path: String::new(),
412 transaction_id: begin.transaction_id.clone(),
413 client_operation_id: String::new(),
414 }))
415 .await
416 .expect("commit")
417 .into_inner();
418 assert!(!commit.state_id.is_empty());
419 assert_eq!(commit.op_count, 0);
420
421 let status = svc
422 .get_transaction_status(Request::new(GetTransactionStatusRequest {
423 repo_path: String::new(),
424 transaction_id: begin.transaction_id,
425 }))
426 .await
427 .expect("status")
428 .into_inner();
429 assert_eq!(status.state, STATE_COMMITTED);
430 }
431
432 #[tokio::test]
433 async fn abort_records_reason() {
434 let (_tmp, svc) = build_service();
435 let begin = svc
436 .begin_transaction(Request::new(begin_req()))
437 .await
438 .expect("begin")
439 .into_inner();
440
441 let abort = svc
442 .abort_transaction(Request::new(AbortTransactionRequest {
443 repo_path: String::new(),
444 transaction_id: begin.transaction_id.clone(),
445 reason: "user cancelled".to_string(),
446 client_operation_id: String::new(),
447 }))
448 .await
449 .expect("abort")
450 .into_inner();
451 assert!(abort.aborted);
452
453 let path = sentinel_path(svc.inner.repo(), &begin.transaction_id);
456 let sentinel = load_sentinel(&path).expect("load");
457 assert_eq!(sentinel.state, STATE_ABORTED);
458 assert_eq!(sentinel.aborted_reason.as_deref(), Some("user cancelled"));
459 }
460
461 #[tokio::test]
462 async fn commit_after_abort_returns_failed_precondition() {
463 let (_tmp, svc) = build_service();
464 let begin = svc
465 .begin_transaction(Request::new(begin_req()))
466 .await
467 .expect("begin")
468 .into_inner();
469 svc.abort_transaction(Request::new(AbortTransactionRequest {
470 repo_path: String::new(),
471 transaction_id: begin.transaction_id.clone(),
472 reason: String::new(),
473 client_operation_id: String::new(),
474 }))
475 .await
476 .expect("abort");
477
478 let err = svc
479 .commit_transaction(Request::new(CommitTransactionRequest {
480 repo_path: String::new(),
481 transaction_id: begin.transaction_id,
482 client_operation_id: String::new(),
483 }))
484 .await
485 .expect_err("commit must fail");
486 assert_eq!(err.code(), tonic::Code::FailedPrecondition);
487 }
488
489 #[tokio::test]
490 async fn get_status_returns_current_state() {
491 let (_tmp, svc) = build_service();
492 let begin = svc
493 .begin_transaction(Request::new(begin_req()))
494 .await
495 .expect("begin")
496 .into_inner();
497
498 let status = svc
499 .get_transaction_status(Request::new(GetTransactionStatusRequest {
500 repo_path: String::new(),
501 transaction_id: begin.transaction_id.clone(),
502 }))
503 .await
504 .expect("status")
505 .into_inner();
506 assert_eq!(status.transaction_id, begin.transaction_id);
507 assert_eq!(status.state, STATE_ACTIVE);
508 assert_eq!(status.started_at, begin.started_at);
509 }
510
511 #[tokio::test]
512 async fn commit_clears_buffered_ops_and_records_oplog_entry() {
513 let (_tmp, svc) = build_service();
514 let begin = svc
515 .begin_transaction(Request::new(begin_req()))
516 .await
517 .expect("begin")
518 .into_inner();
519
520 let path = sentinel_path(svc.inner.repo(), &begin.transaction_id);
524 let mut sentinel = load_sentinel(&path).expect("load");
525 sentinel.buffered_ops = vec!["capture".into(), "merge".into()];
526 save_sentinel(&path, &sentinel).expect("save");
527
528 let before_tail_len = svc
531 .inner
532 .repo()
533 .oplog()
534 .recent(64)
535 .expect("oplog recent")
536 .len();
537
538 let commit = svc
539 .commit_transaction(Request::new(CommitTransactionRequest {
540 repo_path: String::new(),
541 transaction_id: begin.transaction_id.clone(),
542 client_operation_id: String::new(),
543 }))
544 .await
545 .expect("commit")
546 .into_inner();
547 assert_eq!(commit.op_count, 2, "wire response carries the count");
548
549 let after = load_sentinel(&path).expect("load after commit");
551 assert_eq!(after.state, STATE_COMMITTED);
552 assert!(
553 after.buffered_ops.is_empty(),
554 "commit must drain buffered_ops so a re-run cannot double-replay"
555 );
556
557 let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
560 assert!(
561 tail.len() > before_tail_len,
562 "commit should have appended at least one oplog entry"
563 );
564 let last = tail.last().expect("non-empty tail");
565 match &last.operation {
566 oplog::OpRecord::TransactionCommit {
567 transaction_id,
568 op_count,
569 } => {
570 assert_eq!(transaction_id, &begin.transaction_id);
571 assert_eq!(*op_count, 2);
572 }
573 other => panic!("expected TransactionCommit at oplog tail, got {other:?}"),
574 }
575 }
576
577 #[tokio::test]
578 async fn abort_records_oplog_entry_with_reason() {
579 let (_tmp, svc) = build_service();
580 let begin = svc
581 .begin_transaction(Request::new(begin_req()))
582 .await
583 .expect("begin")
584 .into_inner();
585
586 let before_tail_len = svc
587 .inner
588 .repo()
589 .oplog()
590 .recent(64)
591 .expect("oplog recent")
592 .len();
593
594 svc.abort_transaction(Request::new(AbortTransactionRequest {
595 repo_path: String::new(),
596 transaction_id: begin.transaction_id.clone(),
597 reason: "rollback please".to_string(),
598 client_operation_id: String::new(),
599 }))
600 .await
601 .expect("abort");
602
603 let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
604 assert!(
605 tail.len() > before_tail_len,
606 "abort should have appended at least one oplog entry"
607 );
608 let last = tail.last().expect("non-empty tail");
609 match &last.operation {
610 oplog::OpRecord::TransactionAbort {
611 transaction_id,
612 reason,
613 } => {
614 assert_eq!(transaction_id, &begin.transaction_id);
615 assert_eq!(reason, "rollback please");
616 }
617 other => panic!("expected TransactionAbort at oplog tail, got {other:?}"),
618 }
619 }
620
621 #[tokio::test]
622 async fn begin_idempotent_returns_same_transaction_id() {
623 let (_tmp, svc) = build_service();
624 let client_op = OperationId::new().to_string();
625
626 let mut req = begin_req();
627 req.client_operation_id = client_op.clone();
628
629 let first = svc
630 .begin_transaction(Request::new(req.clone()))
631 .await
632 .expect("begin1")
633 .into_inner();
634 let second = svc
635 .begin_transaction(Request::new(req))
636 .await
637 .expect("begin2")
638 .into_inner();
639 assert_eq!(first.transaction_id, second.transaction_id);
640 assert_eq!(first.started_at, second.started_at);
641 }
642}