1use crate::protocol::{
2 MessageType, PullRequest, PullResponse, PushRequest, PushResponse, decode, encode,
3};
4use crate::subjects::{pull_subject, push_subject};
5use contextdb_engine::Database;
6use contextdb_engine::sync_types::ConflictPolicies;
7use futures_util::StreamExt;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12const MAX_CHUNK_SESSIONS: usize = 64;
14
15pub struct SyncServer {
16 db: Arc<Database>,
17 nats_url: String,
18 tenant_id: String,
19 policies: ConflictPolicies,
20 chunk_buffer:
21 std::sync::Mutex<HashMap<uuid::Uuid, (Instant, Vec<crate::protocol::ChunkMessage>)>>,
22}
23
24impl SyncServer {
25 pub fn new(
26 db: Arc<Database>,
27 nats_url: &str,
28 tenant_id: &str,
29 policies: ConflictPolicies,
30 ) -> Self {
31 assert!(
32 !tenant_id.is_empty()
33 && tenant_id
34 .chars()
35 .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
36 "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
37 );
38 Self {
39 db,
40 nats_url: nats_url.to_string(),
41 tenant_id: tenant_id.to_string(),
42 policies,
43 chunk_buffer: std::sync::Mutex::new(HashMap::new()),
44 }
45 }
46
47 pub fn db(&self) -> &Database {
48 &self.db
49 }
50
51 pub async fn run(&self) {
52 let client = loop {
53 match async_nats::connect(&self.nats_url).await {
54 Ok(client) => break client,
55 Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
56 }
57 };
58
59 let (mut push_sub, mut pull_sub) = loop {
60 match self.bootstrap_subscriptions(&client).await {
61 Ok(subscriptions) => break subscriptions,
62 Err(err) => {
63 tracing::error!(
64 error = %err,
65 tenant_id = %self.tenant_id,
66 "sync server bootstrap failed; retrying"
67 );
68 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
69 }
70 }
71 };
72 let mut cleanup_interval = tokio::time::interval(std::time::Duration::from_secs(10));
73
74 loop {
75 tokio::select! {
76 maybe_msg = push_sub.next() => {
77 if let Some(msg) = maybe_msg
78 && let Err(e) = self.handle_push(&client, msg).await {
79 tracing::error!(error = %e, "handle_push failed");
80 }
81 }
82 maybe_msg = pull_sub.next() => {
83 if let Some(msg) = maybe_msg
84 && let Err(e) = self.handle_pull(&client, msg).await {
85 tracing::error!(error = %e, "handle_pull failed");
86 }
87 }
88 _ = cleanup_interval.tick() => {
89 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
90 let before = buf.len();
91 buf.retain(|id, (ts, _chunks)| {
92 let keep = ts.elapsed() < std::time::Duration::from_secs(30);
93 if !keep {
94 tracing::warn!(%id, "evicting stale chunk session (30s TTL expired)");
95 }
96 keep
97 });
98 if buf.len() < before {
99 tracing::info!(evicted = before - buf.len(), remaining = buf.len(), "chunk buffer cleanup");
100 }
101 }
102 }
103 }
104 }
105
106 async fn bootstrap_subscriptions(
107 &self,
108 client: &async_nats::Client,
109 ) -> contextdb_core::Result<(async_nats::Subscriber, async_nats::Subscriber)> {
110 let push_sub = client
111 .subscribe(push_subject(&self.tenant_id))
112 .await
113 .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe push: {e}")))?;
114 let pull_sub = client
115 .subscribe(pull_subject(&self.tenant_id))
116 .await
117 .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe pull: {e}")))?;
118 client
119 .flush()
120 .await
121 .map_err(|e| contextdb_core::Error::SyncError(format!("flush subscriptions: {e}")))?;
122 Ok((push_sub, pull_sub))
123 }
124
125 async fn handle_push(
126 &self,
127 client: &async_nats::Client,
128 msg: async_nats::Message,
129 ) -> contextdb_core::Result<()> {
130 let envelope =
131 decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
132
133 match envelope.message_type {
134 MessageType::Chunk => {
135 let chunk_msg: crate::protocol::ChunkMessage =
136 rmp_serde::from_slice(&envelope.payload)
137 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
138 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
139 if !buf.contains_key(&chunk_msg.chunk_id) && buf.len() >= MAX_CHUNK_SESSIONS {
140 tracing::warn!(
141 chunk_id = %chunk_msg.chunk_id,
142 active_sessions = buf.len(),
143 "chunk buffer full, rejecting new chunk session"
144 );
145 return Err(contextdb_core::Error::SyncError(
146 "chunk buffer full".to_string(),
147 ));
148 }
149 buf.entry(chunk_msg.chunk_id)
150 .or_insert_with(|| (std::time::Instant::now(), Vec::new()))
151 .1
152 .push(chunk_msg);
153 Ok(())
154 }
155 MessageType::ChunkAck => {
156 let ack: crate::protocol::ChunkAck = rmp_serde::from_slice(&envelope.payload)
157 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
158
159 tracing::info!(
160 chunk_id = %ack.chunk_id,
161 total_chunks = ack.total_chunks,
162 "received ChunkAck, attempting reassembly"
163 );
164
165 let process_result: contextdb_core::Result<Vec<u8>> = (|| {
166 let mut chunks = {
167 let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
168 let (_ts, chunks) = buf.remove(&ack.chunk_id).ok_or_else(|| {
169 contextdb_core::Error::SyncError(format!(
170 "no chunks buffered for chunk_id {}",
171 ack.chunk_id
172 ))
173 })?;
174 chunks
175 };
176
177 if chunks.len() != ack.total_chunks as usize {
178 return Err(contextdb_core::Error::SyncError(format!(
179 "expected {} chunks for {}, got {}",
180 ack.total_chunks,
181 ack.chunk_id,
182 chunks.len()
183 )));
184 }
185
186 let assembled = crate::chunking::reassemble(&mut chunks);
187 let inner_envelope = decode(&assembled)
188 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
189 let request: PushRequest = rmp_serde::from_slice(&inner_envelope.payload)
190 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
191
192 match self
193 .db
194 .apply_changes(request.changeset.into(), &self.policies)
195 {
196 Ok(result) => {
197 let response = PushResponse {
198 result: Some(result.into()),
199 error: None,
200 };
201 encode(MessageType::PushResponse, &response)
202 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
203 }
204 Err(err) => {
205 let response = PushResponse {
206 result: None,
207 error: Some(err.to_string()),
208 };
209 encode(MessageType::PushResponse, &response)
210 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
211 }
212 }
213 })();
214
215 match process_result {
216 Ok(payload) => {
217 client
218 .publish(ack.reply_inbox, payload.into())
219 .await
220 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
221 Ok(())
222 }
223 Err(e) => {
224 tracing::error!(
225 chunk_id = %ack.chunk_id,
226 error = %e,
227 "chunked push processing failed, client will timeout"
228 );
229 Err(e)
230 }
231 }
232 }
233 MessageType::PushRequest => {
234 let request: PushRequest = rmp_serde::from_slice(&envelope.payload)
235 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
236 let response = match self
237 .db
238 .apply_changes(request.changeset.into(), &self.policies)
239 {
240 Ok(result) => PushResponse {
241 result: Some(result.into()),
242 error: None,
243 },
244 Err(err) => PushResponse {
245 result: None,
246 error: Some(err.to_string()),
247 },
248 };
249 let payload = encode(MessageType::PushResponse, &response)
250 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
251
252 if let Some(reply) = msg.reply {
253 client
254 .publish(reply, payload.into())
255 .await
256 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
257 }
258 Ok(())
259 }
260 _ => Err(contextdb_core::Error::SyncError(
261 "unexpected message type on push subject".to_string(),
262 )),
263 }
264 }
265
266 async fn handle_pull(
267 &self,
268 client: &async_nats::Client,
269 msg: async_nats::Message,
270 ) -> contextdb_core::Result<()> {
271 let envelope =
272 decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
273 if !matches!(envelope.message_type, MessageType::PullRequest) {
274 return Err(contextdb_core::Error::SyncError(
275 "unexpected message type on pull subject".to_string(),
276 ));
277 }
278
279 let request: PullRequest = rmp_serde::from_slice(&envelope.payload)
280 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
281 let mut changes = self.db.changes_since(request.since_lsn);
282
283 let mut has_more = false;
284 let mut cursor = None;
285 if let Some(max_entries) = request.max_entries {
286 let max = max_entries as usize;
287 if changes.rows.len() > max {
288 let mut remainder = changes.rows.split_off(max);
289 if let (Some(last_returned), Some(first_remainder)) =
292 (changes.rows.last(), remainder.first())
293 && last_returned.lsn == first_remainder.lsn
294 {
295 let boundary_lsn = last_returned.lsn;
296 let split_idx = remainder.partition_point(|r| r.lsn == boundary_lsn);
297 let moved: Vec<_> = remainder.drain(..split_idx).collect();
298 changes.rows.extend(moved);
299 }
300 cursor = changes.rows.last().map(|r| r.lsn);
302 has_more = !remainder.is_empty();
303 }
304 }
305
306 let response = PullResponse {
307 changeset: changes.into(),
308 has_more,
309 cursor,
310 };
311 let payload = encode(MessageType::PullResponse, &response)
312 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
313
314 if let Some(reply) = msg.reply {
315 if crate::chunking::needs_chunking(&payload) {
316 let chunks = crate::chunking::chunk(&payload);
317 tracing::info!(
318 total_chunks = chunks.len(),
319 payload_size = payload.len(),
320 "sending chunked pull response"
321 );
322 for chunk_msg in &chunks {
323 let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
324 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
325 client
326 .publish(reply.clone(), chunk_encoded.into())
327 .await
328 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
329 }
330 client
331 .flush()
332 .await
333 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
334 } else {
335 client
336 .publish(reply, payload.into())
337 .await
338 .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
339 }
340 }
341 Ok(())
342 }
343}