Skip to main content

aimdb_ws_protocol/
lib.rs

1//! # aimdb-ws-protocol
2//!
3//! Shared wire protocol types for the AimDB WebSocket connector ecosystem.
4//!
5//! Used by:
6//!
7//! - **`aimdb-websocket-connector`** — the server side (Axum/Tokio)
8//! - **`aimdb-wasm-adapter`** — the browser client (`WsBridge`)
9//!
10//! # Wire Protocol
11//!
12//! All messages are JSON-encoded with a `"type"` discriminant tag:
13//!
14//! ## Server → Client ([`ServerMessage`])
15//!
16//! - `data` — live record push with timestamp
17//! - `snapshot` — late-join current value
18//! - `subscribed` — subscription acknowledgement
19//! - `error` — per-operation error
20//! - `pong` — response to client ping
21//! - `query_result` — response to a client query request
22//!
23//! ## Client → Server ([`ClientMessage`])
24//!
25//! - `subscribe` — subscribe to one or more topics (supports MQTT wildcards)
26//! - `unsubscribe` — cancel subscriptions
27//! - `write` — inbound value for a `link_from("ws://…")` record
28//! - `ping` — keepalive ping
29//! - `query` — query historical / persisted records
30//!
31//! # Topic Matching
32//!
33//! [`topic_matches`] implements MQTT-style wildcard matching (`#` for
34//! multi-level, `*` for single-level).
35
36use serde::{Deserialize, Serialize};
37
38// ════════════════════════════════════════════════════════════════════
39// Server → Client
40// ════════════════════════════════════════════════════════════════════
41
42/// A message sent from the server to a connected WebSocket client.
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44#[serde(tag = "type", rename_all = "snake_case")]
45pub enum ServerMessage {
46    /// Live data push from an outbound route.
47    Data {
48        topic: String,
49        #[serde(skip_serializing_if = "Option::is_none")]
50        payload: Option<serde_json::Value>,
51        /// Server-side dispatch timestamp (milliseconds since Unix epoch).
52        ts: u64,
53    },
54
55    /// Late-join snapshot — current value sent when a client subscribes.
56    Snapshot {
57        topic: String,
58        #[serde(skip_serializing_if = "Option::is_none")]
59        payload: Option<serde_json::Value>,
60    },
61
62    /// Confirmation sent once subscriptions are recorded.
63    Subscribed { topics: Vec<String> },
64
65    /// Per-operation error.
66    Error {
67        code: ErrorCode,
68        #[serde(skip_serializing_if = "Option::is_none")]
69        topic: Option<String>,
70        message: String,
71    },
72
73    /// Response to a client `ping` message.
74    Pong,
75
76    /// Response to a client `query` request.
77    ///
78    /// Contains the matching historical records and metadata.
79    QueryResult {
80        /// Correlation ID echoed from the client request.
81        id: String,
82        /// Matching records, ordered by timestamp ascending.
83        records: Vec<QueryRecord>,
84        /// Total number of records matched (before any limit).
85        total: usize,
86    },
87
88    /// Response to a client `list_topics` request.
89    TopicList {
90        /// Correlation ID echoed from the client request.
91        id: String,
92        /// All outbound topics served by this endpoint.
93        topics: Vec<TopicInfo>,
94    },
95}
96
97/// A single record returned in a [`ServerMessage::QueryResult`].
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct QueryRecord {
100    /// Topic / record name (e.g. `"temp.vienna"`).
101    pub topic: String,
102    /// Deserialized record value.
103    pub payload: serde_json::Value,
104    /// Storage timestamp (milliseconds since Unix epoch).
105    pub ts: u64,
106}
107
108/// Metadata for a single outbound topic served by a WebSocket endpoint.
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110pub struct TopicInfo {
111    /// Record key / topic name (e.g. `"temp.vienna"`).
112    pub name: String,
113    /// Schema type name (e.g. `"temperature"`), if known by the server.
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub schema_type: Option<String>,
116    /// Entity / node identifier (e.g. `"vienna"`), extracted server-side from the
117    /// topic name. The server is the authority on naming conventions — clients
118    /// should use this field directly rather than parsing the topic name.
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub entity: Option<String>,
121}
122
123/// Machine-readable error codes sent in `ServerMessage::Error`.
124#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
125#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
126pub enum ErrorCode {
127    Unauthorized,
128    Forbidden,
129    UnknownTopic,
130    SerializationError,
131    WriteError,
132    ServerError,
133}
134
135// ════════════════════════════════════════════════════════════════════
136// Client → Server
137// ════════════════════════════════════════════════════════════════════
138
139/// A message received from a WebSocket client.
140#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141#[serde(tag = "type", rename_all = "snake_case")]
142pub enum ClientMessage {
143    /// Subscribe to one or more topics (wildcards supported).
144    Subscribe { topics: Vec<String> },
145
146    /// Unsubscribe from one or more topics.
147    Unsubscribe { topics: Vec<String> },
148
149    /// Write a value to an inbound record (`link_from("ws://…")`).
150    Write {
151        topic: String,
152        payload: serde_json::Value,
153    },
154
155    /// Keepalive ping.
156    Ping,
157
158    /// Query historical / persisted records.
159    ///
160    /// The server responds with [`ServerMessage::QueryResult`] carrying the
161    /// same `id` for correlation.
162    Query {
163        /// Client-generated correlation ID (echoed in the response).
164        id: String,
165        /// Topic pattern to match (MQTT wildcards supported, `"*"` for all).
166        pattern: String,
167        /// Start of time range (milliseconds since Unix epoch), inclusive. Defaults to 1 h ago.
168        #[serde(skip_serializing_if = "Option::is_none")]
169        from: Option<u64>,
170        /// End of time range (milliseconds since Unix epoch), inclusive. Defaults to now.
171        #[serde(skip_serializing_if = "Option::is_none")]
172        to: Option<u64>,
173        /// Maximum number of records to return per matching topic.
174        #[serde(skip_serializing_if = "Option::is_none")]
175        limit: Option<usize>,
176    },
177
178    /// Request the list of topics served by this WebSocket endpoint.
179    ///
180    /// The server responds with [`ServerMessage::TopicList`] carrying the
181    /// same `id` for correlation.
182    ListTopics {
183        /// Client-generated correlation ID (echoed in the response).
184        id: String,
185    },
186}
187
188// ════════════════════════════════════════════════════════════════════
189// Topic matching
190// ════════════════════════════════════════════════════════════════════
191
192/// Returns `true` if `topic` matches `pattern`.
193///
194/// Follows MQTT wildcard conventions:
195///
196/// | Pattern  | Semantics                         |
197/// |----------|-----------------------------------|
198/// | `#`      | Multi-level wildcard (all topics) |
199/// | `a/#`    | Everything under `a/`             |
200/// | `a/*/c`  | Single-level wildcard in segment  |
201/// | `a/b/c`  | Exact match                       |
202pub fn topic_matches(pattern: &str, topic: &str) -> bool {
203    // Fast path: exact match
204    if pattern == topic {
205        return true;
206    }
207
208    // Multi-level wildcard: `#` matches everything
209    if pattern == "#" {
210        return true;
211    }
212
213    // `prefix/#` matches everything under prefix — only when prefix is literal
214    // (no wildcards in the prefix). When wildcards are present, fall through to
215    // the segment loop which handles `#` at any position.
216    if let Some(prefix) = pattern.strip_suffix("/#") {
217        if !prefix.contains('*') && !prefix.contains('#') {
218            return topic.starts_with(prefix)
219                && (topic.len() == prefix.len()
220                    || topic.as_bytes().get(prefix.len()) == Some(&b'/'));
221        }
222    }
223
224    // Segment-by-segment matching with `*` single-level wildcard
225    let mut pattern_parts = pattern.split('/');
226    let mut topic_parts = topic.split('/');
227
228    loop {
229        match (pattern_parts.next(), topic_parts.next()) {
230            (Some("#"), _) => return true,
231            (Some("*"), Some(_)) => {} // single-level wildcard — consume one segment
232            (Some(p), Some(t)) if p == t => {} // literal match
233            (None, None) => return true, // both exhausted at the same time
234            _ => return false,
235        }
236    }
237}
238
239/// Returns the current milliseconds since the Unix epoch (for `ts` fields).
240pub fn now_ms() -> u64 {
241    use std::time::{SystemTime, UNIX_EPOCH};
242    SystemTime::now()
243        .duration_since(UNIX_EPOCH)
244        .unwrap_or_default()
245        .as_millis() as u64
246}
247
248// ════════════════════════════════════════════════════════════════════
249// Tests
250// ════════════════════════════════════════════════════════════════════
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn exact_match() {
258        assert!(topic_matches("a/b/c", "a/b/c"));
259        assert!(!topic_matches("a/b/c", "a/b/d"));
260    }
261
262    #[test]
263    fn hash_wildcard() {
264        assert!(topic_matches("#", "anything/goes/here"));
265        assert!(topic_matches("#", "a"));
266    }
267
268    #[test]
269    fn prefix_hash_wildcard() {
270        assert!(topic_matches("sensors/#", "sensors/temperature/vienna"));
271        assert!(topic_matches("sensors/#", "sensors/humidity/berlin"));
272        assert!(!topic_matches("sensors/#", "commands/setpoint"));
273        // Edge: prefix itself
274        assert!(topic_matches("sensors/#", "sensors"));
275    }
276
277    #[test]
278    fn star_wildcard() {
279        assert!(topic_matches(
280            "sensors/temperature/*",
281            "sensors/temperature/vienna"
282        ));
283        assert!(topic_matches(
284            "sensors/temperature/*",
285            "sensors/temperature/berlin"
286        ));
287        assert!(!topic_matches(
288            "sensors/temperature/*",
289            "sensors/humidity/vienna"
290        ));
291        assert!(!topic_matches(
292            "sensors/temperature/*",
293            "sensors/temperature/a/b"
294        ));
295    }
296
297    #[test]
298    fn mixed_wildcards() {
299        assert!(topic_matches("a/*/c/#", "a/b/c/d/e/f"));
300        assert!(!topic_matches("a/*/c/#", "a/b/x/d"));
301    }
302
303    #[test]
304    fn serde_server_message_roundtrip() {
305        let msg = ServerMessage::Data {
306            topic: "sensors/temp".into(),
307            payload: Some(serde_json::json!({"celsius": 21.5})),
308            ts: 1234567890,
309        };
310        let json = serde_json::to_string(&msg).unwrap();
311        let parsed: ServerMessage = serde_json::from_str(&json).unwrap();
312        match parsed {
313            ServerMessage::Data { topic, ts, .. } => {
314                assert_eq!(topic, "sensors/temp");
315                assert_eq!(ts, 1234567890);
316            }
317            _ => panic!("Expected Data variant"),
318        }
319    }
320
321    #[test]
322    fn serde_client_message_roundtrip() {
323        let msg = ClientMessage::Subscribe {
324            topics: vec!["sensors/#".into()],
325        };
326        let json = serde_json::to_string(&msg).unwrap();
327        let parsed: ClientMessage = serde_json::from_str(&json).unwrap();
328        match parsed {
329            ClientMessage::Subscribe { topics } => {
330                assert_eq!(topics, vec!["sensors/#".to_string()]);
331            }
332            _ => panic!("Expected Subscribe variant"),
333        }
334    }
335
336    #[test]
337    fn serde_error_code_roundtrip() {
338        let msg = ServerMessage::Error {
339            code: ErrorCode::UnknownTopic,
340            topic: Some("foo/bar".into()),
341            message: "not found".into(),
342        };
343        let json = serde_json::to_string(&msg).unwrap();
344        assert!(json.contains("UNKNOWN_TOPIC"));
345        let parsed: ServerMessage = serde_json::from_str(&json).unwrap();
346        match parsed {
347            ServerMessage::Error { code, .. } => {
348                assert!(matches!(code, ErrorCode::UnknownTopic));
349            }
350            _ => panic!("Expected Error variant"),
351        }
352    }
353}