bitcoin_rpc_midas/transport/batch_transport.rs
1// transport/src/batch_transport.rs
2
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::{Arc, Mutex};
5
6use serde_json::{json, Value};
7use thiserror::Error;
8
9use super::{TransportError, TransportTrait};
10
11/// Errors that can occur during batch operations
12#[derive(Debug, Error)]
13pub enum BatchError {
14 /// Error from the underlying transport
15 #[error("Transport error: {0}")]
16 Transport(#[from] crate::TransportError),
17
18 /// Error parsing batch response
19 #[error("Invalid batch response: {0}")]
20 InvalidResponse(String),
21
22 /// Error when trying to end a batch that hasn't been started
23 #[error("No batch in progress")]
24 NoBatchInProgress,
25
26 /// Error from a specific RPC call in the batch
27 #[error("RPC error in batch: {0}")]
28 Rpc(Value),
29}
30
31/// A transport wrapper that supports batching multiple RPC calls into a single request.
32///
33/// # Usage
34/// 1. Call [`begin_batch`] to start collecting requests.
35/// 2. Use [`send_request`] to queue each RPC call (these will not be sent immediately).
36/// 3. Call [`end_batch`] to send all queued requests as a single JSON-RPC batch and receive results.
37///
38/// This struct is thread-safe and can be shared between threads. Each batch is associated with a unique set of request IDs.
39pub struct BatchTransport {
40 inner: Arc<dyn TransportTrait>,
41 batch: Arc<Mutex<Option<Vec<BatchRequest>>>>,
42 /// Used to generate unique IDs for each request in a batch.
43 next_id: AtomicU64,
44}
45
46/// Represents a single RPC request that has been queued for batch processing.
47///
48/// This struct holds the method name, parameters, and a unique identifier for each
49/// request that will be sent as part of a JSON-RPC batch. The `id` field ensures
50/// that responses can be matched back to their corresponding requests when the
51/// batch is executed.
52pub struct BatchRequest {
53 method: String,
54 params: Vec<Value>,
55 id: usize,
56}
57
58impl BatchTransport {
59 /// Create a new batch transport that wraps the given transport
60 pub fn new(inner: Arc<dyn TransportTrait>) -> Self {
61 Self { inner, batch: Arc::new(Mutex::new(None)), next_id: AtomicU64::new(0) }
62 }
63
64 /// Begin collecting requests into a batch.
65 ///
66 /// Any subsequent calls to [`send_request`] will be queued until [`end_batch`] is called.
67 /// If a batch is already in progress, it will be replaced (and any queued requests will be lost).
68 pub fn begin_batch(&self) {
69 let mut batch = self.batch.lock().unwrap();
70 *batch = Some(Vec::new());
71 }
72
73 /// End the current batch and send all collected requests as a single JSON-RPC batch.
74 ///
75 /// Returns a vector of results in the same order as the requests were queued.
76 /// If any request in the batch fails, the entire batch fails.
77 ///
78 /// # Errors
79 /// - Returns [`BatchError::NoBatchInProgress`] if no batch was started.
80 /// - Returns [`BatchError::Transport`] if the underlying transport fails.
81 /// - Returns [`BatchError::Rpc`] if any RPC call in the batch returns an error.
82 pub async fn end_batch(&self) -> Result<Vec<Value>, BatchError> {
83 // 1) Take the queued calls
84 let requests = {
85 let mut b = self.batch.lock().unwrap();
86 b.take().ok_or(BatchError::NoBatchInProgress)?
87 };
88 if requests.is_empty() {
89 return Ok(vec![]);
90 }
91
92 // 2) Build the JSON-RPC batch array
93 let batch_json: Vec<Value> = requests
94 .iter()
95 .map(|req| {
96 json!({
97 "jsonrpc": "2.0",
98 "id": req.id,
99 "method": req.method,
100 "params": req.params,
101 })
102 })
103 .collect();
104
105 // 3) Fire the HTTP request (this re-uses your auth'd DefaultTransport behind the scenes,
106 // so you don't need to think about headers or basic_auth here)
107 let resp = self.inner.send_batch(&batch_json).await.map_err(BatchError::Transport)?;
108
109 // 4) Parse the array of responses
110 let arr: Vec<Value> = resp;
111
112 // 5) Extract each "result" or bail on the first error
113 let mut results = Vec::with_capacity(arr.len());
114 for obj in arr {
115 if let Some(err) = obj.get("error") {
116 return Err(BatchError::Rpc(err.clone()));
117 }
118 // assume "result" is present
119 results.push(obj.get("result").cloned().unwrap_or(Value::Null));
120 }
121 Ok(results)
122 }
123
124 /// Check if a batch is currently in progress.
125 pub fn is_batching(&self) -> bool { self.batch.lock().unwrap().is_some() }
126}
127
128impl TransportTrait for BatchTransport {
129 /// Queue a request if batching, or send immediately if not batching.
130 ///
131 /// # Returns
132 /// - If batching, always returns an error future, since results are only available after [`end_batch`].
133 /// - If not batching, delegates to the inner transport and returns the result.
134 fn send_request<'a>(
135 &'a self,
136 method: &'a str,
137 params: &'a [Value],
138 ) -> std::pin::Pin<
139 Box<dyn std::future::Future<Output = Result<Value, crate::TransportError>> + Send + 'a>,
140 > {
141 let mut batch = self.batch.lock().unwrap();
142
143 // If we're not in a batch, send immediately
144 if batch.is_none() {
145 drop(batch);
146 return Box::pin(self.inner.send_request(method, params));
147 }
148
149 // Add to batch without channel
150 let id = self.next_id.fetch_add(1, Ordering::SeqCst) as usize;
151 batch.as_mut().unwrap().push(BatchRequest {
152 method: method.to_string(),
153 params: params.to_vec(),
154 id,
155 });
156
157 // Return a future that immediately returns an error since we can't wait for the batch
158 // This is by design: end_batch() must be called to get results for all queued requests.
159 Box::pin(async move {
160 Err(TransportError::Rpc(
161 "Cannot wait for individual request result in batch mode. Use end_batch() to get all results.".to_string()
162 ))
163 })
164 }
165
166 fn send_batch<'a>(
167 &'a self,
168 bodies: &'a [Value],
169 ) -> std::pin::Pin<
170 Box<
171 dyn std::future::Future<Output = Result<Vec<Value>, crate::TransportError>> + Send + 'a,
172 >,
173 > {
174 // Delegate to the inner transport's send_batch method
175 Box::pin(self.inner.send_batch(bodies))
176 }
177
178 fn url(&self) -> &str { self.inner.url() }
179}