Skip to main content

celestia_grpc/
tx_client_impl.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use lumina_utils::executor::spawn;
7use prost::Message;
8use tokio::sync::{Mutex, RwLock, oneshot};
9use tokio_util::sync::CancellationToken;
10
11use celestia_types::any::IntoProtobufAny;
12use celestia_types::blob::{MsgPayForBlobs, RawBlobTx, RawMsgPayForBlobs};
13use celestia_types::hash::Hash;
14use celestia_types::state::ErrorCode;
15use celestia_types::state::RawTxBody;
16use celestia_types::state::auth::BaseAccount;
17
18use crate::grpc::{BroadcastMode, GasEstimate, TxStatus as GrpcTxStatus, TxStatusResponse};
19use crate::signer::sign_tx;
20
21use crate::tx_client_v2::{
22    ConfirmResult, NodeId, RejectionReason, SigningError, SigningFailure, StopError, SubmitError,
23    SubmitFailure, Transaction, TransactionWorker, TxCallbacks, TxConfirmResult, TxPayload,
24    TxRequest, TxServer, TxStatus, TxStatusKind, TxSubmitResult, TxSubmitter,
25};
26use crate::{Error, GrpcClient, Result, TxConfig, TxInfo};
27
28const BLOB_TX_TYPE_ID: &str = "BLOB";
29const SEQUENCE_ERROR_PAT: &str = "account sequence mismatch, expected ";
30const DEFAULT_MAX_STATUS_BATCH: usize = 16;
31const DEFAULT_QUEUE_CAPACITY: usize = 128;
32
33/// Handle returned after submitting a transaction, used to await confirmation.
34pub struct ConfirmHandle<ConfirmInfo, ConfirmResponse> {
35    /// Hash of the submitted transaction.
36    pub hash: Hash,
37    /// Receiver for the confirmation result.
38    pub confirmed: oneshot::Receiver<ConfirmResult<ConfirmInfo, Arc<Error>, ConfirmResponse>>,
39}
40
41impl ConfirmHandle<TxConfirmInfo, TxStatusResponse> {
42    /// Await confirmation of the transaction, returning the final [`TxInfo`] or a stop error.
43    pub async fn confirm(
44        self,
45    ) -> std::result::Result<TxInfo, StopError<Arc<Error>, TxConfirmInfo, TxStatusResponse>> {
46        // receiver errors or internal errors are invariant violations
47        let info = self.confirmed.await.expect("confirm receiver dropped")?;
48        Ok(info.info)
49    }
50}
51
52/// Configuration for the [`TransactionService`].
53///
54/// Warning: [`TransactionService`] is experimental and not recommended for use yet.
55pub struct TxServiceConfig {
56    /// List of nodes (id, client) to submit and confirm transactions through.
57    pub nodes: Vec<(NodeId, GrpcClient)>,
58    /// Interval between confirmation polling attempts.
59    pub confirm_interval: Duration,
60    /// Maximum number of transactions to query in a single status batch.
61    pub max_status_batch: usize,
62    /// Capacity of the pending transaction queue.
63    pub queue_capacity: usize,
64}
65
66impl TxServiceConfig {
67    /// Create a new config with defaults and the given node list.
68    pub fn new(nodes: Vec<(NodeId, GrpcClient)>) -> Self {
69        Self {
70            nodes,
71            confirm_interval: Duration::from_millis(TxConfig::default().confirmation_interval_ms),
72            max_status_batch: DEFAULT_MAX_STATUS_BATCH,
73            queue_capacity: DEFAULT_QUEUE_CAPACITY,
74        }
75    }
76}
77
78/// High-level service for submitting and confirming transactions across multiple nodes.
79///
80/// Warning: this service is experimental and not recommended for use yet.
81///
82/// Wraps a [`TransactionWorker`] and provides a simple async API for submitting blobs
83/// or raw transactions and awaiting their on-chain confirmation.
84#[derive(Clone)]
85pub struct TransactionService {
86    inner: Arc<TransactionServiceInner>,
87}
88
89struct TransactionServiceInner {
90    submitter: RwLock<TxSubmitter<Hash, TxConfirmInfo, TxStatusResponse, Arc<Error>, TxRequest>>,
91    worker: Mutex<Option<WorkerHandle>>,
92    clients: HashMap<NodeId, GrpcClient>,
93    primary_client: GrpcClient,
94    account: BaseAccount,
95    confirm_interval: Duration,
96    max_status_batch: usize,
97    queue_capacity: usize,
98}
99
100struct WorkerHandle {
101    done_rx: oneshot::Receiver<Result<()>>,
102}
103
104impl WorkerHandle {
105    fn is_finished(&mut self) -> bool {
106        match self.done_rx.try_recv() {
107            Ok(_) => true,
108            Err(oneshot::error::TryRecvError::Closed) => true,
109            Err(oneshot::error::TryRecvError::Empty) => false,
110        }
111    }
112}
113
114impl TransactionService {
115    /// Create a new transaction service with the given configuration.
116    pub async fn new(config: TxServiceConfig) -> Result<Self> {
117        let Some((_, client)) = config.nodes.first() else {
118            return Err(Error::UnexpectedResponseType(
119                "no grpc clients provided".to_string(),
120            ));
121        };
122        let client = client.clone();
123        let address = client.get_account_address().ok_or(Error::MissingSigner)?;
124        let account = client.get_account(&address).await?;
125        let account = BaseAccount::from(account);
126
127        let clients: HashMap<NodeId, GrpcClient> = config.nodes.into_iter().collect();
128        let (submitter, worker_handle) = Self::spawn_worker(
129            account.clone(),
130            &clients,
131            client.clone(),
132            config.confirm_interval,
133            config.max_status_batch,
134            config.queue_capacity,
135        )
136        .await?;
137
138        Ok(Self {
139            inner: Arc::new(TransactionServiceInner {
140                submitter: RwLock::new(submitter),
141                worker: Mutex::new(Some(worker_handle)),
142                clients,
143                primary_client: client,
144                account,
145                confirm_interval: config.confirm_interval,
146                max_status_batch: config.max_status_batch,
147                queue_capacity: config.queue_capacity,
148            }),
149        })
150    }
151
152    /// Submit a transaction and return a handle to await its confirmation.
153    pub async fn submit(
154        &self,
155        request: TxRequest,
156    ) -> Result<ConfirmHandle<TxConfirmInfo, TxStatusResponse>> {
157        let submitter = self.inner.submitter.read().await.clone();
158        let handle = submitter.add_tx(request).await?;
159        match handle.signed.await {
160            Ok(Ok(())) => {}
161            Ok(Err(err)) => return Err(err),
162            Err(_) => return Err(Error::TxWorkerStopped),
163        }
164        match handle.submitted.await {
165            Ok(Ok(hash)) => Ok(ConfirmHandle {
166                hash,
167                confirmed: handle.confirmed,
168            }),
169            Ok(Err(err)) => Err(err),
170            Err(_) => Err(Error::TxWorkerStopped),
171        }
172    }
173
174    /// Submit a transaction, restarting the worker if it has stopped.
175    pub async fn submit_restart(
176        &self,
177        request: TxRequest,
178    ) -> Result<ConfirmHandle<TxConfirmInfo, TxStatusResponse>> {
179        let retry_request = request.clone();
180        match self.submit(request).await {
181            Ok(handle) => Ok(handle),
182            Err(Error::TxWorkerStopped) => {
183                self.recreate_worker().await?;
184                self.submit(retry_request).await
185            }
186            Err(err) => Err(err),
187        }
188    }
189
190    /// Recreate the internal worker if it has stopped, allowing new submissions.
191    pub async fn recreate_worker(&self) -> Result<()> {
192        {
193            let mut worker_guard = self.inner.worker.lock().await;
194            if let Some(worker) = worker_guard.as_mut()
195                && !worker.is_finished()
196            {
197                return Err(Error::TxWorkerRunning);
198            }
199        }
200        let (submitter, worker_handle) = Self::spawn_worker(
201            self.inner.account.clone(),
202            &self.inner.clients,
203            self.inner.primary_client.clone(),
204            self.inner.confirm_interval,
205            self.inner.max_status_batch,
206            self.inner.queue_capacity,
207        )
208        .await?;
209        let mut submitter_guard = self.inner.submitter.write().await;
210        let mut worker_guard = self.inner.worker.lock().await;
211        *submitter_guard = submitter;
212        *worker_guard = Some(worker_handle);
213
214        Ok(())
215    }
216
217    async fn spawn_worker(
218        account: BaseAccount,
219        clients: &HashMap<NodeId, GrpcClient>,
220        primary_client: GrpcClient,
221        confirm_interval: Duration,
222        max_status_batch: usize,
223        queue_capacity: usize,
224    ) -> Result<(
225        TxSubmitter<Hash, TxConfirmInfo, TxStatusResponse, Arc<Error>, TxRequest>,
226        WorkerHandle,
227    )> {
228        let next_sequence = current_sequence(&primary_client).await?;
229        let confirmed_sequence = next_sequence.checked_sub(1);
230        let nodes = clients
231            .iter()
232            .map(|(node_id, client)| {
233                (
234                    node_id.clone(),
235                    Arc::new(NodeClient {
236                        node_id: node_id.clone(),
237                        client: client.clone(),
238                        account: account.clone(),
239                    }),
240                )
241            })
242            .collect::<HashMap<_, _>>();
243        let (submitter, mut worker) = TransactionWorker::new(
244            nodes,
245            confirm_interval,
246            max_status_batch,
247            confirmed_sequence,
248            queue_capacity,
249        );
250
251        let worker_shutdown = CancellationToken::new();
252        let (done_tx, done_rx) = oneshot::channel();
253        spawn(async move {
254            let result = worker.process(worker_shutdown).await;
255            let _ = done_tx.send(result);
256        });
257
258        Ok((submitter, WorkerHandle { done_rx }))
259    }
260}
261
262#[derive(Clone)]
263struct NodeClient {
264    node_id: NodeId,
265    client: GrpcClient,
266    account: BaseAccount,
267}
268
269#[async_trait]
270impl TxServer for NodeClient {
271    type TxId = Hash;
272    type ConfirmInfo = TxConfirmInfo;
273    type TxRequest = TxRequest;
274    type SubmitError = Arc<Error>;
275    type ConfirmResponse = TxStatusResponse;
276
277    async fn submit(
278        &self,
279        tx_bytes: Arc<Vec<u8>>,
280        _sequence: u64,
281    ) -> TxSubmitResult<Self::TxId, Self::SubmitError> {
282        let resp = match self
283            .client
284            .broadcast_tx(tx_bytes.to_vec(), BroadcastMode::Sync)
285            .await
286        {
287            Ok(resp) => resp,
288            Err(err) => {
289                return Err(SubmitFailure {
290                    mapped_error: map_submit_error_from_client_error(&err),
291                    original_error: Arc::new(err),
292                });
293            }
294        };
295
296        if resp.code == ErrorCode::Success {
297            return Ok(resp.txhash);
298        }
299
300        let original_error = Arc::new(Error::TxBroadcastFailed(
301            resp.txhash,
302            resp.code,
303            resp.raw_log.clone(),
304        ));
305        Err(SubmitFailure {
306            mapped_error: map_submit_error(resp.code, &resp.raw_log),
307            original_error,
308        })
309    }
310
311    async fn status_batch(
312        &self,
313        ids: Vec<Self::TxId>,
314    ) -> TxConfirmResult<
315        Vec<(
316            Self::TxId,
317            TxStatus<Self::ConfirmInfo, Self::ConfirmResponse>,
318        )>,
319    > {
320        let response = self.client.tx_status_batch(ids.clone()).await?;
321        let mut response_map = HashMap::new();
322        for result in response.statuses {
323            response_map.insert(result.hash, result.status);
324        }
325
326        let mut statuses = Vec::with_capacity(ids.len());
327        for hash in ids {
328            match response_map.remove(&hash) {
329                Some(status) => {
330                    let mapped = map_status_response(hash, status, self.node_id.as_ref())?;
331                    statuses.push((hash, mapped));
332                }
333                None => {
334                    return Err(Error::UnexpectedResponseType(format!(
335                        "missing status for tx {:?}",
336                        hash
337                    )));
338                }
339            }
340        }
341
342        Ok(statuses)
343    }
344
345    async fn status(
346        &self,
347        id: Self::TxId,
348    ) -> TxConfirmResult<TxStatus<Self::ConfirmInfo, Self::ConfirmResponse>> {
349        let response = self.client.tx_status(id).await?;
350        map_status_response(id, response, self.node_id.as_ref())
351    }
352
353    async fn current_sequence(&self) -> Result<u64> {
354        current_sequence(&self.client).await
355    }
356
357    async fn simulate_and_sign(
358        &self,
359        req: Arc<Self::TxRequest>,
360        sequence: u64,
361    ) -> std::result::Result<
362        Transaction<Self::TxId, Self::ConfirmInfo, Self::ConfirmResponse, Self::SubmitError>,
363        SigningFailure<Self::SubmitError>,
364    > {
365        sign_with_client(self.account.clone(), &self.client, req.as_ref(), sequence)
366            .await
367            .map_err(map_signing_failure)
368    }
369}
370
371async fn sign_with_client(
372    mut account: BaseAccount,
373    client: &GrpcClient,
374    request: &TxRequest,
375    sequence: u64,
376) -> Result<Transaction<Hash, TxConfirmInfo, TxStatusResponse, Arc<Error>>> {
377    let (pubkey, signer) = client.signer()?;
378    account.sequence = sequence;
379
380    let chain_id = client.chain_id().await?;
381    let cfg = &request.cfg;
382    let (tx_body, blobs) = match &request.tx {
383        TxPayload::Blobs(blobs) => {
384            let pfb =
385                MsgPayForBlobs::new(blobs, account.address).map_err(Error::CelestiaTypesError)?;
386            let tx_body = RawTxBody {
387                messages: vec![RawMsgPayForBlobs::from(pfb).into_any()],
388                memo: cfg.memo.clone().unwrap_or_default(),
389                ..RawTxBody::default()
390            };
391            (tx_body, Some(blobs))
392        }
393        TxPayload::Tx(body) => (body.clone(), None),
394    };
395
396    let (gas_limit, gas_price) = match cfg.gas_limit {
397        Some(gas_limit) => {
398            let gas_price = match cfg.gas_price {
399                Some(price) => price,
400                None => client.estimate_gas_price(cfg.priority).await?,
401            };
402            (gas_limit, gas_price)
403        }
404        None => {
405            let probe_tx = sign_tx(
406                tx_body.clone(),
407                chain_id.clone(),
408                &account,
409                &pubkey,
410                &signer,
411                0,
412                1,
413            )
414            .await?;
415            let GasEstimate { price, usage } = client
416                .estimate_gas_price_and_usage(cfg.priority, probe_tx.encode_to_vec())
417                .await?;
418            let gas_price = cfg.gas_price.unwrap_or(price);
419            (usage, gas_price)
420        }
421    };
422    let fee = (gas_limit as f64 * gas_price).ceil() as u64;
423
424    let tx = sign_tx(
425        tx_body, chain_id, &account, &pubkey, &signer, gas_limit, fee,
426    )
427    .await?;
428    let bytes = match blobs {
429        Some(blobs) => {
430            let blob_tx = RawBlobTx {
431                tx: tx.encode_to_vec(),
432                blobs: blobs.iter().cloned().map(Into::into).collect(),
433                type_id: BLOB_TX_TYPE_ID.to_string(),
434            };
435            blob_tx.encode_to_vec()
436        }
437        None => tx.encode_to_vec(),
438    };
439    Ok(Transaction {
440        sequence,
441        bytes: Arc::new(bytes),
442        callbacks: TxCallbacks::default(),
443        id: None,
444    })
445}
446
447/// Confirmation information returned when a transaction is committed.
448#[derive(Debug, Clone)]
449pub struct TxConfirmInfo {
450    /// Basic transaction info (hash and block height).
451    pub info: TxInfo,
452    /// Execution result code from the chain.
453    pub execution_code: ErrorCode,
454}
455
456fn map_status_response(
457    hash: Hash,
458    response: TxStatusResponse,
459    node_id: &str,
460) -> Result<TxStatus<TxConfirmInfo, TxStatusResponse>> {
461    let original_response = response.clone();
462    match response.status {
463        GrpcTxStatus::Committed => Ok(TxStatus::new(
464            TxStatusKind::Confirmed {
465                info: TxConfirmInfo {
466                    info: TxInfo {
467                        hash,
468                        height: response.height.value(),
469                    },
470                    execution_code: response.execution_code,
471                },
472            },
473            original_response,
474        )),
475        GrpcTxStatus::Rejected => {
476            if is_wrong_sequence(response.execution_code) {
477                let Some(expected) = extract_sequence_on_mismatch(&response.error) else {
478                    return Ok(TxStatus::new(
479                        TxStatusKind::Rejected {
480                            reason: RejectionReason::OtherReason {
481                                error_code: response.execution_code,
482                                message: response.error.clone(),
483                                node_id: Arc::from(node_id),
484                            },
485                        },
486                        original_response,
487                    ));
488                };
489                Ok(TxStatus::new(
490                    TxStatusKind::Rejected {
491                        reason: RejectionReason::SequenceMismatch {
492                            expected,
493                            node_id: Arc::from(node_id),
494                        },
495                    },
496                    original_response,
497                ))
498            } else {
499                Ok(TxStatus::new(
500                    TxStatusKind::Rejected {
501                        reason: RejectionReason::OtherReason {
502                            error_code: response.execution_code,
503                            message: response.error.clone(),
504                            node_id: Arc::from(node_id),
505                        },
506                    },
507                    original_response,
508                ))
509            }
510        }
511        GrpcTxStatus::Evicted => Ok(TxStatus::new(TxStatusKind::Evicted, original_response)),
512        GrpcTxStatus::Pending => Ok(TxStatus::new(TxStatusKind::Pending, original_response)),
513        GrpcTxStatus::Unknown => Ok(TxStatus::new(TxStatusKind::Unknown, original_response)),
514    }
515}
516
517async fn current_sequence(client: &GrpcClient) -> Result<u64> {
518    let address = client.get_account_address().ok_or(Error::MissingSigner)?;
519    let account = client.get_account(&address).await?;
520    Ok(account.sequence)
521}
522
523fn map_submit_error(code: ErrorCode, message: &str) -> SubmitError {
524    if is_wrong_sequence(code)
525        && let Some(expected) = extract_sequence_on_mismatch(message)
526    {
527        return SubmitError::SequenceMismatch { expected };
528    }
529    if code == ErrorCode::InsufficientFee
530        && let Some(expected_fee) = extract_expected_fee(message)
531    {
532        return SubmitError::InsufficientFee {
533            expected_fee,
534            message: message.to_string(),
535        };
536    }
537
538    match code {
539        ErrorCode::MempoolIsFull => SubmitError::MempoolIsFull,
540        ErrorCode::TxInMempoolCache => SubmitError::TxInMempoolCache,
541        _ => SubmitError::Other {
542            error_code: code,
543            message: message.to_string(),
544        },
545    }
546}
547
548fn map_submit_error_from_client_error(err: &Error) -> SubmitError {
549    if let Error::TonicError(status) = err {
550        let message = status.message();
551        let lower = message.to_ascii_lowercase();
552        if let Some(expected) = extract_sequence_on_mismatch(message) {
553            return SubmitError::SequenceMismatch { expected };
554        }
555        if let Some(expected_fee) = extract_expected_fee(message) {
556            return SubmitError::InsufficientFee {
557                expected_fee,
558                message: message.to_string(),
559            };
560        }
561        if lower.contains("tx already exists in mempool") {
562            return SubmitError::TxInMempoolCache;
563        }
564        if lower.contains("mempool is full") {
565            return SubmitError::MempoolIsFull;
566        }
567    }
568
569    SubmitError::NetworkError
570}
571
572fn extract_expected_fee(message: &str) -> Option<u64> {
573    let patterns = ["required fee:", "required:"];
574    for pattern in patterns {
575        if let Some(index) = message.find(pattern) {
576            let rest = &message[index + pattern.len()..];
577            return parse_leading_digits(rest);
578        }
579    }
580    None
581}
582
583fn parse_leading_digits(input: &str) -> Option<u64> {
584    let digits: String = input
585        .trim_start()
586        .chars()
587        .take_while(|c| c.is_ascii_digit())
588        .collect();
589    if digits.is_empty() {
590        None
591    } else {
592        digits.parse().ok()
593    }
594}
595
596fn map_signing_failure(err: Error) -> SigningFailure<Arc<Error>> {
597    if err.is_network_error() {
598        return SigningFailure {
599            mapped_error: SigningError::NetworkError,
600            original_error: Arc::new(err),
601        };
602    }
603    if let Some(expected) = extract_sequence_on_mismatch(&err.to_string()) {
604        return SigningFailure {
605            mapped_error: SigningError::SequenceMismatch { expected },
606            original_error: Arc::new(err),
607        };
608    }
609    SigningFailure {
610        mapped_error: SigningError::Other {
611            message: err.to_string(),
612        },
613        original_error: Arc::new(err),
614    }
615}
616
617fn is_wrong_sequence(code: ErrorCode) -> bool {
618    code == ErrorCode::InvalidSequence || code == ErrorCode::WrongSequence
619}
620
621fn extract_sequence_on_mismatch(msg: &str) -> Option<u64> {
622    msg.contains(SEQUENCE_ERROR_PAT)
623        .then(|| extract_sequence(msg))
624        .and_then(|res| res.ok())
625}
626
627fn extract_sequence(msg: &str) -> Result<u64> {
628    let (_, msg_with_sequence) = msg
629        .split_once(SEQUENCE_ERROR_PAT)
630        .ok_or_else(|| Error::SequenceParsingFailed(msg.into()))?;
631    let (sequence, _) = msg_with_sequence
632        .split_once(',')
633        .ok_or_else(|| Error::SequenceParsingFailed(msg.into()))?;
634    sequence
635        .parse()
636        .map_err(|_| Error::SequenceParsingFailed(msg.into()))
637}
638
639#[cfg(all(test, not(target_arch = "wasm32")))]
640mod tests {
641    use std::ops::RangeInclusive;
642    use std::sync::Arc;
643
644    use super::*;
645    use crate::GrpcClient;
646    use crate::test_utils::{CELESTIA_GRPC_URL, load_account, new_tx_client};
647    use celestia_types::Blob;
648    use celestia_types::nmt::Namespace;
649    use lumina_utils::test_utils::async_test;
650    use rand::{Rng, RngCore};
651
652    #[async_test]
653    async fn submit_with_worker_and_confirm() {
654        let (_lock, _client) = new_tx_client().await;
655        let account = load_account();
656        let client = GrpcClient::builder()
657            .url(CELESTIA_GRPC_URL)
658            .signer_keypair(account.signing_key)
659            .build()
660            .unwrap();
661
662        let service =
663            TransactionService::new(TxServiceConfig::new(vec![(Arc::from("default"), client)]))
664                .await
665                .unwrap();
666        let handle = service
667            .submit(TxRequest::blobs(
668                vec![random_blob(10..=1000)],
669                TxConfig::default(),
670            ))
671            .await
672            .unwrap();
673        let info = handle.confirm().await.unwrap();
674        assert!(info.height > 0);
675    }
676
677    fn random_blob(size: RangeInclusive<usize>) -> Blob {
678        let rng = &mut rand::thread_rng();
679
680        let mut ns_bytes = vec![0u8; 10];
681        rng.fill_bytes(&mut ns_bytes);
682        let namespace = Namespace::new_v0(&ns_bytes).unwrap();
683
684        let len = rng.gen_range(size);
685        let mut blob = vec![0; len];
686        rng.fill_bytes(&mut blob);
687        blob.resize(len, 1);
688
689        Blob::new(namespace, blob, None).unwrap()
690    }
691
692    #[test]
693    fn extract_expected_fee_parses_required_fee() {
694        let message = "insufficient fee; got: 123utest required fee: 456utest";
695        assert_eq!(super::extract_expected_fee(message), Some(456));
696    }
697
698    #[test]
699    fn extract_expected_fee_parses_required_fallback() {
700        let message = "insufficient fee; got: 123utest required: 789utest";
701        assert_eq!(super::extract_expected_fee(message), Some(789));
702    }
703
704    #[test]
705    fn extract_expected_fee_returns_none_when_missing() {
706        let message = "insufficient fee; got: 123utest";
707        assert_eq!(super::extract_expected_fee(message), None);
708    }
709}