Skip to main content

contextdb_server/
sync_client.rs

1use crate::protocol::{
2    ChunkAck, MessageType, PullRequest, PullResponse, PushRequest, PushResponse, WireChangeSet,
3    WireRowChange, decode, encode,
4};
5use crate::subjects::{pull_subject, push_subject};
6use contextdb_core::{AtomicLsn, Error, Lsn};
7use contextdb_engine::Database;
8use contextdb_engine::sync_types::{
9    ApplyResult, ChangeSet, ConflictPolicies, ConflictPolicy, SyncDirection,
10};
11use futures_util::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::Ordering;
15use std::time::Duration;
16
17const SYNC_TIMEOUT: Duration = Duration::from_secs(60);
18/// Overall deadline for collecting all chunks in a chunked pull response.
19const CHUNK_COLLECT_TIMEOUT: Duration = Duration::from_secs(60);
20const PUSH_REQUEST_TIMEOUT: Duration = Duration::from_secs(4);
21const PULL_PAGE_SIZE: u32 = 500;
22const MAX_BATCH_BYTES: usize = 800 * 1024;
23const BATCH_ESTIMATE_SAFETY_MARGIN: usize = 32 * 1024;
24const TARGET_BATCH_BYTES: usize = MAX_BATCH_BYTES - BATCH_ESTIMATE_SAFETY_MARGIN;
25
26pub struct SyncClient {
27    db: Arc<Database>,
28    nats: tokio::sync::Mutex<Option<async_nats::Client>>,
29    nats_url: String,
30    tenant_id: String,
31    push_watermark: AtomicLsn,
32    pull_watermark: AtomicLsn,
33    table_directions: std::sync::RwLock<HashMap<String, SyncDirection>>,
34    conflict_policies: std::sync::RwLock<ConflictPolicies>,
35}
36
37impl SyncClient {
38    pub fn new(db: Arc<Database>, nats_url: &str, tenant_id: &str) -> Self {
39        assert!(
40            !tenant_id.is_empty()
41                && tenant_id
42                    .chars()
43                    .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
44            "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
45        );
46        let (push_watermark, pull_watermark) = db
47            .persisted_sync_watermarks(tenant_id)
48            .unwrap_or_else(|err| {
49                tracing::warn!(%tenant_id, error = %err, "failed to load persisted sync watermarks");
50                (Lsn(0), Lsn(0))
51            });
52        Self {
53            db,
54            nats: tokio::sync::Mutex::new(None),
55            nats_url: nats_url.to_string(),
56            tenant_id: tenant_id.to_string(),
57            push_watermark: AtomicLsn::new(push_watermark),
58            pull_watermark: AtomicLsn::new(pull_watermark),
59            table_directions: std::sync::RwLock::new(HashMap::new()),
60            conflict_policies: std::sync::RwLock::new(ConflictPolicies {
61                per_table: HashMap::new(),
62                default: ConflictPolicy::ServerWins,
63            }),
64        }
65    }
66
67    /// Lazily connect to NATS, reuse existing connection.
68    /// Returns cloned Client (cheap — Arc internally) so the mutex is not held during NATS ops.
69    /// Returns Err with connection error message if NATS is unreachable.
70    pub async fn ensure_connected(&self) -> Result<async_nats::Client, String> {
71        let mut guard = self.nats.lock().await;
72        if guard.is_none() {
73            let mut last_err = None;
74            for attempt in 0..10u32 {
75                if attempt > 0 {
76                    tokio::time::sleep(Duration::from_millis(200 * u64::from(attempt))).await;
77                }
78                match async_nats::connect(&self.nats_url).await {
79                    Ok(client) => {
80                        *guard = Some(client);
81                        break;
82                    }
83                    Err(e) => last_err = Some(e.to_string()),
84                }
85            }
86            if guard.is_none() {
87                let err = last_err.unwrap_or_else(|| "unknown error".to_string());
88                return Err(format!(
89                    "cannot connect to NATS at {}: {err}",
90                    self.nats_url
91                ));
92            }
93        }
94        guard
95            .clone()
96            .ok_or_else(|| format!("cannot connect to NATS at {}", self.nats_url))
97    }
98
99    /// Drop existing connection and reconnect.
100    pub async fn reconnect(&self) {
101        let mut guard = self.nats.lock().await;
102        *guard = None;
103        *guard = async_nats::connect(&self.nats_url).await.ok();
104    }
105
106    pub async fn is_connected(&self) -> bool {
107        let guard = self.nats.lock().await;
108        guard.is_some()
109    }
110
111    pub fn db(&self) -> &Database {
112        &self.db
113    }
114
115    pub fn has_pending_push_changes(&self) -> Result<bool, Error> {
116        let since = self.push_watermark.load(Ordering::SeqCst);
117        let directions = self.table_directions()?;
118        let changes = self
119            .db
120            .changes_since(since)
121            .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
122        Ok(!changes.rows.is_empty()
123            || !changes.edges.is_empty()
124            || !changes.vectors.is_empty()
125            || !changes.ddl.is_empty())
126    }
127
128    pub async fn push(&self) -> Result<ApplyResult, Error> {
129        // Verify NATS connectivity early so users get a clear error even for empty pushes.
130        let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
131
132        let since = self.push_watermark.load(Ordering::SeqCst);
133        // Clone directions out of RwLock BEFORE any .await
134        let directions = self.table_directions()?;
135        let changeset = self
136            .db
137            .changes_since(since)
138            .filter_by_direction(&directions, &[SyncDirection::Push, SyncDirection::Both]);
139
140        if changeset.rows.is_empty()
141            && changeset.edges.is_empty()
142            && changeset.vectors.is_empty()
143            && changeset.ddl.is_empty()
144        {
145            return Ok(ApplyResult {
146                applied_rows: 0,
147                skipped_rows: 0,
148                conflicts: Vec::new(),
149                new_lsn: self.db.current_lsn(),
150            });
151        }
152
153        let mut total = ApplyResult {
154            applied_rows: 0,
155            skipped_rows: 0,
156            conflicts: Vec::new(),
157            new_lsn: since,
158        };
159
160        let mut last_successful_lsn = since;
161        for batch in split_changeset(changeset) {
162            let batch_max_lsn = [
163                batch.rows.last().map(|r| r.lsn),
164                batch.edges.last().map(|e| e.lsn),
165                batch.vectors.last().map(|v| v.lsn),
166            ]
167            .into_iter()
168            .flatten()
169            .max()
170            .unwrap_or(since);
171
172            let request = PushRequest {
173                changeset: batch.clone().into(),
174            };
175            let encoded = encode(MessageType::PushRequest, &request)
176                .map_err(|e| Error::SyncError(e.to_string()))?;
177
178            let result: ApplyResult = if crate::chunking::needs_chunking(&encoded) {
179                use crate::chunking::chunk;
180
181                tracing::info!(
182                    payload_size = encoded.len(),
183                    "push payload exceeds chunking threshold, using chunked send"
184                );
185
186                let inbox = nats_client.new_inbox();
187                let mut inbox_sub = nats_client
188                    .subscribe(inbox.clone())
189                    .await
190                    .map_err(|e| Error::SyncError(e.to_string()))?;
191
192                let subject = push_subject(&self.tenant_id);
193                let chunks = chunk(&encoded);
194                let chunk_id = chunks[0].chunk_id;
195                let total_chunks = chunks[0].total_chunks;
196
197                tracing::debug!(
198                    %chunk_id,
199                    total_chunks,
200                    "sending {} chunks for push request",
201                    total_chunks
202                );
203
204                for chunk_msg in &chunks {
205                    let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
206                        .map_err(|e| Error::SyncError(e.to_string()))?;
207                    nats_client
208                        .publish(subject.clone(), chunk_encoded.into())
209                        .await
210                        .map_err(|e| Error::SyncError(e.to_string()))?;
211                }
212
213                let ack = ChunkAck {
214                    chunk_id,
215                    total_chunks,
216                    reply_inbox: inbox.clone(),
217                };
218                let ack_encoded = encode(MessageType::ChunkAck, &ack)
219                    .map_err(|e| Error::SyncError(e.to_string()))?;
220                nats_client
221                    .publish(subject, ack_encoded.into())
222                    .await
223                    .map_err(|e| Error::SyncError(e.to_string()))?;
224                nats_client
225                    .flush()
226                    .await
227                    .map_err(|e| Error::SyncError(e.to_string()))?;
228
229                let msg = tokio::time::timeout(SYNC_TIMEOUT, inbox_sub.next())
230                    .await
231                    .map_err(|_| Error::SyncError("chunked push timed out".to_string()))?
232                    .ok_or_else(|| {
233                        Error::SyncError("inbox closed before push response".to_string())
234                    })?;
235                let envelope = decode(&msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
236                let response: PushResponse = rmp_serde::from_slice(&envelope.payload)
237                    .map_err(|e| Error::SyncError(e.to_string()))?;
238                if let Some(err) = response.error {
239                    return Err(Error::SyncError(err));
240                }
241                response
242                    .result
243                    .ok_or_else(|| Error::SyncError("push response missing result".to_string()))?
244                    .into()
245            } else {
246                let mut push_result = None;
247                for attempt in 0..5u32 {
248                    if attempt > 0 {
249                        tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
250                    }
251                    let inbox = nats_client.new_inbox();
252                    let mut inbox_sub = nats_client
253                        .subscribe(inbox.clone())
254                        .await
255                        .map_err(|e| Error::SyncError(e.to_string()))?;
256
257                    nats_client
258                        .publish_with_reply(
259                            push_subject(&self.tenant_id),
260                            inbox.clone(),
261                            encoded.clone().into(),
262                        )
263                        .await
264                        .map_err(|e| Error::SyncError(e.to_string()))?;
265
266                    match tokio::time::timeout(PUSH_REQUEST_TIMEOUT, inbox_sub.next()).await {
267                        Ok(Some(msg)) => {
268                            if let Some(status) = msg.status {
269                                if status == async_nats::StatusCode::NO_RESPONDERS && attempt < 4 {
270                                    tracing::debug!(attempt, "push got no responders, retrying");
271                                    continue;
272                                }
273                                if attempt < 4 {
274                                    tracing::debug!(
275                                        attempt,
276                                        ?status,
277                                        "push got status reply, retrying"
278                                    );
279                                    continue;
280                                }
281                                return Err(Error::SyncError(format!(
282                                    "push failed with NATS status reply: {status:?}"
283                                )));
284                            }
285
286                            let envelope = match decode(&msg.payload) {
287                                Ok(envelope) => envelope,
288                                Err(err) if attempt < 4 => {
289                                    tracing::debug!(attempt, error = %err, "push got malformed reply envelope, retrying");
290                                    continue;
291                                }
292                                Err(err) => return Err(Error::SyncError(err.to_string())),
293                            };
294                            let response: PushResponse = match rmp_serde::from_slice(
295                                &envelope.payload,
296                            ) {
297                                Ok(response) => response,
298                                Err(err) if attempt < 4 => {
299                                    tracing::debug!(attempt, error = %err, "push got malformed reply payload, retrying");
300                                    continue;
301                                }
302                                Err(err) => return Err(Error::SyncError(err.to_string())),
303                            };
304                            if let Some(err) = response.error {
305                                return Err(Error::SyncError(err));
306                            }
307                            push_result = Some(
308                                response
309                                    .result
310                                    .ok_or_else(|| {
311                                        Error::SyncError("push response missing result".to_string())
312                                    })?
313                                    .into(),
314                            );
315                            break;
316                        }
317                        Ok(None) => {
318                            return Err(Error::SyncError("push inbox closed".to_string()));
319                        }
320                        Err(_) if attempt < 4 => {
321                            tracing::debug!(attempt, "push timed out, retrying");
322                            continue;
323                        }
324                        Err(_) => {
325                            return Err(Error::SyncError(
326                                "NATS request timed out waiting for push response".to_string(),
327                            ));
328                        }
329                    }
330                }
331                push_result.ok_or_else(|| {
332                    Error::SyncError(
333                        "push failed after retries: no response from server".to_string(),
334                    )
335                })?
336            };
337            last_successful_lsn = batch_max_lsn;
338            total.applied_rows += result.applied_rows;
339            total.skipped_rows += result.skipped_rows;
340            total.conflicts.extend(result.conflicts);
341            total.new_lsn = result.new_lsn;
342        }
343
344        self.push_watermark
345            .store(last_successful_lsn, Ordering::SeqCst);
346        self.db
347            .persist_sync_push_watermark(&self.tenant_id, last_successful_lsn)
348            .map_err(|err| Error::SyncError(err.to_string()))?;
349        Ok(total)
350    }
351
352    /// Pull with explicit policies (frozen test contract, library consumers).
353    pub async fn pull(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
354        let nats_client = self.ensure_connected().await.map_err(Error::SyncError)?;
355        let directions = self.table_directions()?;
356
357        let mut since_lsn = self.pull_watermark.load(Ordering::SeqCst);
358        #[allow(unused_assignments)]
359        let mut last_server_lsn = since_lsn;
360        let mut total = ApplyResult {
361            applied_rows: 0,
362            skipped_rows: 0,
363            conflicts: vec![],
364            new_lsn: since_lsn,
365        };
366
367        loop {
368            let request = PullRequest {
369                since_lsn,
370                max_entries: Some(PULL_PAGE_SIZE),
371            };
372
373            let (changes, has_more, cursor) = {
374                let encoded = encode(MessageType::PullRequest, &request)
375                    .map_err(|e| Error::SyncError(e.to_string()))?;
376
377                let mut first_attempt_response = None;
378                for attempt in 0..5u32 {
379                    if attempt > 0 {
380                        tokio::time::sleep(Duration::from_millis(500 * u64::from(attempt))).await;
381                    }
382                    // Use a fresh inbox per attempt so late responses from earlier attempts
383                    // cannot be mistaken for the current pull reply or chunk stream.
384                    let inbox = nats_client.new_inbox();
385                    let mut inbox_sub = nats_client
386                        .subscribe(inbox.clone())
387                        .await
388                        .map_err(|e| Error::SyncError(e.to_string()))?;
389                    let timeout = if attempt < 2 {
390                        Duration::from_secs(2)
391                    } else {
392                        SYNC_TIMEOUT
393                    };
394
395                    nats_client
396                        .publish_with_reply(
397                            pull_subject(&self.tenant_id),
398                            inbox.clone(),
399                            encoded.clone().into(),
400                        )
401                        .await
402                        .map_err(|e| Error::SyncError(e.to_string()))?;
403
404                    match tokio::time::timeout(timeout, inbox_sub.next()).await {
405                        Ok(Some(msg)) => {
406                            first_attempt_response = Some((msg, inbox_sub));
407                            break;
408                        }
409                        Ok(None) => {
410                            return Err(Error::SyncError("pull inbox closed".to_string()));
411                        }
412                        Err(_) if attempt < 4 => {
413                            tracing::debug!(attempt, "pull timed out, retrying");
414                            continue;
415                        }
416                        Err(_) => {}
417                    }
418                }
419
420                let (first_msg, mut inbox_sub) = first_attempt_response.ok_or_else(|| {
421                    Error::SyncError("NATS request timed out waiting for pull response".to_string())
422                })?;
423
424                let first_envelope =
425                    decode(&first_msg.payload).map_err(|e| Error::SyncError(e.to_string()))?;
426
427                let response_envelope = match first_envelope.message_type {
428                    MessageType::PullResponse => first_envelope,
429                    MessageType::Chunk => {
430                        let first_chunk: crate::protocol::ChunkMessage =
431                            rmp_serde::from_slice(&first_envelope.payload)
432                                .map_err(|e| Error::SyncError(e.to_string()))?;
433                        let total = first_chunk.total_chunks;
434                        let mut collected = vec![first_chunk];
435
436                        tracing::debug!(
437                            total_chunks = total,
438                            "pull response is chunked, collecting chunks"
439                        );
440
441                        let deadline = tokio::time::Instant::now() + CHUNK_COLLECT_TIMEOUT;
442
443                        while collected.len() < total as usize {
444                            let remaining = deadline.duration_since(tokio::time::Instant::now());
445                            if remaining.is_zero() {
446                                return Err(Error::SyncError(format!(
447                                    "overall chunk collection deadline exceeded after {}/{} chunks",
448                                    collected.len(),
449                                    total
450                                )));
451                            }
452                            let chunk_msg = tokio::time::timeout_at(deadline, inbox_sub.next())
453                                .await
454                                .map_err(|_| {
455                                    Error::SyncError(format!(
456                                        "timeout collecting pull chunks ({}/{})",
457                                        collected.len(),
458                                        total
459                                    ))
460                                })?
461                                .ok_or_else(|| {
462                                    Error::SyncError("pull chunk stream ended".to_string())
463                                })?;
464                            let env = decode(&chunk_msg.payload)
465                                .map_err(|e| Error::SyncError(e.to_string()))?;
466                            if matches!(env.message_type, MessageType::Chunk) {
467                                let c: crate::protocol::ChunkMessage =
468                                    rmp_serde::from_slice(&env.payload)
469                                        .map_err(|e| Error::SyncError(e.to_string()))?;
470                                collected.push(c);
471                            } else {
472                                return Err(Error::SyncError(format!(
473                                    "unexpected message type {:?} while collecting pull chunks",
474                                    env.message_type
475                                )));
476                            }
477                        }
478                        let reassembled = crate::chunking::reassemble(&mut collected);
479                        decode(&reassembled).map_err(|e| Error::SyncError(e.to_string()))?
480                    }
481                    _ => {
482                        return Err(Error::SyncError(
483                            "unexpected message type in pull response".to_string(),
484                        ));
485                    }
486                };
487
488                let response: PullResponse = rmp_serde::from_slice(&response_envelope.payload)
489                    .map_err(|e| Error::SyncError(e.to_string()))?;
490                (
491                    ChangeSet::from(response.changeset),
492                    response.has_more,
493                    response.cursor,
494                )
495            };
496
497            // Extract server-side max LSN BEFORE filtering/applying
498            let server_lsn = [
499                changes.rows.last().map(|r| r.lsn),
500                changes.edges.last().map(|e| e.lsn),
501                changes.vectors.last().map(|v| v.lsn),
502            ]
503            .into_iter()
504            .flatten()
505            .max()
506            .unwrap_or(since_lsn);
507
508            let filtered = changes
509                .filter_by_direction(&directions, &[SyncDirection::Pull, SyncDirection::Both]);
510            let result = self
511                .db
512                .apply_changes(filtered, &remap_pull_policies(policies))?;
513            total.applied_rows += result.applied_rows;
514            total.skipped_rows += result.skipped_rows;
515            total.conflicts.extend(result.conflicts);
516            total.new_lsn = result.new_lsn;
517            last_server_lsn = server_lsn;
518
519            if !has_more {
520                break;
521            }
522            since_lsn = cursor.unwrap_or(since_lsn);
523        }
524
525        self.pull_watermark.store(last_server_lsn, Ordering::SeqCst);
526        self.db
527            .persist_sync_pull_watermark(&self.tenant_id, last_server_lsn)
528            .map_err(|err| Error::SyncError(err.to_string()))?;
529        Ok(total)
530    }
531
532    /// Pull using internally configured conflict policies (used by CLI).
533    pub async fn pull_default(&self) -> Result<ApplyResult, Error> {
534        let policies = self.conflict_policies()?;
535        self.pull(&policies).await
536    }
537
538    /// Initial sync using explicit policies (frozen test contract).
539    pub async fn initial_sync(&self, policies: &ConflictPolicies) -> Result<ApplyResult, Error> {
540        self.pull(policies).await
541    }
542
543    pub fn push_watermark(&self) -> Lsn {
544        self.push_watermark.load(Ordering::SeqCst)
545    }
546
547    pub fn pull_watermark(&self) -> Lsn {
548        self.pull_watermark.load(Ordering::SeqCst)
549    }
550
551    pub fn tenant_id(&self) -> &str {
552        &self.tenant_id
553    }
554
555    pub fn nats_url(&self) -> &str {
556        &self.nats_url
557    }
558
559    pub fn set_table_direction(&self, table: &str, direction: SyncDirection) {
560        match self.table_directions.write() {
561            Ok(mut directions) => {
562                directions.insert(table.to_string(), direction);
563            }
564            Err(_) => tracing::warn!("sync table_directions lock poisoned; ignoring update"),
565        }
566    }
567
568    pub fn set_conflict_policy(&self, table: &str, policy: ConflictPolicy) {
569        match self.conflict_policies.write() {
570            Ok(mut policies) => {
571                policies.per_table.insert(table.to_string(), policy);
572            }
573            Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
574        }
575    }
576
577    pub fn set_default_conflict_policy(&self, policy: ConflictPolicy) {
578        match self.conflict_policies.write() {
579            Ok(mut policies) => {
580                policies.default = policy;
581            }
582            Err(_) => tracing::warn!("sync conflict_policies lock poisoned; ignoring update"),
583        }
584    }
585
586    fn table_directions(&self) -> Result<HashMap<String, SyncDirection>, Error> {
587        self.table_directions
588            .read()
589            .map(|directions| directions.clone())
590            .map_err(|_| Error::SyncError("sync table directions lock poisoned".to_string()))
591    }
592
593    fn conflict_policies(&self) -> Result<ConflictPolicies, Error> {
594        self.conflict_policies
595            .read()
596            .map(|policies| policies.clone())
597            .map_err(|_| Error::SyncError("sync conflict policies lock poisoned".to_string()))
598    }
599}
600
601pub(crate) fn split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
602    let wire = WireChangeSet::from(changeset.clone());
603    let estimated = rmp_serde::to_vec(&wire).map(|v| v.len()).unwrap_or(0);
604    if estimated <= MAX_BATCH_BYTES {
605        return vec![changeset];
606    }
607
608    let batches = fast_split_changeset(changeset.clone());
609    if batches
610        .iter()
611        .all(|batch| batch_wire_size(batch) <= MAX_BATCH_BYTES)
612    {
613        return batches;
614    }
615
616    precise_split_changeset(changeset)
617}
618
619fn batch_wire_size(changeset: &ChangeSet) -> usize {
620    rmp_serde::to_vec(&WireChangeSet::from(changeset.clone()))
621        .map(|v| v.len())
622        .unwrap_or(usize::MAX)
623}
624
625fn fast_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
626    let row_sizes: Vec<usize> = changeset
627        .rows
628        .iter()
629        .map(|r| {
630            let wire_row = WireRowChange::from(r.clone());
631            rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
632        })
633        .collect();
634    let vector_sizes: Vec<usize> = changeset
635        .vectors
636        .iter()
637        .map(|v| {
638            let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
639            rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
640        })
641        .collect();
642
643    let mut batches = Vec::new();
644    let mut batch_rows = Vec::new();
645    let mut batch_vectors = Vec::new();
646    let mut batch_size = 0usize;
647    let changeset_edges = changeset.edges;
648    let changeset_vectors = changeset.vectors;
649    let changeset_ddl = changeset.ddl;
650
651    let edges_size: usize = {
652        let edges_wire: Vec<crate::protocol::WireEdgeChange> =
653            changeset_edges.iter().cloned().map(Into::into).collect();
654        rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
655    };
656    let ddl_size: usize = {
657        let ddl_wire: Vec<crate::protocol::WireDdlChange> =
658            changeset_ddl.iter().cloned().map(Into::into).collect();
659        rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
660    };
661    let first_batch_overhead = edges_size + ddl_size;
662
663    for (i, row) in changeset.rows.into_iter().enumerate() {
664        let row_size = row_sizes.get(i).copied().unwrap_or(128);
665        let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
666        let item_size = row_size + vec_size_for_i;
667        let first_item_overhead = if batch_rows.is_empty() && batches.is_empty() {
668            first_batch_overhead
669        } else {
670            0
671        };
672
673        if !batch_rows.is_empty() && batch_size + item_size > TARGET_BATCH_BYTES {
674            batches.push(ChangeSet {
675                rows: std::mem::take(&mut batch_rows),
676                edges: if batches.is_empty() {
677                    changeset_edges.clone()
678                } else {
679                    Vec::new()
680                },
681                vectors: std::mem::take(&mut batch_vectors),
682                ddl: if batches.is_empty() {
683                    changeset_ddl.clone()
684                } else {
685                    Vec::new()
686                },
687            });
688            batch_size = 0;
689        }
690
691        if i < changeset_vectors.len() {
692            batch_vectors.push(changeset_vectors[i].clone());
693            batch_size += vec_size_for_i;
694        }
695        batch_rows.push(row);
696        batch_size += row_size + first_item_overhead;
697    }
698
699    if !batch_rows.is_empty() {
700        batches.push(ChangeSet {
701            rows: batch_rows,
702            edges: if batches.is_empty() {
703                changeset_edges
704            } else {
705                Vec::new()
706            },
707            vectors: batch_vectors,
708            ddl: if batches.is_empty() {
709                changeset_ddl
710            } else {
711                Vec::new()
712            },
713        });
714    } else if batches.is_empty() {
715        batches.push(ChangeSet {
716            rows: Vec::new(),
717            edges: changeset_edges,
718            vectors: Vec::new(),
719            ddl: changeset_ddl,
720        });
721    }
722
723    batches
724}
725
726fn precise_split_changeset(changeset: ChangeSet) -> Vec<ChangeSet> {
727    // Estimate per-row sizes by serializing each WireRowChange individually
728    let row_sizes: Vec<usize> = changeset
729        .rows
730        .iter()
731        .map(|r| {
732            let wire_row = WireRowChange::from(r.clone());
733            rmp_serde::to_vec(&wire_row).map(|v| v.len()).unwrap_or(128)
734        })
735        .collect();
736    let vector_sizes: Vec<usize> = changeset
737        .vectors
738        .iter()
739        .map(|v| {
740            let wire_vec = crate::protocol::WireVectorChange::from(v.clone());
741            rmp_serde::to_vec(&wire_vec).map(|v| v.len()).unwrap_or(64)
742        })
743        .collect();
744
745    let mut batches = Vec::new();
746    let mut batch_rows = Vec::new();
747    let mut batch_vectors = Vec::new();
748    let mut batch_size = 0usize;
749    // Extract edges, vectors, and ddl BEFORE consuming rows via into_iter()
750    let changeset_edges = changeset.edges;
751    let changeset_vectors = changeset.vectors;
752    let changeset_ddl = changeset.ddl;
753
754    // Overhead for edges + DDL in first batch
755    let edges_size: usize = {
756        let edges_wire: Vec<crate::protocol::WireEdgeChange> =
757            changeset_edges.iter().cloned().map(Into::into).collect();
758        rmp_serde::to_vec(&edges_wire).map(|v| v.len()).unwrap_or(0)
759    };
760    let ddl_size: usize = {
761        let ddl_wire: Vec<crate::protocol::WireDdlChange> =
762            changeset_ddl.iter().cloned().map(Into::into).collect();
763        rmp_serde::to_vec(&ddl_wire).map(|v| v.len()).unwrap_or(0)
764    };
765    let first_batch_overhead = edges_size + ddl_size;
766
767    for (i, row) in changeset.rows.into_iter().enumerate() {
768        let row_size = row_sizes.get(i).copied().unwrap_or(128);
769        let vec_size_for_i = vector_sizes.get(i).copied().unwrap_or(64);
770        let overhead = if batches.is_empty() {
771            first_batch_overhead
772        } else {
773            0
774        };
775
776        let should_flush = if batch_rows.is_empty() {
777            false
778        } else {
779            let mut trial_rows = batch_rows.clone();
780            trial_rows.push(row.clone());
781            let mut trial_vectors = batch_vectors.clone();
782            if i < changeset_vectors.len() {
783                trial_vectors.push(changeset_vectors[i].clone());
784            }
785            let trial = ChangeSet {
786                rows: trial_rows.clone(),
787                edges: if batches.is_empty() {
788                    changeset_edges.clone()
789                } else {
790                    Vec::new()
791                },
792                vectors: trial_vectors,
793                ddl: if batches.is_empty() {
794                    changeset_ddl.clone()
795                } else {
796                    Vec::new()
797                },
798            };
799            let actual_size = rmp_serde::to_vec(&WireChangeSet::from(trial))
800                .map(|v| v.len())
801                .unwrap_or(usize::MAX);
802            batch_size + row_size + vec_size_for_i + overhead > MAX_BATCH_BYTES
803                || actual_size > MAX_BATCH_BYTES
804        };
805
806        if should_flush {
807            batches.push(ChangeSet {
808                rows: std::mem::take(&mut batch_rows),
809                edges: if batches.is_empty() {
810                    changeset_edges.clone()
811                } else {
812                    Vec::new()
813                },
814                vectors: std::mem::take(&mut batch_vectors),
815                ddl: if batches.is_empty() {
816                    changeset_ddl.clone()
817                } else {
818                    Vec::new()
819                },
820            });
821            batch_size = 0;
822        }
823
824        // Pair with vector at same index if available
825        if i < changeset_vectors.len() {
826            batch_vectors.push(changeset_vectors[i].clone());
827            batch_size += vector_sizes.get(i).copied().unwrap_or(64);
828        }
829        batch_rows.push(row);
830        batch_size += row_size;
831    }
832
833    if !batch_rows.is_empty() {
834        batches.push(ChangeSet {
835            rows: batch_rows,
836            edges: if batches.is_empty() {
837                changeset_edges
838            } else {
839                Vec::new()
840            },
841            vectors: batch_vectors,
842            ddl: if batches.is_empty() {
843                changeset_ddl
844            } else {
845                Vec::new()
846            },
847        });
848    } else if batches.is_empty() && (!changeset_edges.is_empty() || !changeset_ddl.is_empty()) {
849        batches.push(ChangeSet {
850            rows: Vec::new(),
851            edges: changeset_edges,
852            vectors: Vec::new(),
853            ddl: changeset_ddl,
854        });
855    }
856
857    batches
858}
859
860fn remap_pull_policies(policies: &ConflictPolicies) -> ConflictPolicies {
861    let remap = |policy: ConflictPolicy| match policy {
862        ConflictPolicy::ServerWins => ConflictPolicy::EdgeWins,
863        ConflictPolicy::EdgeWins => ConflictPolicy::ServerWins,
864        other => other,
865    };
866
867    ConflictPolicies {
868        per_table: policies
869            .per_table
870            .iter()
871            .map(|(table, policy)| (table.clone(), remap(*policy)))
872            .collect(),
873        default: remap(policies.default),
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use contextdb_core::{RowId, Value};
881    use contextdb_engine::Database;
882    use contextdb_engine::sync_types::{NaturalKey, RowChange, VectorChange};
883    use std::sync::Arc;
884    use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
885    use testcontainers::runners::AsyncRunner;
886    use testcontainers::{ContainerAsync, GenericImage, ImageExt};
887    use uuid::Uuid;
888
889    struct NatsFixture {
890        _container: ContainerAsync<GenericImage>,
891        nats_url: String,
892    }
893
894    async fn start_nats() -> NatsFixture {
895        let nats_conf = format!("{}/tests/nats.conf", env!("CARGO_MANIFEST_DIR"));
896
897        let image = GenericImage::new("nats", "latest")
898            .with_exposed_port(4222.tcp())
899            .with_wait_for(WaitFor::message_on_stderr("Server is ready"));
900
901        let request = image
902            .with_mount(Mount::bind_mount(&nats_conf, "/etc/nats/nats.conf"))
903            .with_cmd(["--js", "--config", "/etc/nats/nats.conf"]);
904
905        let container: ContainerAsync<GenericImage> = request.start().await.unwrap();
906        let nats_port = container.get_host_port_ipv4(4222.tcp()).await.unwrap();
907
908        NatsFixture {
909            _container: container,
910            nats_url: format!("nats://127.0.0.1:{nats_port}"),
911        }
912    }
913
914    #[tokio::test]
915    async fn sync_01_client_push_survives_poisoned_direction_lock() {
916        let nats = start_nats().await;
917        let client = Arc::new(SyncClient::new(
918            Arc::new(Database::open_memory()),
919            &nats.nats_url,
920            "sync-01",
921        ));
922
923        client.ensure_connected().await.expect("connect NATS");
924        let poison_client = client.clone();
925        let _ = std::thread::spawn(move || {
926            let _guard = poison_client.table_directions.write().unwrap();
927            panic!("poison sync_client directions lock");
928        })
929        .join();
930
931        let join = tokio::spawn({
932            let client = client.clone();
933            async move { client.push().await }
934        })
935        .await;
936
937        assert!(
938            matches!(join, Ok(Err(Error::SyncError(_)))),
939            "push should return a sync error instead of panicking on poisoned table_directions, got {join:?}"
940        );
941    }
942
943    #[tokio::test]
944    async fn sync_02_client_pull_default_survives_poisoned_policy_lock() {
945        let nats = start_nats().await;
946        let client = Arc::new(SyncClient::new(
947            Arc::new(Database::open_memory()),
948            &nats.nats_url,
949            "sync-02",
950        ));
951
952        client.ensure_connected().await.expect("connect NATS");
953        let poison_client = client.clone();
954        let _ = std::thread::spawn(move || {
955            let _guard = poison_client.conflict_policies.write().unwrap();
956            panic!("poison sync_client policies lock");
957        })
958        .join();
959
960        let join = tokio::spawn({
961            let client = client.clone();
962            async move { client.pull_default().await }
963        })
964        .await;
965
966        assert!(
967            matches!(join, Ok(Err(Error::SyncError(_)))),
968            "pull_default should return a sync error instead of panicking on poisoned conflict_policies, got {join:?}"
969        );
970    }
971
972    // A14: Batch splitting respects byte size limits
973    #[test]
974    fn a14_batch_splitting_respects_byte_limits() {
975        // Build a changeset with 10 rows, each ~100KB of data (total ~1MB)
976        let large_text = "x".repeat(100 * 1024); // ~100KB per row
977        let mut rows = Vec::new();
978        for _ in 0..10 {
979            let id = Uuid::new_v4();
980            let mut values = HashMap::new();
981            values.insert("id".to_string(), Value::Uuid(id));
982            values.insert("data".to_string(), Value::Text(large_text.clone()));
983            rows.push(RowChange {
984                table: "t".to_string(),
985                natural_key: NaturalKey {
986                    column: "id".to_string(),
987                    value: Value::Uuid(id),
988                },
989                values,
990                deleted: false,
991                lsn: Lsn(1),
992            });
993        }
994
995        let changeset = ChangeSet {
996            rows,
997            edges: Vec::new(),
998            vectors: Vec::new(),
999            ddl: vec![contextdb_engine::sync_types::DdlChange::CreateTable {
1000                name: "t".to_string(),
1001                columns: vec![
1002                    ("id".to_string(), "UUID".to_string()),
1003                    ("data".to_string(), "TEXT".to_string()),
1004                ],
1005                constraints: vec!["PRIMARY KEY (id)".to_string()],
1006            }],
1007        };
1008
1009        let batches = split_changeset(changeset);
1010
1011        // Must split into 2+ batches (10 rows * ~100KB > 800KB)
1012        assert!(
1013            batches.len() >= 2,
1014            "10 rows of ~100KB each (~1MB total) must split into at least 2 batches, got {}",
1015            batches.len()
1016        );
1017
1018        // Each batch's serialized size must be under 800KB
1019        for (i, batch) in batches.iter().enumerate() {
1020            let wire = WireChangeSet::from(batch.clone());
1021            let size = rmp_serde::to_vec(&wire)
1022                .expect("a14 batch should serialize for byte-size accounting")
1023                .len();
1024            assert!(
1025                size <= 800 * 1024,
1026                "batch {} serialized to {} bytes, exceeds 800KB limit",
1027                i,
1028                size
1029            );
1030        }
1031
1032        // DDL only in first batch
1033        assert!(!batches[0].ddl.is_empty(), "DDL must be in first batch");
1034        for batch in &batches[1..] {
1035            assert!(
1036                batch.ddl.is_empty(),
1037                "DDL must NOT be in subsequent batches"
1038            );
1039            assert!(
1040                batch.edges.is_empty(),
1041                "edges must NOT be in subsequent batches"
1042            );
1043        }
1044    }
1045
1046    #[test]
1047    fn a14b_batch_splitting_accounts_for_vector_sizes() {
1048        let mut rows = Vec::new();
1049        let mut vectors = Vec::new();
1050        for _ in 0..200 {
1051            let id = Uuid::new_v4();
1052            let mut values = HashMap::new();
1053            values.insert("id".to_string(), Value::Uuid(id));
1054            values.insert("data".to_string(), Value::Text("x".repeat(3000)));
1055            rows.push(RowChange {
1056                table: "t".to_string(),
1057                natural_key: NaturalKey {
1058                    column: "id".to_string(),
1059                    value: Value::Uuid(id),
1060                },
1061                values,
1062                deleted: false,
1063                lsn: Lsn(1),
1064            });
1065            vectors.push(VectorChange {
1066                row_id: RowId(0),
1067                vector: (0..384).map(|j| j as f32).collect(),
1068                lsn: Lsn(1),
1069            });
1070        }
1071        let changeset = ChangeSet {
1072            rows,
1073            edges: Vec::new(),
1074            vectors,
1075            ddl: vec![],
1076        };
1077        let batches = split_changeset(changeset);
1078        assert!(
1079            batches.len() >= 2,
1080            "200 rows with 384-dim vectors must split into 2+ batches with correct accounting, got {}",
1081            batches.len()
1082        );
1083        for (i, batch) in batches.iter().enumerate() {
1084            let wire = WireChangeSet::from(batch.clone());
1085            let size = rmp_serde::to_vec(&wire)
1086                .expect("a14b batch should serialize for byte-size accounting")
1087                .len();
1088            assert!(
1089                size <= 800 * 1024,
1090                "batch {} serialized to {} bytes, exceeds 800KB limit",
1091                i,
1092                size
1093            );
1094        }
1095    }
1096
1097    // A15: split_changeset handles a single row that alone exceeds MAX_BATCH_BYTES
1098    #[test]
1099    fn a15_split_changeset_single_oversized_row() {
1100        let oversized_text = "x".repeat(600 * 1024);
1101        let id = Uuid::new_v4();
1102        let mut values = HashMap::new();
1103        values.insert("id".to_string(), Value::Uuid(id));
1104        values.insert("data".to_string(), Value::Text(oversized_text));
1105        let row = RowChange {
1106            table: "observations".to_string(),
1107            natural_key: NaturalKey {
1108                column: "id".to_string(),
1109                value: Value::Uuid(id),
1110            },
1111            values,
1112            deleted: false,
1113            lsn: Lsn(1),
1114        };
1115        let changeset = ChangeSet {
1116            rows: vec![row],
1117            edges: Vec::new(),
1118            vectors: Vec::new(),
1119            ddl: Vec::new(),
1120        };
1121
1122        let batches = split_changeset(changeset);
1123
1124        assert!(
1125            !batches.is_empty(),
1126            "split_changeset must return at least one batch, got {}",
1127            batches.len()
1128        );
1129        let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1130        assert_eq!(
1131            total_rows, 1,
1132            "the single oversized row must appear in exactly one batch, got {}",
1133            total_rows
1134        );
1135    }
1136
1137    // A16: split_changeset preserves row/vector pairing across batch boundaries
1138    #[test]
1139    fn a16_split_changeset_preserves_row_vector_pairing() {
1140        use contextdb_engine::sync_types::VectorChange;
1141
1142        let mut rows = Vec::new();
1143        let mut vectors = Vec::new();
1144        for i in 0..10usize {
1145            let id = Uuid::new_v4();
1146            let mut values = HashMap::new();
1147            values.insert("id".to_string(), Value::Uuid(id));
1148            values.insert("data".to_string(), Value::Text("x".repeat(100 * 1024)));
1149            rows.push(RowChange {
1150                table: "observations".to_string(),
1151                natural_key: NaturalKey {
1152                    column: "id".to_string(),
1153                    value: Value::Uuid(id),
1154                },
1155                values,
1156                deleted: false,
1157                lsn: Lsn((i + 1) as u64),
1158            });
1159            vectors.push(VectorChange {
1160                row_id: RowId((i + 1) as u64),
1161                vector: vec![i as f32; 3],
1162                lsn: Lsn((i + 1) as u64),
1163            });
1164        }
1165        let changeset = ChangeSet {
1166            rows,
1167            edges: Vec::new(),
1168            vectors,
1169            ddl: Vec::new(),
1170        };
1171
1172        let batches = split_changeset(changeset);
1173
1174        assert!(
1175            batches.len() >= 2,
1176            "10 rows * ~100KB each must split into at least 2 batches, got {}",
1177            batches.len()
1178        );
1179        let total_rows: usize = batches.iter().map(|b| b.rows.len()).sum();
1180        let total_vecs: usize = batches.iter().map(|b| b.vectors.len()).sum();
1181        assert_eq!(total_rows, 10, "all 10 rows must be present across batches");
1182        assert_eq!(
1183            total_vecs, 10,
1184            "all 10 vectors must be present across batches"
1185        );
1186        for (i, batch) in batches.iter().enumerate() {
1187            assert_eq!(
1188                batch.rows.len(),
1189                batch.vectors.len(),
1190                "batch {} must have equal row and vector counts: rows={}, vectors={}",
1191                i,
1192                batch.rows.len(),
1193                batch.vectors.len()
1194            );
1195            for j in 0..batch.rows.len() {
1196                assert_eq!(
1197                    batch.rows[j].lsn, batch.vectors[j].lsn,
1198                    "batch {} position {}: row.lsn={} != vector.lsn={} — pairing is broken",
1199                    i, j, batch.rows[j].lsn, batch.vectors[j].lsn
1200                );
1201            }
1202        }
1203    }
1204
1205    // A17: split_changeset on empty input returns exactly one empty batch
1206    #[test]
1207    fn a17_split_changeset_empty_input_returns_one_batch() {
1208        let changeset = ChangeSet {
1209            rows: Vec::new(),
1210            edges: Vec::new(),
1211            vectors: Vec::new(),
1212            ddl: Vec::new(),
1213        };
1214
1215        let batches = split_changeset(changeset);
1216
1217        assert_eq!(
1218            batches.len(),
1219            1,
1220            "empty changeset must produce exactly 1 batch (not 0), got {}",
1221            batches.len()
1222        );
1223        assert!(
1224            batches[0].rows.is_empty(),
1225            "the single batch for an empty input must have no rows"
1226        );
1227    }
1228
1229    // A18: split_changeset with edge-only changeset must not return vec![]
1230    #[test]
1231    fn a18_split_changeset_edge_only_not_dropped() {
1232        use contextdb_engine::sync_types::EdgeChange;
1233
1234        let mut edges = Vec::new();
1235        for _ in 0..200 {
1236            edges.push(EdgeChange {
1237                source: Uuid::new_v4(),
1238                target: Uuid::new_v4(),
1239                edge_type: "x".repeat(5_000),
1240                properties: HashMap::new(),
1241                lsn: Lsn(1),
1242            });
1243        }
1244        let changeset = ChangeSet {
1245            rows: Vec::new(),
1246            edges,
1247            vectors: Vec::new(),
1248            ddl: Vec::new(),
1249        };
1250
1251        let batches = split_changeset(changeset);
1252
1253        assert!(
1254            !batches.is_empty(),
1255            "edge-only changeset must produce at least 1 batch, got {} — edges silently dropped",
1256            batches.len()
1257        );
1258        let total_edges: usize = batches.iter().map(|b| b.edges.len()).sum();
1259        assert_eq!(
1260            total_edges, 200,
1261            "all 200 edges must be present across batches, got {}",
1262            total_edges
1263        );
1264    }
1265
1266    // A19: split_changeset with DDL-only changeset must not return vec![]
1267    // Column names are padded to force estimated size > MAX_BATCH_BYTES
1268    #[test]
1269    fn a19_split_changeset_ddl_only_not_dropped() {
1270        use contextdb_engine::sync_types::DdlChange;
1271
1272        let mut ddl = Vec::new();
1273        for i in 0..20 {
1274            ddl.push(DdlChange::CreateTable {
1275                name: format!("table_{}", i),
1276                columns: (0..100)
1277                    .map(|j| (format!("col_{}_{}", j, "x".repeat(500)), "TEXT".to_string()))
1278                    .collect(),
1279                constraints: vec![format!("PRIMARY KEY (col_{})", "x".repeat(500))],
1280            });
1281        }
1282        let changeset = ChangeSet {
1283            rows: Vec::new(),
1284            edges: Vec::new(),
1285            vectors: Vec::new(),
1286            ddl,
1287        };
1288
1289        let batches = split_changeset(changeset);
1290
1291        assert!(
1292            !batches.is_empty(),
1293            "DDL-only changeset must produce at least 1 batch, got {} — DDL silently dropped",
1294            batches.len()
1295        );
1296        let total_ddl: usize = batches.iter().map(|b| b.ddl.len()).sum();
1297        assert_eq!(
1298            total_ddl, 20,
1299            "all 20 DDL entries must be present across batches, got {}",
1300            total_ddl
1301        );
1302    }
1303}