Skip to main content

contextdb_server/
protocol.rs

1use crate::error::SyncError;
2use contextdb_core::{Lsn, RowId, Value};
3use contextdb_engine::sync_types::{
4    ApplyResult, ChangeSet, Conflict, DdlChange, EdgeChange, NaturalKey, RowChange, VectorChange,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9const PROTOCOL_VERSION: u8 = 1;
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
12pub struct Envelope {
13    pub version: u8,
14    pub message_type: MessageType,
15    pub payload: Vec<u8>,
16}
17
18impl Envelope {
19    /// Constructs an Envelope pre-populated for a pull request with the current
20    /// protocol version and empty payload.
21    pub fn default_pull_request() -> Self {
22        Self {
23            version: PROTOCOL_VERSION,
24            message_type: MessageType::PullRequest,
25            payload: Vec::new(),
26        }
27    }
28}
29
30#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub enum MessageType {
32    PushRequest,
33    PushResponse,
34    #[default]
35    PullRequest,
36    PullResponse,
37    Chunk,
38    ChunkAck,
39}
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct PushRequest {
43    pub changeset: WireChangeSet,
44}
45
46#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
47pub struct PushResponse {
48    pub result: Option<WireApplyResult>,
49    pub error: Option<String>,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
53pub struct PullRequest {
54    pub since_lsn: Lsn,
55    pub max_entries: Option<u32>,
56}
57
58impl<'de> serde::Deserialize<'de> for PullRequest {
59    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        use serde::de::{SeqAccess, Visitor};
64
65        struct PullRequestVisitor;
66
67        impl<'de> Visitor<'de> for PullRequestVisitor {
68            type Value = PullRequest;
69
70            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71                f.write_str("PullRequest with 1 or 2 elements")
72            }
73
74            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
75            where
76                A: SeqAccess<'de>,
77            {
78                let since_lsn: Lsn = seq
79                    .next_element()?
80                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
81                let max_entries: Option<u32> = seq.next_element()?.unwrap_or(None);
82                Ok(PullRequest {
83                    since_lsn,
84                    max_entries,
85                })
86            }
87        }
88
89        deserializer.deserialize_tuple(2, PullRequestVisitor)
90    }
91}
92
93#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
94pub struct PullResponse {
95    pub changeset: WireChangeSet,
96    pub has_more: bool,
97    pub cursor: Option<Lsn>,
98}
99
100#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
101pub struct ChunkMessage {
102    pub chunk_id: uuid::Uuid,
103    pub sequence: u32,
104    pub total_chunks: u32,
105    #[serde(with = "serde_bytes")]
106    pub payload: Vec<u8>,
107}
108
109#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
110pub struct ChunkAck {
111    pub chunk_id: uuid::Uuid,
112    pub total_chunks: u32,
113    pub reply_inbox: String,
114}
115
116#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
117pub struct WireChangeSet {
118    pub rows: Vec<WireRowChange>,
119    pub edges: Vec<WireEdgeChange>,
120    pub vectors: Vec<WireVectorChange>,
121    pub ddl: Vec<WireDdlChange>,
122}
123
124#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
125pub struct WireRowChange {
126    pub table: String,
127    pub natural_key: WireNaturalKey,
128    pub values: HashMap<String, Value>,
129    #[serde(default)]
130    pub deleted: bool,
131    pub lsn: Lsn,
132}
133
134#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
135pub struct WireEdgeChange {
136    pub source: uuid::Uuid,
137    pub target: uuid::Uuid,
138    pub edge_type: String,
139    pub properties: HashMap<String, Value>,
140    pub lsn: Lsn,
141}
142
143#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
144pub struct WireVectorChange {
145    pub row_id: RowId,
146    pub vector: Vec<f32>,
147    pub lsn: Lsn,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151pub enum WireDdlChange {
152    CreateTable {
153        name: String,
154        columns: Vec<(String, String)>,
155        constraints: Vec<String>,
156    },
157    DropTable {
158        name: String,
159    },
160    AlterTable {
161        name: String,
162        columns: Vec<(String, String)>,
163        constraints: Vec<String>,
164    },
165    // Direction is encoded as a string ("ASC" / "DESC") on the wire so the
166    // server does not depend on contextdb-core's SortDirection type shape.
167    CreateIndex {
168        table: String,
169        name: String,
170        columns: Vec<(String, String)>,
171    },
172    DropIndex {
173        table: String,
174        name: String,
175    },
176}
177
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub struct WireNaturalKey {
180    pub column: String,
181    pub value: Value,
182}
183
184impl Default for WireNaturalKey {
185    fn default() -> Self {
186        Self {
187            column: String::new(),
188            value: Value::Null,
189        }
190    }
191}
192
193#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
194pub struct WireApplyResult {
195    pub applied_rows: usize,
196    pub skipped_rows: usize,
197    pub conflicts: Vec<WireConflict>,
198    pub new_lsn: Lsn,
199}
200
201#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
202pub struct WireConflict {
203    pub natural_key: WireNaturalKey,
204    pub resolution: String,
205    pub reason: Option<String>,
206}
207
208pub fn encode<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>, SyncError> {
209    let payload = rmp_serde::to_vec(msg).map_err(|e| SyncError::Serde(e.to_string()))?;
210    let envelope = Envelope {
211        version: PROTOCOL_VERSION,
212        message_type: msg_type,
213        payload,
214    };
215    rmp_serde::to_vec(&envelope).map_err(|e| SyncError::Serde(e.to_string()))
216}
217
218pub fn decode(data: &[u8]) -> Result<Envelope, SyncError> {
219    let envelope: Envelope =
220        rmp_serde::from_slice(data).map_err(|e| SyncError::Serde(e.to_string()))?;
221    if envelope.version > PROTOCOL_VERSION {
222        return Err(SyncError::UnsupportedVersion(envelope.version));
223    }
224    Ok(envelope)
225}
226
227impl From<ChangeSet> for WireChangeSet {
228    fn from(value: ChangeSet) -> Self {
229        Self {
230            rows: value.rows.into_iter().map(Into::into).collect(),
231            edges: value.edges.into_iter().map(Into::into).collect(),
232            vectors: value.vectors.into_iter().map(Into::into).collect(),
233            ddl: value.ddl.into_iter().map(Into::into).collect(),
234        }
235    }
236}
237
238impl From<WireChangeSet> for ChangeSet {
239    fn from(value: WireChangeSet) -> Self {
240        Self {
241            rows: value.rows.into_iter().map(Into::into).collect(),
242            edges: value.edges.into_iter().map(Into::into).collect(),
243            vectors: value.vectors.into_iter().map(Into::into).collect(),
244            ddl: value.ddl.into_iter().map(Into::into).collect(),
245        }
246    }
247}
248
249impl From<RowChange> for WireRowChange {
250    fn from(value: RowChange) -> Self {
251        Self {
252            table: value.table,
253            natural_key: value.natural_key.into(),
254            values: value.values,
255            deleted: value.deleted,
256            lsn: value.lsn,
257        }
258    }
259}
260
261impl From<WireRowChange> for RowChange {
262    fn from(value: WireRowChange) -> Self {
263        Self {
264            table: value.table,
265            natural_key: value.natural_key.into(),
266            values: value.values,
267            deleted: value.deleted,
268            lsn: value.lsn,
269        }
270    }
271}
272
273impl From<EdgeChange> for WireEdgeChange {
274    fn from(value: EdgeChange) -> Self {
275        Self {
276            source: value.source,
277            target: value.target,
278            edge_type: value.edge_type,
279            properties: value.properties,
280            lsn: value.lsn,
281        }
282    }
283}
284
285impl From<WireEdgeChange> for EdgeChange {
286    fn from(value: WireEdgeChange) -> Self {
287        Self {
288            source: value.source,
289            target: value.target,
290            edge_type: value.edge_type,
291            properties: value.properties,
292            lsn: value.lsn,
293        }
294    }
295}
296
297impl From<VectorChange> for WireVectorChange {
298    fn from(value: VectorChange) -> Self {
299        Self {
300            row_id: value.row_id,
301            vector: value.vector,
302            lsn: value.lsn,
303        }
304    }
305}
306
307impl From<WireVectorChange> for VectorChange {
308    fn from(value: WireVectorChange) -> Self {
309        Self {
310            row_id: value.row_id,
311            vector: value.vector,
312            lsn: value.lsn,
313        }
314    }
315}
316
317impl From<DdlChange> for WireDdlChange {
318    fn from(value: DdlChange) -> Self {
319        match value {
320            DdlChange::CreateTable {
321                name,
322                columns,
323                constraints,
324            } => Self::CreateTable {
325                name,
326                columns,
327                constraints,
328            },
329            DdlChange::DropTable { name } => Self::DropTable { name },
330            DdlChange::AlterTable {
331                name,
332                columns,
333                constraints,
334            } => Self::AlterTable {
335                name,
336                columns,
337                constraints,
338            },
339            DdlChange::CreateIndex {
340                table,
341                name,
342                columns,
343            } => {
344                let wire_cols = columns
345                    .into_iter()
346                    .map(|(c, dir)| {
347                        let dir_str = match dir {
348                            contextdb_core::SortDirection::Asc => "ASC".to_string(),
349                            contextdb_core::SortDirection::Desc => "DESC".to_string(),
350                        };
351                        (c, dir_str)
352                    })
353                    .collect();
354                Self::CreateIndex {
355                    table,
356                    name,
357                    columns: wire_cols,
358                }
359            }
360            DdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
361        }
362    }
363}
364
365impl From<WireDdlChange> for DdlChange {
366    fn from(value: WireDdlChange) -> Self {
367        match value {
368            WireDdlChange::CreateTable {
369                name,
370                columns,
371                constraints,
372            } => Self::CreateTable {
373                name,
374                columns,
375                constraints,
376            },
377            WireDdlChange::DropTable { name } => Self::DropTable { name },
378            WireDdlChange::AlterTable {
379                name,
380                columns,
381                constraints,
382            } => Self::AlterTable {
383                name,
384                columns,
385                constraints,
386            },
387            WireDdlChange::CreateIndex {
388                table,
389                name,
390                columns,
391            } => {
392                let engine_cols = columns
393                    .into_iter()
394                    .map(|(c, dir_str)| {
395                        let dir = if dir_str.eq_ignore_ascii_case("DESC") {
396                            contextdb_core::SortDirection::Desc
397                        } else {
398                            contextdb_core::SortDirection::Asc
399                        };
400                        (c, dir)
401                    })
402                    .collect();
403                Self::CreateIndex {
404                    table,
405                    name,
406                    columns: engine_cols,
407                }
408            }
409            WireDdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
410        }
411    }
412}
413
414impl From<NaturalKey> for WireNaturalKey {
415    fn from(value: NaturalKey) -> Self {
416        Self {
417            column: value.column,
418            value: value.value,
419        }
420    }
421}
422
423impl From<WireNaturalKey> for NaturalKey {
424    fn from(value: WireNaturalKey) -> Self {
425        Self {
426            column: value.column,
427            value: value.value,
428        }
429    }
430}
431
432impl From<ApplyResult> for WireApplyResult {
433    fn from(value: ApplyResult) -> Self {
434        Self {
435            applied_rows: value.applied_rows,
436            skipped_rows: value.skipped_rows,
437            conflicts: value.conflicts.into_iter().map(Into::into).collect(),
438            new_lsn: value.new_lsn,
439        }
440    }
441}
442
443impl From<WireApplyResult> for ApplyResult {
444    fn from(value: WireApplyResult) -> Self {
445        Self {
446            applied_rows: value.applied_rows,
447            skipped_rows: value.skipped_rows,
448            conflicts: value.conflicts.into_iter().map(Into::into).collect(),
449            new_lsn: value.new_lsn,
450        }
451    }
452}
453
454impl From<Conflict> for WireConflict {
455    fn from(value: Conflict) -> Self {
456        Self {
457            natural_key: value.natural_key.into(),
458            resolution: format!("{:?}", value.resolution),
459            reason: value.reason,
460        }
461    }
462}
463
464impl From<WireConflict> for Conflict {
465    fn from(value: WireConflict) -> Self {
466        let resolution = match value.resolution.as_str() {
467            "InsertIfNotExists" => contextdb_engine::sync_types::ConflictPolicy::InsertIfNotExists,
468            "ServerWins" => contextdb_engine::sync_types::ConflictPolicy::ServerWins,
469            "EdgeWins" => contextdb_engine::sync_types::ConflictPolicy::EdgeWins,
470            _ => contextdb_engine::sync_types::ConflictPolicy::LatestWins,
471        };
472        Self {
473            natural_key: value.natural_key.into(),
474            resolution,
475            reason: value.reason,
476        }
477    }
478}