1use crate::error::SyncError;
2use contextdb_core::{Lsn, RowId, Value, VectorIndexRef};
3use contextdb_engine::sync_types::{
4 ApplyResult, ChangeSet, Conflict, DdlChange, EdgeChange, NaturalKey, RowChange, VectorChange,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9pub const PROTOCOL_VERSION: u8 = 2;
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
12pub struct Envelope {
13 pub version: u8,
14 pub message_type: MessageType,
15 pub payload: Vec<u8>,
16}
17
18impl Envelope {
19 pub fn default_pull_request() -> Self {
22 Self {
23 version: PROTOCOL_VERSION,
24 message_type: MessageType::PullRequest,
25 payload: Vec::new(),
26 }
27 }
28}
29
30#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub enum MessageType {
32 PushRequest,
33 PushResponse,
34 #[default]
35 PullRequest,
36 PullResponse,
37 Chunk,
38 ChunkAck,
39}
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct PushRequest {
43 pub changeset: WireChangeSet,
44}
45
46#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
47pub struct PushResponse {
48 pub result: Option<WireApplyResult>,
49 pub error: Option<String>,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
53pub struct PullRequest {
54 pub since_lsn: Lsn,
55 pub max_entries: Option<u32>,
56}
57
58impl<'de> serde::Deserialize<'de> for PullRequest {
59 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
60 where
61 D: serde::Deserializer<'de>,
62 {
63 use serde::de::{SeqAccess, Visitor};
64
65 struct PullRequestVisitor;
66
67 impl<'de> Visitor<'de> for PullRequestVisitor {
68 type Value = PullRequest;
69
70 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.write_str("PullRequest with 1 or 2 elements")
72 }
73
74 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
75 where
76 A: SeqAccess<'de>,
77 {
78 let since_lsn: Lsn = seq
79 .next_element()?
80 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
81 let max_entries: Option<u32> = seq.next_element()?.unwrap_or(None);
82 Ok(PullRequest {
83 since_lsn,
84 max_entries,
85 })
86 }
87 }
88
89 deserializer.deserialize_tuple(2, PullRequestVisitor)
90 }
91}
92
93#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
94pub struct PullResponse {
95 pub changeset: WireChangeSet,
96 pub has_more: bool,
97 pub cursor: Option<Lsn>,
98}
99
100#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
101pub struct ChunkMessage {
102 pub chunk_id: uuid::Uuid,
103 pub sequence: u32,
104 pub total_chunks: u32,
105 #[serde(with = "serde_bytes")]
106 pub payload: Vec<u8>,
107}
108
109#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
110pub struct ChunkAck {
111 pub chunk_id: uuid::Uuid,
112 pub total_chunks: u32,
113 pub reply_inbox: String,
114}
115
116#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
117pub struct WireChangeSet {
118 pub ddl: Vec<WireDdlChange>,
119 pub rows: Vec<WireRowChange>,
120 pub edges: Vec<WireEdgeChange>,
121 pub vectors: Vec<WireVectorChange>,
122}
123
124#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
125pub struct WireRowChange {
126 pub table: String,
127 pub natural_key: WireNaturalKey,
128 pub values: HashMap<String, Value>,
129 #[serde(default)]
130 pub deleted: bool,
131 pub lsn: Lsn,
132}
133
134#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
135pub struct WireEdgeChange {
136 pub source: uuid::Uuid,
137 pub target: uuid::Uuid,
138 pub edge_type: String,
139 pub properties: HashMap<String, Value>,
140 pub lsn: Lsn,
141}
142
143#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
144pub struct WireVectorChange {
145 pub index: VectorIndexRef,
146 pub row_id: RowId,
147 pub vector: Vec<f32>,
148 pub lsn: Lsn,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152pub enum WireDdlChange {
153 CreateTable {
154 name: String,
155 columns: Vec<(String, String)>,
156 constraints: Vec<String>,
157 },
158 DropTable {
159 name: String,
160 },
161 AlterTable {
162 name: String,
163 columns: Vec<(String, String)>,
164 constraints: Vec<String>,
165 },
166 CreateIndex {
169 table: String,
170 name: String,
171 columns: Vec<(String, String)>,
172 },
173 DropIndex {
174 table: String,
175 name: String,
176 },
177}
178
179#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
180pub struct WireNaturalKey {
181 pub column: String,
182 pub value: Value,
183}
184
185impl Default for WireNaturalKey {
186 fn default() -> Self {
187 Self {
188 column: String::new(),
189 value: Value::Null,
190 }
191 }
192}
193
194#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
195pub struct WireApplyResult {
196 pub applied_rows: usize,
197 pub skipped_rows: usize,
198 pub conflicts: Vec<WireConflict>,
199 pub new_lsn: Lsn,
200}
201
202#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
203pub struct WireConflict {
204 pub natural_key: WireNaturalKey,
205 pub resolution: String,
206 pub reason: Option<String>,
207}
208
209pub fn encode<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>, SyncError> {
210 let payload = rmp_serde::to_vec(msg).map_err(|e| SyncError::Serde(e.to_string()))?;
211 let envelope = Envelope {
212 version: PROTOCOL_VERSION,
213 message_type: msg_type,
214 payload,
215 };
216 rmp_serde::to_vec(&envelope).map_err(|e| SyncError::Serde(e.to_string()))
217}
218
219pub fn decode(data: &[u8]) -> Result<Envelope, SyncError> {
220 let envelope: Envelope =
221 rmp_serde::from_slice(data).map_err(|e| SyncError::Serde(e.to_string()))?;
222 if envelope.version != PROTOCOL_VERSION {
223 return Err(SyncError::ProtocolVersionMismatch {
224 received: envelope.version,
225 supported: PROTOCOL_VERSION,
226 });
227 }
228 Ok(envelope)
229}
230
231impl From<ChangeSet> for WireChangeSet {
232 fn from(value: ChangeSet) -> Self {
233 Self {
234 ddl: value.ddl.into_iter().map(Into::into).collect(),
235 rows: value.rows.into_iter().map(Into::into).collect(),
236 edges: value.edges.into_iter().map(Into::into).collect(),
237 vectors: value.vectors.into_iter().map(Into::into).collect(),
238 }
239 }
240}
241
242impl From<WireChangeSet> for ChangeSet {
243 fn from(value: WireChangeSet) -> Self {
244 Self {
245 ddl: value.ddl.into_iter().map(Into::into).collect(),
246 rows: value.rows.into_iter().map(Into::into).collect(),
247 edges: value.edges.into_iter().map(Into::into).collect(),
248 vectors: value.vectors.into_iter().map(Into::into).collect(),
249 }
250 }
251}
252
253impl From<RowChange> for WireRowChange {
254 fn from(value: RowChange) -> Self {
255 Self {
256 table: value.table,
257 natural_key: value.natural_key.into(),
258 values: value.values,
259 deleted: value.deleted,
260 lsn: value.lsn,
261 }
262 }
263}
264
265impl From<WireRowChange> for RowChange {
266 fn from(value: WireRowChange) -> Self {
267 Self {
268 table: value.table,
269 natural_key: value.natural_key.into(),
270 values: value.values,
271 deleted: value.deleted,
272 lsn: value.lsn,
273 }
274 }
275}
276
277impl From<EdgeChange> for WireEdgeChange {
278 fn from(value: EdgeChange) -> Self {
279 Self {
280 source: value.source,
281 target: value.target,
282 edge_type: value.edge_type,
283 properties: value.properties,
284 lsn: value.lsn,
285 }
286 }
287}
288
289impl From<WireEdgeChange> for EdgeChange {
290 fn from(value: WireEdgeChange) -> Self {
291 Self {
292 source: value.source,
293 target: value.target,
294 edge_type: value.edge_type,
295 properties: value.properties,
296 lsn: value.lsn,
297 }
298 }
299}
300
301impl From<VectorChange> for WireVectorChange {
302 fn from(value: VectorChange) -> Self {
303 Self {
304 index: value.index,
305 row_id: value.row_id,
306 vector: value.vector,
307 lsn: value.lsn,
308 }
309 }
310}
311
312impl From<WireVectorChange> for VectorChange {
313 fn from(value: WireVectorChange) -> Self {
314 Self {
315 index: value.index,
316 row_id: value.row_id,
317 vector: value.vector,
318 lsn: value.lsn,
319 }
320 }
321}
322
323impl From<DdlChange> for WireDdlChange {
324 fn from(value: DdlChange) -> Self {
325 match value {
326 DdlChange::CreateTable {
327 name,
328 columns,
329 constraints,
330 } => Self::CreateTable {
331 name,
332 columns,
333 constraints,
334 },
335 DdlChange::DropTable { name } => Self::DropTable { name },
336 DdlChange::AlterTable {
337 name,
338 columns,
339 constraints,
340 } => Self::AlterTable {
341 name,
342 columns,
343 constraints,
344 },
345 DdlChange::CreateIndex {
346 table,
347 name,
348 columns,
349 } => {
350 let wire_cols = columns
351 .into_iter()
352 .map(|(c, dir)| {
353 let dir_str = match dir {
354 contextdb_core::SortDirection::Asc => "ASC".to_string(),
355 contextdb_core::SortDirection::Desc => "DESC".to_string(),
356 };
357 (c, dir_str)
358 })
359 .collect();
360 Self::CreateIndex {
361 table,
362 name,
363 columns: wire_cols,
364 }
365 }
366 DdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
367 }
368 }
369}
370
371impl From<WireDdlChange> for DdlChange {
372 fn from(value: WireDdlChange) -> Self {
373 match value {
374 WireDdlChange::CreateTable {
375 name,
376 columns,
377 constraints,
378 } => Self::CreateTable {
379 name,
380 columns,
381 constraints,
382 },
383 WireDdlChange::DropTable { name } => Self::DropTable { name },
384 WireDdlChange::AlterTable {
385 name,
386 columns,
387 constraints,
388 } => Self::AlterTable {
389 name,
390 columns,
391 constraints,
392 },
393 WireDdlChange::CreateIndex {
394 table,
395 name,
396 columns,
397 } => {
398 let engine_cols = columns
399 .into_iter()
400 .map(|(c, dir_str)| {
401 let dir = if dir_str.eq_ignore_ascii_case("DESC") {
402 contextdb_core::SortDirection::Desc
403 } else {
404 contextdb_core::SortDirection::Asc
405 };
406 (c, dir)
407 })
408 .collect();
409 Self::CreateIndex {
410 table,
411 name,
412 columns: engine_cols,
413 }
414 }
415 WireDdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
416 }
417 }
418}
419
420impl From<NaturalKey> for WireNaturalKey {
421 fn from(value: NaturalKey) -> Self {
422 Self {
423 column: value.column,
424 value: value.value,
425 }
426 }
427}
428
429impl From<WireNaturalKey> for NaturalKey {
430 fn from(value: WireNaturalKey) -> Self {
431 Self {
432 column: value.column,
433 value: value.value,
434 }
435 }
436}
437
438impl From<ApplyResult> for WireApplyResult {
439 fn from(value: ApplyResult) -> Self {
440 Self {
441 applied_rows: value.applied_rows,
442 skipped_rows: value.skipped_rows,
443 conflicts: value.conflicts.into_iter().map(Into::into).collect(),
444 new_lsn: value.new_lsn,
445 }
446 }
447}
448
449impl From<WireApplyResult> for ApplyResult {
450 fn from(value: WireApplyResult) -> Self {
451 Self {
452 applied_rows: value.applied_rows,
453 skipped_rows: value.skipped_rows,
454 conflicts: value.conflicts.into_iter().map(Into::into).collect(),
455 new_lsn: value.new_lsn,
456 }
457 }
458}
459
460impl From<Conflict> for WireConflict {
461 fn from(value: Conflict) -> Self {
462 Self {
463 natural_key: value.natural_key.into(),
464 resolution: format!("{:?}", value.resolution),
465 reason: value.reason,
466 }
467 }
468}
469
470impl From<WireConflict> for Conflict {
471 fn from(value: WireConflict) -> Self {
472 let resolution = match value.resolution.as_str() {
473 "InsertIfNotExists" => contextdb_engine::sync_types::ConflictPolicy::InsertIfNotExists,
474 "ServerWins" => contextdb_engine::sync_types::ConflictPolicy::ServerWins,
475 "EdgeWins" => contextdb_engine::sync_types::ConflictPolicy::EdgeWins,
476 _ => contextdb_engine::sync_types::ConflictPolicy::LatestWins,
477 };
478 Self {
479 natural_key: value.natural_key.into(),
480 resolution,
481 reason: value.reason,
482 }
483 }
484}