Skip to main content

contextdb_server/
protocol.rs

1use crate::error::SyncError;
2use contextdb_core::{Lsn, RowId, Value, VectorIndexRef};
3use contextdb_engine::sync_types::{
4    ApplyResult, ChangeSet, Conflict, DdlChange, EdgeChange, NaturalKey, RowChange, VectorChange,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9pub const PROTOCOL_VERSION: u8 = 2;
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
12pub struct Envelope {
13    pub version: u8,
14    pub message_type: MessageType,
15    pub payload: Vec<u8>,
16}
17
18impl Envelope {
19    /// Constructs an Envelope pre-populated for a pull request with the current
20    /// protocol version and empty payload.
21    pub fn default_pull_request() -> Self {
22        Self {
23            version: PROTOCOL_VERSION,
24            message_type: MessageType::PullRequest,
25            payload: Vec::new(),
26        }
27    }
28}
29
30#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub enum MessageType {
32    PushRequest,
33    PushResponse,
34    #[default]
35    PullRequest,
36    PullResponse,
37    Chunk,
38    ChunkAck,
39}
40
41#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct PushRequest {
43    pub changeset: WireChangeSet,
44}
45
46#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
47pub struct PushResponse {
48    pub result: Option<WireApplyResult>,
49    pub error: Option<String>,
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
53pub struct PullRequest {
54    pub since_lsn: Lsn,
55    pub max_entries: Option<u32>,
56}
57
58impl<'de> serde::Deserialize<'de> for PullRequest {
59    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        use serde::de::{SeqAccess, Visitor};
64
65        struct PullRequestVisitor;
66
67        impl<'de> Visitor<'de> for PullRequestVisitor {
68            type Value = PullRequest;
69
70            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71                f.write_str("PullRequest with 1 or 2 elements")
72            }
73
74            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
75            where
76                A: SeqAccess<'de>,
77            {
78                let since_lsn: Lsn = seq
79                    .next_element()?
80                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
81                let max_entries: Option<u32> = seq.next_element()?.unwrap_or(None);
82                Ok(PullRequest {
83                    since_lsn,
84                    max_entries,
85                })
86            }
87        }
88
89        deserializer.deserialize_tuple(2, PullRequestVisitor)
90    }
91}
92
93#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
94pub struct PullResponse {
95    pub changeset: WireChangeSet,
96    pub has_more: bool,
97    pub cursor: Option<Lsn>,
98}
99
100#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
101pub struct ChunkMessage {
102    pub chunk_id: uuid::Uuid,
103    pub sequence: u32,
104    pub total_chunks: u32,
105    #[serde(with = "serde_bytes")]
106    pub payload: Vec<u8>,
107}
108
109#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
110pub struct ChunkAck {
111    pub chunk_id: uuid::Uuid,
112    pub total_chunks: u32,
113    pub reply_inbox: String,
114}
115
116#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
117pub struct WireChangeSet {
118    pub ddl: Vec<WireDdlChange>,
119    pub rows: Vec<WireRowChange>,
120    pub edges: Vec<WireEdgeChange>,
121    pub vectors: Vec<WireVectorChange>,
122}
123
124#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
125pub struct WireRowChange {
126    pub table: String,
127    pub natural_key: WireNaturalKey,
128    pub values: HashMap<String, Value>,
129    #[serde(default)]
130    pub deleted: bool,
131    pub lsn: Lsn,
132}
133
134#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
135pub struct WireEdgeChange {
136    pub source: uuid::Uuid,
137    pub target: uuid::Uuid,
138    pub edge_type: String,
139    pub properties: HashMap<String, Value>,
140    pub lsn: Lsn,
141}
142
143#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
144pub struct WireVectorChange {
145    pub index: VectorIndexRef,
146    pub row_id: RowId,
147    pub vector: Vec<f32>,
148    pub lsn: Lsn,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152pub enum WireDdlChange {
153    CreateTable {
154        name: String,
155        columns: Vec<(String, String)>,
156        constraints: Vec<String>,
157    },
158    DropTable {
159        name: String,
160    },
161    AlterTable {
162        name: String,
163        columns: Vec<(String, String)>,
164        constraints: Vec<String>,
165    },
166    // Direction is encoded as a string ("ASC" / "DESC") on the wire so the
167    // server does not depend on contextdb-core's SortDirection type shape.
168    CreateIndex {
169        table: String,
170        name: String,
171        columns: Vec<(String, String)>,
172    },
173    DropIndex {
174        table: String,
175        name: String,
176    },
177}
178
179#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
180pub struct WireNaturalKey {
181    pub column: String,
182    pub value: Value,
183}
184
185impl Default for WireNaturalKey {
186    fn default() -> Self {
187        Self {
188            column: String::new(),
189            value: Value::Null,
190        }
191    }
192}
193
194#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
195pub struct WireApplyResult {
196    pub applied_rows: usize,
197    pub skipped_rows: usize,
198    pub conflicts: Vec<WireConflict>,
199    pub new_lsn: Lsn,
200}
201
202#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
203pub struct WireConflict {
204    pub natural_key: WireNaturalKey,
205    pub resolution: String,
206    pub reason: Option<String>,
207}
208
209pub fn encode<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>, SyncError> {
210    let payload = rmp_serde::to_vec(msg).map_err(|e| SyncError::Serde(e.to_string()))?;
211    let envelope = Envelope {
212        version: PROTOCOL_VERSION,
213        message_type: msg_type,
214        payload,
215    };
216    rmp_serde::to_vec(&envelope).map_err(|e| SyncError::Serde(e.to_string()))
217}
218
219pub fn decode(data: &[u8]) -> Result<Envelope, SyncError> {
220    let envelope: Envelope =
221        rmp_serde::from_slice(data).map_err(|e| SyncError::Serde(e.to_string()))?;
222    if envelope.version != PROTOCOL_VERSION {
223        return Err(SyncError::ProtocolVersionMismatch {
224            received: envelope.version,
225            supported: PROTOCOL_VERSION,
226        });
227    }
228    Ok(envelope)
229}
230
231impl From<ChangeSet> for WireChangeSet {
232    fn from(value: ChangeSet) -> Self {
233        Self {
234            ddl: value.ddl.into_iter().map(Into::into).collect(),
235            rows: value.rows.into_iter().map(Into::into).collect(),
236            edges: value.edges.into_iter().map(Into::into).collect(),
237            vectors: value.vectors.into_iter().map(Into::into).collect(),
238        }
239    }
240}
241
242impl From<WireChangeSet> for ChangeSet {
243    fn from(value: WireChangeSet) -> Self {
244        Self {
245            ddl: value.ddl.into_iter().map(Into::into).collect(),
246            rows: value.rows.into_iter().map(Into::into).collect(),
247            edges: value.edges.into_iter().map(Into::into).collect(),
248            vectors: value.vectors.into_iter().map(Into::into).collect(),
249        }
250    }
251}
252
253impl From<RowChange> for WireRowChange {
254    fn from(value: RowChange) -> Self {
255        Self {
256            table: value.table,
257            natural_key: value.natural_key.into(),
258            values: value.values,
259            deleted: value.deleted,
260            lsn: value.lsn,
261        }
262    }
263}
264
265impl From<WireRowChange> for RowChange {
266    fn from(value: WireRowChange) -> Self {
267        Self {
268            table: value.table,
269            natural_key: value.natural_key.into(),
270            values: value.values,
271            deleted: value.deleted,
272            lsn: value.lsn,
273        }
274    }
275}
276
277impl From<EdgeChange> for WireEdgeChange {
278    fn from(value: EdgeChange) -> Self {
279        Self {
280            source: value.source,
281            target: value.target,
282            edge_type: value.edge_type,
283            properties: value.properties,
284            lsn: value.lsn,
285        }
286    }
287}
288
289impl From<WireEdgeChange> for EdgeChange {
290    fn from(value: WireEdgeChange) -> Self {
291        Self {
292            source: value.source,
293            target: value.target,
294            edge_type: value.edge_type,
295            properties: value.properties,
296            lsn: value.lsn,
297        }
298    }
299}
300
301impl From<VectorChange> for WireVectorChange {
302    fn from(value: VectorChange) -> Self {
303        Self {
304            index: value.index,
305            row_id: value.row_id,
306            vector: value.vector,
307            lsn: value.lsn,
308        }
309    }
310}
311
312impl From<WireVectorChange> for VectorChange {
313    fn from(value: WireVectorChange) -> Self {
314        Self {
315            index: value.index,
316            row_id: value.row_id,
317            vector: value.vector,
318            lsn: value.lsn,
319        }
320    }
321}
322
323impl From<DdlChange> for WireDdlChange {
324    fn from(value: DdlChange) -> Self {
325        match value {
326            DdlChange::CreateTable {
327                name,
328                columns,
329                constraints,
330            } => Self::CreateTable {
331                name,
332                columns,
333                constraints,
334            },
335            DdlChange::DropTable { name } => Self::DropTable { name },
336            DdlChange::AlterTable {
337                name,
338                columns,
339                constraints,
340            } => Self::AlterTable {
341                name,
342                columns,
343                constraints,
344            },
345            DdlChange::CreateIndex {
346                table,
347                name,
348                columns,
349            } => {
350                let wire_cols = columns
351                    .into_iter()
352                    .map(|(c, dir)| {
353                        let dir_str = match dir {
354                            contextdb_core::SortDirection::Asc => "ASC".to_string(),
355                            contextdb_core::SortDirection::Desc => "DESC".to_string(),
356                        };
357                        (c, dir_str)
358                    })
359                    .collect();
360                Self::CreateIndex {
361                    table,
362                    name,
363                    columns: wire_cols,
364                }
365            }
366            DdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
367        }
368    }
369}
370
371impl From<WireDdlChange> for DdlChange {
372    fn from(value: WireDdlChange) -> Self {
373        match value {
374            WireDdlChange::CreateTable {
375                name,
376                columns,
377                constraints,
378            } => Self::CreateTable {
379                name,
380                columns,
381                constraints,
382            },
383            WireDdlChange::DropTable { name } => Self::DropTable { name },
384            WireDdlChange::AlterTable {
385                name,
386                columns,
387                constraints,
388            } => Self::AlterTable {
389                name,
390                columns,
391                constraints,
392            },
393            WireDdlChange::CreateIndex {
394                table,
395                name,
396                columns,
397            } => {
398                let engine_cols = columns
399                    .into_iter()
400                    .map(|(c, dir_str)| {
401                        let dir = if dir_str.eq_ignore_ascii_case("DESC") {
402                            contextdb_core::SortDirection::Desc
403                        } else {
404                            contextdb_core::SortDirection::Asc
405                        };
406                        (c, dir)
407                    })
408                    .collect();
409                Self::CreateIndex {
410                    table,
411                    name,
412                    columns: engine_cols,
413                }
414            }
415            WireDdlChange::DropIndex { table, name } => Self::DropIndex { table, name },
416        }
417    }
418}
419
420impl From<NaturalKey> for WireNaturalKey {
421    fn from(value: NaturalKey) -> Self {
422        Self {
423            column: value.column,
424            value: value.value,
425        }
426    }
427}
428
429impl From<WireNaturalKey> for NaturalKey {
430    fn from(value: WireNaturalKey) -> Self {
431        Self {
432            column: value.column,
433            value: value.value,
434        }
435    }
436}
437
438impl From<ApplyResult> for WireApplyResult {
439    fn from(value: ApplyResult) -> Self {
440        Self {
441            applied_rows: value.applied_rows,
442            skipped_rows: value.skipped_rows,
443            conflicts: value.conflicts.into_iter().map(Into::into).collect(),
444            new_lsn: value.new_lsn,
445        }
446    }
447}
448
449impl From<WireApplyResult> for ApplyResult {
450    fn from(value: WireApplyResult) -> Self {
451        Self {
452            applied_rows: value.applied_rows,
453            skipped_rows: value.skipped_rows,
454            conflicts: value.conflicts.into_iter().map(Into::into).collect(),
455            new_lsn: value.new_lsn,
456        }
457    }
458}
459
460impl From<Conflict> for WireConflict {
461    fn from(value: Conflict) -> Self {
462        Self {
463            natural_key: value.natural_key.into(),
464            resolution: format!("{:?}", value.resolution),
465            reason: value.reason,
466        }
467    }
468}
469
470impl From<WireConflict> for Conflict {
471    fn from(value: WireConflict) -> Self {
472        let resolution = match value.resolution.as_str() {
473            "InsertIfNotExists" => contextdb_engine::sync_types::ConflictPolicy::InsertIfNotExists,
474            "ServerWins" => contextdb_engine::sync_types::ConflictPolicy::ServerWins,
475            "EdgeWins" => contextdb_engine::sync_types::ConflictPolicy::EdgeWins,
476            _ => contextdb_engine::sync_types::ConflictPolicy::LatestWins,
477        };
478        Self {
479            natural_key: value.natural_key.into(),
480            resolution,
481            reason: value.reason,
482        }
483    }
484}