1use std::sync::Arc;
2
3use anyhow::Context as _;
4use clawdb::{prelude::MergeStrategy, ClawDBError, ClawDBSession};
5use tonic::{metadata::MetadataValue, Code, Request, Response, Status};
6use uuid::Uuid;
7
8use crate::state::{AppState, PendingTransaction};
9
10pub mod proto {
11 tonic::include_proto!("clawdb.v1");
12 pub const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("clawdb_descriptor");
13}
14
15use proto::claw_db_service_server::ClawDbService;
16
17pub struct ClawDbServiceImpl {
18 state: Arc<AppState>,
19}
20
21impl ClawDbServiceImpl {
22 pub fn new(state: Arc<AppState>) -> Self {
23 Self { state }
24 }
25
26 async fn session_from_request<T>(
27 &self,
28 request: &Request<T>,
29 ) -> Result<(String, ClawDBSession), Status> {
30 let token = request
31 .metadata()
32 .get("x-claw-session")
33 .and_then(|value| value.to_str().ok())
34 .ok_or_else(|| Status::unauthenticated("missing session token"))?
35 .to_string();
36 let session = self
37 .state
38 .db
39 .validate_session(&token)
40 .await
41 .map_err(|_| Status::unauthenticated("invalid session token"))?;
42 Ok((token, session))
43 }
44
45 fn request_id<T>(request: &Request<T>) -> String {
46 request
47 .metadata()
48 .get("x-request-id")
49 .and_then(|value| value.to_str().ok())
50 .map(ToOwned::to_owned)
51 .unwrap_or_else(|| Uuid::new_v4().to_string())
52 }
53
54 fn set_request_id<T>(response: &mut Response<T>, request_id: &str) {
55 if let Ok(value) = MetadataValue::try_from(request_id) {
56 response.metadata_mut().insert("x-request-id", value);
57 }
58 }
59
60 fn observe(&self, method: &str, status: Code) {
61 let status_name = match status {
62 Code::Ok => "OK",
63 Code::Unauthenticated => "UNAUTHENTICATED",
64 Code::PermissionDenied => "PERMISSION_DENIED",
65 Code::FailedPrecondition => "FAILED_PRECONDITION",
66 Code::ResourceExhausted => "RESOURCE_EXHAUSTED",
67 Code::InvalidArgument => "INVALID_ARGUMENT",
68 Code::NotFound => "NOT_FOUND",
69 _ => "INTERNAL",
70 };
71 self.state.metrics.observe_grpc(method, status_name);
72 }
73
74 fn status_with_request_id(mut status: Status, request_id: &str) -> Status {
75 if let Ok(value) = MetadataValue::try_from(request_id) {
76 status.metadata_mut().insert("x-request-id", value);
77 }
78 status
79 }
80
81 fn map_error(&self, error: ClawDBError, request_id: &str) -> Status {
82 let status = match error {
83 ClawDBError::PermissionDenied(reason) => Status::permission_denied(reason),
84 ClawDBError::SessionInvalid => Status::unauthenticated("session_invalid"),
85 ClawDBError::ComponentDisabled(name) => {
86 Status::failed_precondition(format!("component_disabled:{name}"))
87 }
88 ClawDBError::Config(_) | ClawDBError::ComponentInit(_, _) => {
89 Status::internal(format!("internal error; request_id={request_id}"))
90 }
91 other => {
92 tracing::error!(request_id, error = %other, "gRPC handler failed");
93 Status::internal(format!("internal error; request_id={request_id}"))
94 }
95 };
96 Self::status_with_request_id(status, request_id)
97 }
98
99 fn response_with_request_id<T>(
100 &self,
101 method: &str,
102 mut response: Response<T>,
103 request_id: &str,
104 ) -> Response<T> {
105 Self::set_request_id(&mut response, request_id);
106 self.observe(method, Code::Ok);
107 response
108 }
109
110 fn parse_merge_strategy(strategy: &str) -> MergeStrategy {
111 match strategy.to_ascii_lowercase().as_str() {
112 "ours" => MergeStrategy::Ours,
113 "union" => MergeStrategy::Union,
114 "manual" => MergeStrategy::Manual,
115 _ => MergeStrategy::Theirs,
116 }
117 }
118}
119
120#[tonic::async_trait]
121impl ClawDbService for ClawDbServiceImpl {
122 async fn health(
123 &self,
124 request: Request<proto::HealthRequest>,
125 ) -> Result<Response<proto::HealthResponse>, Status> {
126 let request_id = Self::request_id(&request);
127 match self.state.db.health().await {
128 Ok(health) => Ok(self.response_with_request_id(
129 "Health",
130 Response::new(proto::HealthResponse {
131 ok: health.ok,
132 components: health.components,
133 uptime_secs: self.state.db.uptime_secs(),
134 request_id: request_id.clone(),
135 }),
136 &request_id,
137 )),
138 Err(error) => {
139 let status = self.map_error(error, &request_id);
140 self.observe("Health", status.code());
141 Err(status)
142 }
143 }
144 }
145
146 async fn create_session(
147 &self,
148 request: Request<proto::CreateSessionRequest>,
149 ) -> Result<Response<proto::CreateSessionResponse>, Status> {
150 let request_id = Self::request_id(&request);
151 if let Err(status) = self.session_from_request(&request).await {
152 self.observe("CreateSession", status.code());
153 return Err(Self::status_with_request_id(status, &request_id));
154 }
155 let inner = request.into_inner();
156 let agent_id = match Uuid::parse_str(&inner.agent_id) {
157 Ok(agent_id) => agent_id,
158 Err(_) => {
159 let status = Self::status_with_request_id(
160 Status::invalid_argument("invalid agent_id"),
161 &request_id,
162 );
163 self.observe("CreateSession", status.code());
164 return Err(status);
165 }
166 };
167 match self
168 .state
169 .db
170 .session(agent_id, &inner.role, inner.scopes)
171 .await
172 {
173 Ok(session) => Ok(self.response_with_request_id(
174 "CreateSession",
175 Response::new(proto::CreateSessionResponse {
176 id: session.id.to_string(),
177 token: session.token,
178 expires_at: session.expires_at.to_rfc3339(),
179 scopes: session.scopes,
180 request_id: request_id.clone(),
181 }),
182 &request_id,
183 )),
184 Err(error) => {
185 let status = self.map_error(error, &request_id);
186 self.observe("CreateSession", status.code());
187 Err(status)
188 }
189 }
190 }
191
192 async fn validate_session(
193 &self,
194 request: Request<proto::ValidateSessionRequest>,
195 ) -> Result<Response<proto::ValidateSessionResponse>, Status> {
196 let request_id = Self::request_id(&request);
197 match self.session_from_request(&request).await {
198 Ok((_, session)) => Ok(self.response_with_request_id(
199 "ValidateSession",
200 Response::new(proto::ValidateSessionResponse {
201 session_id: session.id.to_string(),
202 agent_id: session.agent_id.to_string(),
203 workspace_id: session.workspace_id.to_string(),
204 role: session.role,
205 scopes: session.scopes,
206 expires_at: session.expires_at.to_rfc3339(),
207 request_id: request_id.clone(),
208 }),
209 &request_id,
210 )),
211 Err(status) => {
212 let status = Self::status_with_request_id(status, &request_id);
213 self.observe("ValidateSession", status.code());
214 Err(status)
215 }
216 }
217 }
218
219 async fn revoke_session(
220 &self,
221 request: Request<proto::RevokeSessionRequest>,
222 ) -> Result<Response<proto::RevokeSessionResponse>, Status> {
223 let request_id = Self::request_id(&request);
224 if let Err(status) = self.session_from_request(&request).await {
225 self.observe("RevokeSession", status.code());
226 return Err(Self::status_with_request_id(status, &request_id));
227 }
228 let session_id = match Uuid::parse_str(&request.get_ref().session_id) {
229 Ok(session_id) => session_id,
230 Err(_) => {
231 let status = Self::status_with_request_id(
232 Status::invalid_argument("invalid session_id"),
233 &request_id,
234 );
235 self.observe("RevokeSession", status.code());
236 return Err(status);
237 }
238 };
239 match self.state.db.revoke_session(session_id).await {
240 Ok(()) => Ok(self.response_with_request_id(
241 "RevokeSession",
242 Response::new(proto::RevokeSessionResponse {
243 revoked: true,
244 request_id: request_id.clone(),
245 }),
246 &request_id,
247 )),
248 Err(error) => {
249 let status = self.map_error(error, &request_id);
250 self.observe("RevokeSession", status.code());
251 Err(status)
252 }
253 }
254 }
255
256 async fn remember(
257 &self,
258 request: Request<proto::RememberRequest>,
259 ) -> Result<Response<proto::RememberResponse>, Status> {
260 let request_id = Self::request_id(&request);
261 let session = match self.session_from_request(&request).await {
262 Ok((_, session)) => session,
263 Err(status) => {
264 let status = Self::status_with_request_id(status, &request_id);
265 self.observe("Remember", status.code());
266 return Err(status);
267 }
268 };
269 match self
270 .state
271 .db
272 .remember(&session, &request.get_ref().content)
273 .await
274 {
275 Ok(remembered) => Ok(self.response_with_request_id(
276 "Remember",
277 Response::new(proto::RememberResponse {
278 memory_id: remembered.memory_id.to_string(),
279 indexed: remembered.indexed,
280 request_id: request_id.clone(),
281 }),
282 &request_id,
283 )),
284 Err(error) => {
285 let status = self.map_error(error, &request_id);
286 self.observe("Remember", status.code());
287 Err(status)
288 }
289 }
290 }
291
292 async fn remember_typed(
293 &self,
294 request: Request<proto::RememberTypedRequest>,
295 ) -> Result<Response<proto::RememberResponse>, Status> {
296 let request_id = Self::request_id(&request);
297 let session = match self.session_from_request(&request).await {
298 Ok((_, session)) => session,
299 Err(status) => {
300 let status = Self::status_with_request_id(status, &request_id);
301 self.observe("RememberTyped", status.code());
302 return Err(status);
303 }
304 };
305 let inner = request.into_inner();
306 let metadata = if inner.metadata_json.trim().is_empty() {
307 serde_json::Value::Null
308 } else {
309 match serde_json::from_str(&inner.metadata_json) {
310 Ok(metadata) => metadata,
311 Err(_) => {
312 let status = Self::status_with_request_id(
313 Status::invalid_argument("invalid metadata_json"),
314 &request_id,
315 );
316 self.observe("RememberTyped", status.code());
317 return Err(status);
318 }
319 }
320 };
321 match self
322 .state
323 .db
324 .remember_typed(
325 &session,
326 &inner.content,
327 &inner.r#type,
328 &inner.tags,
329 metadata,
330 )
331 .await
332 {
333 Ok(remembered) => Ok(self.response_with_request_id(
334 "RememberTyped",
335 Response::new(proto::RememberResponse {
336 memory_id: remembered.memory_id.to_string(),
337 indexed: remembered.indexed,
338 request_id: request_id.clone(),
339 }),
340 &request_id,
341 )),
342 Err(error) => {
343 let status = self.map_error(error, &request_id);
344 self.observe("RememberTyped", status.code());
345 Err(status)
346 }
347 }
348 }
349
350 async fn search(
351 &self,
352 request: Request<proto::SearchRequest>,
353 ) -> Result<Response<proto::SearchResponse>, Status> {
354 let request_id = Self::request_id(&request);
355 let session = match self.session_from_request(&request).await {
356 Ok((_, session)) => session,
357 Err(status) => {
358 let status = Self::status_with_request_id(status, &request_id);
359 self.observe("Search", status.code());
360 return Err(status);
361 }
362 };
363 let inner = request.into_inner();
364 let filter = if inner.filter_json.trim().is_empty() {
365 None
366 } else {
367 match serde_json::from_str(&inner.filter_json) {
368 Ok(filter) => Some(filter),
369 Err(_) => {
370 let status = Self::status_with_request_id(
371 Status::invalid_argument("invalid filter_json"),
372 &request_id,
373 );
374 self.observe("Search", status.code());
375 return Err(status);
376 }
377 }
378 };
379 match self
380 .state
381 .db
382 .search_with_options(
383 &session,
384 &inner.query,
385 inner.top_k.max(1) as usize,
386 inner.semantic,
387 filter,
388 )
389 .await
390 {
391 Ok(hits) => Ok(self.response_with_request_id(
392 "Search",
393 Response::new(proto::SearchResponse {
394 hits: hits
395 .into_iter()
396 .map(|hit| proto::SearchHit {
397 id: hit.id.to_string(),
398 score: hit.score,
399 content: hit.content,
400 memory_type: hit.memory_type,
401 tags: hit.tags,
402 metadata_json: hit.metadata.to_string(),
403 })
404 .collect(),
405 request_id: request_id.clone(),
406 }),
407 &request_id,
408 )),
409 Err(error) => {
410 let status = self.map_error(error, &request_id);
411 self.observe("Search", status.code());
412 Err(status)
413 }
414 }
415 }
416
417 async fn recall(
418 &self,
419 request: Request<proto::RecallRequest>,
420 ) -> Result<Response<proto::RecallResponse>, Status> {
421 let request_id = Self::request_id(&request);
422 let session = match self.session_from_request(&request).await {
423 Ok((_, session)) => session,
424 Err(status) => {
425 let status = Self::status_with_request_id(status, &request_id);
426 self.observe("Recall", status.code());
427 return Err(status);
428 }
429 };
430 let mut ids = Vec::with_capacity(request.get_ref().memory_ids.len());
431 for id in &request.get_ref().memory_ids {
432 match Uuid::parse_str(id) {
433 Ok(parsed) => ids.push(parsed),
434 Err(_) => {
435 let status = Self::status_with_request_id(
436 Status::invalid_argument("invalid memory_id"),
437 &request_id,
438 );
439 self.observe("Recall", status.code());
440 return Err(status);
441 }
442 }
443 }
444 match self.state.db.recall(&session, &ids).await {
445 Ok(memories) => Ok(self.response_with_request_id(
446 "Recall",
447 Response::new(proto::RecallResponse {
448 memories: memories
449 .into_iter()
450 .map(|memory| proto::MemoryRecord {
451 id: memory.id.to_string(),
452 content: memory.content,
453 memory_type: memory.memory_type.as_str().to_string(),
454 tags: memory.tags,
455 })
456 .collect(),
457 request_id: request_id.clone(),
458 }),
459 &request_id,
460 )),
461 Err(error) => {
462 let status = self.map_error(error, &request_id);
463 self.observe("Recall", status.code());
464 Err(status)
465 }
466 }
467 }
468
469 async fn branch(
470 &self,
471 request: Request<proto::BranchRequest>,
472 ) -> Result<Response<proto::BranchResponse>, Status> {
473 let request_id = Self::request_id(&request);
474 let session = match self.session_from_request(&request).await {
475 Ok((_, session)) => session,
476 Err(status) => {
477 let status = Self::status_with_request_id(status, &request_id);
478 self.observe("Branch", status.code());
479 return Err(status);
480 }
481 };
482 let inner = request.into_inner();
483 let branch_id = if inner.from.is_empty() {
484 self.state.db.branch(&session, &inner.name).await
485 } else {
486 match Uuid::parse_str(&inner.from) {
487 Ok(parent) => {
488 self.state
489 .db
490 .fork_branch(&session, parent, &inner.name)
491 .await
492 }
493 Err(_) => {
494 let status = Self::status_with_request_id(
495 Status::invalid_argument("invalid from branch"),
496 &request_id,
497 );
498 self.observe("Branch", status.code());
499 return Err(status);
500 }
501 }
502 };
503 match branch_id {
504 Ok(branch_id) => Ok(self.response_with_request_id(
505 "Branch",
506 Response::new(proto::BranchResponse {
507 branch_id: branch_id.to_string(),
508 name: inner.name,
509 request_id: request_id.clone(),
510 }),
511 &request_id,
512 )),
513 Err(error) => {
514 let status = self.map_error(error, &request_id);
515 self.observe("Branch", status.code());
516 Err(status)
517 }
518 }
519 }
520
521 async fn merge(
522 &self,
523 request: Request<proto::MergeRequest>,
524 ) -> Result<Response<proto::MergeResponse>, Status> {
525 let request_id = Self::request_id(&request);
526 let session = match self.session_from_request(&request).await {
527 Ok((_, session)) => session,
528 Err(status) => {
529 let status = Self::status_with_request_id(status, &request_id);
530 self.observe("Merge", status.code());
531 return Err(status);
532 }
533 };
534 let source = match Uuid::parse_str(&request.get_ref().source) {
535 Ok(source) => source,
536 Err(_) => {
537 let status = Self::status_with_request_id(
538 Status::invalid_argument("invalid source"),
539 &request_id,
540 );
541 self.observe("Merge", status.code());
542 return Err(status);
543 }
544 };
545 let target = match Uuid::parse_str(&request.get_ref().target) {
546 Ok(target) => target,
547 Err(_) => {
548 let status = Self::status_with_request_id(
549 Status::invalid_argument("invalid target"),
550 &request_id,
551 );
552 self.observe("Merge", status.code());
553 return Err(status);
554 }
555 };
556 match self
557 .state
558 .db
559 .merge_with_strategy(
560 &session,
561 source,
562 target,
563 Self::parse_merge_strategy(&request.get_ref().strategy),
564 )
565 .await
566 {
567 Ok(result) => Ok(self.response_with_request_id(
568 "Merge",
569 Response::new(proto::MergeResponse {
570 success: result.success,
571 applied: result.applied,
572 skipped: result.skipped,
573 conflicts: result.conflicts.len() as u32,
574 duration_ms: result.duration_ms,
575 request_id: request_id.clone(),
576 }),
577 &request_id,
578 )),
579 Err(error) => {
580 let status = self.map_error(error, &request_id);
581 self.observe("Merge", status.code());
582 Err(status)
583 }
584 }
585 }
586
587 async fn diff(
588 &self,
589 request: Request<proto::DiffRequest>,
590 ) -> Result<Response<proto::DiffResponse>, Status> {
591 let request_id = Self::request_id(&request);
592 let session = match self.session_from_request(&request).await {
593 Ok((_, session)) => session,
594 Err(status) => {
595 let status = Self::status_with_request_id(status, &request_id);
596 self.observe("Diff", status.code());
597 return Err(status);
598 }
599 };
600 let branch_id = match Uuid::parse_str(&request.get_ref().branch_id) {
601 Ok(branch_id) => branch_id,
602 Err(_) => {
603 let status = Self::status_with_request_id(
604 Status::invalid_argument("invalid branch_id"),
605 &request_id,
606 );
607 self.observe("Diff", status.code());
608 return Err(status);
609 }
610 };
611 let target = match Uuid::parse_str(&request.get_ref().target) {
612 Ok(target) => target,
613 Err(_) => {
614 let status = Self::status_with_request_id(
615 Status::invalid_argument("invalid target"),
616 &request_id,
617 );
618 self.observe("Diff", status.code());
619 return Err(status);
620 }
621 };
622 match self.state.db.diff(&session, branch_id, target).await {
623 Ok(diff) => {
624 let diff_json = match serde_json::to_string(&diff)
625 .context("failed to serialize diff")
626 {
627 Ok(diff_json) => diff_json,
628 Err(_) => {
629 let status = Self::status_with_request_id(
630 Status::internal(format!("internal error; request_id={request_id}")),
631 &request_id,
632 );
633 self.observe("Diff", status.code());
634 return Err(status);
635 }
636 };
637 Ok(self.response_with_request_id(
638 "Diff",
639 Response::new(proto::DiffResponse {
640 added: diff.stats.added,
641 removed: diff.stats.removed,
642 modified: diff.stats.modified,
643 unchanged: diff.stats.unchanged,
644 divergence_score: diff.divergence_score as f32,
645 diff_json,
646 request_id: request_id.clone(),
647 }),
648 &request_id,
649 ))
650 }
651 Err(error) => {
652 let status = self.map_error(error, &request_id);
653 self.observe("Diff", status.code());
654 Err(status)
655 }
656 }
657 }
658
659 async fn sync(
660 &self,
661 request: Request<proto::SyncRequest>,
662 ) -> Result<Response<proto::SyncResponse>, Status> {
663 let request_id = Self::request_id(&request);
664 let session = match self.session_from_request(&request).await {
665 Ok((_, session)) => session,
666 Err(status) => {
667 let status = Self::status_with_request_id(status, &request_id);
668 self.observe("Sync", status.code());
669 return Err(status);
670 }
671 };
672 match self.state.db.sync(&session).await {
673 Ok(result) => Ok(self.response_with_request_id(
674 "Sync",
675 Response::new(proto::SyncResponse {
676 pushed: result.pushed,
677 pulled: result.pulled,
678 conflicts: result.conflicts,
679 duration_ms: result.duration_ms,
680 request_id: request_id.clone(),
681 }),
682 &request_id,
683 )),
684 Err(error) => {
685 let status = self.map_error(error, &request_id);
686 self.observe("Sync", status.code());
687 Err(status)
688 }
689 }
690 }
691
692 async fn reflect(
693 &self,
694 request: Request<proto::ReflectRequest>,
695 ) -> Result<Response<proto::ReflectResponse>, Status> {
696 let request_id = Self::request_id(&request);
697 let session = match self.session_from_request(&request).await {
698 Ok((_, session)) => session,
699 Err(status) => {
700 let status = Self::status_with_request_id(status, &request_id);
701 self.observe("Reflect", status.code());
702 return Err(status);
703 }
704 };
705 match self.state.db.reflect(&session).await {
706 Ok(result) => Ok(self.response_with_request_id(
707 "Reflect",
708 Response::new(proto::ReflectResponse {
709 job_id: result.job_id.unwrap_or_default(),
710 status: result.status,
711 message: result.message,
712 skipped: result.skipped,
713 request_id: request_id.clone(),
714 }),
715 &request_id,
716 )),
717 Err(error) => {
718 let status = self.map_error(error, &request_id);
719 self.observe("Reflect", status.code());
720 Err(status)
721 }
722 }
723 }
724
725 async fn begin_tx(
726 &self,
727 request: Request<proto::BeginTxRequest>,
728 ) -> Result<Response<proto::BeginTxResponse>, Status> {
729 let request_id = Self::request_id(&request);
730 let session = match self.session_from_request(&request).await {
731 Ok((_, session)) => session,
732 Err(status) => {
733 let status = Self::status_with_request_id(status, &request_id);
734 self.observe("BeginTx", status.code());
735 return Err(status);
736 }
737 };
738 let tx_id = Uuid::new_v4();
739 self.state
740 .transactions
741 .lock()
742 .await
743 .insert(tx_id, PendingTransaction { id: tx_id, session });
744 Ok(self.response_with_request_id(
745 "BeginTx",
746 Response::new(proto::BeginTxResponse {
747 tx_id: tx_id.to_string(),
748 request_id: request_id.clone(),
749 }),
750 &request_id,
751 ))
752 }
753
754 async fn commit_tx(
755 &self,
756 request: Request<proto::CommitTxRequest>,
757 ) -> Result<Response<proto::CommitTxResponse>, Status> {
758 let request_id = Self::request_id(&request);
759 let tx_id = match Uuid::parse_str(&request.get_ref().tx_id) {
760 Ok(tx_id) => tx_id,
761 Err(_) => {
762 let status = Self::status_with_request_id(
763 Status::invalid_argument("invalid tx_id"),
764 &request_id,
765 );
766 self.observe("CommitTx", status.code());
767 return Err(status);
768 }
769 };
770 let pending = match self.state.transactions.lock().await.remove(&tx_id) {
771 Some(pending) => pending,
772 None => {
773 let status = Self::status_with_request_id(
774 Status::not_found("transaction not found"),
775 &request_id,
776 );
777 self.observe("CommitTx", status.code());
778 return Err(status);
779 }
780 };
781 match self.state.db.transaction(&pending.session).await {
782 Ok(tx) => match tx.commit().await {
783 Ok(()) => Ok(self.response_with_request_id(
784 "CommitTx",
785 Response::new(proto::CommitTxResponse {
786 committed: true,
787 request_id: request_id.clone(),
788 }),
789 &request_id,
790 )),
791 Err(error) => {
792 let status = self.map_error(error, &request_id);
793 self.observe("CommitTx", status.code());
794 Err(status)
795 }
796 },
797 Err(error) => {
798 let status = self.map_error(error, &request_id);
799 self.observe("CommitTx", status.code());
800 Err(status)
801 }
802 }
803 }
804
805 async fn rollback_tx(
806 &self,
807 request: Request<proto::RollbackTxRequest>,
808 ) -> Result<Response<proto::RollbackTxResponse>, Status> {
809 let request_id = Self::request_id(&request);
810 let tx_id = match Uuid::parse_str(&request.get_ref().tx_id) {
811 Ok(tx_id) => tx_id,
812 Err(_) => {
813 let status = Self::status_with_request_id(
814 Status::invalid_argument("invalid tx_id"),
815 &request_id,
816 );
817 self.observe("RollbackTx", status.code());
818 return Err(status);
819 }
820 };
821 let pending = match self.state.transactions.lock().await.remove(&tx_id) {
822 Some(pending) => pending,
823 None => {
824 let status = Self::status_with_request_id(
825 Status::not_found("transaction not found"),
826 &request_id,
827 );
828 self.observe("RollbackTx", status.code());
829 return Err(status);
830 }
831 };
832 match self.state.db.transaction(&pending.session).await {
833 Ok(tx) => match tx.rollback().await {
834 Ok(()) => Ok(self.response_with_request_id(
835 "RollbackTx",
836 Response::new(proto::RollbackTxResponse {
837 rolled_back: true,
838 request_id: request_id.clone(),
839 }),
840 &request_id,
841 )),
842 Err(error) => {
843 let status = self.map_error(error, &request_id);
844 self.observe("RollbackTx", status.code());
845 Err(status)
846 }
847 },
848 Err(error) => {
849 let status = self.map_error(error, &request_id);
850 self.observe("RollbackTx", status.code());
851 Err(status)
852 }
853 }
854 }
855}