1use crate::protocol::{
2 ChunkAck, MessageType, PullRequest, PullResponse, PushRequest, PushResponse, WireChangeSet,
3 WireRowChange, decode, encode,
4};
5use crate::subjects::{pull_subject, push_subject};
6use contextdb_core::{AtomicLsn, Error, Lsn};
7use contextdb_engine::Database;
8use contextdb_engine::sync_types::{
9 ApplyResult, ChangeSet, ConflictPolicies, ConflictPolicy, SyncDirection,
10};
11use futures_util::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::Ordering;
15use std::time::Duration;
16
17const SYNC_TIMEOUT: Duration = Duration::from_secs(60);
18const CHUNK_COLLECT_TIMEOUT: Duration = Duration::from_secs(60);
20const PUSH_REQUEST_TIMEOUT: Duration = Duration::from_secs(4);
21const PULL_PAGE_SIZE: u32 = 500;
22const MAX_BATCH_BYTES: usize = 800 * 1024;
23const BATCH_ESTIMATE_SAFETY_MARGIN: usize = 32 * 1024;
24const TARGET_BATCH_BYTES: usize = MAX_BATCH_BYTES - BATCH_ESTIMATE_SAFETY_MARGIN;
25
26pub struct SyncClient {
27 db: Arc<Database>,
28 nats: tokio::sync::Mutex<Option<async_nats::Client>>,
29 nats_url: String,
30 tenant_id: String,
31 push_watermark: AtomicLsn,
32 pull_watermark: AtomicLsn,
33 table_directions: std::sync::RwLock<HashMap<String, SyncDirection>>,
34 conflict_policies: std::sync::RwLock<ConflictPolicies>,
35}
36
37impl SyncClient {
38 pub fn new(db: Arc<Database>, nats_url: &str, tenant_id: &str) -> Self {
39 assert!(
40 !tenant_id.is_empty()
41 && tenant_id
42 .chars()
43 .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
44 "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
45 );
46 let (push_watermark, pull_watermark) = db
47 .persisted_sync_watermarks(tenant_id)
48 .unwrap_or_else(|err| {
49 tracing::warn!(%tenant_id, error = %err, "failed to load persisted sync watermarks");
50 (Lsn(0), Lsn(0))
51 });
52 Self {
53 db,
54 nats: tokio::sync::Mutex::new(None),
55 nats_url: nats_url.to_string(),
56 tenant_id: tenant_id.to_string(),
57 push_watermark: AtomicLsn::new(push_watermark),
58 pull_watermark: AtomicLsn::new(pull_watermark),
59 table_directions: std::sync::RwLock::new(HashMap::new()),
60 conflict_policies: std::sync::RwLock::new(ConflictPolicies {
61 per_table: HashMap::new(),
62 default: ConflictPolicy::ServerWins,
63 }),
64 }
65 }
66
67 pub async fn ensure_connected(&self) -> Result<async_nats::Client, String> {
71 let mut guard = self.nats.lock().await;
72 if guard.is_none() {
73 let mut last_err = None;
74 for attempt in 0..10u32 {
75 if attempt > 0 {
76 tokio::time::sleep(Duration::from_millis(200 * u64::from(attempt))).await;
77 }
78 match async_nats::connect(&self.nats_url).await {
79 Ok(client) => {
80 *guard = Some(client);
81 break;
82 }
83 Err(e) => last_err = Some(e.to_string()),
84 }
85 }
86 if guard.is_none() {
87 let err = last_err.unwrap_or_else(|| "unknown error".to_string());
88 return Err(format!(
89 "cannot connect to NATS at {}: {err}",
90 self.nats_url
91 ));
92 }
93 }
94 guard
95 .clone()
96 .ok_or_else(|| format!("cannot connect to NATS at {}", self.nats_url))
97 }
98
99 pub async fn reconnect(&self) {
101 let mut guard = self.nats.lock().await;
102 *guard = None;
103 *guard = async_nats::connect(&self.nats_url).await.ok();
104 }
105
106 pub async fn is_connected(&self) -> bool {
107 let guard = self.nats.lock().await;
108 guard.is_some()
109 }
110
111 pub fn db(&self) -> &Database {
112 &self.db
113 }
114
115 pub fn has_pending_push_changes(&self) -> Result<bool, Error> {
116 let since = self.push_watermark.load(Ordering::SeqCst);
117 let directions = self.table_directions()?;
118 let changes = self
119 .db
120 .changes_since(since)
121 .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
122 Ok(!changes.rows.is_empty()
123 || !changes.edges.is_empty()
124 || !changes.vectors.is_empty()
125 || !changes.ddl.is_empty())
126 }
127
128 pub async fn push(&self) -> Result<ApplyResult, Error> {
129 let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
131
132 let since = self.push_watermark.load(Ordering::SeqCst);
133 let directions = self.table_directions()?;
135 let changeset = self
136 .db
137 .changes_since(since)
138 .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
139
140 if changeset.rows.is_empty()
141 && changeset.edges.is_empty()
142 && changeset.vectors.is_empty()
143 && changeset.ddl.is_empty()
144 {
145 return Ok(ApplyResult {
146 applied_rows: 0,
147 skipped_rows: 0,
148 conflicts: Vec::new(),
149 new_lsn: self.db.current_lsn(),
150 });
151 }
152
153 let mut total = ApplyResult {
154 applied_rows: 0,
155 skipped_rows: 0,
156 conflicts: Vec::new(),
157 new_lsn: since,
158 };
159
160 let mut last_successful_lsn = since;
161 for batch in split_changeset(changeset) {
162 let batch_max_lsn = [
163 batch.rows.last().map(|r| r.lsn),
164 batch.edges.last().map(|e| e.lsn),
165 batch.vectors.last().map(|v| v.lsn),
166 ]
167 .into_iter()
168 .flatten()
169 .max()
170 .unwrap_or(since);
171
172 let request = PushRequest {
173 changeset: batch.clone().into(),
174 };
175 let encoded = encode(MessageType::PushRequest, &request)
176 .map_err(|e| Error::SyncError(e.to_string()))?;
177
178 let result: ApplyResult = if crate::chunking::needs_chunking(&encoded) {
179 use crate::chunking::chunk;
180
181 tracing::info!(
182 payload_size = encoded.len(),
183 "push payload exceeds chunking threshold, using chunked send"
184 );
185
186 let inbox = nats_client.new_inbox();
187 let mut inbox_sub = nats_client
188 .subscribe(inbox.clone())
189 .await
190 .map_err(|e| Error::SyncError(e.to_string()))?;
191
192 let subject = push_subject(&self.tenant_id);
193 let chunks = chunk(&encoded);
194 let chunk_id = chunks[0].chunk_id;
195 let total_chunks = chunks[0].total_chunks;
196
197 tracing::debug!(
198 %chunk_id,
199 total_chunks,
200 "sending {} chunks for push request",
201 total_chunks
202 );
203
204 for chunk_msg in &chunks {
205 let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
206 .map_err(|e| Error::SyncError(e.to_string()))?;
207 nats_client
208 .publish(subject.clone(), chunk_encoded.into())
209 .await
210 .map_err(|e| Error::SyncError(e.to_string()))?;
211 }
212
213 let ack = ChunkAck {
214 chunk_id,
215 total_chunks,
216 reply_inbox: inbox.clone(),
217 };
218 let ack_encoded = encode(MessageType::ChunkAck, &ack)
219 .map_err(|e| Error::SyncError(e.to_string()))?;
220 nats_client
221 .publish(subject, ack_encoded.into())
222 .await
223 .map_err(|e| Error::SyncError(e.to_string()))?;
224 nats_client
225 .flush()
226 .await
227 .map_err(|e| Error::SyncError(e.to_string()))?;
228
229 let msg = tokio::time::timeout(SYNC_TIMEOUT, inbox_sub.next())
230 .await
231 .map_err(|_| Error::SyncError("chunked push timed out".to_string()))?
232 .ok_or_else(|| {
233 Error::SyncError("inbox closed before push response".to_string())
234 })?;
235 let envelope = decode(&msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
236 let response: PushResponse = rmp_serde::from_slice(&envelope.payload)
237 .map_err(|e| Error::SyncError(e.to_string()))?;
238 if let Some(err) = response.error {
239 return Err(Error::SyncError(err));
240 }
241 response
242 .result
243 .ok_or_else(|| Error::SyncError("push response missing result".to_string()))?
244 .into()
245 } else {
246 let mut push_result = None;
247 for attempt in 0..5u32 {
248 if attempt > 0 {
249 tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
250 }
251 let inbox = nats_client.new_inbox();
252 let mut inbox_sub = nats_client
253 .subscribe(inbox.clone())
254 .await
255 .map_err(|e| Error::SyncError(e.to_string()))?;
256
257 nats_client
258 .publish_with_reply(
259 push_subject(&self.tenant_id),
260 inbox.clone(),
261 encoded.clone().into(),
262 )
263 .await
264 .map_err(|e| Error::SyncError(e.to_string()))?;
265
266 match tokio::time::timeout(PUSH_REQUEST_TIMEOUT, inbox_sub.next()).await {
267 Ok(Some(msg)) => {
268 if let Some(status) = msg.status {
269 if status == async_nats::StatusCode::NO_RESPONDERS && attempt < 4 {
270 tracing::debug!(attempt, "push got no responders, retrying");
271 continue;
272 }
273 if attempt < 4 {
274 tracing::debug!(
275 attempt,
276 ?status,
277 "push got status reply, retrying"
278 );
279 continue;
280 }
281 return Err(Error::SyncError(format!(
282 "push failed with NATS status reply: {status:?}"
283 )));
284 }
285
286 let envelope = match decode(&msg.payload) {
287 Ok(envelope) => envelope,
288 Err(err) if attempt < 4 => {
289 tracing::debug!(attempt, error = %err, "push got malformed reply envelope, retrying");
290 continue;
291 }
292 Err(err) => return Err(Error::SyncError(err.to_string())),
293 };
294 let response: PushResponse = match rmp_serde::from_slice(
295 &envelope.payload,
296 ) {
297 Ok(response) => response,
298 Err(err) if attempt < 4 => {
299 tracing::debug!(attempt, error = %err, "push got malformed reply payload, retrying");
300 continue;
301 }
302 Err(err) => return Err(Error::SyncError(err.to_string())),
303 };
304 if let Some(err) = response.error {
305 return Err(Error::SyncError(err));
306 }
307 push_result = Some(
308 response
309 .result
310 .ok_or_else(|| {
311 Error::SyncError("push response missing result".to_string())
312 })?
313 .into(),
314 );
315 break;
316 }
317 Ok(None) => {
318 return Err(Error::SyncError("push inbox closed".to_string()));
319 }
320 Err(_) if attempt < 4 => {
321 tracing::debug!(attempt, "push timed out, retrying");
322 continue;
323 }
324 Err(_) => {
325 return Err(Error::SyncError(
326 "NATS request timed out waiting for push response".to_string(),
327 ));
328 }
329 }
330 }
331 push_result.ok_or_else(|| {
332 Error::SyncError(
333 "push failed after retries: no response from server".to_string(),
334 )
335 })?
336 };
337 last_successful_lsn = batch_max_lsn;
338 total.applied_rows += result.applied_rows;
339 total.skipped_rows += result.skipped_rows;
340 total.conflicts.extend(result.conflicts);
341 total.new_lsn = result.new_lsn;
342 }
343
344 self.push_watermark
345 .store(last_successful_lsn, Ordering::SeqCst);
346 self.db
347 .persist_sync_push_watermark(&self.tenant_id, last_successful_lsn)
348 .map_err(|err| Error::SyncError(err.to_string()))?;
349 Ok(total)
350 }
351
352 pub async fn pull(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
354 let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
355 let directions = self.table_directions()?;
356
357 let mut since_lsn = self.pull_watermark.load(Ordering::SeqCst);
358 #[allow(unused_assignments)]
359 let mut last_server_lsn = since_lsn;
360 let mut total = ApplyResult {
361 applied_rows: 0,
362 skipped_rows: 0,
363 conflicts: vec![],
364 new_lsn: since_lsn,
365 };
366
367 loop {
368 let request = PullRequest {
369 since_lsn,
370 max_entries: Some(PULL_PAGE_SIZE),
371 };
372
373 let (changes, has_more, cursor) = {
374 let encoded = encode(MessageType::PullRequest, &request)
375 .map_err(|e| Error::SyncError(e.to_string()))?;
376
377 let mut first_attempt_response = None;
378 for attempt in 0..5u32 {
379 if attempt > 0 {
380 tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
381 }
382 let inbox = nats_client.new_inbox();
385 let mut inbox_sub = nats_client
386 .subscribe(inbox.clone())
387 .await
388 .map_err(|e| Error::SyncError(e.to_string()))?;
389 let timeout = if attempt < 2 {
390 Duration::from_secs(2)
391 } else {
392 SYNC_TIMEOUT
393 };
394
395 nats_client
396 .publish_with_reply(
397 pull_subject(&self.tenant_id),
398 inbox.clone(),
399 encoded.clone().into(),
400 )
401 .await
402 .map_err(|e| Error::SyncError(e.to_string()))?;
403
404 match tokio::time::timeout(timeout, inbox_sub.next()).await {
405 Ok(Some(msg)) => {
406 first_attempt_response = Some((msg, inbox_sub));
407 break;
408 }
409 Ok(None) => {
410 return Err(Error::SyncError("pull inbox closed".to_string()));
411 }
412 Err(_) if attempt < 4 => {
413 tracing::debug!(attempt, "pull timed out, retrying");
414 continue;
415 }
416 Err(_) => {}
417 }
418 }
419
420 let (first_msg, mut inbox_sub) = first_attempt_response.ok_or_else(|| {
421 Error::SyncError("NATS request timed out waiting for pull response".to_string())
422 })?;
423
424 let first_envelope =
425 decode(&first_msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
426
427 let response_envelope = match first_envelope.message_type {
428 MessageType::PullResponse => first_envelope,
429 MessageType::Chunk => {
430 let first_chunk: crate::protocol::ChunkMessage =
431 rmp_serde::from_slice(&first_envelope.payload)
432 .map_err(|e| Error::SyncError(e.to_string()))?;
433 let total = first_chunk.total_chunks;
434 let mut collected = vec![first_chunk];
435
436 tracing::debug!(
437 total_chunks = total,
438 "pull response is chunked, collecting chunks"
439 );
440
441 let deadline = tokio::time::Instant::now() + CHUNK_COLLECT_TIMEOUT;
442
443 while collected.len() < total as usize {
444 let remaining = deadline.duration_since(tokio::time::Instant::now());
445 if remaining.is_zero() {
446 return Err(Error::SyncError(format!(
447 "overall chunk collection deadline exceeded after {}/{} chunks",
448 collected.len(),
449 total
450 )));
451 }
452 let chunk_msg = tokio::time::timeout_at(deadline, inbox_sub.next())
453 .await
454 .map_err(|_| {
455 Error::SyncError(format!(
456 "timeout collecting pull chunks ({}/{})",
457 collected.len(),
458 total
459 ))
460 })?
461 .ok_or_else(|| {
462 Error::SyncError("pull chunk stream ended".to_string())
463 })?;
464 let env = decode(&chunk_msg.payload)
465 .map_err(|e| Error::SyncError(e.to_string()))?;
466 if matches!(env.message_type, MessageType::Chunk) {
467 let c: crate::protocol::ChunkMessage =
468 rmp_serde::from_slice(&env.payload)
469 .map_err(|e| Error::SyncError(e.to_string()))?;
470 collected.push(c);
471 } else {
472 return Err(Error::SyncError(format!(
473 "unexpected message type {:?} while collecting pull chunks",
474 env.message_type
475 )));
476 }
477 }
478 let reassembled = crate::chunking::reassemble(&mut collected);
479 decode(&reassembled).map_err(|e| Error::SyncError(e.to_string()))?
480 }
481 _ => {
482 return Err(Error::SyncError(
483 "unexpected message type in pull response".to_string(),
484 ));
485 }
486 };
487
488 let response: PullResponse = rmp_serde::from_slice(&response_envelope.payload)
489 .map_err(|e| Error::SyncError(e.to_string()))?;
490 (
491 ChangeSet::from(response.changeset),
492 response.has_more,
493 response.cursor,
494 )
495 };
496
497 let server_lsn = [
499 changes.rows.last().map(|r| r.lsn),
500 changes.edges.last().map(|e| e.lsn),
501 changes.vectors.last().map(|v| v.lsn),
502 ]
503 .into_iter()
504 .flatten()
505 .max()
506 .unwrap_or(since_lsn);
507
508 let filtered = changes
509 .filter_by_direction(&directions, &[SyncDirection::Pull, SyncDirection::Both]);
510 let result = self
511 .db
512 .apply_changes(filtered, &remap_pull_policies(policies))?;
513 total.applied_rows += result.applied_rows;
514 total.skipped_rows += result.skipped_rows;
515 total.conflicts.extend(result.conflicts);
516 total.new_lsn = result.new_lsn;
517 last_server_lsn = server_lsn;
518
519 if !has_more {
520 break;
521 }
522 since_lsn = cursor.unwrap_or(since_lsn);
523 }
524
525 self.pull_watermark.store(last_server_lsn, Ordering::SeqCst);
526 self.db
527 .persist_sync_pull_watermark(&self.tenant_id, last_server_lsn)
528 .map_err(|err| Error::SyncError(err.to_string()))?;
529 Ok(total)
530 }
531
532 pub async fn pull_default(&self) -> Result<ApplyResult, Error> {
534 let policies = self.conflict_policies()?;
535 self.pull(&policies).await
536 }
537
538 pub async fn initial_sync(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
540 self.pull(policies).await
541 }
542
543 pub fn push_watermark(&self) -> Lsn {
544 self.push_watermark.load(Ordering::SeqCst)
545 }
546
547 pub fn pull_watermark(&self) -> Lsn {
548 self.pull_watermark.load(Ordering::SeqCst)
549 }
550
551 pub fn tenant_id(&self) -> &str {
552 &self.tenant_id
553 }
554
555 pub fn nats_url(&self) -> &str {
556 &self.nats_url
557 }
558
559 pub fn set_table_direction(&self, table: &str, direction: SyncDirection) {
560 match self.table_directions.write() {
561 Ok(mut directions) => {
562 directions.insert(table.to_string(), direction);
563 }
564 Err(_) => tracing::warn!("sync table_directions lock poisoned; ignoring update"),
565 }
566 }
567
568 pub fn set_conflict_policy(&self, table: &str, policy: ConflictPolicy) {
569 match self.conflict_policies.write() {
570 Ok(mut policies) => {
571 policies.per_table.insert(table.to_string(), policy);
572 }
573 Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
574 }
575 }
576
577 pub fn set_default_conflict_policy(&self, policy: ConflictPolicy) {
578 match self.conflict_policies.write() {
579 Ok(mut policies) => {
580 policies.default = policy;
581 }
582 Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
583 }
584 }
585
586 fn table_directions(&self) -> Result<HashMap<String, SyncDirection>, Error> {
587 self.table_directions
588 .read()
589 .map(|directions| directions.clone())
590 .map_err(|_| Error::SyncError("sync table directions lock poisoned".to_string()))
591 }
592
593 fn conflict_policies(&self) -> Result<ConflictPolicies, Error> {
594 self.conflict_policies
595 .read()
596 .map(|policies| policies.clone())
597 .map_err(|_| Error::SyncError("sync conflict policies lock poisoned".to_string()))
598 }
599}
600
601pub(crate) fn split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
602 let wire = WireChangeSet::from(changeset.clone());
603 let estimated = rmp_serde::to_vec(&wire).map(|v| v.len()).unwrap_or(0);
604 if estimated <= MAX_BATCH_BYTES {
605 return vec![changeset];
606 }
607
608 let batches = fast_split_changeset(changeset.clone());
609 if batches
610 .iter()
611 .all(|batch| batch_wire_size(batch) <= MAX_BATCH_BYTES)
612 {
613 return batches;
614 }
615
616 precise_split_changeset(changeset)
617}
618
619fn batch_wire_size(changeset: &ChangeSet) -> usize {
620 rmp_serde::to_vec(&WireChangeSet::from(changeset.clone()))
621 .map(|v| v.len())
622 .unwrap_or(usize::MAX)
623}
624
625fn fast_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
626 let row_sizes: Vec<usize> = changeset
627 .rows
628 .iter()
629 .map(|r| {
630 let wire_row = WireRowChange::from(r.clone());
631 rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
632 })
633 .collect();
634 let vector_sizes: Vec<usize> = changeset
635 .vectors
636 .iter()
637 .map(|v| {
638 let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
639 rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
640 })
641 .collect();
642
643 let mut batches = Vec::new();
644 let mut batch_rows = Vec::new();
645 let mut batch_vectors = Vec::new();
646 let mut batch_size = 0usize;
647 let changeset_edges = changeset.edges;
648 let changeset_vectors = changeset.vectors;
649 let changeset_ddl = changeset.ddl;
650
651 let edges_size: usize = {
652 let edges_wire: Vec<crate::protocol::WireEdgeChange> =
653 changeset_edges.iter().cloned().map(Into::into).collect();
654 rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
655 };
656 let ddl_size: usize = {
657 let ddl_wire: Vec<crate::protocol::WireDdlChange> =
658 changeset_ddl.iter().cloned().map(Into::into).collect();
659 rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
660 };
661 let first_batch_overhead = edges_size + ddl_size;
662
663 for (i, row) in changeset.rows.into_iter().enumerate() {
664 let row_size = row_sizes.get(i).copied().unwrap_or(128);
665 let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
666 let item_size = row_size + vec_size_for_i;
667 let first_item_overhead = if batch_rows.is_empty() && batches.is_empty() {
668 first_batch_overhead
669 } else {
670 0
671 };
672
673 if !batch_rows.is_empty() && batch_size + item_size > TARGET_BATCH_BYTES {
674 batches.push(ChangeSet {
675 rows: std::mem::take(&mut batch_rows),
676 edges: if batches.is_empty() {
677 changeset_edges.clone()
678 } else {
679 Vec::new()
680 },
681 vectors: std::mem::take(&mut batch_vectors),
682 ddl: if batches.is_empty() {
683 changeset_ddl.clone()
684 } else {
685 Vec::new()
686 },
687 });
688 batch_size = 0;
689 }
690
691 if i < changeset_vectors.len() {
692 batch_vectors.push(changeset_vectors[i].clone());
693 batch_size += vec_size_for_i;
694 }
695 batch_rows.push(row);
696 batch_size += row_size + first_item_overhead;
697 }
698
699 if !batch_rows.is_empty() {
700 batches.push(ChangeSet {
701 rows: batch_rows,
702 edges: if batches.is_empty() {
703 changeset_edges
704 } else {
705 Vec::new()
706 },
707 vectors: batch_vectors,
708 ddl: if batches.is_empty() {
709 changeset_ddl
710 } else {
711 Vec::new()
712 },
713 });
714 } else if batches.is_empty() {
715 batches.push(ChangeSet {
716 rows: Vec::new(),
717 edges: changeset_edges,
718 vectors: Vec::new(),
719 ddl: changeset_ddl,
720 });
721 }
722
723 batches
724}
725
726fn precise_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
727 let row_sizes: Vec<usize> = changeset
729 .rows
730 .iter()
731 .map(|r| {
732 let wire_row = WireRowChange::from(r.clone());
733 rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
734 })
735 .collect();
736 let vector_sizes: Vec<usize> = changeset
737 .vectors
738 .iter()
739 .map(|v| {
740 let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
741 rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
742 })
743 .collect();
744
745 let mut batches = Vec::new();
746 let mut batch_rows = Vec::new();
747 let mut batch_vectors = Vec::new();
748 let mut batch_size = 0usize;
749 let changeset_edges = changeset.edges;
751 let changeset_vectors = changeset.vectors;
752 let changeset_ddl = changeset.ddl;
753
754 let edges_size: usize = {
756 let edges_wire: Vec<crate::protocol::WireEdgeChange> =
757 changeset_edges.iter().cloned().map(Into::into).collect();
758 rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
759 };
760 let ddl_size: usize = {
761 let ddl_wire: Vec<crate::protocol::WireDdlChange> =
762 changeset_ddl.iter().cloned().map(Into::into).collect();
763 rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
764 };
765 let first_batch_overhead = edges_size + ddl_size;
766
767 for (i, row) in changeset.rows.into_iter().enumerate() {
768 let row_size = row_sizes.get(i).copied().unwrap_or(128);
769 let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
770 let overhead = if batches.is_empty() {
771 first_batch_overhead
772 } else {
773 0
774 };
775
776 let should_flush = if batch_rows.is_empty() {
777 false
778 } else {
779 let mut trial_rows = batch_rows.clone();
780 trial_rows.push(row.clone());
781 let mut trial_vectors = batch_vectors.clone();
782 if i < changeset_vectors.len() {
783 trial_vectors.push(changeset_vectors[i].clone());
784 }
785 let trial = ChangeSet {
786 rows: trial_rows.clone(),
787 edges: if batches.is_empty() {
788 changeset_edges.clone()
789 } else {
790 Vec::new()
791 },
792 vectors: trial_vectors,
793 ddl: if batches.is_empty() {
794 changeset_ddl.clone()
795 } else {
796 Vec::new()
797 },
798 };
799 let actual_size = rmp_serde::to_vec(&WireChangeSet::from(trial))
800 .map(|v| v.len())
801 .unwrap_or(usize::MAX);
802 batch_size + row_size + vec_size_for_i + overhead > MAX_BATCH_BYTES
803 || actual_size > MAX_BATCH_BYTES
804 };
805
806 if should_flush {
807 batches.push(ChangeSet {
808 rows: std::mem::take(&mut batch_rows),
809 edges: if batches.is_empty() {
810 changeset_edges.clone()
811 } else {
812 Vec::new()
813 },
814 vectors: std::mem::take(&mut batch_vectors),
815 ddl: if batches.is_empty() {
816 changeset_ddl.clone()
817 } else {
818 Vec::new()
819 },
820 });
821 batch_size = 0;
822 }
823
824 if i < changeset_vectors.len() {
826 batch_vectors.push(changeset_vectors[i].clone());
827 batch_size += vector_sizes.get(i).copied().unwrap_or(64);
828 }
829 batch_rows.push(row);
830 batch_size += row_size;
831 }
832
833 if !batch_rows.is_empty() {
834 batches.push(ChangeSet {
835 rows: batch_rows,
836 edges: if batches.is_empty() {
837 changeset_edges
838 } else {
839 Vec::new()
840 },
841 vectors: batch_vectors,
842 ddl: if batches.is_empty() {
843 changeset_ddl
844 } else {
845 Vec::new()
846 },
847 });
848 } else if batches.is_empty() && (!changeset_edges.is_empty() || !changeset_ddl.is_empty()) {
849 batches.push(ChangeSet {
850 rows: Vec::new(),
851 edges: changeset_edges,
852 vectors: Vec::new(),
853 ddl: changeset_ddl,
854 });
855 }
856
857 batches
858}
859
860fn remap_pull_policies(policies: &ConflictPolicies) -> ConflictPolicies {
861 let remap = |policy: ConflictPolicy| match policy {
862 ConflictPolicy::ServerWins => ConflictPolicy::EdgeWins,
863 ConflictPolicy::EdgeWins => ConflictPolicy::ServerWins,
864 other => other,
865 };
866
867 ConflictPolicies {
868 per_table: policies
869 .per_table
870 .iter()
871 .map(|(table, policy)| (table.clone(), remap(*policy)))
872 .collect(),
873 default: remap(policies.default),
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use contextdb_core::{RowId, Value};
881 use contextdb_engine::Database;
882 use contextdb_engine::sync_types::{NaturalKey, RowChange, VectorChange};
883 use std::sync::Arc;
884 use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
885 use testcontainers::runners::AsyncRunner;
886 use testcontainers::{ContainerAsync, GenericImage, ImageExt};
887 use uuid::Uuid;
888
889 struct NatsFixture {
890 _container: ContainerAsync<GenericImage>,
891 nats_url: String,
892 }
893
894 async fn start_nats() -> NatsFixture {
895 let nats_conf = format!("{}/tests/nats.conf", env!("CARGO_MANIFEST_DIR"));
896
897 let image = GenericImage::new("nats", "latest")
898 .with_exposed_port(4222.tcp())
899 .with_wait_for(WaitFor::message_on_stderr("Server is ready"));
900
901 let request = image
902 .with_mount(Mount::bind_mount(&nats_conf, "/etc/nats/nats.conf"))
903 .with_cmd(["--js", "--config", "/etc/nats/nats.conf"]);
904
905 let container: ContainerAsync<GenericImage> = request.start().await.unwrap();
906 let nats_port = container.get_host_port_ipv4(4222.tcp()).await.unwrap();
907
908 NatsFixture {
909 _container: container,
910 nats_url: format!("nats://127.0.0.1:{nats_port}"),
911 }
912 }
913
914 #[tokio::test]
915 async fn sync_01_client_push_survives_poisoned_direction_lock() {
916 let nats = start_nats().await;
917 let client = Arc::new(SyncClient::new(
918 Arc::new(Database::open_memory()),
919 &nats.nats_url,
920 "sync-01",
921 ));
922
923 client.ensure_connected().await.expect("connect NATS");
924 let poison_client = client.clone();
925 let _ = std::thread::spawn(move || {
926 let _guard = poison_client.table_directions.write().unwrap();
927 panic!("poison sync_client directions lock");
928 })
929 .join();
930
931 let join = tokio::spawn({
932 let client = client.clone();
933 async move { client.push().await }
934 })
935 .await;
936
937 assert!(
938 matches!(join, Ok(Err(Error::SyncError(_)))),
939 "push should return a sync error instead of panicking on poisoned table_directions, got {join:?}"
940 );
941 }
942
943 #[tokio::test]
944 async fn sync_02_client_pull_default_survives_poisoned_policy_lock() {
945 let nats = start_nats().await;
946 let client = Arc::new(SyncClient::new(
947 Arc::new(Database::open_memory()),
948 &nats.nats_url,
949 "sync-02",
950 ));
951
952 client.ensure_connected().await.expect("connect NATS");
953 let poison_client = client.clone();
954 let _ = std::thread::spawn(move || {
955 let _guard = poison_client.conflict_policies.write().unwrap();
956 panic!("poison sync_client policies lock");
957 })
958 .join();
959
960 let join = tokio::spawn({
961 let client = client.clone();
962 async move { client.pull_default().await }
963 })
964 .await;
965
966 assert!(
967 matches!(join, Ok(Err(Error::SyncError(_)))),
968 "pull_default should return a sync error instead of panicking on poisoned conflict_policies, got {join:?}"
969 );
970 }
971
972 #[test]
974 fn a14_batch_splitting_respects_byte_limits() {
975 let large_text = "x".repeat(100 * 1024); let mut rows = Vec::new();
978 for _ in 0..10 {
979 let id = Uuid::new_v4();
980 let mut values = HashMap::new();
981 values.insert("id".to_string(), Value::Uuid(id));
982 values.insert("data".to_string(), Value::Text(large_text.clone()));
983 rows.push(RowChange {
984 table: "t".to_string(),
985 natural_key: NaturalKey {
986 column: "id".to_string(),
987 value: Value::Uuid(id),
988 },
989 values,
990 deleted: false,
991 lsn: Lsn(1),
992 });
993 }
994
995 let changeset = ChangeSet {
996 rows,
997 edges: Vec::new(),
998 vectors: Vec::new(),
999 ddl: vec![contextdb_engine::sync_types::DdlChange::CreateTable {
1000 name: "t".to_string(),
1001 columns: vec![
1002 ("id".to_string(), "UUID".to_string()),
1003 ("data".to_string(), "TEXT".to_string()),
1004 ],
1005 constraints: vec!["PRIMARY KEY (id)".to_string()],
1006 }],
1007 };
1008
1009 let batches = split_changeset(changeset);
1010
1011 assert!(
1013 batches.len() >= 2,
1014 "10 rows of ~100KB each (~1MB total) must split into at least 2 batches, got {}",
1015 batches.len()
1016 );
1017
1018 for (i, batch) in batches.iter().enumerate() {
1020 let wire = WireChangeSet::from(batch.clone());
1021 let size = rmp_serde::to_vec(&wire)
1022 .expect("a14 batch should serialize for byte-size accounting")
1023 .len();
1024 assert!(
1025 size <= 800 * 1024,
1026 "batch {} serialized to {} bytes, exceeds 800KB limit",
1027 i,
1028 size
1029 );
1030 }
1031
1032 assert!(!batches[0].ddl.is_empty(), "DDL must be in first batch");
1034 for batch in &batches[1..] {
1035 assert!(
1036 batch.ddl.is_empty(),
1037 "DDL must NOT be in subsequent batches"
1038 );
1039 assert!(
1040 batch.edges.is_empty(),
1041 "edges must NOT be in subsequent batches"
1042 );
1043 }
1044 }
1045
1046 #[test]
1047 fn nv_snapshot_split_batches_emit_ddl_before_vector_batches_and_apply_cleanly() {
1048 let table = "snapshot_split_evidence";
1049 let ddl_only_marker = "ddl_order_marker_column";
1050 let vector_column = "vector_later_marker_text";
1051 let row_count = 10usize;
1052 let large_payload = "x".repeat(120 * 1024);
1053 let mut ids = Vec::new();
1054 let mut rows = Vec::new();
1055 let mut vectors = Vec::new();
1056
1057 for i in 0..row_count {
1058 let id = Uuid::new_v4();
1059 ids.push(id);
1060 let mut values = HashMap::new();
1061 values.insert("id".to_string(), Value::Uuid(id));
1062 values.insert("payload".to_string(), Value::Text(large_payload.clone()));
1063 rows.push(RowChange {
1064 table: table.to_string(),
1065 natural_key: NaturalKey {
1066 column: "id".to_string(),
1067 value: Value::Uuid(id),
1068 },
1069 values,
1070 deleted: false,
1071 lsn: Lsn((i + 1) as u64),
1072 });
1073 vectors.push(VectorChange {
1074 index: contextdb_core::VectorIndexRef::new(table, vector_column),
1075 row_id: RowId((i + 1) as u64),
1076 vector: if i == 0 {
1077 vec![1.0, 0.0, 0.0, 0.0]
1078 } else {
1079 vec![0.0, 1.0, 0.0, 0.0]
1080 },
1081 lsn: Lsn((i + 1) as u64),
1082 });
1083 }
1084
1085 let changeset = ChangeSet {
1086 rows,
1087 edges: Vec::new(),
1088 vectors,
1089 ddl: vec![contextdb_engine::sync_types::DdlChange::CreateTable {
1090 name: table.to_string(),
1091 columns: vec![
1092 ("id".to_string(), "UUID PRIMARY KEY".to_string()),
1093 ("payload".to_string(), "TEXT".to_string()),
1094 (ddl_only_marker.to_string(), "TEXT".to_string()),
1095 (vector_column.to_string(), "VECTOR(4)".to_string()),
1096 ],
1097 constraints: Vec::new(),
1098 }],
1099 };
1100
1101 let batches = split_changeset(changeset);
1102 assert!(
1103 batches.len() >= 2,
1104 "snapshot-shaped changeset with large rows must exercise real split path; got {} batch(es)",
1105 batches.len()
1106 );
1107 let first_ddl_idx = batches
1108 .iter()
1109 .position(|batch| !batch.ddl.is_empty())
1110 .expect("split stream must include schema DDL");
1111 let first_vector_idx = batches
1112 .iter()
1113 .position(|batch| !batch.vectors.is_empty())
1114 .expect("split stream must include vector changes");
1115 assert!(
1116 first_ddl_idx <= first_vector_idx,
1117 "first vector batch must not be emitted before schema DDL; first_ddl_idx={first_ddl_idx}, first_vector_idx={first_vector_idx}"
1118 );
1119 for (idx, batch) in batches.iter().enumerate().skip(first_ddl_idx + 1) {
1120 assert!(
1121 batch.ddl.is_empty(),
1122 "schema DDL must appear once before vector replay, not again in batch {idx}"
1123 );
1124 }
1125
1126 if first_ddl_idx == first_vector_idx {
1127 fn byte_pos(haystack: &[u8], needle: &str) -> usize {
1128 haystack
1129 .windows(needle.len())
1130 .position(|window| window == needle.as_bytes())
1131 .unwrap_or_else(|| panic!("encoded batch must contain sentinel {needle:?}"))
1132 }
1133
1134 let bytes = rmp_serde::to_vec(&WireChangeSet::from(batches[first_vector_idx].clone()))
1135 .expect("encode split vector-bearing batch");
1136 let ddl_marker_pos = byte_pos(&bytes, ddl_only_marker);
1137 let vector_marker_pos = byte_pos(&bytes, vector_column);
1138 assert!(
1139 ddl_marker_pos < vector_marker_pos,
1140 "vector-bearing split batch must serialize schema bytes before vector index bytes; \
1141 ddl_marker_pos={ddl_marker_pos}, vector_marker_pos={vector_marker_pos}, encoded_len={}",
1142 bytes.len()
1143 );
1144 }
1145
1146 let receiver = Database::open_memory();
1147 let policies = ConflictPolicies::uniform(ConflictPolicy::LatestWins);
1148 for (idx, batch) in batches.into_iter().enumerate() {
1149 receiver
1150 .apply_changes(batch, &policies)
1151 .unwrap_or_else(|err| panic!("receiver must apply split batch {idx}: {err}"));
1152 }
1153
1154 let rows = receiver
1155 .execute(&format!("SELECT id FROM {table}"), &HashMap::new())
1156 .expect("receiver must expose replayed rows after split apply");
1157 assert_eq!(
1158 rows.rows.len(),
1159 row_count,
1160 "fresh receiver must contain every row after applying split snapshot batches"
1161 );
1162
1163 let mut params = HashMap::new();
1164 params.insert("q".to_string(), Value::Vector(vec![1.0, 0.0, 0.0, 0.0]));
1165 let nearest = receiver
1166 .execute(
1167 &format!("SELECT id FROM {table} ORDER BY {vector_column} <=> $q LIMIT 1"),
1168 ¶ms,
1169 )
1170 .expect("receiver must expose replayed vector index after split apply");
1171 let id_idx = nearest
1172 .columns
1173 .iter()
1174 .position(|column| column == "id")
1175 .expect("nearest query must project id");
1176 assert_eq!(
1177 nearest.rows[0][id_idx],
1178 Value::Uuid(ids[0]),
1179 "replayed vector index must route to the declared table+column after split apply"
1180 );
1181 }
1182
1183 #[test]
1184 fn a14b_batch_splitting_accounts_for_vector_sizes() {
1185 let mut rows = Vec::new();
1186 let mut vectors = Vec::new();
1187 for _ in 0..200 {
1188 let id = Uuid::new_v4();
1189 let mut values = HashMap::new();
1190 values.insert("id".to_string(), Value::Uuid(id));
1191 values.insert("data".to_string(), Value::Text("x".repeat(3000)));
1192 rows.push(RowChange {
1193 table: "t".to_string(),
1194 natural_key: NaturalKey {
1195 column: "id".to_string(),
1196 value: Value::Uuid(id),
1197 },
1198 values,
1199 deleted: false,
1200 lsn: Lsn(1),
1201 });
1202 vectors.push(VectorChange {
1203 index: contextdb_core::VectorIndexRef::default(),
1204 row_id: RowId(0),
1205 vector: (0..384).map(|j| j as f32).collect(),
1206 lsn: Lsn(1),
1207 });
1208 }
1209 let changeset = ChangeSet {
1210 rows,
1211 edges: Vec::new(),
1212 vectors,
1213 ddl: vec![],
1214 };
1215 let batches = split_changeset(changeset);
1216 assert!(
1217 batches.len() >= 2,
1218 "200 rows with 384-dim vectors must split into 2+ batches with correct accounting, got {}",
1219 batches.len()
1220 );
1221 for (i, batch) in batches.iter().enumerate() {
1222 let wire = WireChangeSet::from(batch.clone());
1223 let size = rmp_serde::to_vec(&wire)
1224 .expect("a14b batch should serialize for byte-size accounting")
1225 .len();
1226 assert!(
1227 size <= 800 * 1024,
1228 "batch {} serialized to {} bytes, exceeds 800KB limit",
1229 i,
1230 size
1231 );
1232 }
1233 }
1234
1235 #[test]
1237 fn a15_split_changeset_single_oversized_row() {
1238 let oversized_text = "x".repeat(600 * 1024);
1239 let id = Uuid::new_v4();
1240 let mut values = HashMap::new();
1241 values.insert("id".to_string(), Value::Uuid(id));
1242 values.insert("data".to_string(), Value::Text(oversized_text));
1243 let row = RowChange {
1244 table: "observations".to_string(),
1245 natural_key: NaturalKey {
1246 column: "id".to_string(),
1247 value: Value::Uuid(id),
1248 },
1249 values,
1250 deleted: false,
1251 lsn: Lsn(1),
1252 };
1253 let changeset = ChangeSet {
1254 rows: vec![row],
1255 edges: Vec::new(),
1256 vectors: Vec::new(),
1257 ddl: Vec::new(),
1258 };
1259
1260 let batches = split_changeset(changeset);
1261
1262 assert!(
1263 !batches.is_empty(),
1264 "split_changeset must return at least one batch, got {}",
1265 batches.len()
1266 );
1267 let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1268 assert_eq!(
1269 total_rows, 1,
1270 "the single oversized row must appear in exactly one batch, got {}",
1271 total_rows
1272 );
1273 }
1274
1275 #[test]
1277 fn a16_split_changeset_preserves_row_vector_pairing() {
1278 use contextdb_engine::sync_types::VectorChange;
1279
1280 let mut rows = Vec::new();
1281 let mut vectors = Vec::new();
1282 for i in 0..10usize {
1283 let id = Uuid::new_v4();
1284 let mut values = HashMap::new();
1285 values.insert("id".to_string(), Value::Uuid(id));
1286 values.insert("data".to_string(), Value::Text("x".repeat(100 * 1024)));
1287 rows.push(RowChange {
1288 table: "observations".to_string(),
1289 natural_key: NaturalKey {
1290 column: "id".to_string(),
1291 value: Value::Uuid(id),
1292 },
1293 values,
1294 deleted: false,
1295 lsn: Lsn((i + 1) as u64),
1296 });
1297 vectors.push(VectorChange {
1298 index: contextdb_core::VectorIndexRef::default(),
1299 row_id: RowId((i + 1) as u64),
1300 vector: vec![i as f32; 3],
1301 lsn: Lsn((i + 1) as u64),
1302 });
1303 }
1304 let changeset = ChangeSet {
1305 rows,
1306 edges: Vec::new(),
1307 vectors,
1308 ddl: Vec::new(),
1309 };
1310
1311 let batches = split_changeset(changeset);
1312
1313 assert!(
1314 batches.len() >= 2,
1315 "10 rows * ~100KB each must split into at least 2 batches, got {}",
1316 batches.len()
1317 );
1318 let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1319 let total_vecs: usize = batches.iter().map(|b| b.vectors.len()).sum();
1320 assert_eq!(total_rows, 10, "all 10 rows must be present across batches");
1321 assert_eq!(
1322 total_vecs, 10,
1323 "all 10 vectors must be present across batches"
1324 );
1325 for (i, batch) in batches.iter().enumerate() {
1326 assert_eq!(
1327 batch.rows.len(),
1328 batch.vectors.len(),
1329 "batch {} must have equal row and vector counts: rows={}, vectors={}",
1330 i,
1331 batch.rows.len(),
1332 batch.vectors.len()
1333 );
1334 for j in 0..batch.rows.len() {
1335 assert_eq!(
1336 batch.rows[j].lsn, batch.vectors[j].lsn,
1337 "batch {} position {}: row.lsn={} != vector.lsn={} — pairing is broken",
1338 i, j, batch.rows[j].lsn, batch.vectors[j].lsn
1339 );
1340 }
1341 }
1342 }
1343
1344 #[test]
1346 fn a17_split_changeset_empty_input_returns_one_batch() {
1347 let changeset = ChangeSet {
1348 rows: Vec::new(),
1349 edges: Vec::new(),
1350 vectors: Vec::new(),
1351 ddl: Vec::new(),
1352 };
1353
1354 let batches = split_changeset(changeset);
1355
1356 assert_eq!(
1357 batches.len(),
1358 1,
1359 "empty changeset must produce exactly 1 batch (not 0), got {}",
1360 batches.len()
1361 );
1362 assert!(
1363 batches[0].rows.is_empty(),
1364 "the single batch for an empty input must have no rows"
1365 );
1366 }
1367
1368 #[test]
1370 fn a18_split_changeset_edge_only_not_dropped() {
1371 use contextdb_engine::sync_types::EdgeChange;
1372
1373 let mut edges = Vec::new();
1374 for _ in 0..200 {
1375 edges.push(EdgeChange {
1376 source: Uuid::new_v4(),
1377 target: Uuid::new_v4(),
1378 edge_type: "x".repeat(5_000),
1379 properties: HashMap::new(),
1380 lsn: Lsn(1),
1381 });
1382 }
1383 let changeset = ChangeSet {
1384 rows: Vec::new(),
1385 edges,
1386 vectors: Vec::new(),
1387 ddl: Vec::new(),
1388 };
1389
1390 let batches = split_changeset(changeset);
1391
1392 assert!(
1393 !batches.is_empty(),
1394 "edge-only changeset must produce at least 1 batch, got {} — edges silently dropped",
1395 batches.len()
1396 );
1397 let total_edges: usize = batches.iter().map(|b| b.edges.len()).sum();
1398 assert_eq!(
1399 total_edges, 200,
1400 "all 200 edges must be present across batches, got {}",
1401 total_edges
1402 );
1403 }
1404
1405 #[test]
1408 fn a19_split_changeset_ddl_only_not_dropped() {
1409 use contextdb_engine::sync_types::DdlChange;
1410
1411 let mut ddl = Vec::new();
1412 for i in 0..20 {
1413 ddl.push(DdlChange::CreateTable {
1414 name: format!("table_{}", i),
1415 columns: (0..100)
1416 .map(|j| (format!("col_{}_{}", j, "x".repeat(500)), "TEXT".to_string()))
1417 .collect(),
1418 constraints: vec![format!("PRIMARY KEY (col_{})", "x".repeat(500))],
1419 });
1420 }
1421 let changeset = ChangeSet {
1422 rows: Vec::new(),
1423 edges: Vec::new(),
1424 vectors: Vec::new(),
1425 ddl,
1426 };
1427
1428 let batches = split_changeset(changeset);
1429
1430 assert!(
1431 !batches.is_empty(),
1432 "DDL-only changeset must produce at least 1 batch, got {} — DDL silently dropped",
1433 batches.len()
1434 );
1435 let total_ddl: usize = batches.iter().map(|b| b.ddl.len()).sum();
1436 assert_eq!(
1437 total_ddl, 20,
1438 "all 20 DDL entries must be present across batches, got {}",
1439 total_ddl
1440 );
1441 }
1442}