Skip to main content

contextdb_server/
sync_server.rs

1use crate::protocol::{
2    MessageType, PullRequest, PullResponse, PushRequest, PushResponse, decode, encode,
3};
4use crate::subjects::{pull_subject, push_subject};
5use contextdb_engine::Database;
6use contextdb_engine::sync_types::ConflictPolicies;
7use futures_util::StreamExt;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Instant;
12
13/// Max concurrent in-flight chunked push sessions. Rejects new chunk_ids past this limit.
14const MAX_CHUNK_SESSIONS: usize = 64;
15
16async fn maybe_wait_for_test_push_barrier(row_count: usize) {
17    let Some(min_rows) = std::env::var("CONTEXTDB_TEST_PUSH_BARRIER_MIN_ROWS")
18        .ok()
19        .and_then(|value| value.parse::<usize>().ok())
20    else {
21        return;
22    };
23    if row_count < min_rows {
24        return;
25    }
26
27    let Some(barrier_path) = std::env::var_os("CONTEXTDB_TEST_PUSH_BARRIER_FILE") else {
28        return;
29    };
30    let Some(release_path) = std::env::var_os("CONTEXTDB_TEST_PUSH_RELEASE_FILE") else {
31        return;
32    };
33
34    let barrier_path = std::path::PathBuf::from(barrier_path);
35    let release_path = std::path::PathBuf::from(release_path);
36    let _ = std::fs::write(&barrier_path, b"push-handler-started");
37
38    while !release_path.exists() {
39        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
40    }
41}
42
43pub struct SyncServer {
44    db: Arc<Database>,
45    nats_url: String,
46    tenant_id: String,
47    policies: ConflictPolicies,
48    chunk_buffer:
49        std::sync::Mutex<HashMap<uuid::Uuid, (Instant, Vec<crate::protocol::ChunkMessage>)>>,
50}
51
52impl SyncServer {
53    pub fn new(
54        db: Arc<Database>,
55        nats_url: &str,
56        tenant_id: &str,
57        policies: ConflictPolicies,
58    ) -> Self {
59        assert!(
60            !tenant_id.is_empty()
61                && tenant_id
62                    .chars()
63                    .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
64            "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
65        );
66        Self {
67            db,
68            nats_url: nats_url.to_string(),
69            tenant_id: tenant_id.to_string(),
70            policies,
71            chunk_buffer: std::sync::Mutex::new(HashMap::new()),
72        }
73    }
74
75    pub fn db(&self) -> &Database {
76        &self.db
77    }
78
79    pub async fn run(&self) {
80        self.run_until(Arc::new(AtomicBool::new(false))).await;
81    }
82
83    pub async fn run_until(&self, shutdown: Arc<AtomicBool>) {
84        let client = loop {
85            if shutdown.load(Ordering::SeqCst) {
86                return;
87            }
88            match async_nats::connect(&self.nats_url).await {
89                Ok(client) => break client,
90                Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
91            }
92        };
93
94        let (mut push_sub, mut pull_sub) = loop {
95            if shutdown.load(Ordering::SeqCst) {
96                return;
97            }
98            match self.bootstrap_subscriptions(&client).await {
99                Ok(subscriptions) => break subscriptions,
100                Err(err) => {
101                    tracing::error!(
102                        error = %err,
103                        tenant_id = %self.tenant_id,
104                        "sync server bootstrap failed; retrying"
105                    );
106                    tokio::time::sleep(std::time::Duration::from_millis(200)).await;
107                }
108            }
109        };
110        let mut cleanup_interval = tokio::time::interval(std::time::Duration::from_secs(10));
111        let mut shutdown_poll = tokio::time::interval(std::time::Duration::from_millis(50));
112
113        while !shutdown.load(Ordering::SeqCst) {
114            tokio::select! {
115                maybe_msg = push_sub.next() => {
116                    if let Some(msg) = maybe_msg
117                        && let Err(e) = self.handle_push(&client, msg).await {
118                        tracing::error!(error = %e, "handle_push failed");
119                    }
120                }
121                maybe_msg = pull_sub.next() => {
122                    if let Some(msg) = maybe_msg
123                        && let Err(e) = self.handle_pull(&client, msg).await {
124                        tracing::error!(error = %e, "handle_pull failed");
125                    }
126                }
127                _ = cleanup_interval.tick() => {
128                    let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
129                    let before = buf.len();
130                    buf.retain(|id, (ts, _chunks)| {
131                        let keep = ts.elapsed() < std::time::Duration::from_secs(30);
132                        if !keep {
133                            tracing::warn!(%id, "evicting stale chunk session (30s TTL expired)");
134                        }
135                        keep
136                    });
137                    if buf.len() < before {
138                        tracing::info!(evicted = before - buf.len(), remaining = buf.len(), "chunk buffer cleanup");
139                    }
140                }
141                _ = shutdown_poll.tick() => {}
142            }
143        }
144    }
145
146    async fn bootstrap_subscriptions(
147        &self,
148        client: &async_nats::Client,
149    ) -> contextdb_core::Result<(async_nats::Subscriber, async_nats::Subscriber)> {
150        let push_sub = client
151            .subscribe(push_subject(&self.tenant_id))
152            .await
153            .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe push: {e}")))?;
154        let pull_sub = client
155            .subscribe(pull_subject(&self.tenant_id))
156            .await
157            .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe pull: {e}")))?;
158        client
159            .flush()
160            .await
161            .map_err(|e| contextdb_core::Error::SyncError(format!("flush subscriptions: {e}")))?;
162        Ok((push_sub, pull_sub))
163    }
164
165    async fn handle_push(
166        &self,
167        client: &async_nats::Client,
168        msg: async_nats::Message,
169    ) -> contextdb_core::Result<()> {
170        let envelope =
171            decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
172
173        match envelope.message_type {
174            MessageType::Chunk => {
175                let chunk_msg: crate::protocol::ChunkMessage =
176                    rmp_serde::from_slice(&envelope.payload)
177                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
178                let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
179                if !buf.contains_key(&chunk_msg.chunk_id) && buf.len() >= MAX_CHUNK_SESSIONS {
180                    tracing::warn!(
181                        chunk_id = %chunk_msg.chunk_id,
182                        active_sessions = buf.len(),
183                        "chunk buffer full, rejecting new chunk session"
184                    );
185                    return Err(contextdb_core::Error::SyncError(
186                        "chunk buffer full".to_string(),
187                    ));
188                }
189                buf.entry(chunk_msg.chunk_id)
190                    .or_insert_with(|| (std::time::Instant::now(), Vec::new()))
191                    .1
192                    .push(chunk_msg);
193                Ok(())
194            }
195            MessageType::ChunkAck => {
196                let ack: crate::protocol::ChunkAck = rmp_serde::from_slice(&envelope.payload)
197                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
198
199                tracing::info!(
200                    chunk_id = %ack.chunk_id,
201                    total_chunks = ack.total_chunks,
202                    "received ChunkAck, attempting reassembly"
203                );
204
205                let process_result: contextdb_core::Result<Vec<u8>> = (|| {
206                    let mut chunks = {
207                        let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
208                        let (_ts, chunks) = buf.remove(&ack.chunk_id).ok_or_else(|| {
209                            contextdb_core::Error::SyncError(format!(
210                                "no chunks buffered for chunk_id {}",
211                                ack.chunk_id
212                            ))
213                        })?;
214                        chunks
215                    };
216
217                    if chunks.len() != ack.total_chunks as usize {
218                        return Err(contextdb_core::Error::SyncError(format!(
219                            "expected {} chunks for {}, got {}",
220                            ack.total_chunks,
221                            ack.chunk_id,
222                            chunks.len()
223                        )));
224                    }
225
226                    let assembled = crate::chunking::reassemble(&mut chunks);
227                    let inner_envelope = decode(&assembled)
228                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
229                    let request: PushRequest = rmp_serde::from_slice(&inner_envelope.payload)
230                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
231
232                    match self
233                        .db
234                        .apply_changes(request.changeset.into(), &self.policies)
235                    {
236                        Ok(result) => {
237                            let response = PushResponse {
238                                result: Some(result.into()),
239                                error: None,
240                            };
241                            encode(MessageType::PushResponse, &response)
242                                .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
243                        }
244                        Err(err) => {
245                            let response = PushResponse {
246                                result: None,
247                                error: Some(err.to_string()),
248                            };
249                            encode(MessageType::PushResponse, &response)
250                                .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
251                        }
252                    }
253                })();
254
255                match process_result {
256                    Ok(payload) => {
257                        client
258                            .publish(ack.reply_inbox, payload.into())
259                            .await
260                            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
261                        Ok(())
262                    }
263                    Err(e) => {
264                        tracing::error!(
265                            chunk_id = %ack.chunk_id,
266                            error = %e,
267                            "chunked push processing failed, client will timeout"
268                        );
269                        Err(e)
270                    }
271                }
272            }
273            MessageType::PushRequest => {
274                let request: PushRequest = rmp_serde::from_slice(&envelope.payload)
275                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
276                maybe_wait_for_test_push_barrier(request.changeset.rows.len()).await;
277                let response = match self
278                    .db
279                    .apply_changes(request.changeset.into(), &self.policies)
280                {
281                    Ok(result) => PushResponse {
282                        result: Some(result.into()),
283                        error: None,
284                    },
285                    Err(err) => PushResponse {
286                        result: None,
287                        error: Some(err.to_string()),
288                    },
289                };
290                let payload = encode(MessageType::PushResponse, &response)
291                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
292
293                if let Some(reply) = msg.reply {
294                    client
295                        .publish(reply, payload.into())
296                        .await
297                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
298                    client
299                        .flush()
300                        .await
301                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
302                }
303                Ok(())
304            }
305            _ => Err(contextdb_core::Error::SyncError(
306                "unexpected message type on push subject".to_string(),
307            )),
308        }
309    }
310
311    async fn handle_pull(
312        &self,
313        client: &async_nats::Client,
314        msg: async_nats::Message,
315    ) -> contextdb_core::Result<()> {
316        let envelope =
317            decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
318        if !matches!(envelope.message_type, MessageType::PullRequest) {
319            return Err(contextdb_core::Error::SyncError(
320                "unexpected message type on pull subject".to_string(),
321            ));
322        }
323
324        let request: PullRequest = rmp_serde::from_slice(&envelope.payload)
325            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
326        let mut changes = self.db.changes_since(request.since_lsn);
327
328        let mut has_more = false;
329        let mut cursor = None;
330        if let Some(max_entries) = request.max_entries {
331            let max = max_entries as usize;
332            if changes.rows.len() > max {
333                let mut remainder = changes.rows.split_off(max);
334                // Don't split in the middle of a same-LSN group: extend returned set
335                // to include all rows sharing the LSN at the split boundary.
336                if let (Some(last_returned), Some(first_remainder)) =
337                    (changes.rows.last(), remainder.first())
338                    && last_returned.lsn == first_remainder.lsn
339                {
340                    let boundary_lsn = last_returned.lsn;
341                    let split_idx = remainder.partition_point(|r| r.lsn == boundary_lsn);
342                    let moved: Vec<_> = remainder.drain(..split_idx).collect();
343                    changes.rows.extend(moved);
344                }
345                // Cursor = last LSN of returned set (not remainder)
346                cursor = changes.rows.last().map(|r| r.lsn);
347                has_more = !remainder.is_empty();
348            }
349        }
350
351        let response = PullResponse {
352            changeset: changes.into(),
353            has_more,
354            cursor,
355        };
356        let payload = encode(MessageType::PullResponse, &response)
357            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
358
359        if let Some(reply) = msg.reply {
360            if crate::chunking::needs_chunking(&payload) {
361                let chunks = crate::chunking::chunk(&payload);
362                tracing::info!(
363                    total_chunks = chunks.len(),
364                    payload_size = payload.len(),
365                    "sending chunked pull response"
366                );
367                for chunk_msg in &chunks {
368                    let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
369                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
370                    client
371                        .publish(reply.clone(), chunk_encoded.into())
372                        .await
373                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
374                }
375                client
376                    .flush()
377                    .await
378                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
379            } else {
380                client
381                    .publish(reply, payload.into())
382                    .await
383                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
384            }
385        }
386        Ok(())
387    }
388}