Skip to main content

oxigdal_ws/
protocol.rs

1//! WebSocket protocol definitions and message types.
2
3use serde::{Deserialize, Serialize};
4use std::ops::Range;
5
6/// Protocol version for compatibility checking.
7pub const PROTOCOL_VERSION: u32 = 1;
8
9/// Message encoding format.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11pub enum MessageFormat {
12    /// JSON text format (human-readable, larger size)
13    Json,
14    /// Binary MessagePack format (compact, efficient)
15    #[default]
16    MessagePack,
17    /// Binary format with optional compression
18    Binary,
19}
20
21/// Compression algorithm for messages.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
23pub enum Compression {
24    /// No compression
25    None,
26    /// Zstandard compression
27    #[default]
28    Zstd,
29}
30
31/// WebSocket message types exchanged between client and server.
32///
33/// Custom [`serde::Deserialize`] is implemented to work around the
34/// `serde_json/arbitrary_precision` issue: when that feature is active,
35/// serde's internal `Content` type represents numbers as `Map`, causing
36/// `[f64; 4]` arrays to fail deserialization through the normal
37/// `#[serde(tag = "type")]` machinery.  The fix routes JSON through
38/// `serde_json::Value` (which handles `arbitrary_precision` natively) and
39/// only uses the derived tagged-enum path for non-JSON formats such as
40/// MessagePack.
41#[derive(Debug, Clone, Serialize)]
42#[serde(tag = "type", rename_all = "snake_case")]
43pub enum Message {
44    /// Handshake message to negotiate protocol
45    Handshake {
46        /// Protocol version
47        version: u32,
48        /// Preferred message format
49        format: MessageFormat,
50        /// Preferred compression
51        compression: Compression,
52    },
53
54    /// Handshake acknowledgement
55    HandshakeAck {
56        /// Accepted protocol version
57        version: u32,
58        /// Accepted message format
59        format: MessageFormat,
60        /// Accepted compression
61        compression: Compression,
62    },
63
64    /// Subscribe to tile updates
65    SubscribeTiles {
66        /// Subscription ID
67        subscription_id: String,
68        /// Bounding box [min_x, min_y, max_x, max_y]
69        bbox: [f64; 4],
70        /// Zoom level range
71        zoom_range: Range<u8>,
72        /// Tile size (default 256)
73        tile_size: Option<u32>,
74    },
75
76    /// Subscribe to feature updates
77    SubscribeFeatures {
78        /// Subscription ID
79        subscription_id: String,
80        /// Bounding box filter (optional)
81        bbox: Option<[f64; 4]>,
82        /// Attribute filters (key-value pairs)
83        filters: Option<Vec<(String, String)>>,
84        /// Layer name filter
85        layer: Option<String>,
86    },
87
88    /// Subscribe to events
89    SubscribeEvents {
90        /// Subscription ID
91        subscription_id: String,
92        /// Event types to subscribe to
93        event_types: Vec<EventType>,
94    },
95
96    /// Unsubscribe from updates
97    Unsubscribe {
98        /// Subscription ID to cancel
99        subscription_id: String,
100    },
101
102    /// Tile data response
103    TileData {
104        /// Subscription ID
105        subscription_id: String,
106        /// Tile coordinates (x, y, zoom)
107        tile: (u32, u32, u8),
108        /// Tile data (MVT, PNG, etc.)
109        data: Vec<u8>,
110        /// MIME type
111        mime_type: String,
112    },
113
114    /// Feature data response (GeoJSON)
115    FeatureData {
116        /// Subscription ID
117        subscription_id: String,
118        /// GeoJSON feature or feature collection
119        geojson: String,
120        /// Change type (added, updated, deleted)
121        change_type: ChangeType,
122    },
123
124    /// Event notification
125    Event {
126        /// Subscription ID
127        subscription_id: String,
128        /// Event type
129        event_type: EventType,
130        /// Event payload
131        payload: serde_json::Value,
132        /// Event timestamp (RFC3339)
133        timestamp: String,
134    },
135
136    /// Error message
137    Error {
138        /// Error code
139        code: String,
140        /// Error message
141        message: String,
142        /// Request ID that caused the error (if applicable)
143        request_id: Option<String>,
144    },
145
146    /// Ping message for keep-alive
147    Ping {
148        /// Ping ID
149        id: u64,
150    },
151
152    /// Pong response to ping
153    Pong {
154        /// Ping ID being acknowledged
155        id: u64,
156    },
157
158    /// Acknowledgement message
159    Ack {
160        /// Request ID being acknowledged
161        request_id: String,
162        /// Success status
163        success: bool,
164        /// Optional message
165        message: Option<String>,
166    },
167}
168
169// ---------------------------------------------------------------------------
170// Custom Deserialize for Message
171// ---------------------------------------------------------------------------
172//
173// The derived `#[serde(tag = "type")]` implementation routes all formats
174// through serde's internal `Content` type, which represents numbers as
175// `Content::Map` when `serde_json/arbitrary_precision` is active.  That
176// breaks `[f64; 4]` fields.
177//
178// Fix: for human-readable formats (JSON) we first deserialize into a
179// `serde_json::Value` — which always handles `arbitrary_precision`
180// correctly — and then dispatch per the `"type"` field.  For non-human-
181// readable formats (MessagePack, Binary) the `arbitrary_precision` issue
182// does not apply, so we use a private mirror enum with derived
183// `Deserialize` instead.
184
185impl<'de> serde::Deserialize<'de> for Message {
186    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
187    where
188        D: serde::Deserializer<'de>,
189    {
190        use serde::de::Error as _;
191
192        if deserializer.is_human_readable() {
193            // ----------------------------------------------------------------
194            // JSON path — use serde_json::Value as an intermediate to bypass
195            // the Content/arbitrary_precision mismatch.
196            // ----------------------------------------------------------------
197            let value = serde_json::Value::deserialize(deserializer).map_err(D::Error::custom)?;
198
199            let type_str = value
200                .get("type")
201                .and_then(|t| t.as_str())
202                .ok_or_else(|| D::Error::custom("missing 'type' field in Message"))?;
203
204            match type_str {
205                "handshake" => {
206                    #[derive(serde::Deserialize)]
207                    struct HandshakeData {
208                        version: u32,
209                        format: MessageFormat,
210                        compression: Compression,
211                    }
212                    let d: HandshakeData =
213                        serde_json::from_value(value).map_err(D::Error::custom)?;
214                    Ok(Message::Handshake {
215                        version: d.version,
216                        format: d.format,
217                        compression: d.compression,
218                    })
219                }
220                "handshake_ack" => {
221                    #[derive(serde::Deserialize)]
222                    struct HandshakeAckData {
223                        version: u32,
224                        format: MessageFormat,
225                        compression: Compression,
226                    }
227                    let d: HandshakeAckData =
228                        serde_json::from_value(value).map_err(D::Error::custom)?;
229                    Ok(Message::HandshakeAck {
230                        version: d.version,
231                        format: d.format,
232                        compression: d.compression,
233                    })
234                }
235                "subscribe_tiles" => {
236                    #[derive(serde::Deserialize)]
237                    struct SubscribeTilesData {
238                        subscription_id: String,
239                        bbox: [f64; 4],
240                        zoom_range: Range<u8>,
241                        tile_size: Option<u32>,
242                    }
243                    let d: SubscribeTilesData =
244                        serde_json::from_value(value).map_err(D::Error::custom)?;
245                    Ok(Message::SubscribeTiles {
246                        subscription_id: d.subscription_id,
247                        bbox: d.bbox,
248                        zoom_range: d.zoom_range,
249                        tile_size: d.tile_size,
250                    })
251                }
252                "subscribe_features" => {
253                    #[derive(serde::Deserialize)]
254                    struct SubscribeFeaturesData {
255                        subscription_id: String,
256                        bbox: Option<[f64; 4]>,
257                        filters: Option<Vec<(String, String)>>,
258                        layer: Option<String>,
259                    }
260                    let d: SubscribeFeaturesData =
261                        serde_json::from_value(value).map_err(D::Error::custom)?;
262                    Ok(Message::SubscribeFeatures {
263                        subscription_id: d.subscription_id,
264                        bbox: d.bbox,
265                        filters: d.filters,
266                        layer: d.layer,
267                    })
268                }
269                "subscribe_events" => {
270                    #[derive(serde::Deserialize)]
271                    struct SubscribeEventsData {
272                        subscription_id: String,
273                        event_types: Vec<EventType>,
274                    }
275                    let d: SubscribeEventsData =
276                        serde_json::from_value(value).map_err(D::Error::custom)?;
277                    Ok(Message::SubscribeEvents {
278                        subscription_id: d.subscription_id,
279                        event_types: d.event_types,
280                    })
281                }
282                "unsubscribe" => {
283                    #[derive(serde::Deserialize)]
284                    struct UnsubscribeData {
285                        subscription_id: String,
286                    }
287                    let d: UnsubscribeData =
288                        serde_json::from_value(value).map_err(D::Error::custom)?;
289                    Ok(Message::Unsubscribe {
290                        subscription_id: d.subscription_id,
291                    })
292                }
293                "tile_data" => {
294                    #[derive(serde::Deserialize)]
295                    struct TileDataData {
296                        subscription_id: String,
297                        tile: (u32, u32, u8),
298                        data: Vec<u8>,
299                        mime_type: String,
300                    }
301                    let d: TileDataData =
302                        serde_json::from_value(value).map_err(D::Error::custom)?;
303                    Ok(Message::TileData {
304                        subscription_id: d.subscription_id,
305                        tile: d.tile,
306                        data: d.data,
307                        mime_type: d.mime_type,
308                    })
309                }
310                "feature_data" => {
311                    #[derive(serde::Deserialize)]
312                    struct FeatureDataData {
313                        subscription_id: String,
314                        geojson: String,
315                        change_type: ChangeType,
316                    }
317                    let d: FeatureDataData =
318                        serde_json::from_value(value).map_err(D::Error::custom)?;
319                    Ok(Message::FeatureData {
320                        subscription_id: d.subscription_id,
321                        geojson: d.geojson,
322                        change_type: d.change_type,
323                    })
324                }
325                "event" => {
326                    #[derive(serde::Deserialize)]
327                    struct EventData {
328                        subscription_id: String,
329                        event_type: EventType,
330                        payload: serde_json::Value,
331                        timestamp: String,
332                    }
333                    let d: EventData = serde_json::from_value(value).map_err(D::Error::custom)?;
334                    Ok(Message::Event {
335                        subscription_id: d.subscription_id,
336                        event_type: d.event_type,
337                        payload: d.payload,
338                        timestamp: d.timestamp,
339                    })
340                }
341                "error" => {
342                    #[derive(serde::Deserialize)]
343                    struct ErrorData {
344                        code: String,
345                        message: String,
346                        request_id: Option<String>,
347                    }
348                    let d: ErrorData = serde_json::from_value(value).map_err(D::Error::custom)?;
349                    Ok(Message::Error {
350                        code: d.code,
351                        message: d.message,
352                        request_id: d.request_id,
353                    })
354                }
355                "ping" => {
356                    #[derive(serde::Deserialize)]
357                    struct PingData {
358                        id: u64,
359                    }
360                    let d: PingData = serde_json::from_value(value).map_err(D::Error::custom)?;
361                    Ok(Message::Ping { id: d.id })
362                }
363                "pong" => {
364                    #[derive(serde::Deserialize)]
365                    struct PongData {
366                        id: u64,
367                    }
368                    let d: PongData = serde_json::from_value(value).map_err(D::Error::custom)?;
369                    Ok(Message::Pong { id: d.id })
370                }
371                "ack" => {
372                    #[derive(serde::Deserialize)]
373                    struct AckData {
374                        request_id: String,
375                        success: bool,
376                        message: Option<String>,
377                    }
378                    let d: AckData = serde_json::from_value(value).map_err(D::Error::custom)?;
379                    Ok(Message::Ack {
380                        request_id: d.request_id,
381                        success: d.success,
382                        message: d.message,
383                    })
384                }
385                other => Err(D::Error::custom(format!("unknown Message type: {other}"))),
386            }
387        } else {
388            // ----------------------------------------------------------------
389            // Non-JSON path (MessagePack, Binary) — arbitrary_precision does
390            // not apply here, so the normal derived tagged-enum path works.
391            // ----------------------------------------------------------------
392            #[derive(serde::Deserialize)]
393            #[serde(tag = "type", rename_all = "snake_case")]
394            enum MessageInner {
395                Handshake {
396                    version: u32,
397                    format: MessageFormat,
398                    compression: Compression,
399                },
400                HandshakeAck {
401                    version: u32,
402                    format: MessageFormat,
403                    compression: Compression,
404                },
405                SubscribeTiles {
406                    subscription_id: String,
407                    bbox: [f64; 4],
408                    zoom_range: Range<u8>,
409                    tile_size: Option<u32>,
410                },
411                SubscribeFeatures {
412                    subscription_id: String,
413                    bbox: Option<[f64; 4]>,
414                    filters: Option<Vec<(String, String)>>,
415                    layer: Option<String>,
416                },
417                SubscribeEvents {
418                    subscription_id: String,
419                    event_types: Vec<EventType>,
420                },
421                Unsubscribe {
422                    subscription_id: String,
423                },
424                TileData {
425                    subscription_id: String,
426                    tile: (u32, u32, u8),
427                    data: Vec<u8>,
428                    mime_type: String,
429                },
430                FeatureData {
431                    subscription_id: String,
432                    geojson: String,
433                    change_type: ChangeType,
434                },
435                Event {
436                    subscription_id: String,
437                    event_type: EventType,
438                    payload: serde_json::Value,
439                    timestamp: String,
440                },
441                Error {
442                    code: String,
443                    message: String,
444                    request_id: Option<String>,
445                },
446                Ping {
447                    id: u64,
448                },
449                Pong {
450                    id: u64,
451                },
452                Ack {
453                    request_id: String,
454                    success: bool,
455                    message: Option<String>,
456                },
457            }
458
459            let inner = MessageInner::deserialize(deserializer)?;
460            Ok(match inner {
461                MessageInner::Handshake {
462                    version,
463                    format,
464                    compression,
465                } => Message::Handshake {
466                    version,
467                    format,
468                    compression,
469                },
470                MessageInner::HandshakeAck {
471                    version,
472                    format,
473                    compression,
474                } => Message::HandshakeAck {
475                    version,
476                    format,
477                    compression,
478                },
479                MessageInner::SubscribeTiles {
480                    subscription_id,
481                    bbox,
482                    zoom_range,
483                    tile_size,
484                } => Message::SubscribeTiles {
485                    subscription_id,
486                    bbox,
487                    zoom_range,
488                    tile_size,
489                },
490                MessageInner::SubscribeFeatures {
491                    subscription_id,
492                    bbox,
493                    filters,
494                    layer,
495                } => Message::SubscribeFeatures {
496                    subscription_id,
497                    bbox,
498                    filters,
499                    layer,
500                },
501                MessageInner::SubscribeEvents {
502                    subscription_id,
503                    event_types,
504                } => Message::SubscribeEvents {
505                    subscription_id,
506                    event_types,
507                },
508                MessageInner::Unsubscribe { subscription_id } => {
509                    Message::Unsubscribe { subscription_id }
510                }
511                MessageInner::TileData {
512                    subscription_id,
513                    tile,
514                    data,
515                    mime_type,
516                } => Message::TileData {
517                    subscription_id,
518                    tile,
519                    data,
520                    mime_type,
521                },
522                MessageInner::FeatureData {
523                    subscription_id,
524                    geojson,
525                    change_type,
526                } => Message::FeatureData {
527                    subscription_id,
528                    geojson,
529                    change_type,
530                },
531                MessageInner::Event {
532                    subscription_id,
533                    event_type,
534                    payload,
535                    timestamp,
536                } => Message::Event {
537                    subscription_id,
538                    event_type,
539                    payload,
540                    timestamp,
541                },
542                MessageInner::Error {
543                    code,
544                    message,
545                    request_id,
546                } => Message::Error {
547                    code,
548                    message,
549                    request_id,
550                },
551                MessageInner::Ping { id } => Message::Ping { id },
552                MessageInner::Pong { id } => Message::Pong { id },
553                MessageInner::Ack {
554                    request_id,
555                    success,
556                    message,
557                } => Message::Ack {
558                    request_id,
559                    success,
560                    message,
561                },
562            })
563        }
564    }
565}
566
567/// Change type for feature updates.
568#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
569#[serde(rename_all = "lowercase")]
570pub enum ChangeType {
571    /// Feature was added
572    Added,
573    /// Feature was updated
574    Updated,
575    /// Feature was deleted
576    Deleted,
577}
578
579/// Event types that can be subscribed to.
580#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
581#[serde(rename_all = "snake_case")]
582pub enum EventType {
583    /// File change notification
584    FileChange,
585    /// Processing status update
586    ProcessingStatus,
587    /// Error notification
588    Error,
589    /// Progress update
590    Progress,
591    /// Custom event
592    Custom,
593}
594
595/// Serialization helpers for messages.
596impl Message {
597    /// Serialize message to JSON.
598    pub fn to_json(&self) -> crate::error::Result<String> {
599        serde_json::to_string(self).map_err(Into::into)
600    }
601
602    /// Deserialize message from JSON.
603    pub fn from_json(s: &str) -> crate::error::Result<Self> {
604        serde_json::from_str(s).map_err(Into::into)
605    }
606
607    /// Serialize message to MessagePack.
608    pub fn to_msgpack(&self) -> crate::error::Result<Vec<u8>> {
609        rmp_serde::to_vec(self).map_err(Into::into)
610    }
611
612    /// Deserialize message from MessagePack.
613    pub fn from_msgpack(data: &[u8]) -> crate::error::Result<Self> {
614        rmp_serde::from_slice(data).map_err(Into::into)
615    }
616
617    /// Compress data using zstd.
618    pub fn compress(data: &[u8], level: i32) -> crate::error::Result<Vec<u8>> {
619        oxiarc_zstd::encode_all(data, level)
620            .map_err(|e| crate::error::Error::Compression(e.to_string()))
621    }
622
623    /// Decompress zstd data.
624    pub fn decompress(data: &[u8]) -> crate::error::Result<Vec<u8>> {
625        oxiarc_zstd::decode_all(data).map_err(|e| crate::error::Error::Decompression(e.to_string()))
626    }
627
628    /// Encode message with specified format and compression.
629    pub fn encode(
630        &self,
631        format: MessageFormat,
632        compression: Compression,
633    ) -> crate::error::Result<Vec<u8>> {
634        let data = match format {
635            MessageFormat::Json => self.to_json()?.into_bytes(),
636            MessageFormat::MessagePack | MessageFormat::Binary => self.to_msgpack()?,
637        };
638
639        match compression {
640            Compression::None => Ok(data),
641            Compression::Zstd => Self::compress(&data, 3),
642        }
643    }
644
645    /// Decode message with specified format and compression.
646    pub fn decode(
647        data: &[u8],
648        format: MessageFormat,
649        compression: Compression,
650    ) -> crate::error::Result<Self> {
651        let decompressed = match compression {
652            Compression::None => data.to_vec(),
653            Compression::Zstd => Self::decompress(data)?,
654        };
655
656        match format {
657            MessageFormat::Json => {
658                let s = String::from_utf8(decompressed)
659                    .map_err(|e| crate::error::Error::Deserialization(e.to_string()))?;
660                Self::from_json(&s)
661            }
662            MessageFormat::MessagePack | MessageFormat::Binary => Self::from_msgpack(&decompressed),
663        }
664    }
665}
666
667/// Subscription filter for spatial queries.
668#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct SpatialFilter {
670    /// Bounding box [min_x, min_y, max_x, max_y]
671    pub bbox: [f64; 4],
672    /// Coordinate reference system (EPSG code)
673    pub crs: Option<String>,
674}
675
676/// Subscription filter for temporal queries.
677#[derive(Debug, Clone, Serialize, Deserialize)]
678pub struct TemporalFilter {
679    /// Start time (RFC3339)
680    pub start: Option<String>,
681    /// End time (RFC3339)
682    pub end: Option<String>,
683}
684
685/// Subscription filter combining spatial, temporal, and attribute filters.
686#[derive(Debug, Clone, Serialize, Deserialize)]
687pub struct SubscriptionFilter {
688    /// Spatial filter
689    pub spatial: Option<SpatialFilter>,
690    /// Temporal filter
691    pub temporal: Option<TemporalFilter>,
692    /// Attribute filters (key-value pairs)
693    pub attributes: Option<Vec<(String, String)>>,
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    #[test]
701    fn test_message_json_roundtrip() {
702        let msg = Message::Ping { id: 42 };
703        let json_str = msg.to_json().expect("Failed to serialize message to JSON");
704        let decoded =
705            Message::from_json(&json_str).expect("Failed to deserialize message from JSON");
706
707        assert!(matches!(decoded, Message::Ping { id: 42 }));
708    }
709
710    #[test]
711    fn test_message_msgpack_roundtrip() {
712        let msg = Message::Ping { id: 42 };
713        let msgpack_bytes = msg
714            .to_msgpack()
715            .expect("Failed to serialize message to MessagePack");
716        let decoded = Message::from_msgpack(&msgpack_bytes)
717            .expect("Failed to deserialize message from MessagePack");
718
719        assert!(matches!(decoded, Message::Ping { id: 42 }));
720    }
721
722    #[test]
723    fn test_compression_roundtrip() {
724        let data = b"Hello, WebSocket!";
725        let compressed = Message::compress(data, 3).expect("Failed to compress data");
726        let decompressed = Message::decompress(&compressed).expect("Failed to decompress data");
727
728        assert_eq!(data, decompressed.as_slice());
729    }
730
731    #[test]
732    fn test_message_encode_decode() {
733        let msg = Message::SubscribeTiles {
734            subscription_id: "test-123".to_string(),
735            bbox: [-180.0, -90.0, 180.0, 90.0],
736            zoom_range: 0..14,
737            tile_size: Some(256),
738        };
739
740        // Test JSON encoding
741        let encoded = msg
742            .encode(MessageFormat::Json, Compression::None)
743            .expect("Failed to encode message as JSON");
744        let decoded = Message::decode(&encoded, MessageFormat::Json, Compression::None)
745            .expect("Failed to decode message from JSON");
746
747        assert!(
748            matches!(
749                decoded,
750                Message::SubscribeTiles {
751                    subscription_id,
752                    bbox,
753                    zoom_range,
754                    tile_size,
755                } if subscription_id == "test-123"
756                    && bbox == [-180.0, -90.0, 180.0, 90.0]
757                    && zoom_range == (0..14)
758                    && tile_size == Some(256)
759            ),
760            "Decoded message does not match expected values"
761        );
762
763        // Test MessagePack with compression
764        let encoded = msg
765            .encode(MessageFormat::MessagePack, Compression::Zstd)
766            .expect("Failed to encode message as MessagePack with Zstd");
767        let decoded = Message::decode(&encoded, MessageFormat::MessagePack, Compression::Zstd)
768            .expect("Failed to decode message from MessagePack with Zstd");
769
770        assert!(
771            matches!(
772                decoded,
773                Message::SubscribeTiles {
774                    subscription_id,
775                    bbox,
776                    zoom_range,
777                    tile_size,
778                } if subscription_id == "test-123"
779                    && bbox == [-180.0, -90.0, 180.0, 90.0]
780                    && zoom_range == (0..14)
781                    && tile_size == Some(256)
782            ),
783            "Decoded message does not match expected values"
784        );
785    }
786}