1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Arc, Mutex};
4
5use tokio::sync::mpsc;
6use tokio_stream::StreamExt;
7use tokio_stream::wrappers::ReceiverStream;
8use tonic::{Request, Response, Status, Streaming};
9use tracing::{debug, info, warn};
10
11use crate::events::StateProviderEvent;
12use crate::proto::replicator_data_server::ReplicatorData;
13use crate::proto::{
14 CopyItem, CopyStreamResponse, GetCopyContextRequest, GetCopyContextResponse, RawOperation,
15 ReplicationAck, ReplicationItem,
16};
17use crate::types::{Epoch, Lsn, Operation};
18
19pub struct SecondaryReceiver {
26 state: Arc<SecondaryState>,
27 partition_state: Option<Arc<crate::handles::PartitionState>>,
29 operation_tx: Option<mpsc::Sender<Operation>>,
32 copy_stream_tx: Option<Mutex<Option<mpsc::Sender<Operation>>>>,
34 state_provider_tx: Option<mpsc::UnboundedSender<StateProviderEvent>>,
36}
37
38pub struct SecondaryState {
39 current_epoch: Mutex<Epoch>,
40 log: Mutex<HashMap<Lsn, bytes::Bytes>>,
42 received_lsn: AtomicI64,
44 committed_lsn: AtomicI64,
46}
47
48impl SecondaryState {
49 pub fn new() -> Self {
50 Self {
51 current_epoch: Mutex::new(Epoch::default()),
52 log: Mutex::new(HashMap::new()),
53 received_lsn: AtomicI64::new(0),
54 committed_lsn: AtomicI64::new(0),
55 }
56 }
57
58 pub fn received_lsn(&self) -> Lsn {
59 self.received_lsn.load(Ordering::Acquire)
60 }
61
62 pub fn committed_lsn(&self) -> Lsn {
63 self.committed_lsn.load(Ordering::Acquire)
64 }
65
66 pub fn set_committed_lsn(&self, lsn: Lsn) {
67 self.committed_lsn.store(lsn, Ordering::Release);
68 }
69
70 pub fn update_epoch(&self, new_epoch: Epoch) {
71 let mut epoch = self.current_epoch.lock().unwrap();
72 *epoch = new_epoch;
73
74 let committed = self.committed_lsn.load(Ordering::Acquire);
76 let mut log = self.log.lock().unwrap();
77 log.retain(|lsn, _| *lsn <= committed);
78
79 let new_received = committed.max(self.received_lsn.load(Ordering::Acquire).min(committed));
80 self.received_lsn.store(new_received, Ordering::Release);
81 }
82
83 pub fn log_len(&self) -> usize {
84 self.log.lock().unwrap().len()
85 }
86
87 pub fn get(&self, lsn: Lsn) -> Option<bytes::Bytes> {
88 self.log.lock().unwrap().get(&lsn).cloned()
89 }
90
91 fn accept_item(&self, item: &ReplicationItem) -> Result<(), Status> {
92 let epoch = self.current_epoch.lock().unwrap();
93 let item_epoch = Epoch::new(item.epoch_data_loss, item.epoch_config);
94
95 if item_epoch < *epoch {
96 return Err(Status::failed_precondition(format!(
97 "stale epoch: got {:?}, current {:?}",
98 item_epoch, *epoch
99 )));
100 }
101
102 drop(epoch);
103
104 let mut log = self.log.lock().unwrap();
105 log.insert(item.lsn, bytes::Bytes::copy_from_slice(&item.data));
106
107 let prev = self.received_lsn.load(Ordering::Acquire);
108 if item.lsn > prev {
109 self.received_lsn.store(item.lsn, Ordering::Release);
110 }
111
112 if item.committed_lsn > self.committed_lsn.load(Ordering::Acquire) {
115 self.committed_lsn
116 .store(item.committed_lsn, Ordering::Release);
117 }
118
119 Ok(())
120 }
121}
122
123impl Default for SecondaryState {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl SecondaryReceiver {
130 pub fn new(state: Arc<SecondaryState>) -> Self {
132 Self {
133 state,
134 partition_state: None,
135 operation_tx: None,
136 copy_stream_tx: None,
137 state_provider_tx: None,
138 }
139 }
140
141 pub fn with_streams(
143 state: Arc<SecondaryState>,
144 partition_state: Arc<crate::handles::PartitionState>,
145 operation_tx: mpsc::Sender<Operation>,
146 copy_stream_tx: mpsc::Sender<Operation>,
147 state_provider_tx: mpsc::UnboundedSender<StateProviderEvent>,
148 ) -> Self {
149 Self {
150 state,
151 partition_state: Some(partition_state),
152 operation_tx: Some(operation_tx),
153 copy_stream_tx: Some(Mutex::new(Some(copy_stream_tx))),
154 state_provider_tx: Some(state_provider_tx),
155 }
156 }
157}
158
159#[tonic::async_trait]
160impl ReplicatorData for SecondaryReceiver {
161 type ReplicationStreamStream = ReceiverStream<Result<ReplicationAck, Status>>;
162
163 async fn replication_stream(
164 &self,
165 request: Request<Streaming<ReplicationItem>>,
166 ) -> Result<Response<Self::ReplicationStreamStream>, Status> {
167 let mut inbound = request.into_inner();
168 let state = self.state.clone();
169 let partition_state = self.partition_state.clone();
170 let (ack_tx, ack_rx) = mpsc::channel(256);
171 let operation_tx = self.operation_tx.clone();
172
173 tokio::spawn(async move {
174 while let Some(result) = inbound.next().await {
175 match result {
176 Ok(item) => {
177 let lsn = item.lsn;
178 match state.accept_item(&item) {
179 Ok(()) => {
180 debug!(lsn, "accepted replication item");
181
182 if let Some(ref ps) = partition_state
186 && item.committed_lsn > ps.committed_lsn()
187 {
188 ps.set_committed_lsn(item.committed_lsn);
189 }
190
191 if let Some(ref op_tx) = operation_tx {
192 let (user_ack_tx, user_ack_rx) =
194 tokio::sync::oneshot::channel();
195 let op = Operation::new(
196 lsn,
197 bytes::Bytes::copy_from_slice(&item.data),
198 Some(user_ack_tx),
199 );
200 if op_tx.send(op).await.is_err() {
201 warn!(lsn, "operation stream closed");
202 break;
203 }
204 let ack_tx = ack_tx.clone();
205 tokio::spawn(async move {
206 if user_ack_rx.await.is_ok() {
207 let _ = ack_tx.send(Ok(ReplicationAck { lsn })).await;
208 }
209 });
210 } else {
211 if ack_tx.send(Ok(ReplicationAck { lsn })).await.is_err() {
213 break;
214 }
215 }
216 }
217 Err(status) => {
218 warn!(
219 lsn,
220 error = %status.message(),
221 "rejected replication item"
222 );
223 break;
224 }
225 }
226 }
227 Err(e) => {
228 warn!(error = %e, "replication stream error");
229 break;
230 }
231 }
232 }
233 });
234
235 Ok(Response::new(ReceiverStream::new(ack_rx)))
236 }
237
238 async fn get_copy_context(
239 &self,
240 _request: Request<GetCopyContextRequest>,
241 ) -> Result<Response<GetCopyContextResponse>, Status> {
242 let Some(ref sp_tx) = self.state_provider_tx else {
243 return Ok(Response::new(GetCopyContextResponse { operations: vec![] }));
245 };
246
247 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
249 sp_tx
250 .send(StateProviderEvent::GetCopyContext { reply: reply_tx })
251 .map_err(|_| Status::internal("state provider closed"))?;
252
253 let mut stream = reply_rx
254 .await
255 .map_err(|_| Status::internal("state provider reply dropped"))?
256 .map_err(|e| Status::internal(e.to_string()))?;
257
258 let mut ops = Vec::new();
260 while let Some(op) = stream.get_operation().await {
261 ops.push(RawOperation {
262 lsn: op.lsn,
263 data: op.data.to_vec(),
264 });
265 op.acknowledge();
266 }
267
268 info!(count = ops.len(), "GetCopyContext: sent context");
269 Ok(Response::new(GetCopyContextResponse { operations: ops }))
270 }
271
272 async fn copy_stream(
273 &self,
274 request: Request<Streaming<CopyItem>>,
275 ) -> Result<Response<CopyStreamResponse>, Status> {
276 let tx = self
278 .copy_stream_tx
279 .as_ref()
280 .and_then(|m| m.lock().unwrap().take())
281 .ok_or_else(|| {
282 Status::failed_precondition("copy stream not available or already used")
283 })?;
284
285 let mut inbound = request.into_inner();
286 let mut count: i64 = 0;
287
288 while let Some(result) = inbound.next().await {
289 match result {
290 Ok(item) => {
291 let op = Operation::new(item.lsn, bytes::Bytes::from(item.data), None);
292 if tx.send(op).await.is_err() {
293 warn!("copy stream receiver closed");
294 break;
295 }
296 count += 1;
297 }
298 Err(e) => {
299 warn!(error = %e, "copy stream error");
300 return Err(Status::internal(e.to_string()));
301 }
302 }
303 }
304
305 drop(tx);
307 info!(count, "CopyStream: received all copy data");
308
309 Ok(Response::new(CopyStreamResponse {
310 items_received: count,
311 }))
312 }
313}