1use crate::protocol::{
2 MessageType, PullRequest, PullResponse, PushRequest, PushResponse, decode, encode,
3};
4use crate::subjects::{pull_subject, push_subject};
5use contextdb_engine::Database;
6use contextdb_engine::sync_types::ConflictPolicies;
7use futures_util::StreamExt;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Instant;
12
13const MAX_CHUNK_SESSIONS: usize = 64;
15
16async fn maybe_wait_for_test_push_barrier(row_count: usize) {
17 let Some(min_rows) = std::env::var("CONTEXTDB_TEST_PUSH_BARRIER_MIN_ROWS")
18 .ok()
19 .and_then(|value| value.parse::<usize>().ok())
20 else {
21 return;
22 };
23 if row_count < min_rows {
24 return;
25 }
26
27 let Some(barrier_path) = std::env::var_os("CONTEXTDB_TEST_PUSH_BARRIER_FILE") else {
28 return;
29 };
30 let Some(release_path) = std::env::var_os("CONTEXTDB_TEST_PUSH_RELEASE_FILE") else {
31 return;
32 };
33
34 let barrier_path = std::path::PathBuf::from(barrier_path);
35 let release_path = std::path::PathBuf::from(release_path);
36 let _ = std::fs::write(&barrier_path, b"push-handler-started");
37
38 while !release_path.exists() {
39 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
40 }
41}
42
43pub struct SyncServer {
44 db: Arc<Database>,
45 nats_url: String,
46 tenant_id: String,
47 policies: ConflictPolicies,
48 chunk_buffer:
49 std::sync::Mutex<HashMap<uuid::Uuid, (Instant, Vec<crate::protocol::ChunkMessage>)>>,
50}
51
52impl SyncServer {
53 pub fn new(
54 db: Arc<Database>,
55 nats_url: &str,
56 tenant_id: &str,
57 policies: ConflictPolicies,
58 ) -> Self {
59 assert!(
60 !tenant_id.is_empty()
61 && tenant_id
62 .chars()
63 .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
64 "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
65 );
66 Self {
67 db,
68 nats_url: nats_url.to_string(),
69 tenant_id: tenant_id.to_string(),
70 policies,
71 chunk_buffer: std::sync::Mutex::new(HashMap::new()),
72 }
73 }
74
75 pub fn db(&self) -> &Database {
76 &self.db
77 }
78
79 pub async fn run(&self) {
80 self.run_until(Arc::new(AtomicBool::new(false))).await;
81 }
82
83 pub async fn run_until(&self, shutdown: Arc<AtomicBool>) {
84 let client = loop {
85 if shutdown.load(Ordering::SeqCst) {
86 return;
87 }
88 match async_nats::connect(&self.nats_url).await {
89 Ok(client) => break client,
90 Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
91 }
92 };
93
94 let (mut push_sub, mut pull_sub) = loop {
95 if shutdown.load(Ordering::SeqCst) {
96 return;
97 }
98 match self.bootstrap_subscriptions(&client).await {
99 Ok(subscriptions) => break subscriptions,
100 Err(err) => {
101 tracing::error!(
102 error = %err,
103 tenant_id = %self.tenant_id,
104 "sync server bootstrap failed; retrying"
105 );
106 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
107 }
108 }
109 };
110 let mut cleanup_interval = tokio::time::interval(std::time::Duration::from_secs(10));
111 let mut shutdown_poll = tokio::time::interval(std::time::Duration::from_millis(50));
112
113 while !shutdown.load(Ordering::SeqCst) {
114 tokio::select! {
115 maybe_msg = push_sub.next() => {
116 if let Some(msg) = maybe_msg
117 && let Err(e) = self.handle_push(&client, msg).await {
118 tracing::error!(error = %e, "handle_push failed");
119 }
120 }
121 maybe_msg = pull_sub.next() => {
122 if let Some(msg) = maybe_msg
123 && let Err(e) = self.handle_pull(&client, msg).await {
124 tracing::error!(error = %e, "handle_pull failed");
125 }
126 }
127 _ = cleanup_interval.tick() => {
128 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
129 let before = buf.len();
130 buf.retain(|id, (ts, _chunks)| {
131 let keep = ts.elapsed() < std::time::Duration::from_secs(30);
132 if !keep {
133 tracing::warn!(%id, "evicting stale chunk session (30s TTL expired)");
134 }
135 keep
136 });
137 if buf.len() < before {
138 tracing::info!(evicted = before - buf.len(), remaining = buf.len(), "chunk buffer cleanup");
139 }
140 }
141 _ = shutdown_poll.tick() => {}
142 }
143 }
144 }
145
146 async fn bootstrap_subscriptions(
147 &self,
148 client: &async_nats::Client,
149 ) -> contextdb_core::Result<(async_nats::Subscriber, async_nats::Subscriber)> {
150 let push_sub = client
151 .subscribe(push_subject(&self.tenant_id))
152 .await
153 .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe push: {e}")))?;
154 let pull_sub = client
155 .subscribe(pull_subject(&self.tenant_id))
156 .await
157 .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe pull: {e}")))?;
158 client
159 .flush()
160 .await
161 .map_err(|e| contextdb_core::Error::SyncError(format!("flush subscriptions: {e}")))?;
162 Ok((push_sub, pull_sub))
163 }
164
165 async fn handle_push(
166 &self,
167 client: &async_nats::Client,
168 msg: async_nats::Message,
169 ) -> contextdb_core::Result<()> {
170 let envelope =
171 decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
172
173 match envelope.message_type {
174 MessageType::Chunk => {
175 let chunk_msg: crate::protocol::ChunkMessage =
176 rmp_serde::from_slice(&envelope.payload)
177 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
178 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
179 if !buf.contains_key(&chunk_msg.chunk_id) && buf.len() >= MAX_CHUNK_SESSIONS {
180 tracing::warn!(
181 chunk_id = %chunk_msg.chunk_id,
182 active_sessions = buf.len(),
183 "chunk buffer full, rejecting new chunk session"
184 );
185 return Err(contextdb_core::Error::SyncError(
186 "chunk buffer full".to_string(),
187 ));
188 }
189 buf.entry(chunk_msg.chunk_id)
190 .or_insert_with(|| (std::time::Instant::now(), Vec::new()))
191 .1
192 .push(chunk_msg);
193 Ok(())
194 }
195 MessageType::ChunkAck => {
196 let ack: crate::protocol::ChunkAck = rmp_serde::from_slice(&envelope.payload)
197 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
198
199 tracing::info!(
200 chunk_id = %ack.chunk_id,
201 total_chunks = ack.total_chunks,
202 "received ChunkAck, attempting reassembly"
203 );
204
205 let process_result: contextdb_core::Result<Vec<u8>> = (|| {
206 let mut chunks = {
207 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
208 let (_ts, chunks) = buf.remove(&ack.chunk_id).ok_or_else(|| {
209 contextdb_core::Error::SyncError(format!(
210 "no chunks buffered for chunk_id {}",
211 ack.chunk_id
212 ))
213 })?;
214 chunks
215 };
216
217 if chunks.len() != ack.total_chunks as usize {
218 return Err(contextdb_core::Error::SyncError(format!(
219 "expected {} chunks for {}, got {}",
220 ack.total_chunks,
221 ack.chunk_id,
222 chunks.len()
223 )));
224 }
225
226 let assembled = crate::chunking::reassemble(&mut chunks);
227 let inner_envelope = decode(&assembled)
228 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
229 let request: PushRequest = rmp_serde::from_slice(&inner_envelope.payload)
230 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
231
232 match self
233 .db
234 .apply_changes(request.changeset.into(), &self.policies)
235 {
236 Ok(result) => {
237 let response = PushResponse {
238 result: Some(result.into()),
239 error: None,
240 };
241 encode(MessageType::PushResponse, &response)
242 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
243 }
244 Err(err) => {
245 let response = PushResponse {
246 result: None,
247 error: Some(err.to_string()),
248 };
249 encode(MessageType::PushResponse, &response)
250 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
251 }
252 }
253 })();
254
255 match process_result {
256 Ok(payload) => {
257 client
258 .publish(ack.reply_inbox, payload.into())
259 .await
260 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
261 Ok(())
262 }
263 Err(e) => {
264 tracing::error!(
265 chunk_id = %ack.chunk_id,
266 error = %e,
267 "chunked push processing failed, client will timeout"
268 );
269 Err(e)
270 }
271 }
272 }
273 MessageType::PushRequest => {
274 let request: PushRequest = rmp_serde::from_slice(&envelope.payload)
275 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
276 maybe_wait_for_test_push_barrier(request.changeset.rows.len()).await;
277 let response = match self
278 .db
279 .apply_changes(request.changeset.into(), &self.policies)
280 {
281 Ok(result) => PushResponse {
282 result: Some(result.into()),
283 error: None,
284 },
285 Err(err) => PushResponse {
286 result: None,
287 error: Some(err.to_string()),
288 },
289 };
290 let payload = encode(MessageType::PushResponse, &response)
291 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
292
293 if let Some(reply) = msg.reply {
294 client
295 .publish(reply, payload.into())
296 .await
297 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
298 client
299 .flush()
300 .await
301 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
302 }
303 Ok(())
304 }
305 _ => Err(contextdb_core::Error::SyncError(
306 "unexpected message type on push subject".to_string(),
307 )),
308 }
309 }
310
311 async fn handle_pull(
312 &self,
313 client: &async_nats::Client,
314 msg: async_nats::Message,
315 ) -> contextdb_core::Result<()> {
316 let envelope =
317 decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
318 if !matches!(envelope.message_type, MessageType::PullRequest) {
319 return Err(contextdb_core::Error::SyncError(
320 "unexpected message type on pull subject".to_string(),
321 ));
322 }
323
324 let request: PullRequest = rmp_serde::from_slice(&envelope.payload)
325 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
326 let mut changes = self.db.changes_since(request.since_lsn);
327
328 let mut has_more = false;
329 let mut cursor = None;
330 if let Some(max_entries) = request.max_entries {
331 let max = max_entries as usize;
332 if changes.rows.len() > max {
333 let mut remainder = changes.rows.split_off(max);
334 if let (Some(last_returned), Some(first_remainder)) =
337 (changes.rows.last(), remainder.first())
338 && last_returned.lsn == first_remainder.lsn
339 {
340 let boundary_lsn = last_returned.lsn;
341 let split_idx = remainder.partition_point(|r| r.lsn == boundary_lsn);
342 let moved: Vec<_> = remainder.drain(..split_idx).collect();
343 changes.rows.extend(moved);
344 }
345 cursor = changes.rows.last().map(|r| r.lsn);
347 has_more = !remainder.is_empty();
348 }
349 }
350
351 let response = PullResponse {
352 changeset: changes.into(),
353 has_more,
354 cursor,
355 };
356 let payload = encode(MessageType::PullResponse, &response)
357 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
358
359 if let Some(reply) = msg.reply {
360 if crate::chunking::needs_chunking(&payload) {
361 let chunks = crate::chunking::chunk(&payload);
362 tracing::info!(
363 total_chunks = chunks.len(),
364 payload_size = payload.len(),
365 "sending chunked pull response"
366 );
367 for chunk_msg in &chunks {
368 let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
369 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
370 client
371 .publish(reply.clone(), chunk_encoded.into())
372 .await
373 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
374 }
375 client
376 .flush()
377 .await
378 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
379 } else {
380 client
381 .publish(reply, payload.into())
382 .await
383 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
384 }
385 }
386 Ok(())
387 }
388}