Skip to main content

syncular_testkit/
transport.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::sync::{Arc, Mutex};
3use std::thread;
4use std::time::Duration;
5
6use serde_json::{Map, Value};
7use syncular_runtime::binary_snapshot::SnapshotChunkRows;
8use syncular_runtime::error::{ErrorKind, Result, SyncularError};
9use syncular_runtime::protocol::{
10    BlobRef, CombinedRequest, CombinedResponse, OperationResult, PullResponse, PushBatchResponse,
11    PushCommitRequest, PushCommitResponse, ScopeValues, ScopedSnapshotArtifactRef,
12    SnapshotChunkRef, SubscriptionResponse,
13};
14use syncular_runtime::transport::{
15    BlobTransport, RealtimeEvent, RealtimeTransport, SyncAuthHeaderStore, SyncAuthHeaders,
16    SyncTransport,
17};
18
19#[derive(Debug, Clone)]
20pub struct SnapshotChunkFetch {
21    pub chunk: SnapshotChunkRef,
22    pub scopes: ScopeValues,
23}
24
25#[derive(Debug, Clone)]
26pub struct SnapshotArtifactFetch {
27    pub artifact: ScopedSnapshotArtifactRef,
28    pub scopes: ScopeValues,
29}
30
31#[derive(Debug, Clone)]
32pub struct BlobUploadRecord {
33    pub blob: BlobRef,
34    pub bytes: Vec<u8>,
35}
36
37type HttpResponseFn = dyn Fn(&CombinedRequest) -> Result<CombinedResponse> + Send + Sync + 'static;
38
39enum QueuedHttpResponse {
40    Static(CombinedResponse),
41    Dynamic(Box<HttpResponseFn>),
42}
43
44#[derive(Default)]
45struct TestTransportState {
46    requests: Vec<CombinedRequest>,
47    ws_pushes: Vec<PushCommitRequest>,
48    chunk_fetches: Vec<SnapshotChunkFetch>,
49    artifact_fetches: Vec<SnapshotArtifactFetch>,
50    auth_headers: Vec<SyncAuthHeaders>,
51    realtime_events: VecDeque<RealtimeEvent>,
52    http_responses: VecDeque<QueuedHttpResponse>,
53    ws_push_responses: VecDeque<PushCommitResponse>,
54    chunk_rows: VecDeque<SnapshotChunkRows>,
55    artifact_bytes: VecDeque<Vec<u8>>,
56    blob_uploads: Vec<BlobUploadRecord>,
57    blobs: BTreeMap<String, Vec<u8>>,
58    closed_realtime_count: usize,
59}
60
61#[derive(Clone, Default)]
62pub struct TestTransport {
63    state: Arc<Mutex<TestTransportState>>,
64}
65
66#[derive(Clone)]
67pub struct TestTransportHandle {
68    state: Arc<Mutex<TestTransportState>>,
69}
70
71#[derive(Clone)]
72pub struct TestRealtime {
73    state: Arc<Mutex<TestTransportState>>,
74}
75
76impl TestTransport {
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    pub fn handle(&self) -> TestTransportHandle {
82        TestTransportHandle {
83            state: self.state.clone(),
84        }
85    }
86
87    pub fn push_http_response(&self, response: CombinedResponse) {
88        self.state
89            .lock()
90            .expect("test transport state")
91            .http_responses
92            .push_back(QueuedHttpResponse::Static(response));
93    }
94
95    pub fn push_http_response_fn<F>(&self, response_fn: F)
96    where
97        F: Fn(&CombinedRequest) -> Result<CombinedResponse> + Send + Sync + 'static,
98    {
99        self.state
100            .lock()
101            .expect("test transport state")
102            .http_responses
103            .push_back(QueuedHttpResponse::Dynamic(Box::new(response_fn)));
104    }
105
106    pub fn push_ws_push_response(&self, response: PushCommitResponse) {
107        self.state
108            .lock()
109            .expect("test transport state")
110            .ws_push_responses
111            .push_back(response);
112    }
113
114    pub fn push_realtime_event(&self, event: RealtimeEvent) {
115        self.state
116            .lock()
117            .expect("test transport state")
118            .realtime_events
119            .push_back(event);
120    }
121
122    pub fn push_snapshot_chunk_rows(&self, rows: Vec<Value>) {
123        self.state
124            .lock()
125            .expect("test transport state")
126            .chunk_rows
127            .push_back(SnapshotChunkRows::Json(rows));
128    }
129
130    pub fn push_snapshot_artifact_bytes(&self, bytes: Vec<u8>) {
131        self.state
132            .lock()
133            .expect("test transport state")
134            .artifact_bytes
135            .push_back(bytes);
136    }
137
138    pub fn seed_blob(&self, blob: &BlobRef, bytes: Vec<u8>) {
139        self.state
140            .lock()
141            .expect("test transport state")
142            .blobs
143            .insert(blob.hash.clone(), bytes);
144    }
145}
146
147impl TestTransportHandle {
148    pub fn requests(&self) -> Vec<CombinedRequest> {
149        self.state
150            .lock()
151            .expect("test transport state")
152            .requests
153            .clone()
154    }
155
156    pub fn request_count(&self) -> usize {
157        self.state
158            .lock()
159            .expect("test transport state")
160            .requests
161            .len()
162    }
163
164    pub fn last_request(&self) -> Option<CombinedRequest> {
165        self.state
166            .lock()
167            .expect("test transport state")
168            .requests
169            .last()
170            .cloned()
171    }
172
173    pub fn ws_pushes(&self) -> Vec<PushCommitRequest> {
174        self.state
175            .lock()
176            .expect("test transport state")
177            .ws_pushes
178            .clone()
179    }
180
181    pub fn chunk_fetches(&self) -> Vec<SnapshotChunkFetch> {
182        self.state
183            .lock()
184            .expect("test transport state")
185            .chunk_fetches
186            .clone()
187    }
188
189    pub fn artifact_fetches(&self) -> Vec<SnapshotArtifactFetch> {
190        self.state
191            .lock()
192            .expect("test transport state")
193            .artifact_fetches
194            .clone()
195    }
196
197    pub fn auth_headers(&self) -> Vec<SyncAuthHeaders> {
198        self.state
199            .lock()
200            .expect("test transport state")
201            .auth_headers
202            .clone()
203    }
204
205    pub fn blob_uploads(&self) -> Vec<BlobUploadRecord> {
206        self.state
207            .lock()
208            .expect("test transport state")
209            .blob_uploads
210            .clone()
211    }
212
213    pub fn closed_realtime_count(&self) -> usize {
214        self.state
215            .lock()
216            .expect("test transport state")
217            .closed_realtime_count
218    }
219}
220
221impl SyncAuthHeaderStore for TestTransport {
222    fn set_auth_headers(&mut self, headers: SyncAuthHeaders) {
223        self.state
224            .lock()
225            .expect("test transport state")
226            .auth_headers
227            .push(headers);
228    }
229}
230
231impl SyncTransport for TestTransport {
232    type Realtime = TestRealtime;
233
234    fn post_sync(&self, request: &CombinedRequest) -> Result<CombinedResponse> {
235        let response = {
236            let mut state = self.state.lock().expect("test transport state");
237            state.requests.push(request.clone());
238            state.http_responses.pop_front()
239        };
240        if let Some(response) = response {
241            return match response {
242                QueuedHttpResponse::Static(response) => Ok(response),
243                QueuedHttpResponse::Dynamic(response_fn) => response_fn(request),
244            };
245        }
246        Ok(default_combined_response(request))
247    }
248
249    fn fetch_snapshot_chunk_rows(
250        &self,
251        chunk: &SnapshotChunkRef,
252        scopes: &Map<String, Value>,
253    ) -> Result<SnapshotChunkRows> {
254        let mut state = self.state.lock().expect("test transport state");
255        state.chunk_fetches.push(SnapshotChunkFetch {
256            chunk: chunk.clone(),
257            scopes: scopes.clone(),
258        });
259        Ok(state
260            .chunk_rows
261            .pop_front()
262            .unwrap_or_else(|| SnapshotChunkRows::Json(Vec::new())))
263    }
264
265    fn fetch_snapshot_artifact_bytes(
266        &self,
267        artifact: &ScopedSnapshotArtifactRef,
268        scopes: &Map<String, Value>,
269    ) -> Result<Vec<u8>> {
270        let mut state = self.state.lock().expect("test transport state");
271        state.artifact_fetches.push(SnapshotArtifactFetch {
272            artifact: artifact.clone(),
273            scopes: scopes.clone(),
274        });
275        state.artifact_bytes.pop_front().ok_or_else(|| {
276            SyncularError::protocol_message("no snapshot artifact bytes queued in TestTransport")
277        })
278    }
279
280    fn connect_realtime(&self) -> Result<Self::Realtime> {
281        Ok(TestRealtime {
282            state: self.state.clone(),
283        })
284    }
285}
286
287impl BlobTransport for TestTransport {
288    fn upload_blob(&self, blob: &BlobRef, bytes: &[u8]) -> Result<()> {
289        let mut state = self.state.lock().expect("test transport state");
290        state.blob_uploads.push(BlobUploadRecord {
291            blob: blob.clone(),
292            bytes: bytes.to_vec(),
293        });
294        state.blobs.insert(blob.hash.clone(), bytes.to_vec());
295        Ok(())
296    }
297
298    fn download_blob(&self, blob: &BlobRef) -> Result<Vec<u8>> {
299        self.state
300            .lock()
301            .expect("test transport state")
302            .blobs
303            .get(&blob.hash)
304            .cloned()
305            .ok_or_else(|| {
306                SyncularError::message(
307                    ErrorKind::Transport,
308                    format!("test blob not found: {}", blob.hash),
309                )
310            })
311    }
312}
313
314impl RealtimeTransport for TestRealtime {
315    fn push_commit(&mut self, commit: PushCommitRequest) -> Result<PushCommitResponse> {
316        let mut state = self.state.lock().expect("test transport state");
317        let response = state
318            .ws_push_responses
319            .pop_front()
320            .unwrap_or_else(|| default_push_commit_response(&commit));
321        state.ws_pushes.push(commit);
322        Ok(response)
323    }
324
325    fn read_event(&mut self) -> Result<Option<RealtimeEvent>> {
326        Ok(self
327            .state
328            .lock()
329            .expect("test transport state")
330            .realtime_events
331            .pop_front())
332    }
333
334    fn close(&mut self) {
335        self.state
336            .lock()
337            .expect("test transport state")
338            .closed_realtime_count += 1;
339    }
340}
341
342pub fn empty_combined_response() -> CombinedResponse {
343    CombinedResponse {
344        ok: true,
345        required_schema_version: None,
346        latest_schema_version: None,
347        push: None,
348        pull: Some(PullResponse {
349            ok: true,
350            subscriptions: Vec::new(),
351        }),
352    }
353}
354
355pub fn default_combined_response(request: &CombinedRequest) -> CombinedResponse {
356    CombinedResponse {
357        ok: true,
358        required_schema_version: None,
359        latest_schema_version: None,
360        push: request.push.as_ref().map(|push| PushBatchResponse {
361            ok: true,
362            commits: push
363                .commits
364                .iter()
365                .map(default_push_commit_response)
366                .collect(),
367        }),
368        pull: request.pull.as_ref().map(|pull| PullResponse {
369            ok: true,
370            subscriptions: pull
371                .subscriptions
372                .iter()
373                .map(|subscription| SubscriptionResponse {
374                    id: subscription.id.clone(),
375                    status: "active".to_string(),
376                    scopes: subscription.scopes.clone(),
377                    bootstrap: false,
378                    bootstrap_state: None,
379                    next_cursor: subscription.cursor.max(0),
380                    integrity: None,
381                    commits: Vec::new(),
382                    snapshots: None,
383                })
384                .collect(),
385        }),
386    }
387}
388
389pub fn default_push_commit_response(commit: &PushCommitRequest) -> PushCommitResponse {
390    PushCommitResponse {
391        client_commit_id: commit.client_commit_id.clone(),
392        status: "applied".to_string(),
393        commit_seq: Some(1),
394        results: commit
395            .operations
396            .iter()
397            .enumerate()
398            .map(|(index, _)| OperationResult {
399                op_index: index as i32,
400                status: "applied".to_string(),
401                message: None,
402                error: None,
403                code: None,
404                retriable: None,
405                server_version: Some(1),
406                server_row: None,
407            })
408            .collect(),
409    }
410}
411
412#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum FaultPhase {
414    Before,
415    After,
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq)]
419pub enum FaultOperation {
420    AnySync,
421    Push,
422    Pull,
423    SnapshotChunk,
424    RealtimeConnect,
425}
426
427#[derive(Debug, Clone, PartialEq, Eq)]
428pub enum FaultAction {
429    Fail { message: String },
430    Delay { duration: Duration },
431}
432
433#[derive(Debug, Clone, PartialEq, Eq)]
434pub struct FaultStep {
435    pub phase: FaultPhase,
436    pub operation: FaultOperation,
437    pub action: FaultAction,
438    pub remaining: usize,
439}
440
441impl FaultStep {
442    pub fn fail(phase: FaultPhase, operation: FaultOperation, message: impl Into<String>) -> Self {
443        Self {
444            phase,
445            operation,
446            action: FaultAction::Fail {
447                message: message.into(),
448            },
449            remaining: 1,
450        }
451    }
452
453    pub fn delay(phase: FaultPhase, operation: FaultOperation, duration: Duration) -> Self {
454        Self {
455            phase,
456            operation,
457            action: FaultAction::Delay { duration },
458            remaining: 1,
459        }
460    }
461
462    pub fn repeat(mut self, remaining: usize) -> Self {
463        self.remaining = remaining;
464        self
465    }
466}
467
468#[derive(Debug, Default)]
469struct FaultState {
470    steps: VecDeque<FaultStep>,
471    failures: usize,
472    delays: usize,
473}
474
475#[derive(Debug, Clone)]
476pub struct FaultTransport<T> {
477    inner: T,
478    state: Arc<Mutex<FaultState>>,
479}
480
481#[derive(Debug, Clone)]
482pub struct FaultHandle {
483    state: Arc<Mutex<FaultState>>,
484}
485
486impl<T> FaultTransport<T> {
487    pub fn new(inner: T, steps: impl IntoIterator<Item = FaultStep>) -> Self {
488        Self {
489            inner,
490            state: Arc::new(Mutex::new(FaultState {
491                steps: steps.into_iter().collect(),
492                failures: 0,
493                delays: 0,
494            })),
495        }
496    }
497
498    pub fn handle(&self) -> FaultHandle {
499        FaultHandle {
500            state: self.state.clone(),
501        }
502    }
503
504    pub fn into_inner(self) -> T {
505        self.inner
506    }
507}
508
509impl FaultHandle {
510    pub fn failures(&self) -> usize {
511        self.state.lock().expect("fault state").failures
512    }
513
514    pub fn delays(&self) -> usize {
515        self.state.lock().expect("fault state").delays
516    }
517
518    pub fn remaining_steps(&self) -> usize {
519        self.state.lock().expect("fault state").steps.len()
520    }
521}
522
523impl<T> SyncTransport for FaultTransport<T>
524where
525    T: SyncTransport,
526{
527    type Realtime = T::Realtime;
528
529    fn post_sync(&self, request: &CombinedRequest) -> Result<CombinedResponse> {
530        let operation = request_fault_operation(request);
531        apply_fault(&self.state, FaultPhase::Before, operation)?;
532        let response = self.inner.post_sync(request);
533        apply_fault(&self.state, FaultPhase::After, operation)?;
534        response
535    }
536
537    fn fetch_snapshot_chunk_rows(
538        &self,
539        chunk: &SnapshotChunkRef,
540        scopes: &Map<String, Value>,
541    ) -> Result<SnapshotChunkRows> {
542        apply_fault(
543            &self.state,
544            FaultPhase::Before,
545            FaultOperation::SnapshotChunk,
546        )?;
547        let rows = self.inner.fetch_snapshot_chunk_rows(chunk, scopes);
548        apply_fault(
549            &self.state,
550            FaultPhase::After,
551            FaultOperation::SnapshotChunk,
552        )?;
553        rows
554    }
555
556    fn connect_realtime(&self) -> Result<Self::Realtime> {
557        apply_fault(
558            &self.state,
559            FaultPhase::Before,
560            FaultOperation::RealtimeConnect,
561        )?;
562        let realtime = self.inner.connect_realtime();
563        apply_fault(
564            &self.state,
565            FaultPhase::After,
566            FaultOperation::RealtimeConnect,
567        )?;
568        realtime
569    }
570}
571
572impl<T> BlobTransport for FaultTransport<T>
573where
574    T: BlobTransport,
575{
576    fn upload_blob(&self, blob: &BlobRef, bytes: &[u8]) -> Result<()> {
577        self.inner.upload_blob(blob, bytes)
578    }
579
580    fn download_blob(&self, blob: &BlobRef) -> Result<Vec<u8>> {
581        self.inner.download_blob(blob)
582    }
583}
584
585impl<T> SyncAuthHeaderStore for FaultTransport<T>
586where
587    T: SyncAuthHeaderStore,
588{
589    fn set_auth_headers(&mut self, headers: SyncAuthHeaders) {
590        self.inner.set_auth_headers(headers);
591    }
592}
593
594fn request_fault_operation(request: &CombinedRequest) -> FaultOperation {
595    if request.push.is_some() {
596        FaultOperation::Push
597    } else if request.pull.is_some() {
598        FaultOperation::Pull
599    } else {
600        FaultOperation::AnySync
601    }
602}
603
604fn operation_matches(expected: FaultOperation, actual: FaultOperation) -> bool {
605    expected == FaultOperation::AnySync || expected == actual
606}
607
608fn apply_fault(
609    state: &Arc<Mutex<FaultState>>,
610    phase: FaultPhase,
611    operation: FaultOperation,
612) -> Result<()> {
613    let action =
614        {
615            let mut state = state.lock().expect("fault state");
616            let Some(index) = state.steps.iter().position(|step| {
617                step.phase == phase && operation_matches(step.operation, operation)
618            }) else {
619                return Ok(());
620            };
621
622            let mut step = state.steps.remove(index).expect("fault step");
623            let action = step.action.clone();
624            if step.remaining > 1 {
625                step.remaining -= 1;
626                state.steps.insert(index, step);
627            }
628            match action {
629                FaultAction::Fail { .. } => state.failures += 1,
630                FaultAction::Delay { .. } => state.delays += 1,
631            }
632            action
633        };
634
635    match action {
636        FaultAction::Fail { message } => Err(SyncularError::message(ErrorKind::Transport, message)),
637        FaultAction::Delay { duration } => {
638            thread::sleep(duration);
639            Ok(())
640        }
641    }
642}