1use crate::protocol::{
2 ChunkAck, MessageType, PullRequest, PullResponse, PushRequest, PushResponse, WireChangeSet,
3 WireRowChange, decode, encode,
4};
5use crate::subjects::{pull_subject, push_subject};
6use contextdb_core::Error;
7use contextdb_engine::Database;
8use contextdb_engine::sync_types::{
9 ApplyResult, ChangeSet, ConflictPolicies, ConflictPolicy, SyncDirection,
10};
11use futures_util::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16
17const SYNC_TIMEOUT: Duration = Duration::from_secs(60);
18const CHUNK_COLLECT_TIMEOUT: Duration = Duration::from_secs(60);
20const PUSH_REQUEST_TIMEOUT: Duration = Duration::from_secs(4);
21const PULL_PAGE_SIZE: u32 = 500;
22const MAX_BATCH_BYTES: usize = 800 * 1024;
23const BATCH_ESTIMATE_SAFETY_MARGIN: usize = 32 * 1024;
24const TARGET_BATCH_BYTES: usize = MAX_BATCH_BYTES - BATCH_ESTIMATE_SAFETY_MARGIN;
25
26pub struct SyncClient {
27 db: Arc<Database>,
28 nats: tokio::sync::Mutex<Option<async_nats::Client>>,
29 nats_url: String,
30 tenant_id: String,
31 push_watermark: AtomicU64,
32 pull_watermark: AtomicU64,
33 table_directions: std::sync::RwLock<HashMap<String, SyncDirection>>,
34 conflict_policies: std::sync::RwLock<ConflictPolicies>,
35}
36
37impl SyncClient {
38 pub fn new(db: Arc<Database>, nats_url: &str, tenant_id: &str) -> Self {
39 assert!(
40 !tenant_id.is_empty()
41 && tenant_id
42 .chars()
43 .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
44 "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
45 );
46 let (push_watermark, pull_watermark) = db
47 .persisted_sync_watermarks(tenant_id)
48 .unwrap_or_else(|err| {
49 tracing::warn!(%tenant_id, error = %err, "failed to load persisted sync watermarks");
50 (0, 0)
51 });
52 Self {
53 db,
54 nats: tokio::sync::Mutex::new(None),
55 nats_url: nats_url.to_string(),
56 tenant_id: tenant_id.to_string(),
57 push_watermark: AtomicU64::new(push_watermark),
58 pull_watermark: AtomicU64::new(pull_watermark),
59 table_directions: std::sync::RwLock::new(HashMap::new()),
60 conflict_policies: std::sync::RwLock::new(ConflictPolicies {
61 per_table: HashMap::new(),
62 default: ConflictPolicy::ServerWins,
63 }),
64 }
65 }
66
67 pub async fn ensure_connected(&self) -> Result<async_nats::Client, String> {
71 let mut guard = self.nats.lock().await;
72 if guard.is_none() {
73 let mut last_err = None;
74 for attempt in 0..10u32 {
75 if attempt > 0 {
76 tokio::time::sleep(Duration::from_millis(200 * u64::from(attempt))).await;
77 }
78 match async_nats::connect(&self.nats_url).await {
79 Ok(client) => {
80 *guard = Some(client);
81 break;
82 }
83 Err(e) => last_err = Some(e.to_string()),
84 }
85 }
86 if guard.is_none() {
87 let err = last_err.unwrap_or_else(|| "unknown error".to_string());
88 return Err(format!(
89 "cannot connect to NATS at {}: {err}",
90 self.nats_url
91 ));
92 }
93 }
94 guard
95 .clone()
96 .ok_or_else(|| format!("cannot connect to NATS at {}", self.nats_url))
97 }
98
99 pub async fn reconnect(&self) {
101 let mut guard = self.nats.lock().await;
102 *guard = None;
103 *guard = async_nats::connect(&self.nats_url).await.ok();
104 }
105
106 pub async fn is_connected(&self) -> bool {
107 let guard = self.nats.lock().await;
108 guard.is_some()
109 }
110
111 pub fn db(&self) -> &Database {
112 &self.db
113 }
114
115 pub fn has_pending_push_changes(&self) -> Result<bool, Error> {
116 let since = self.push_watermark.load(Ordering::SeqCst);
117 let directions = self.table_directions()?;
118 let changes = self
119 .db
120 .changes_since(since)
121 .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
122 Ok(!changes.rows.is_empty()
123 || !changes.edges.is_empty()
124 || !changes.vectors.is_empty()
125 || !changes.ddl.is_empty())
126 }
127
128 pub async fn push(&self) -> Result<ApplyResult, Error> {
129 let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
131
132 let since = self.push_watermark.load(Ordering::SeqCst);
133 let directions = self.table_directions()?;
135 let changeset = self
136 .db
137 .changes_since(since)
138 .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
139
140 if changeset.rows.is_empty()
141 && changeset.edges.is_empty()
142 && changeset.vectors.is_empty()
143 && changeset.ddl.is_empty()
144 {
145 return Ok(ApplyResult {
146 applied_rows: 0,
147 skipped_rows: 0,
148 conflicts: Vec::new(),
149 new_lsn: self.db.current_lsn(),
150 });
151 }
152
153 let mut total = ApplyResult {
154 applied_rows: 0,
155 skipped_rows: 0,
156 conflicts: Vec::new(),
157 new_lsn: since,
158 };
159
160 let mut last_successful_lsn = since;
161 for batch in split_changeset(changeset) {
162 let batch_max_lsn = [
163 batch.rows.last().map(|r| r.lsn),
164 batch.edges.last().map(|e| e.lsn),
165 batch.vectors.last().map(|v| v.lsn),
166 ]
167 .into_iter()
168 .flatten()
169 .max()
170 .unwrap_or(since);
171
172 let request = PushRequest {
173 changeset: batch.clone().into(),
174 };
175 let encoded = encode(MessageType::PushRequest, &request)
176 .map_err(|e| Error::SyncError(e.to_string()))?;
177
178 let result: ApplyResult = if crate::chunking::needs_chunking(&encoded) {
179 use crate::chunking::chunk;
180
181 tracing::info!(
182 payload_size = encoded.len(),
183 "push payload exceeds chunking threshold, using chunked send"
184 );
185
186 let inbox = nats_client.new_inbox();
187 let mut inbox_sub = nats_client
188 .subscribe(inbox.clone())
189 .await
190 .map_err(|e| Error::SyncError(e.to_string()))?;
191
192 let subject = push_subject(&self.tenant_id);
193 let chunks = chunk(&encoded);
194 let chunk_id = chunks[0].chunk_id;
195 let total_chunks = chunks[0].total_chunks;
196
197 tracing::debug!(
198 %chunk_id,
199 total_chunks,
200 "sending {} chunks for push request",
201 total_chunks
202 );
203
204 for chunk_msg in &chunks {
205 let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
206 .map_err(|e| Error::SyncError(e.to_string()))?;
207 nats_client
208 .publish(subject.clone(), chunk_encoded.into())
209 .await
210 .map_err(|e| Error::SyncError(e.to_string()))?;
211 }
212
213 let ack = ChunkAck {
214 chunk_id,
215 total_chunks,
216 reply_inbox: inbox.clone(),
217 };
218 let ack_encoded = encode(MessageType::ChunkAck, &ack)
219 .map_err(|e| Error::SyncError(e.to_string()))?;
220 nats_client
221 .publish(subject, ack_encoded.into())
222 .await
223 .map_err(|e| Error::SyncError(e.to_string()))?;
224 nats_client
225 .flush()
226 .await
227 .map_err(|e| Error::SyncError(e.to_string()))?;
228
229 let msg = tokio::time::timeout(SYNC_TIMEOUT, inbox_sub.next())
230 .await
231 .map_err(|_| Error::SyncError("chunked push timed out".to_string()))?
232 .ok_or_else(|| {
233 Error::SyncError("inbox closed before push response".to_string())
234 })?;
235 let envelope = decode(&msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
236 let response: PushResponse = rmp_serde::from_slice(&envelope.payload)
237 .map_err(|e| Error::SyncError(e.to_string()))?;
238 if let Some(err) = response.error {
239 return Err(Error::SyncError(err));
240 }
241 response
242 .result
243 .ok_or_else(|| Error::SyncError("push response missing result".to_string()))?
244 .into()
245 } else {
246 let mut push_result = None;
247 for attempt in 0..5u32 {
248 if attempt > 0 {
249 tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
250 }
251 let inbox = nats_client.new_inbox();
252 let mut inbox_sub = nats_client
253 .subscribe(inbox.clone())
254 .await
255 .map_err(|e| Error::SyncError(e.to_string()))?;
256
257 nats_client
258 .publish_with_reply(
259 push_subject(&self.tenant_id),
260 inbox.clone(),
261 encoded.clone().into(),
262 )
263 .await
264 .map_err(|e| Error::SyncError(e.to_string()))?;
265
266 match tokio::time::timeout(PUSH_REQUEST_TIMEOUT, inbox_sub.next()).await {
267 Ok(Some(msg)) => {
268 if let Some(status) = msg.status {
269 if status == async_nats::StatusCode::NO_RESPONDERS && attempt < 4 {
270 tracing::debug!(attempt, "push got no responders, retrying");
271 continue;
272 }
273 if attempt < 4 {
274 tracing::debug!(
275 attempt,
276 ?status,
277 "push got status reply, retrying"
278 );
279 continue;
280 }
281 return Err(Error::SyncError(format!(
282 "push failed with NATS status reply: {status:?}"
283 )));
284 }
285
286 let envelope = match decode(&msg.payload) {
287 Ok(envelope) => envelope,
288 Err(err) if attempt < 4 => {
289 tracing::debug!(attempt, error = %err, "push got malformed reply envelope, retrying");
290 continue;
291 }
292 Err(err) => return Err(Error::SyncError(err.to_string())),
293 };
294 let response: PushResponse = match rmp_serde::from_slice(
295 &envelope.payload,
296 ) {
297 Ok(response) => response,
298 Err(err) if attempt < 4 => {
299 tracing::debug!(attempt, error = %err, "push got malformed reply payload, retrying");
300 continue;
301 }
302 Err(err) => return Err(Error::SyncError(err.to_string())),
303 };
304 if let Some(err) = response.error {
305 return Err(Error::SyncError(err));
306 }
307 push_result = Some(
308 response
309 .result
310 .ok_or_else(|| {
311 Error::SyncError("push response missing result".to_string())
312 })?
313 .into(),
314 );
315 break;
316 }
317 Ok(None) => {
318 return Err(Error::SyncError("push inbox closed".to_string()));
319 }
320 Err(_) if attempt < 4 => {
321 tracing::debug!(attempt, "push timed out, retrying");
322 continue;
323 }
324 Err(_) => {
325 return Err(Error::SyncError(
326 "NATS request timed out waiting for push response".to_string(),
327 ));
328 }
329 }
330 }
331 push_result.ok_or_else(|| {
332 Error::SyncError(
333 "push failed after retries: no response from server".to_string(),
334 )
335 })?
336 };
337 last_successful_lsn = batch_max_lsn;
338 total.applied_rows += result.applied_rows;
339 total.skipped_rows += result.skipped_rows;
340 total.conflicts.extend(result.conflicts);
341 total.new_lsn = result.new_lsn;
342 }
343
344 self.push_watermark
345 .store(last_successful_lsn, Ordering::SeqCst);
346 self.db
347 .persist_sync_push_watermark(&self.tenant_id, last_successful_lsn)
348 .map_err(|err| Error::SyncError(err.to_string()))?;
349 Ok(total)
350 }
351
352 pub async fn pull(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
354 let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
355 let directions = self.table_directions()?;
356
357 let mut since_lsn = self.pull_watermark.load(Ordering::SeqCst);
358 #[allow(unused_assignments)]
359 let mut last_server_lsn = since_lsn;
360 let mut total = ApplyResult {
361 applied_rows: 0,
362 skipped_rows: 0,
363 conflicts: vec![],
364 new_lsn: since_lsn,
365 };
366
367 loop {
368 let request = PullRequest {
369 since_lsn,
370 max_entries: Some(PULL_PAGE_SIZE),
371 };
372
373 let (changes, has_more, cursor) = {
374 let encoded = encode(MessageType::PullRequest, &request)
375 .map_err(|e| Error::SyncError(e.to_string()))?;
376
377 let mut first_attempt_response = None;
378 for attempt in 0..5u32 {
379 if attempt > 0 {
380 tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
381 }
382 let inbox = nats_client.new_inbox();
385 let mut inbox_sub = nats_client
386 .subscribe(inbox.clone())
387 .await
388 .map_err(|e| Error::SyncError(e.to_string()))?;
389 let timeout = if attempt < 2 {
390 Duration::from_secs(2)
391 } else {
392 SYNC_TIMEOUT
393 };
394
395 nats_client
396 .publish_with_reply(
397 pull_subject(&self.tenant_id),
398 inbox.clone(),
399 encoded.clone().into(),
400 )
401 .await
402 .map_err(|e| Error::SyncError(e.to_string()))?;
403
404 match tokio::time::timeout(timeout, inbox_sub.next()).await {
405 Ok(Some(msg)) => {
406 first_attempt_response = Some((msg, inbox_sub));
407 break;
408 }
409 Ok(None) => {
410 return Err(Error::SyncError("pull inbox closed".to_string()));
411 }
412 Err(_) if attempt < 4 => {
413 tracing::debug!(attempt, "pull timed out, retrying");
414 continue;
415 }
416 Err(_) => {}
417 }
418 }
419
420 let (first_msg, mut inbox_sub) = first_attempt_response.ok_or_else(|| {
421 Error::SyncError("NATS request timed out waiting for pull response".to_string())
422 })?;
423
424 let first_envelope =
425 decode(&first_msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
426
427 let response_envelope = match first_envelope.message_type {
428 MessageType::PullResponse => first_envelope,
429 MessageType::Chunk => {
430 let first_chunk: crate::protocol::ChunkMessage =
431 rmp_serde::from_slice(&first_envelope.payload)
432 .map_err(|e| Error::SyncError(e.to_string()))?;
433 let total = first_chunk.total_chunks;
434 let mut collected = vec![first_chunk];
435
436 tracing::debug!(
437 total_chunks = total,
438 "pull response is chunked, collecting chunks"
439 );
440
441 let deadline = tokio::time::Instant::now() + CHUNK_COLLECT_TIMEOUT;
442
443 while collected.len() < total as usize {
444 let remaining = deadline.duration_since(tokio::time::Instant::now());
445 if remaining.is_zero() {
446 return Err(Error::SyncError(format!(
447 "overall chunk collection deadline exceeded after {}/{} chunks",
448 collected.len(),
449 total
450 )));
451 }
452 let chunk_msg = tokio::time::timeout_at(deadline, inbox_sub.next())
453 .await
454 .map_err(|_| {
455 Error::SyncError(format!(
456 "timeout collecting pull chunks ({}/{})",
457 collected.len(),
458 total
459 ))
460 })?
461 .ok_or_else(|| {
462 Error::SyncError("pull chunk stream ended".to_string())
463 })?;
464 let env = decode(&chunk_msg.payload)
465 .map_err(|e| Error::SyncError(e.to_string()))?;
466 if matches!(env.message_type, MessageType::Chunk) {
467 let c: crate::protocol::ChunkMessage =
468 rmp_serde::from_slice(&env.payload)
469 .map_err(|e| Error::SyncError(e.to_string()))?;
470 collected.push(c);
471 } else {
472 return Err(Error::SyncError(format!(
473 "unexpected message type {:?} while collecting pull chunks",
474 env.message_type
475 )));
476 }
477 }
478 let reassembled = crate::chunking::reassemble(&mut collected);
479 decode(&reassembled).map_err(|e| Error::SyncError(e.to_string()))?
480 }
481 _ => {
482 return Err(Error::SyncError(
483 "unexpected message type in pull response".to_string(),
484 ));
485 }
486 };
487
488 let response: PullResponse = rmp_serde::from_slice(&response_envelope.payload)
489 .map_err(|e| Error::SyncError(e.to_string()))?;
490 (
491 ChangeSet::from(response.changeset),
492 response.has_more,
493 response.cursor,
494 )
495 };
496
497 let server_lsn = [
499 changes.rows.last().map(|r| r.lsn),
500 changes.edges.last().map(|e| e.lsn),
501 changes.vectors.last().map(|v| v.lsn),
502 ]
503 .into_iter()
504 .flatten()
505 .max()
506 .unwrap_or(since_lsn);
507
508 let filtered = changes
509 .filter_by_direction(&directions, &[SyncDirection::Pull, SyncDirection::Both]);
510 let result = self
511 .db
512 .apply_changes(filtered, &remap_pull_policies(policies))?;
513 total.applied_rows += result.applied_rows;
514 total.skipped_rows += result.skipped_rows;
515 total.conflicts.extend(result.conflicts);
516 total.new_lsn = result.new_lsn;
517 last_server_lsn = server_lsn;
518
519 if !has_more {
520 break;
521 }
522 since_lsn = cursor.unwrap_or(since_lsn);
523 }
524
525 self.pull_watermark.store(last_server_lsn, Ordering::SeqCst);
526 self.db
527 .persist_sync_pull_watermark(&self.tenant_id, last_server_lsn)
528 .map_err(|err| Error::SyncError(err.to_string()))?;
529 Ok(total)
530 }
531
532 pub async fn pull_default(&self) -> Result<ApplyResult, Error> {
534 let policies = self.conflict_policies()?;
535 self.pull(&policies).await
536 }
537
538 pub async fn initial_sync(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
540 self.pull(policies).await
541 }
542
543 pub fn push_watermark(&self) -> u64 {
544 self.push_watermark.load(Ordering::SeqCst)
545 }
546
547 pub fn pull_watermark(&self) -> u64 {
548 self.pull_watermark.load(Ordering::SeqCst)
549 }
550
551 pub fn tenant_id(&self) -> &str {
552 &self.tenant_id
553 }
554
555 pub fn nats_url(&self) -> &str {
556 &self.nats_url
557 }
558
559 pub fn set_table_direction(&self, table: &str, direction: SyncDirection) {
560 match self.table_directions.write() {
561 Ok(mut directions) => {
562 directions.insert(table.to_string(), direction);
563 }
564 Err(_) => tracing::warn!("sync table_directions lock poisoned; ignoring update"),
565 }
566 }
567
568 pub fn set_conflict_policy(&self, table: &str, policy: ConflictPolicy) {
569 match self.conflict_policies.write() {
570 Ok(mut policies) => {
571 policies.per_table.insert(table.to_string(), policy);
572 }
573 Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
574 }
575 }
576
577 pub fn set_default_conflict_policy(&self, policy: ConflictPolicy) {
578 match self.conflict_policies.write() {
579 Ok(mut policies) => {
580 policies.default = policy;
581 }
582 Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
583 }
584 }
585
586 fn table_directions(&self) -> Result<HashMap<String, SyncDirection>, Error> {
587 self.table_directions
588 .read()
589 .map(|directions| directions.clone())
590 .map_err(|_| Error::SyncError("sync table directions lock poisoned".to_string()))
591 }
592
593 fn conflict_policies(&self) -> Result<ConflictPolicies, Error> {
594 self.conflict_policies
595 .read()
596 .map(|policies| policies.clone())
597 .map_err(|_| Error::SyncError("sync conflict policies lock poisoned".to_string()))
598 }
599}
600
601pub(crate) fn split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
602 let wire = WireChangeSet::from(changeset.clone());
603 let estimated = rmp_serde::to_vec(&wire).map(|v| v.len()).unwrap_or(0);
604 if estimated <= MAX_BATCH_BYTES {
605 return vec![changeset];
606 }
607
608 let batches = fast_split_changeset(changeset.clone());
609 if batches
610 .iter()
611 .all(|batch| batch_wire_size(batch) <= MAX_BATCH_BYTES)
612 {
613 return batches;
614 }
615
616 precise_split_changeset(changeset)
617}
618
619fn batch_wire_size(changeset: &ChangeSet) -> usize {
620 rmp_serde::to_vec(&WireChangeSet::from(changeset.clone()))
621 .map(|v| v.len())
622 .unwrap_or(usize::MAX)
623}
624
625fn fast_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
626 let row_sizes: Vec<usize> = changeset
627 .rows
628 .iter()
629 .map(|r| {
630 let wire_row = WireRowChange::from(r.clone());
631 rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
632 })
633 .collect();
634 let vector_sizes: Vec<usize> = changeset
635 .vectors
636 .iter()
637 .map(|v| {
638 let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
639 rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
640 })
641 .collect();
642
643 let mut batches = Vec::new();
644 let mut batch_rows = Vec::new();
645 let mut batch_vectors = Vec::new();
646 let mut batch_size = 0usize;
647 let changeset_edges = changeset.edges;
648 let changeset_vectors = changeset.vectors;
649 let changeset_ddl = changeset.ddl;
650
651 let edges_size: usize = {
652 let edges_wire: Vec<crate::protocol::WireEdgeChange> =
653 changeset_edges.iter().cloned().map(Into::into).collect();
654 rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
655 };
656 let ddl_size: usize = {
657 let ddl_wire: Vec<crate::protocol::WireDdlChange> =
658 changeset_ddl.iter().cloned().map(Into::into).collect();
659 rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
660 };
661 let first_batch_overhead = edges_size + ddl_size;
662
663 for (i, row) in changeset.rows.into_iter().enumerate() {
664 let row_size = row_sizes.get(i).copied().unwrap_or(128);
665 let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
666 let item_size = row_size + vec_size_for_i;
667 let first_item_overhead = if batch_rows.is_empty() && batches.is_empty() {
668 first_batch_overhead
669 } else {
670 0
671 };
672
673 if !batch_rows.is_empty() && batch_size + item_size > TARGET_BATCH_BYTES {
674 batches.push(ChangeSet {
675 rows: std::mem::take(&mut batch_rows),
676 edges: if batches.is_empty() {
677 changeset_edges.clone()
678 } else {
679 Vec::new()
680 },
681 vectors: std::mem::take(&mut batch_vectors),
682 ddl: if batches.is_empty() {
683 changeset_ddl.clone()
684 } else {
685 Vec::new()
686 },
687 });
688 batch_size = 0;
689 }
690
691 if i < changeset_vectors.len() {
692 batch_vectors.push(changeset_vectors[i].clone());
693 batch_size += vec_size_for_i;
694 }
695 batch_rows.push(row);
696 batch_size += row_size + first_item_overhead;
697 }
698
699 if !batch_rows.is_empty() {
700 batches.push(ChangeSet {
701 rows: batch_rows,
702 edges: if batches.is_empty() {
703 changeset_edges
704 } else {
705 Vec::new()
706 },
707 vectors: batch_vectors,
708 ddl: if batches.is_empty() {
709 changeset_ddl
710 } else {
711 Vec::new()
712 },
713 });
714 } else if batches.is_empty() {
715 batches.push(ChangeSet {
716 rows: Vec::new(),
717 edges: changeset_edges,
718 vectors: Vec::new(),
719 ddl: changeset_ddl,
720 });
721 }
722
723 batches
724}
725
726fn precise_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
727 let row_sizes: Vec<usize> = changeset
729 .rows
730 .iter()
731 .map(|r| {
732 let wire_row = WireRowChange::from(r.clone());
733 rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
734 })
735 .collect();
736 let vector_sizes: Vec<usize> = changeset
737 .vectors
738 .iter()
739 .map(|v| {
740 let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
741 rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
742 })
743 .collect();
744
745 let mut batches = Vec::new();
746 let mut batch_rows = Vec::new();
747 let mut batch_vectors = Vec::new();
748 let mut batch_size = 0usize;
749 let changeset_edges = changeset.edges;
751 let changeset_vectors = changeset.vectors;
752 let changeset_ddl = changeset.ddl;
753
754 let edges_size: usize = {
756 let edges_wire: Vec<crate::protocol::WireEdgeChange> =
757 changeset_edges.iter().cloned().map(Into::into).collect();
758 rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
759 };
760 let ddl_size: usize = {
761 let ddl_wire: Vec<crate::protocol::WireDdlChange> =
762 changeset_ddl.iter().cloned().map(Into::into).collect();
763 rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
764 };
765 let first_batch_overhead = edges_size + ddl_size;
766
767 for (i, row) in changeset.rows.into_iter().enumerate() {
768 let row_size = row_sizes.get(i).copied().unwrap_or(128);
769 let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
770 let overhead = if batches.is_empty() {
771 first_batch_overhead
772 } else {
773 0
774 };
775
776 let should_flush = if batch_rows.is_empty() {
777 false
778 } else {
779 let mut trial_rows = batch_rows.clone();
780 trial_rows.push(row.clone());
781 let mut trial_vectors = batch_vectors.clone();
782 if i < changeset_vectors.len() {
783 trial_vectors.push(changeset_vectors[i].clone());
784 }
785 let trial = ChangeSet {
786 rows: trial_rows.clone(),
787 edges: if batches.is_empty() {
788 changeset_edges.clone()
789 } else {
790 Vec::new()
791 },
792 vectors: trial_vectors,
793 ddl: if batches.is_empty() {
794 changeset_ddl.clone()
795 } else {
796 Vec::new()
797 },
798 };
799 let actual_size = rmp_serde::to_vec(&WireChangeSet::from(trial))
800 .map(|v| v.len())
801 .unwrap_or(usize::MAX);
802 batch_size + row_size + vec_size_for_i + overhead > MAX_BATCH_BYTES
803 || actual_size > MAX_BATCH_BYTES
804 };
805
806 if should_flush {
807 batches.push(ChangeSet {
808 rows: std::mem::take(&mut batch_rows),
809 edges: if batches.is_empty() {
810 changeset_edges.clone()
811 } else {
812 Vec::new()
813 },
814 vectors: std::mem::take(&mut batch_vectors),
815 ddl: if batches.is_empty() {
816 changeset_ddl.clone()
817 } else {
818 Vec::new()
819 },
820 });
821 batch_size = 0;
822 }
823
824 if i < changeset_vectors.len() {
826 batch_vectors.push(changeset_vectors[i].clone());
827 batch_size += vector_sizes.get(i).copied().unwrap_or(64);
828 }
829 batch_rows.push(row);
830 batch_size += row_size;
831 }
832
833 if !batch_rows.is_empty() {
834 batches.push(ChangeSet {
835 rows: batch_rows,
836 edges: if batches.is_empty() {
837 changeset_edges
838 } else {
839 Vec::new()
840 },
841 vectors: batch_vectors,
842 ddl: if batches.is_empty() {
843 changeset_ddl
844 } else {
845 Vec::new()
846 },
847 });
848 } else if batches.is_empty() && (!changeset_edges.is_empty() || !changeset_ddl.is_empty()) {
849 batches.push(ChangeSet {
850 rows: Vec::new(),
851 edges: changeset_edges,
852 vectors: Vec::new(),
853 ddl: changeset_ddl,
854 });
855 }
856
857 batches
858}
859
860fn remap_pull_policies(policies: &ConflictPolicies) -> ConflictPolicies {
861 let remap = |policy: ConflictPolicy| match policy {
862 ConflictPolicy::ServerWins => ConflictPolicy::EdgeWins,
863 ConflictPolicy::EdgeWins => ConflictPolicy::ServerWins,
864 other => other,
865 };
866
867 ConflictPolicies {
868 per_table: policies
869 .per_table
870 .iter()
871 .map(|(table, policy)| (table.clone(), remap(*policy)))
872 .collect(),
873 default: remap(policies.default),
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use contextdb_core::Value;
881 use contextdb_engine::Database;
882 use contextdb_engine::sync_types::{NaturalKey, RowChange, VectorChange};
883 use std::sync::Arc;
884 use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
885 use testcontainers::runners::AsyncRunner;
886 use testcontainers::{ContainerAsync, GenericImage, ImageExt};
887 use uuid::Uuid;
888
889 struct NatsFixture {
890 _container: ContainerAsync<GenericImage>,
891 nats_url: String,
892 }
893
894 async fn start_nats() -> NatsFixture {
895 let nats_conf = format!("{}/tests/nats.conf", env!("CARGO_MANIFEST_DIR"));
896
897 let image = GenericImage::new("nats", "latest")
898 .with_exposed_port(4222.tcp())
899 .with_wait_for(WaitFor::message_on_stderr("Server is ready"));
900
901 let request = image
902 .with_mount(Mount::bind_mount(&nats_conf, "/etc/nats/nats.conf"))
903 .with_cmd(["--js", "--config", "/etc/nats/nats.conf"]);
904
905 let container: ContainerAsync<GenericImage> = request.start().await.unwrap();
906 let nats_port = container.get_host_port_ipv4(4222.tcp()).await.unwrap();
907
908 NatsFixture {
909 _container: container,
910 nats_url: format!("nats://127.0.0.1:{nats_port}"),
911 }
912 }
913
914 #[tokio::test]
915 async fn sync_01_client_push_survives_poisoned_direction_lock() {
916 let nats = start_nats().await;
917 let client = Arc::new(SyncClient::new(
918 Arc::new(Database::open_memory()),
919 &nats.nats_url,
920 "sync-01",
921 ));
922
923 client.ensure_connected().await.expect("connect NATS");
924 let poison_client = client.clone();
925 let _ = std::thread::spawn(move || {
926 let _guard = poison_client.table_directions.write().unwrap();
927 panic!("poison sync_client directions lock");
928 })
929 .join();
930
931 let join = tokio::spawn({
932 let client = client.clone();
933 async move { client.push().await }
934 })
935 .await;
936
937 assert!(
938 matches!(join, Ok(Err(Error::SyncError(_)))),
939 "push should return a sync error instead of panicking on poisoned table_directions, got {join:?}"
940 );
941 }
942
943 #[tokio::test]
944 async fn sync_02_client_pull_default_survives_poisoned_policy_lock() {
945 let nats = start_nats().await;
946 let client = Arc::new(SyncClient::new(
947 Arc::new(Database::open_memory()),
948 &nats.nats_url,
949 "sync-02",
950 ));
951
952 client.ensure_connected().await.expect("connect NATS");
953 let poison_client = client.clone();
954 let _ = std::thread::spawn(move || {
955 let _guard = poison_client.conflict_policies.write().unwrap();
956 panic!("poison sync_client policies lock");
957 })
958 .join();
959
960 let join = tokio::spawn({
961 let client = client.clone();
962 async move { client.pull_default().await }
963 })
964 .await;
965
966 assert!(
967 matches!(join, Ok(Err(Error::SyncError(_)))),
968 "pull_default should return a sync error instead of panicking on poisoned conflict_policies, got {join:?}"
969 );
970 }
971
972 #[test]
974 fn a14_batch_splitting_respects_byte_limits() {
975 let large_text = "x".repeat(100 * 1024); let mut rows = Vec::new();
978 for _ in 0..10 {
979 let id = Uuid::new_v4();
980 let mut values = HashMap::new();
981 values.insert("id".to_string(), Value::Uuid(id));
982 values.insert("data".to_string(), Value::Text(large_text.clone()));
983 rows.push(RowChange {
984 table: "t".to_string(),
985 natural_key: NaturalKey {
986 column: "id".to_string(),
987 value: Value::Uuid(id),
988 },
989 values,
990 deleted: false,
991 lsn: 1,
992 });
993 }
994
995 let changeset = ChangeSet {
996 rows,
997 edges: Vec::new(),
998 vectors: Vec::new(),
999 ddl: vec![contextdb_engine::sync_types::DdlChange::CreateTable {
1000 name: "t".to_string(),
1001 columns: vec![
1002 ("id".to_string(), "UUID".to_string()),
1003 ("data".to_string(), "TEXT".to_string()),
1004 ],
1005 constraints: vec!["PRIMARY KEY (id)".to_string()],
1006 }],
1007 };
1008
1009 let batches = split_changeset(changeset);
1010
1011 assert!(
1013 batches.len() >= 2,
1014 "10 rows of ~100KB each (~1MB total) must split into at least 2 batches, got {}",
1015 batches.len()
1016 );
1017
1018 for (i, batch) in batches.iter().enumerate() {
1020 let wire = WireChangeSet::from(batch.clone());
1021 let size = rmp_serde::to_vec(&wire)
1022 .expect("a14 batch should serialize for byte-size accounting")
1023 .len();
1024 assert!(
1025 size <= 800 * 1024,
1026 "batch {} serialized to {} bytes, exceeds 800KB limit",
1027 i,
1028 size
1029 );
1030 }
1031
1032 assert!(!batches[0].ddl.is_empty(), "DDL must be in first batch");
1034 for batch in &batches[1..] {
1035 assert!(
1036 batch.ddl.is_empty(),
1037 "DDL must NOT be in subsequent batches"
1038 );
1039 assert!(
1040 batch.edges.is_empty(),
1041 "edges must NOT be in subsequent batches"
1042 );
1043 }
1044 }
1045
1046 #[test]
1047 fn a14b_batch_splitting_accounts_for_vector_sizes() {
1048 let mut rows = Vec::new();
1049 let mut vectors = Vec::new();
1050 for _ in 0..200 {
1051 let id = Uuid::new_v4();
1052 let mut values = HashMap::new();
1053 values.insert("id".to_string(), Value::Uuid(id));
1054 values.insert("data".to_string(), Value::Text("x".repeat(3000)));
1055 rows.push(RowChange {
1056 table: "t".to_string(),
1057 natural_key: NaturalKey {
1058 column: "id".to_string(),
1059 value: Value::Uuid(id),
1060 },
1061 values,
1062 deleted: false,
1063 lsn: 1,
1064 });
1065 vectors.push(VectorChange {
1066 row_id: 0,
1067 vector: (0..384).map(|j| j as f32).collect(),
1068 lsn: 1,
1069 });
1070 }
1071 let changeset = ChangeSet {
1072 rows,
1073 edges: Vec::new(),
1074 vectors,
1075 ddl: vec![],
1076 };
1077 let batches = split_changeset(changeset);
1078 assert!(
1079 batches.len() >= 2,
1080 "200 rows with 384-dim vectors must split into 2+ batches with correct accounting, got {}",
1081 batches.len()
1082 );
1083 for (i, batch) in batches.iter().enumerate() {
1084 let wire = WireChangeSet::from(batch.clone());
1085 let size = rmp_serde::to_vec(&wire)
1086 .expect("a14b batch should serialize for byte-size accounting")
1087 .len();
1088 assert!(
1089 size <= 800 * 1024,
1090 "batch {} serialized to {} bytes, exceeds 800KB limit",
1091 i,
1092 size
1093 );
1094 }
1095 }
1096
1097 #[test]
1099 fn a15_split_changeset_single_oversized_row() {
1100 let oversized_text = "x".repeat(600 * 1024);
1101 let id = Uuid::new_v4();
1102 let mut values = HashMap::new();
1103 values.insert("id".to_string(), Value::Uuid(id));
1104 values.insert("data".to_string(), Value::Text(oversized_text));
1105 let row = RowChange {
1106 table: "observations".to_string(),
1107 natural_key: NaturalKey {
1108 column: "id".to_string(),
1109 value: Value::Uuid(id),
1110 },
1111 values,
1112 deleted: false,
1113 lsn: 1,
1114 };
1115 let changeset = ChangeSet {
1116 rows: vec![row],
1117 edges: Vec::new(),
1118 vectors: Vec::new(),
1119 ddl: Vec::new(),
1120 };
1121
1122 let batches = split_changeset(changeset);
1123
1124 assert!(
1125 !batches.is_empty(),
1126 "split_changeset must return at least one batch, got {}",
1127 batches.len()
1128 );
1129 let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1130 assert_eq!(
1131 total_rows, 1,
1132 "the single oversized row must appear in exactly one batch, got {}",
1133 total_rows
1134 );
1135 }
1136
1137 #[test]
1139 fn a16_split_changeset_preserves_row_vector_pairing() {
1140 use contextdb_engine::sync_types::VectorChange;
1141
1142 let mut rows = Vec::new();
1143 let mut vectors = Vec::new();
1144 for i in 0..10usize {
1145 let id = Uuid::new_v4();
1146 let mut values = HashMap::new();
1147 values.insert("id".to_string(), Value::Uuid(id));
1148 values.insert("data".to_string(), Value::Text("x".repeat(100 * 1024)));
1149 rows.push(RowChange {
1150 table: "observations".to_string(),
1151 natural_key: NaturalKey {
1152 column: "id".to_string(),
1153 value: Value::Uuid(id),
1154 },
1155 values,
1156 deleted: false,
1157 lsn: (i + 1) as u64,
1158 });
1159 vectors.push(VectorChange {
1160 row_id: (i + 1) as u64,
1161 vector: vec![i as f32; 3],
1162 lsn: (i + 1) as u64,
1163 });
1164 }
1165 let changeset = ChangeSet {
1166 rows,
1167 edges: Vec::new(),
1168 vectors,
1169 ddl: Vec::new(),
1170 };
1171
1172 let batches = split_changeset(changeset);
1173
1174 assert!(
1175 batches.len() >= 2,
1176 "10 rows * ~100KB each must split into at least 2 batches, got {}",
1177 batches.len()
1178 );
1179 let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1180 let total_vecs: usize = batches.iter().map(|b| b.vectors.len()).sum();
1181 assert_eq!(total_rows, 10, "all 10 rows must be present across batches");
1182 assert_eq!(
1183 total_vecs, 10,
1184 "all 10 vectors must be present across batches"
1185 );
1186 for (i, batch) in batches.iter().enumerate() {
1187 assert_eq!(
1188 batch.rows.len(),
1189 batch.vectors.len(),
1190 "batch {} must have equal row and vector counts: rows={}, vectors={}",
1191 i,
1192 batch.rows.len(),
1193 batch.vectors.len()
1194 );
1195 for j in 0..batch.rows.len() {
1196 assert_eq!(
1197 batch.rows[j].lsn, batch.vectors[j].lsn,
1198 "batch {} position {}: row.lsn={} != vector.lsn={} — pairing is broken",
1199 i, j, batch.rows[j].lsn, batch.vectors[j].lsn
1200 );
1201 }
1202 }
1203 }
1204
1205 #[test]
1207 fn a17_split_changeset_empty_input_returns_one_batch() {
1208 let changeset = ChangeSet {
1209 rows: Vec::new(),
1210 edges: Vec::new(),
1211 vectors: Vec::new(),
1212 ddl: Vec::new(),
1213 };
1214
1215 let batches = split_changeset(changeset);
1216
1217 assert_eq!(
1218 batches.len(),
1219 1,
1220 "empty changeset must produce exactly 1 batch (not 0), got {}",
1221 batches.len()
1222 );
1223 assert!(
1224 batches[0].rows.is_empty(),
1225 "the single batch for an empty input must have no rows"
1226 );
1227 }
1228
1229 #[test]
1231 fn a18_split_changeset_edge_only_not_dropped() {
1232 use contextdb_engine::sync_types::EdgeChange;
1233
1234 let mut edges = Vec::new();
1235 for _ in 0..200 {
1236 edges.push(EdgeChange {
1237 source: Uuid::new_v4(),
1238 target: Uuid::new_v4(),
1239 edge_type: "x".repeat(5_000),
1240 properties: HashMap::new(),
1241 lsn: 1,
1242 });
1243 }
1244 let changeset = ChangeSet {
1245 rows: Vec::new(),
1246 edges,
1247 vectors: Vec::new(),
1248 ddl: Vec::new(),
1249 };
1250
1251 let batches = split_changeset(changeset);
1252
1253 assert!(
1254 !batches.is_empty(),
1255 "edge-only changeset must produce at least 1 batch, got {} — edges silently dropped",
1256 batches.len()
1257 );
1258 let total_edges: usize = batches.iter().map(|b| b.edges.len()).sum();
1259 assert_eq!(
1260 total_edges, 200,
1261 "all 200 edges must be present across batches, got {}",
1262 total_edges
1263 );
1264 }
1265
1266 #[test]
1269 fn a19_split_changeset_ddl_only_not_dropped() {
1270 use contextdb_engine::sync_types::DdlChange;
1271
1272 let mut ddl = Vec::new();
1273 for i in 0..20 {
1274 ddl.push(DdlChange::CreateTable {
1275 name: format!("table_{}", i),
1276 columns: (0..100)
1277 .map(|j| (format!("col_{}_{}", j, "x".repeat(500)), "TEXT".to_string()))
1278 .collect(),
1279 constraints: vec![format!("PRIMARY KEY (col_{})", "x".repeat(500))],
1280 });
1281 }
1282 let changeset = ChangeSet {
1283 rows: Vec::new(),
1284 edges: Vec::new(),
1285 vectors: Vec::new(),
1286 ddl,
1287 };
1288
1289 let batches = split_changeset(changeset);
1290
1291 assert!(
1292 !batches.is_empty(),
1293 "DDL-only changeset must produce at least 1 batch, got {} — DDL silently dropped",
1294 batches.len()
1295 );
1296 let total_ddl: usize = batches.iter().map(|b| b.ddl.len()).sum();
1297 assert_eq!(
1298 total_ddl, 20,
1299 "all 20 DDL entries must be present across batches, got {}",
1300 total_ddl
1301 );
1302 }
1303}