1use crate::error::SyncError;
2use contextdb_core::{Lsn, RowId, Value};
3use contextdb_engine::sync_types::{
4 ApplyResult, ChangeSet, Conflict, DdlChange, EdgeChange, NaturalKey, RowChange, VectorChange,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9const PROTOCOL_VERSION: u8 = 1;
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
12pub struct Envelope {
13 pub version: u8,
14 pub message_type: MessageType,
15 pub payload: Vec<u8>,
16}
17
18impl Envelope {
19 pub fn default_pull_request() -> Self {
22 Self {
23 version: PROTOCOL_VERSION,
24 message_type: MessageType::PullRequest,
25 payload: Vec::new(),
26 }
27 }
28}
29
30#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub enum MessageType {
32 PushRequest,
33 PushResponse,
34 #[default]
35 PullRequest,
36 PullResponse,
37 Chunk,
38 ChunkAck,
39}
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct PushRequest {
43 pub changeset: WireChangeSet,
44}
45
46#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
47pub struct PushResponse {
48 pub result: Option<WireApplyResult>,
49 pub error: Option<String>,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
53pub struct PullRequest {
54 pub since_lsn: Lsn,
55 pub max_entries: Option<u32>,
56}
57
58impl<'de> serde::Deserialize<'de> for PullRequest {
59 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
60 where
61 D: serde::Deserializer<'de>,
62 {
63 use serde::de::{SeqAccess, Visitor};
64
65 struct PullRequestVisitor;
66
67 impl<'de> Visitor<'de> for PullRequestVisitor {
68 type Value = PullRequest;
69
70 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.write_str("PullRequest with 1 or 2 elements")
72 }
73
74 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
75 where
76 A: SeqAccess<'de>,
77 {
78 let since_lsn: Lsn = seq
79 .next_element()?
80 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
81 let max_entries: Option<u32> = seq.next_element()?.unwrap_or(None);
82 Ok(PullRequest {
83 since_lsn,
84 max_entries,
85 })
86 }
87 }
88
89 deserializer.deserialize_tuple(2, PullRequestVisitor)
90 }
91}
92
93#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
94pub struct PullResponse {
95 pub changeset: WireChangeSet,
96 pub has_more: bool,
97 pub cursor: Option<Lsn>,
98}
99
100#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
101pub struct ChunkMessage {
102 pub chunk_id: uuid::Uuid,
103 pub sequence: u32,
104 pub total_chunks: u32,
105 #[serde(with = "serde_bytes")]
106 pub payload: Vec<u8>,
107}
108
109#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
110pub struct ChunkAck {
111 pub chunk_id: uuid::Uuid,
112 pub total_chunks: u32,
113 pub reply_inbox: String,
114}
115
116#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
117pub struct WireChangeSet {
118 pub rows: Vec<WireRowChange>,
119 pub edges: Vec<WireEdgeChange>,
120 pub vectors: Vec<WireVectorChange>,
121 pub ddl: Vec<WireDdlChange>,
122}
123
124#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
125pub struct WireRowChange {
126 pub table: String,
127 pub natural_key: WireNaturalKey,
128 pub values: HashMap<String, Value>,
129 #[serde(default)]
130 pub deleted: bool,
131 pub lsn: Lsn,
132}
133
134#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
135pub struct WireEdgeChange {
136 pub source: uuid::Uuid,
137 pub target: uuid::Uuid,
138 pub edge_type: String,
139 pub properties: HashMap<String, Value>,
140 pub lsn: Lsn,
141}
142
143#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
144pub struct WireVectorChange {
145 pub row_id: RowId,
146 pub vector: Vec<f32>,
147 pub lsn: Lsn,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151pub enum WireDdlChange {
152 CreateTable {
153 name: String,
154 columns: Vec<(String, String)>,
155 constraints: Vec<String>,
156 },
157 DropTable {
158 name: String,
159 },
160 AlterTable {
161 name: String,
162 columns: Vec<(String, String)>,
163 constraints: Vec<String>,
164 },
165 CreateIndex {
168 table: String,
169 name: String,
170 columns: Vec<(String, String)>,
171 },
172 DropIndex {
173 table: String,
174 name: String,
175 },
176}
177
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub struct WireNaturalKey {
180 pub column: String,
181 pub value: Value,
182}
183
184impl Default for WireNaturalKey {
185 fn default() -> Self {
186 Self {
187 column: String::new(),
188 value: Value::Null,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
194pub struct WireApplyResult {
195 pub applied_rows: usize,
196 pub skipped_rows: usize,
197 pub conflicts: Vec<WireConflict>,
198 pub new_lsn: Lsn,
199}
200
201#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
202pub struct WireConflict {
203 pub natural_key: WireNaturalKey,
204 pub resolution: String,
205 pub reason: Option<String>,
206}
207
208pub fn encode<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>, SyncError> {
209 let payload = rmp_serde::to_vec(msg).map_err(|e| SyncError::Serde(e.to_string()))?;
210 let envelope = Envelope {
211 version: PROTOCOL_VERSION,
212 message_type: msg_type,
213 payload,
214 };
215 rmp_serde::to_vec(&envelope).map_err(|e| SyncError::Serde(e.to_string()))
216}
217
218pub fn decode(data: &[u8]) -> Result<Envelope, SyncError> {
219 let envelope: Envelope =
220 rmp_serde::from_slice(data).map_err(|e| SyncError::Serde(e.to_string()))?;
221 if envelope.version > PROTOCOL_VERSION {
222 return Err(SyncError::UnsupportedVersion(envelope.version));
223 }
224 Ok(envelope)
225}
226
227impl From<ChangeSet> for WireChangeSet {
228 fn from(value: ChangeSet) -> Self {
229 Self {
230 rows: value.rows.into_iter().map(Into::into).collect(),
231 edges: value.edges.into_iter().map(Into::into).collect(),
232 vectors: value.vectors.into_iter().map(Into::into).collect(),
233 ddl: value.ddl.into_iter().map(Into::into).collect(),
234 }
235 }
236}
237
238impl From<WireChangeSet> for ChangeSet {
239 fn from(value: WireChangeSet) -> Self {
240 Self {
241 rows: value.rows.into_iter().map(Into::into).collect(),
242 edges: value.edges.into_iter().map(Into::into).collect(),
243 vectors: value.vectors.into_iter().map(Into::into).collect(),
244 ddl: value.ddl.into_iter().map(Into::into).collect(),
245 }
246 }
247}
248
249impl From<RowChange> for WireRowChange {
250 fn from(value: RowChange) -> Self {
251 Self {
252 table: value.table,
253 natural_key: value.natural_key.into(),
254 values: value.values,
255 deleted: value.deleted,
256 lsn: value.lsn,
257 }
258 }
259}
260
261impl From<WireRowChange> for RowChange {
262 fn from(value: WireRowChange) -> Self {
263 Self {
264 table: value.table,
265 natural_key: value.natural_key.into(),
266 values: value.values,
267 deleted: value.deleted,
268 lsn: value.lsn,
269 }
270 }
271}
272
273impl From<EdgeChange> for WireEdgeChange {
274 fn from(value: EdgeChange) -> Self {
275 Self {
276 source: value.source,
277 target: value.target,
278 edge_type: value.edge_type,
279 properties: value.properties,
280 lsn: value.lsn,
281 }
282 }
283}
284
285impl From<WireEdgeChange> for EdgeChange {
286 fn from(value: WireEdgeChange) -> Self {
287 Self {
288 source: value.source,
289 target: value.target,
290 edge_type: value.edge_type,
291 properties: value.properties,
292 lsn: value.lsn,
293 }
294 }
295}
296
297impl From<VectorChange> for WireVectorChange {
298 fn from(value: VectorChange) -> Self {
299 Self {
300 row_id: value.row_id,
301 vector: value.vector,
302 lsn: value.lsn,
303 }
304 }
305}
306
307impl From<WireVectorChange> for VectorChange {
308 fn from(value: WireVectorChange) -> Self {
309 Self {
310 row_id: value.row_id,
311 vector: value.vector,
312 lsn: value.lsn,
313 }
314 }
315}
316
317impl From<DdlChange> for WireDdlChange {
318 fn from(value: DdlChange) -> Self {
319 match value {
320 DdlChange::CreateTable {
321 name,
322 columns,
323 constraints,
324 } => Self::CreateTable {
325 name,
326 columns,
327 constraints,
328 },
329 DdlChange::DropTable { name } => Self::DropTable { name },
330 DdlChange::AlterTable {
331 name,
332 columns,
333 constraints,
334 } => Self::AlterTable {
335 name,
336 columns,
337 constraints,
338 },
339 DdlChange::CreateIndex {
340 table,
341 name,
342 columns,
343 } => {
344 let wire_cols = columns
345 .into_iter()
346 .map(|(c, dir)| {
347 let dir_str = match dir {
348 contextdb_core::SortDirection::Asc => "ASC".to_string(),
349 contextdb_core::SortDirection::Desc => "DESC".to_string(),
350 };
351 (c, dir_str)
352 })
353 .collect();
354 Self::CreateIndex {
355 table,
356 name,
357 columns: wire_cols,
358 }
359 }
360 DdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
361 }
362 }
363}
364
365impl From<WireDdlChange> for DdlChange {
366 fn from(value: WireDdlChange) -> Self {
367 match value {
368 WireDdlChange::CreateTable {
369 name,
370 columns,
371 constraints,
372 } => Self::CreateTable {
373 name,
374 columns,
375 constraints,
376 },
377 WireDdlChange::DropTable { name } => Self::DropTable { name },
378 WireDdlChange::AlterTable {
379 name,
380 columns,
381 constraints,
382 } => Self::AlterTable {
383 name,
384 columns,
385 constraints,
386 },
387 WireDdlChange::CreateIndex {
388 table,
389 name,
390 columns,
391 } => {
392 let engine_cols = columns
393 .into_iter()
394 .map(|(c, dir_str)| {
395 let dir = if dir_str.eq_ignore_ascii_case("DESC") {
396 contextdb_core::SortDirection::Desc
397 } else {
398 contextdb_core::SortDirection::Asc
399 };
400 (c, dir)
401 })
402 .collect();
403 Self::CreateIndex {
404 table,
405 name,
406 columns: engine_cols,
407 }
408 }
409 WireDdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
410 }
411 }
412}
413
414impl From<NaturalKey> for WireNaturalKey {
415 fn from(value: NaturalKey) -> Self {
416 Self {
417 column: value.column,
418 value: value.value,
419 }
420 }
421}
422
423impl From<WireNaturalKey> for NaturalKey {
424 fn from(value: WireNaturalKey) -> Self {
425 Self {
426 column: value.column,
427 value: value.value,
428 }
429 }
430}
431
432impl From<ApplyResult> for WireApplyResult {
433 fn from(value: ApplyResult) -> Self {
434 Self {
435 applied_rows: value.applied_rows,
436 skipped_rows: value.skipped_rows,
437 conflicts: value.conflicts.into_iter().map(Into::into).collect(),
438 new_lsn: value.new_lsn,
439 }
440 }
441}
442
443impl From<WireApplyResult> for ApplyResult {
444 fn from(value: WireApplyResult) -> Self {
445 Self {
446 applied_rows: value.applied_rows,
447 skipped_rows: value.skipped_rows,
448 conflicts: value.conflicts.into_iter().map(Into::into).collect(),
449 new_lsn: value.new_lsn,
450 }
451 }
452}
453
454impl From<Conflict> for WireConflict {
455 fn from(value: Conflict) -> Self {
456 Self {
457 natural_key: value.natural_key.into(),
458 resolution: format!("{:?}", value.resolution),
459 reason: value.reason,
460 }
461 }
462}
463
464impl From<WireConflict> for Conflict {
465 fn from(value: WireConflict) -> Self {
466 let resolution = match value.resolution.as_str() {
467 "InsertIfNotExists" => contextdb_engine::sync_types::ConflictPolicy::InsertIfNotExists,
468 "ServerWins" => contextdb_engine::sync_types::ConflictPolicy::ServerWins,
469 "EdgeWins" => contextdb_engine::sync_types::ConflictPolicy::EdgeWins,
470 _ => contextdb_engine::sync_types::ConflictPolicy::LatestWins,
471 };
472 Self {
473 natural_key: value.natural_key.into(),
474 resolution,
475 reason: value.reason,
476 }
477 }
478}