Skip to main content

contextdb_server/
sync_server.rs

1use crate::protocol::{
2    MessageType, PullRequest, PullResponse, PushRequest, PushResponse, decode, encode,
3};
4use crate::subjects::{pull_subject, push_subject};
5use contextdb_engine::Database;
6use contextdb_engine::sync_types::ConflictPolicies;
7use futures_util::StreamExt;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12/// Max concurrent in-flight chunked push sessions. Rejects new chunk_ids past this limit.
13const MAX_CHUNK_SESSIONS: usize = 64;
14
15pub struct SyncServer {
16    db: Arc<Database>,
17    nats_url: String,
18    tenant_id: String,
19    policies: ConflictPolicies,
20    chunk_buffer:
21        std::sync::Mutex<HashMap<uuid::Uuid, (Instant, Vec<crate::protocol::ChunkMessage>)>>,
22}
23
24impl SyncServer {
25    pub fn new(
26        db: Arc<Database>,
27        nats_url: &str,
28        tenant_id: &str,
29        policies: ConflictPolicies,
30    ) -> Self {
31        assert!(
32            !tenant_id.is_empty()
33                && tenant_id
34                    .chars()
35                    .all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
36            "tenant_id must be non-empty and alphanumeric (hyphens and underscores allowed): {tenant_id}"
37        );
38        Self {
39            db,
40            nats_url: nats_url.to_string(),
41            tenant_id: tenant_id.to_string(),
42            policies,
43            chunk_buffer: std::sync::Mutex::new(HashMap::new()),
44        }
45    }
46
47    pub fn db(&self) -> &Database {
48        &self.db
49    }
50
51    pub async fn run(&self) {
52        let client = loop {
53            match async_nats::connect(&self.nats_url).await {
54                Ok(client) => break client,
55                Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
56            }
57        };
58
59        let (mut push_sub, mut pull_sub) = loop {
60            match self.bootstrap_subscriptions(&client).await {
61                Ok(subscriptions) => break subscriptions,
62                Err(err) => {
63                    tracing::error!(
64                        error = %err,
65                        tenant_id = %self.tenant_id,
66                        "sync server bootstrap failed; retrying"
67                    );
68                    tokio::time::sleep(std::time::Duration::from_millis(200)).await;
69                }
70            }
71        };
72        let mut cleanup_interval = tokio::time::interval(std::time::Duration::from_secs(10));
73
74        loop {
75            tokio::select! {
76                maybe_msg = push_sub.next() => {
77                    if let Some(msg) = maybe_msg
78                        && let Err(e) = self.handle_push(&client, msg).await {
79                        tracing::error!(error = %e, "handle_push failed");
80                    }
81                }
82                maybe_msg = pull_sub.next() => {
83                    if let Some(msg) = maybe_msg
84                        && let Err(e) = self.handle_pull(&client, msg).await {
85                        tracing::error!(error = %e, "handle_pull failed");
86                    }
87                }
88                _ = cleanup_interval.tick() => {
89                    let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
90                    let before = buf.len();
91                    buf.retain(|id, (ts, _chunks)| {
92                        let keep = ts.elapsed() < std::time::Duration::from_secs(30);
93                        if !keep {
94                            tracing::warn!(%id, "evicting stale chunk session (30s TTL expired)");
95                        }
96                        keep
97                    });
98                    if buf.len() < before {
99                        tracing::info!(evicted = before - buf.len(), remaining = buf.len(), "chunk buffer cleanup");
100                    }
101                }
102            }
103        }
104    }
105
106    async fn bootstrap_subscriptions(
107        &self,
108        client: &async_nats::Client,
109    ) -> contextdb_core::Result<(async_nats::Subscriber, async_nats::Subscriber)> {
110        let push_sub = client
111            .subscribe(push_subject(&self.tenant_id))
112            .await
113            .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe push: {e}")))?;
114        let pull_sub = client
115            .subscribe(pull_subject(&self.tenant_id))
116            .await
117            .map_err(|e| contextdb_core::Error::SyncError(format!("subscribe pull: {e}")))?;
118        client
119            .flush()
120            .await
121            .map_err(|e| contextdb_core::Error::SyncError(format!("flush subscriptions: {e}")))?;
122        Ok((push_sub, pull_sub))
123    }
124
125    async fn handle_push(
126        &self,
127        client: &async_nats::Client,
128        msg: async_nats::Message,
129    ) -> contextdb_core::Result<()> {
130        let envelope =
131            decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
132
133        match envelope.message_type {
134            MessageType::Chunk => {
135                let chunk_msg: crate::protocol::ChunkMessage =
136                    rmp_serde::from_slice(&envelope.payload)
137                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
138                let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
139                if !buf.contains_key(&chunk_msg.chunk_id) && buf.len() >= MAX_CHUNK_SESSIONS {
140                    tracing::warn!(
141                        chunk_id = %chunk_msg.chunk_id,
142                        active_sessions = buf.len(),
143                        "chunk buffer full, rejecting new chunk session"
144                    );
145                    return Err(contextdb_core::Error::SyncError(
146                        "chunk buffer full".to_string(),
147                    ));
148                }
149                buf.entry(chunk_msg.chunk_id)
150                    .or_insert_with(|| (std::time::Instant::now(), Vec::new()))
151                    .1
152                    .push(chunk_msg);
153                Ok(())
154            }
155            MessageType::ChunkAck => {
156                let ack: crate::protocol::ChunkAck = rmp_serde::from_slice(&envelope.payload)
157                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
158
159                tracing::info!(
160                    chunk_id = %ack.chunk_id,
161                    total_chunks = ack.total_chunks,
162                    "received ChunkAck, attempting reassembly"
163                );
164
165                let process_result: contextdb_core::Result<Vec<u8>> = (|| {
166                    let mut chunks = {
167                        let mut buf = self.chunk_buffer.lock().unwrap_or_else(|e| e.into_inner());
168                        let (_ts, chunks) = buf.remove(&ack.chunk_id).ok_or_else(|| {
169                            contextdb_core::Error::SyncError(format!(
170                                "no chunks buffered for chunk_id {}",
171                                ack.chunk_id
172                            ))
173                        })?;
174                        chunks
175                    };
176
177                    if chunks.len() != ack.total_chunks as usize {
178                        return Err(contextdb_core::Error::SyncError(format!(
179                            "expected {} chunks for {}, got {}",
180                            ack.total_chunks,
181                            ack.chunk_id,
182                            chunks.len()
183                        )));
184                    }
185
186                    let assembled = crate::chunking::reassemble(&mut chunks);
187                    let inner_envelope = decode(&assembled)
188                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
189                    let request: PushRequest = rmp_serde::from_slice(&inner_envelope.payload)
190                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
191
192                    match self
193                        .db
194                        .apply_changes(request.changeset.into(), &self.policies)
195                    {
196                        Ok(result) => {
197                            let response = PushResponse {
198                                result: Some(result.into()),
199                                error: None,
200                            };
201                            encode(MessageType::PushResponse, &response)
202                                .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
203                        }
204                        Err(err) => {
205                            let response = PushResponse {
206                                result: None,
207                                error: Some(err.to_string()),
208                            };
209                            encode(MessageType::PushResponse, &response)
210                                .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))
211                        }
212                    }
213                })();
214
215                match process_result {
216                    Ok(payload) => {
217                        client
218                            .publish(ack.reply_inbox, payload.into())
219                            .await
220                            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
221                        Ok(())
222                    }
223                    Err(e) => {
224                        tracing::error!(
225                            chunk_id = %ack.chunk_id,
226                            error = %e,
227                            "chunked push processing failed, client will timeout"
228                        );
229                        Err(e)
230                    }
231                }
232            }
233            MessageType::PushRequest => {
234                let request: PushRequest = rmp_serde::from_slice(&envelope.payload)
235                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
236                let response = match self
237                    .db
238                    .apply_changes(request.changeset.into(), &self.policies)
239                {
240                    Ok(result) => PushResponse {
241                        result: Some(result.into()),
242                        error: None,
243                    },
244                    Err(err) => PushResponse {
245                        result: None,
246                        error: Some(err.to_string()),
247                    },
248                };
249                let payload = encode(MessageType::PushResponse, &response)
250                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
251
252                if let Some(reply) = msg.reply {
253                    client
254                        .publish(reply, payload.into())
255                        .await
256                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
257                }
258                Ok(())
259            }
260            _ => Err(contextdb_core::Error::SyncError(
261                "unexpected message type on push subject".to_string(),
262            )),
263        }
264    }
265
266    async fn handle_pull(
267        &self,
268        client: &async_nats::Client,
269        msg: async_nats::Message,
270    ) -> contextdb_core::Result<()> {
271        let envelope =
272            decode(&msg.payload).map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
273        if !matches!(envelope.message_type, MessageType::PullRequest) {
274            return Err(contextdb_core::Error::SyncError(
275                "unexpected message type on pull subject".to_string(),
276            ));
277        }
278
279        let request: PullRequest = rmp_serde::from_slice(&envelope.payload)
280            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
281        let mut changes = self.db.changes_since(request.since_lsn);
282
283        let mut has_more = false;
284        let mut cursor = None;
285        if let Some(max_entries) = request.max_entries {
286            let max = max_entries as usize;
287            if changes.rows.len() > max {
288                let mut remainder = changes.rows.split_off(max);
289                // Don't split in the middle of a same-LSN group: extend returned set
290                // to include all rows sharing the LSN at the split boundary.
291                if let (Some(last_returned), Some(first_remainder)) =
292                    (changes.rows.last(), remainder.first())
293                    && last_returned.lsn == first_remainder.lsn
294                {
295                    let boundary_lsn = last_returned.lsn;
296                    let split_idx = remainder.partition_point(|r| r.lsn == boundary_lsn);
297                    let moved: Vec<_> = remainder.drain(..split_idx).collect();
298                    changes.rows.extend(moved);
299                }
300                // Cursor = last LSN of returned set (not remainder)
301                cursor = changes.rows.last().map(|r| r.lsn);
302                has_more = !remainder.is_empty();
303            }
304        }
305
306        let response = PullResponse {
307            changeset: changes.into(),
308            has_more,
309            cursor,
310        };
311        let payload = encode(MessageType::PullResponse, &response)
312            .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
313
314        if let Some(reply) = msg.reply {
315            if crate::chunking::needs_chunking(&payload) {
316                let chunks = crate::chunking::chunk(&payload);
317                tracing::info!(
318                    total_chunks = chunks.len(),
319                    payload_size = payload.len(),
320                    "sending chunked pull response"
321                );
322                for chunk_msg in &chunks {
323                    let chunk_encoded = encode(MessageType::Chunk, chunk_msg)
324                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
325                    client
326                        .publish(reply.clone(), chunk_encoded.into())
327                        .await
328                        .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
329                }
330                client
331                    .flush()
332                    .await
333                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
334            } else {
335                client
336                    .publish(reply, payload.into())
337                    .await
338                    .map_err(|e| contextdb_core::Error::SyncError(e.to_string()))?;
339            }
340        }
341        Ok(())
342    }
343}