Skip to main content

daemon/grpc_local_impl/
transaction.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Local `TransactionService`.
3//!
4//! Establishes the *shape* of transactions: a sentinel TOML file under
5//! `.heddle/state/transactions/<id>.toml` records that a transaction is
6//! active, who started it, what its base state is, and (eventually) which
7//! verbs it has buffered. Buffered-op wiring — actually routing CLI verbs
8//! through the open transaction so the sentinel collects an ordered list of
9//! operations — is follow-on work alongside the rewind-on-abort logic. For
10//! now the sentinel is a status object: callers can begin, observe, commit,
11//! or abort, but the worktree is not yet rewound on abort and `state_id` on
12//! commit is the *base* state, not a freshly produced one.
13
14use std::{
15    path::{Path, PathBuf},
16    time::{SystemTime, UNIX_EPOCH},
17};
18
19use grpc::heddle::v1::{
20    AbortTransactionRequest, AbortTransactionResponse, BeginTransactionRequest,
21    BeginTransactionResponse, CommitTransactionRequest, CommitTransactionResponse,
22    GetTransactionStatusRequest, TransactionStatus, transaction_service_server::TransactionService,
23};
24use objects::{
25    fs_atomic::write_file_atomic,
26    object::{ChangeId, OperationId},
27};
28use oplog::OpRecord;
29use prost::Message;
30use serde::{Deserialize, Serialize};
31use tonic::{Request, Response, Status};
32
33use super::{GrpcLocalService, to_status, with_idempotency};
34
35/// On-disk transaction sentinel.
36///
37/// Persisted at `<heddle_dir>/state/transactions/<transaction_id>.toml`. The
38/// sentinel's lifecycle mirrors the RPC surface: written on `begin`, mutated
39/// in place by `commit`/`abort`, and read by `get_status`.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct TransactionSentinel {
42    transaction_id: String,
43    repo_path: String,
44    thread: String,
45    message: String,
46    /// "active" | "committed" | "aborted".
47    state: String,
48    started_at_secs: i64,
49    started_by_email: String,
50    /// Full `ChangeId` at begin time, recorded so a future rewind has a target.
51    base_state: String,
52    /// Verb names buffered into the transaction. Empty for now — CLI verbs
53    /// do not yet route through the open transaction; that wiring is
54    /// follow-on work.
55    buffered_ops: Vec<String>,
56    /// Reason supplied via `AbortTransactionRequest::reason`.
57    aborted_reason: Option<String>,
58}
59
60const STATE_ACTIVE: &str = "active";
61const STATE_COMMITTED: &str = "committed";
62const STATE_ABORTED: &str = "aborted";
63
64/// Build the on-disk sentinel path for a transaction id.
65fn sentinel_path(repo: &repo::Repository, transaction_id: &str) -> PathBuf {
66    repo.heddle_dir()
67        .join("state")
68        .join("transactions")
69        .join(format!("{transaction_id}.toml"))
70}
71
72/// Read and parse the sentinel for `path`, mapping I/O and parse errors to
73/// `tonic::Status`.
74fn load_sentinel(path: &Path) -> Result<TransactionSentinel, Status> {
75    let bytes = std::fs::read(path).map_err(|err| {
76        if err.kind() == std::io::ErrorKind::NotFound {
77            Status::not_found(format!(
78                "transaction sentinel not found at {}",
79                path.display()
80            ))
81        } else {
82            Status::internal(format!("read sentinel failed: {err}"))
83        }
84    })?;
85    let text = std::str::from_utf8(&bytes)
86        .map_err(|err| Status::internal(format!("sentinel not utf8: {err}")))?;
87    toml::from_str(text).map_err(|err| Status::internal(format!("sentinel parse failed: {err}")))
88}
89
90/// Atomically write the sentinel to `path`.
91fn save_sentinel(path: &Path, sentinel: &TransactionSentinel) -> Result<(), Status> {
92    let serialized = toml::to_string_pretty(sentinel)
93        .map_err(|err| Status::internal(format!("sentinel serialize failed: {err}")))?;
94    write_file_atomic(path, serialized.as_bytes())
95        .map_err(|err| Status::internal(format!("sentinel write failed: {err}")))
96}
97
98/// Wall-clock seconds since the UNIX epoch.
99fn now_secs() -> i64 {
100    SystemTime::now()
101        .duration_since(UNIX_EPOCH)
102        .map(|d| d.as_secs() as i64)
103        .unwrap_or(0)
104}
105
106/// Local-mode `TransactionService` impl.
107///
108/// Wraps the shared [`GrpcLocalService`] state so the dedup store and
109/// repository handle are available to every RPC.
110#[derive(Clone)]
111pub struct LocalTransactionService {
112    inner: GrpcLocalService,
113}
114
115impl LocalTransactionService {
116    pub fn new(inner: GrpcLocalService) -> Self {
117        Self { inner }
118    }
119}
120
121#[tonic::async_trait]
122impl TransactionService for LocalTransactionService {
123    async fn begin_transaction(
124        &self,
125        request: Request<BeginTransactionRequest>,
126    ) -> Result<Response<BeginTransactionResponse>, Status> {
127        let req = request.into_inner();
128        let request_body = req.encode_to_vec();
129        let client_op = req.client_operation_id.clone();
130        let inner = self.inner.clone();
131
132        let response = with_idempotency(
133            self.inner.dedup(),
134            &client_op,
135            "TransactionService.BeginTransaction",
136            &request_body,
137            |resp: &BeginTransactionResponse| resp.encode_to_vec(),
138            |bytes| {
139                BeginTransactionResponse::decode(bytes.as_slice())
140                    .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
141            },
142            move || async move {
143                let repo = inner.repo();
144
145                // Resolve base_state from the request's thread (if non-empty)
146                // or from current HEAD. Either path can produce `None` if the
147                // repository has no snapshots yet — tests therefore seed at
148                // least one snapshot before calling `begin_transaction`.
149                let base_change_id = if !req.thread.is_empty() {
150                    repo.refs().get_thread(&req.thread).map_err(to_status)?
151                } else {
152                    repo.head().map_err(to_status)?
153                };
154                let base_state = base_change_id
155                    .ok_or_else(|| {
156                        Status::failed_precondition(
157                            "cannot begin transaction: no base state (repository has no snapshots)",
158                        )
159                    })?
160                    .to_string_full();
161
162                let started_by_email = repo.get_principal().map(|p| p.email).unwrap_or_default();
163                let started_at_secs = now_secs();
164                let transaction_id = OperationId::new().to_string();
165
166                let sentinel = TransactionSentinel {
167                    transaction_id: transaction_id.clone(),
168                    repo_path: req.repo_path.clone(),
169                    thread: req.thread.clone(),
170                    message: req.message.clone(),
171                    state: STATE_ACTIVE.to_string(),
172                    started_at_secs,
173                    started_by_email,
174                    base_state,
175                    buffered_ops: Vec::new(),
176                    aborted_reason: None,
177                };
178                let path = sentinel_path(repo, &transaction_id);
179                save_sentinel(&path, &sentinel)?;
180
181                Ok(BeginTransactionResponse {
182                    transaction_id,
183                    started_at: Some(prost_types::Timestamp {
184                        seconds: started_at_secs,
185                        nanos: 0,
186                    }),
187                })
188            },
189        )
190        .await?;
191
192        Ok(Response::new(response))
193    }
194
195    async fn commit_transaction(
196        &self,
197        request: Request<CommitTransactionRequest>,
198    ) -> Result<Response<CommitTransactionResponse>, Status> {
199        let req = request.into_inner();
200        let request_body = req.encode_to_vec();
201        let client_op = req.client_operation_id.clone();
202        let inner = self.inner.clone();
203
204        let response = with_idempotency(
205            self.inner.dedup(),
206            &client_op,
207            "TransactionService.CommitTransaction",
208            &request_body,
209            |resp: &CommitTransactionResponse| resp.encode_to_vec(),
210            |bytes| {
211                CommitTransactionResponse::decode(bytes.as_slice())
212                    .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
213            },
214            move || async move {
215                let repo = inner.repo();
216                let path = sentinel_path(repo, &req.transaction_id);
217                let mut sentinel = load_sentinel(&path)?;
218
219                if sentinel.state != STATE_ACTIVE {
220                    return Err(Status::failed_precondition(format!(
221                        "transaction already {}",
222                        sentinel.state
223                    )));
224                }
225
226                // Capture the buffered op count, drain the list so a
227                // re-run cannot double-replay, flip the sentinel, and
228                // append `OpRecord::TransactionCommit` to the oplog. Real
229                // per-op replay (executing the buffered verbs at commit
230                // time rather than at call time) is the next follow-on.
231                let op_count = sentinel.buffered_ops.len() as u32;
232                let transaction_id = sentinel.transaction_id.clone();
233                sentinel.state = STATE_COMMITTED.to_string();
234                sentinel.buffered_ops.clear();
235                save_sentinel(&path, &sentinel)?;
236
237                if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionCommit {
238                    transaction_id,
239                    op_count,
240                }]) {
241                    tracing::warn!(error = %err, txn = %sentinel.transaction_id,
242                        "transaction-service: failed to record TransactionCommit");
243                }
244
245                Ok(CommitTransactionResponse {
246                    // `base_state` is a hex-display string in the sentinel
247                    // file; decode back to bytes for the wire response.
248                    state_id: ChangeId::parse(&sentinel.base_state)
249                        .map(|id| id.as_bytes().to_vec())
250                        .unwrap_or_default(),
251                    op_count,
252                })
253            },
254        )
255        .await?;
256
257        Ok(Response::new(response))
258    }
259
260    async fn abort_transaction(
261        &self,
262        request: Request<AbortTransactionRequest>,
263    ) -> Result<Response<AbortTransactionResponse>, Status> {
264        let req = request.into_inner();
265        let request_body = req.encode_to_vec();
266        let client_op = req.client_operation_id.clone();
267        let inner = self.inner.clone();
268
269        let response = with_idempotency(
270            self.inner.dedup(),
271            &client_op,
272            "TransactionService.AbortTransaction",
273            &request_body,
274            |resp: &AbortTransactionResponse| resp.encode_to_vec(),
275            |bytes| {
276                AbortTransactionResponse::decode(bytes.as_slice())
277                    .map_err(|err| Status::internal(format!("decode replay failed: {err}")))
278            },
279            move || async move {
280                let repo = inner.repo();
281                let path = sentinel_path(repo, &req.transaction_id);
282                let mut sentinel = load_sentinel(&path)?;
283
284                if sentinel.state != STATE_ACTIVE {
285                    return Err(Status::failed_precondition(format!(
286                        "transaction already {}",
287                        sentinel.state
288                    )));
289                }
290
291                let reason = if req.reason.is_empty() {
292                    None
293                } else {
294                    Some(req.reason.clone())
295                };
296                let transaction_id = sentinel.transaction_id.clone();
297                sentinel.state = STATE_ABORTED.to_string();
298                sentinel.aborted_reason = reason.clone();
299                // Drain buffered ops on abort too — the abort is now
300                // the terminal state, and future reads of this sentinel
301                // shouldn't surface the list as still-pending work.
302                sentinel.buffered_ops.clear();
303                save_sentinel(&path, &sentinel)?;
304
305                // Record `OpRecord::TransactionAbort` so the abort shows
306                // up in the audit trail. Worktree rewind to `base_state`
307                // is follow-on work — today the worktree stays where the
308                // (still-execute-immediately) buffered verbs left it.
309                if let Err(err) = repo.oplog().record_batch(vec![OpRecord::TransactionAbort {
310                    transaction_id,
311                    reason: reason.unwrap_or_default(),
312                }]) {
313                    tracing::warn!(error = %err, txn = %sentinel.transaction_id,
314                        "transaction-service: failed to record TransactionAbort");
315                }
316
317                Ok(AbortTransactionResponse { aborted: true })
318            },
319        )
320        .await?;
321
322        Ok(Response::new(response))
323    }
324
325    async fn get_transaction_status(
326        &self,
327        request: Request<GetTransactionStatusRequest>,
328    ) -> Result<Response<TransactionStatus>, Status> {
329        let req = request.into_inner();
330        let repo = self.inner.repo();
331        let path = sentinel_path(repo, &req.transaction_id);
332        let sentinel = load_sentinel(&path)?;
333
334        Ok(Response::new(TransactionStatus {
335            transaction_id: sentinel.transaction_id,
336            state: sentinel.state,
337            started_at: Some(prost_types::Timestamp {
338                seconds: sentinel.started_at_secs,
339                nanos: 0,
340            }),
341            buffered_ops: sentinel.buffered_ops.len() as u32,
342        }))
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use std::sync::Arc;
349
350    use repo::{Repository, operation_dedup::OperationDedupStore};
351    use tempfile::TempDir;
352
353    use super::*;
354
355    /// Build a repository with at least one snapshot (so HEAD is non-empty)
356    /// and wrap it in a [`LocalTransactionService`] for direct RPC calls.
357    fn build_service() -> (TempDir, LocalTransactionService) {
358        let tmp = TempDir::new().expect("tempdir");
359        let repo = Repository::init_default(tmp.path()).expect("init repo");
360        // `init_default` already seeds the empty-tree snapshot on `main`, so
361        // HEAD resolves to a real ChangeId.
362        assert!(repo.head().expect("head").is_some(), "head should be set");
363        let dedup = OperationDedupStore::open(repo.heddle_dir()).expect("dedup open");
364        let service = GrpcLocalService::new(Arc::new(repo), Arc::new(dedup));
365        (tmp, LocalTransactionService::new(service))
366    }
367
368    fn begin_req() -> BeginTransactionRequest {
369        BeginTransactionRequest {
370            repo_path: String::new(),
371            thread: String::new(),
372            message: "test txn".to_string(),
373            client_operation_id: String::new(),
374        }
375    }
376
377    #[tokio::test]
378    async fn begin_creates_active_sentinel() {
379        let (_tmp, svc) = build_service();
380        let resp = svc
381            .begin_transaction(Request::new(begin_req()))
382            .await
383            .expect("begin")
384            .into_inner();
385        assert!(!resp.transaction_id.is_empty());
386        assert!(resp.started_at.as_ref().map(|t| t.seconds).unwrap_or(0) > 0);
387
388        let status = svc
389            .get_transaction_status(Request::new(GetTransactionStatusRequest {
390                repo_path: String::new(),
391                transaction_id: resp.transaction_id.clone(),
392            }))
393            .await
394            .expect("status")
395            .into_inner();
396        assert_eq!(status.state, STATE_ACTIVE);
397        assert_eq!(status.buffered_ops, 0);
398    }
399
400    #[tokio::test]
401    async fn commit_flips_state_to_committed() {
402        let (_tmp, svc) = build_service();
403        let begin = svc
404            .begin_transaction(Request::new(begin_req()))
405            .await
406            .expect("begin")
407            .into_inner();
408
409        let commit = svc
410            .commit_transaction(Request::new(CommitTransactionRequest {
411                repo_path: String::new(),
412                transaction_id: begin.transaction_id.clone(),
413                client_operation_id: String::new(),
414            }))
415            .await
416            .expect("commit")
417            .into_inner();
418        assert!(!commit.state_id.is_empty());
419        assert_eq!(commit.op_count, 0);
420
421        let status = svc
422            .get_transaction_status(Request::new(GetTransactionStatusRequest {
423                repo_path: String::new(),
424                transaction_id: begin.transaction_id,
425            }))
426            .await
427            .expect("status")
428            .into_inner();
429        assert_eq!(status.state, STATE_COMMITTED);
430    }
431
432    #[tokio::test]
433    async fn abort_records_reason() {
434        let (_tmp, svc) = build_service();
435        let begin = svc
436            .begin_transaction(Request::new(begin_req()))
437            .await
438            .expect("begin")
439            .into_inner();
440
441        let abort = svc
442            .abort_transaction(Request::new(AbortTransactionRequest {
443                repo_path: String::new(),
444                transaction_id: begin.transaction_id.clone(),
445                reason: "user cancelled".to_string(),
446                client_operation_id: String::new(),
447            }))
448            .await
449            .expect("abort")
450            .into_inner();
451        assert!(abort.aborted);
452
453        // Read the sentinel back via the loader to confirm `aborted_reason`
454        // round-trips through TOML.
455        let path = sentinel_path(svc.inner.repo(), &begin.transaction_id);
456        let sentinel = load_sentinel(&path).expect("load");
457        assert_eq!(sentinel.state, STATE_ABORTED);
458        assert_eq!(sentinel.aborted_reason.as_deref(), Some("user cancelled"));
459    }
460
461    #[tokio::test]
462    async fn commit_after_abort_returns_failed_precondition() {
463        let (_tmp, svc) = build_service();
464        let begin = svc
465            .begin_transaction(Request::new(begin_req()))
466            .await
467            .expect("begin")
468            .into_inner();
469        svc.abort_transaction(Request::new(AbortTransactionRequest {
470            repo_path: String::new(),
471            transaction_id: begin.transaction_id.clone(),
472            reason: String::new(),
473            client_operation_id: String::new(),
474        }))
475        .await
476        .expect("abort");
477
478        let err = svc
479            .commit_transaction(Request::new(CommitTransactionRequest {
480                repo_path: String::new(),
481                transaction_id: begin.transaction_id,
482                client_operation_id: String::new(),
483            }))
484            .await
485            .expect_err("commit must fail");
486        assert_eq!(err.code(), tonic::Code::FailedPrecondition);
487    }
488
489    #[tokio::test]
490    async fn get_status_returns_current_state() {
491        let (_tmp, svc) = build_service();
492        let begin = svc
493            .begin_transaction(Request::new(begin_req()))
494            .await
495            .expect("begin")
496            .into_inner();
497
498        let status = svc
499            .get_transaction_status(Request::new(GetTransactionStatusRequest {
500                repo_path: String::new(),
501                transaction_id: begin.transaction_id.clone(),
502            }))
503            .await
504            .expect("status")
505            .into_inner();
506        assert_eq!(status.transaction_id, begin.transaction_id);
507        assert_eq!(status.state, STATE_ACTIVE);
508        assert_eq!(status.started_at, begin.started_at);
509    }
510
511    #[tokio::test]
512    async fn commit_clears_buffered_ops_and_records_oplog_entry() {
513        let (_tmp, svc) = build_service();
514        let begin = svc
515            .begin_transaction(Request::new(begin_req()))
516            .await
517            .expect("begin")
518            .into_inner();
519
520        // Hand-write a couple of buffered ops onto the sentinel —
521        // mirrors what the CLI front-end does today
522        // (`append_op_to_active_for_thread`).
523        let path = sentinel_path(svc.inner.repo(), &begin.transaction_id);
524        let mut sentinel = load_sentinel(&path).expect("load");
525        sentinel.buffered_ops = vec!["capture".into(), "merge".into()];
526        save_sentinel(&path, &sentinel).expect("save");
527
528        // Snapshot the oplog tail length so we can pick out the entry
529        // commit_transaction appends.
530        let before_tail_len = svc
531            .inner
532            .repo()
533            .oplog()
534            .recent(64)
535            .expect("oplog recent")
536            .len();
537
538        let commit = svc
539            .commit_transaction(Request::new(CommitTransactionRequest {
540                repo_path: String::new(),
541                transaction_id: begin.transaction_id.clone(),
542                client_operation_id: String::new(),
543            }))
544            .await
545            .expect("commit")
546            .into_inner();
547        assert_eq!(commit.op_count, 2, "wire response carries the count");
548
549        // Sentinel: buffered list drained, state flipped.
550        let after = load_sentinel(&path).expect("load after commit");
551        assert_eq!(after.state, STATE_COMMITTED);
552        assert!(
553            after.buffered_ops.is_empty(),
554            "commit must drain buffered_ops so a re-run cannot double-replay"
555        );
556
557        // Oplog: a TransactionCommit entry pinned to this transaction id
558        // with the captured count is present in the tail.
559        let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
560        assert!(
561            tail.len() > before_tail_len,
562            "commit should have appended at least one oplog entry"
563        );
564        let last = tail.last().expect("non-empty tail");
565        match &last.operation {
566            oplog::OpRecord::TransactionCommit {
567                transaction_id,
568                op_count,
569            } => {
570                assert_eq!(transaction_id, &begin.transaction_id);
571                assert_eq!(*op_count, 2);
572            }
573            other => panic!("expected TransactionCommit at oplog tail, got {other:?}"),
574        }
575    }
576
577    #[tokio::test]
578    async fn abort_records_oplog_entry_with_reason() {
579        let (_tmp, svc) = build_service();
580        let begin = svc
581            .begin_transaction(Request::new(begin_req()))
582            .await
583            .expect("begin")
584            .into_inner();
585
586        let before_tail_len = svc
587            .inner
588            .repo()
589            .oplog()
590            .recent(64)
591            .expect("oplog recent")
592            .len();
593
594        svc.abort_transaction(Request::new(AbortTransactionRequest {
595            repo_path: String::new(),
596            transaction_id: begin.transaction_id.clone(),
597            reason: "rollback please".to_string(),
598            client_operation_id: String::new(),
599        }))
600        .await
601        .expect("abort");
602
603        let tail = svc.inner.repo().oplog().recent(64).expect("oplog recent");
604        assert!(
605            tail.len() > before_tail_len,
606            "abort should have appended at least one oplog entry"
607        );
608        let last = tail.last().expect("non-empty tail");
609        match &last.operation {
610            oplog::OpRecord::TransactionAbort {
611                transaction_id,
612                reason,
613            } => {
614                assert_eq!(transaction_id, &begin.transaction_id);
615                assert_eq!(reason, "rollback please");
616            }
617            other => panic!("expected TransactionAbort at oplog tail, got {other:?}"),
618        }
619    }
620
621    #[tokio::test]
622    async fn begin_idempotent_returns_same_transaction_id() {
623        let (_tmp, svc) = build_service();
624        let client_op = OperationId::new().to_string();
625
626        let mut req = begin_req();
627        req.client_operation_id = client_op.clone();
628
629        let first = svc
630            .begin_transaction(Request::new(req.clone()))
631            .await
632            .expect("begin1")
633            .into_inner();
634        let second = svc
635            .begin_transaction(Request::new(req))
636            .await
637            .expect("begin2")
638            .into_inner();
639        assert_eq!(first.transaction_id, second.transaction_id);
640        assert_eq!(first.started_at, second.started_at);
641    }
642}