Skip to main content

allsource_core/infrastructure/resp/
commands.rs

1//! Redis command handler for AllSource Core.
2//!
3//! Maps Redis Streams commands to EventStore operations:
4//!
5//! | Redis Command | AllSource Operation |
6//! |---------------|---------------------|
7//! | `XADD stream * field value ...` | `EventStore::ingest()` |
8//! | `XRANGE stream - +` | `EventStore::query()` |
9//! | `XLEN stream` | event count for entity |
10//! | `SUBSCRIBE channel` | WebSocket broadcast subscribe |
11//! | `PING` | health check |
12//! | `COMMAND` / `COMMAND DOCS` | command metadata |
13//! | `INFO` | server info |
14
15use crate::{application::dto::QueryEventsRequest, domain::entities::Event, store::EventStore};
16use std::sync::Arc;
17use tokio::sync::broadcast;
18
19use super::protocol::RespValue;
20
21/// Subscription info returned by SUBSCRIBE command.
22pub struct SubscriptionInfo {
23    pub rx: broadcast::Receiver<Arc<Event>>,
24    /// Event type prefix filters (e.g. `["scheduler.*"]`). Empty = all events.
25    pub filters: Vec<String>,
26}
27
28/// Execute a parsed RESP command against the EventStore.
29///
30/// Returns `(response, Option<subscription>)` — subscription is `Some` only for
31/// SUBSCRIBE, giving the caller a broadcast receiver and filters.
32pub fn execute(
33    args: &[RespValue],
34    store: &Arc<EventStore>,
35) -> (RespValue, Option<SubscriptionInfo>) {
36    if args.is_empty() {
37        return (RespValue::err("empty command"), None);
38    }
39
40    let Some(s) = args[0].as_str() else {
41        return (RespValue::err("invalid command"), None);
42    };
43    let cmd = s.to_ascii_uppercase();
44
45    match cmd.as_str() {
46        "PING" => handle_ping(&args[1..]),
47        "XADD" => handle_xadd(&args[1..], store),
48        "XRANGE" => handle_xrange(&args[1..], store),
49        "XLEN" => handle_xlen(&args[1..], store),
50        "SUBSCRIBE" => handle_subscribe(&args[1..], store),
51        "COMMAND" => handle_command(&args[1..]),
52        "INFO" => handle_info(store),
53        "QUIT" => (RespValue::ok(), None),
54        _ => (RespValue::err(format!("unknown command '{cmd}'")), None),
55    }
56}
57
58// ── PING ────────────────────────────────────────────────────────────────────
59
60fn handle_ping(args: &[RespValue]) -> (RespValue, Option<SubscriptionInfo>) {
61    if let Some(msg) = args.first().and_then(|v| v.as_str()) {
62        (RespValue::bulk_string(msg), None)
63    } else {
64        (RespValue::SimpleString("PONG".to_string()), None)
65    }
66}
67
68// ── XADD ────────────────────────────────────────────────────────────────────
69// Usage: XADD <stream_key> * event_type <type> entity_id <id> [payload <json>] [metadata <json>]
70//
71// The stream_key is treated as the tenant_id. The `*` auto-generates an ID.
72// Field-value pairs must include `event_type` and `entity_id`.
73
74fn handle_xadd(
75    args: &[RespValue],
76    store: &Arc<EventStore>,
77) -> (RespValue, Option<SubscriptionInfo>) {
78    // Minimum: stream_key, id_arg, field, value (4 args for one field pair — but we need 2 fields minimum)
79    if args.len() < 6 {
80        return (
81            RespValue::err(
82                "wrong number of arguments for 'XADD' command. Usage: XADD <tenant> * event_type <type> entity_id <id> [payload <json>] [metadata <json>]",
83            ),
84            None,
85        );
86    }
87
88    let Some(s) = args[0].as_str() else {
89        return (RespValue::err("invalid stream key"), None);
90    };
91    let tenant_id = s.to_string();
92
93    // args[1] should be "*" or an explicit ID (we only support "*")
94    match args[1].as_str() {
95        Some("*") => {}
96        _ => {
97            return (
98                RespValue::err("only '*' auto-ID is supported for XADD"),
99                None,
100            );
101        }
102    }
103
104    // Parse field-value pairs
105    let pairs = &args[2..];
106    if !pairs.len().is_multiple_of(2) {
107        return (RespValue::err("odd number of field-value pairs"), None);
108    }
109
110    let mut event_type: Option<String> = None;
111    let mut entity_id: Option<String> = None;
112    let mut payload: serde_json::Value = serde_json::json!({});
113    let mut metadata: Option<serde_json::Value> = None;
114
115    for chunk in pairs.chunks(2) {
116        let Some(field) = chunk[0].as_str() else {
117            return (RespValue::err("field name must be a string"), None);
118        };
119        let Some(value) = chunk[1].as_str() else {
120            return (RespValue::err("field value must be a string"), None);
121        };
122
123        match field {
124            "event_type" => event_type = Some(value.to_string()),
125            "entity_id" => entity_id = Some(value.to_string()),
126            "payload" => {
127                payload = serde_json::from_str(value).unwrap_or_else(|_| {
128                    // If not valid JSON, wrap as a string value
129                    serde_json::Value::String(value.to_string())
130                });
131            }
132            "metadata" => {
133                metadata = Some(
134                    serde_json::from_str(value)
135                        .unwrap_or_else(|_| serde_json::Value::String(value.to_string())),
136                );
137            }
138            _ => {
139                // Unknown fields go into payload
140                if let serde_json::Value::Object(ref mut map) = payload {
141                    map.insert(
142                        field.to_string(),
143                        serde_json::Value::String(value.to_string()),
144                    );
145                }
146            }
147        }
148    }
149
150    let Some(event_type) = event_type else {
151        return (RespValue::err("missing required field 'event_type'"), None);
152    };
153    let Some(entity_id) = entity_id else {
154        return (RespValue::err("missing required field 'entity_id'"), None);
155    };
156
157    // Create and ingest the event
158    let event = match Event::from_strings(event_type, entity_id, tenant_id, payload, metadata) {
159        Ok(e) => e,
160        Err(e) => return (RespValue::err(format!("event creation failed: {e}")), None),
161    };
162
163    let event_id = event.id.to_string();
164    let timestamp = event.timestamp.timestamp_millis();
165
166    match store.ingest(&event) {
167        Ok(()) => {
168            // Return a Redis-style stream ID: "<timestamp>-0"
169            let stream_id = format!("{timestamp}-0");
170            (RespValue::bulk_string(&stream_id), None)
171        }
172        Err(e) => (RespValue::err(format!("ingest failed: {e}")), None),
173    }
174}
175
176// ── XRANGE ──────────────────────────────────────────────────────────────────
177// Usage: XRANGE <stream_key> <start> <end> [COUNT <n>]
178//
179// stream_key = tenant_id (or "entity:<id>" / "type:<type>" for filtered queries)
180// start/end: "-"/"+", or timestamps in milliseconds
181
182fn handle_xrange(
183    args: &[RespValue],
184    store: &Arc<EventStore>,
185) -> (RespValue, Option<SubscriptionInfo>) {
186    if args.len() < 3 {
187        return (
188            RespValue::err("wrong number of arguments for 'XRANGE' command"),
189            None,
190        );
191    }
192
193    let Some(stream_key) = args[0].as_str() else {
194        return (RespValue::err("invalid stream key"), None);
195    };
196    let Some(_start) = args[1].as_str() else {
197        return (RespValue::err("invalid start"), None);
198    };
199    let Some(_end) = args[2].as_str() else {
200        return (RespValue::err("invalid end"), None);
201    };
202
203    // Parse optional COUNT
204    let mut limit: Option<usize> = None;
205    let mut i = 3;
206    while i < args.len() {
207        if let Some(kw) = args[i].as_str()
208            && kw.eq_ignore_ascii_case("COUNT")
209            && i + 1 < args.len()
210        {
211            if let Some(n) = args[i + 1].as_str().and_then(|s| s.parse::<usize>().ok()) {
212                limit = Some(n);
213            }
214            i += 2;
215            continue;
216        }
217        i += 1;
218    }
219
220    // Parse stream_key into query filters
221    let mut request = QueryEventsRequest {
222        entity_id: None,
223        event_type: None,
224        tenant_id: None,
225        as_of: None,
226        since: None,
227        until: None,
228        limit,
229        event_type_prefix: None,
230        payload_filter: None,
231    };
232
233    if let Some(rest) = stream_key.strip_prefix("entity:") {
234        request.entity_id = Some(rest.to_string());
235    } else if let Some(rest) = stream_key.strip_prefix("type:") {
236        request.event_type = Some(rest.to_string());
237    } else {
238        request.tenant_id = Some(stream_key.to_string());
239    }
240
241    // Parse time range (start/end can be "-"/"+", or millisecond timestamps)
242    if _start != "-"
243        && let Ok(ms) = _start.split('-').next().unwrap_or(_start).parse::<i64>()
244    {
245        request.since = chrono::DateTime::from_timestamp_millis(ms);
246    }
247    if _end != "+"
248        && let Ok(ms) = _end.split('-').next().unwrap_or(_end).parse::<i64>()
249    {
250        request.until = chrono::DateTime::from_timestamp_millis(ms);
251    }
252
253    match store.query(&request) {
254        Ok(events) => {
255            // Return as array of [stream_id, [field, value, ...]] pairs (Redis XRANGE format)
256            let entries: Vec<RespValue> = events
257                .iter()
258                .map(|e| {
259                    let stream_id = format!("{}-0", e.timestamp.timestamp_millis());
260                    let fields = vec![
261                        RespValue::bulk_string("event_id"),
262                        RespValue::bulk_string(&e.id.to_string()),
263                        RespValue::bulk_string("event_type"),
264                        RespValue::bulk_string(e.event_type_str()),
265                        RespValue::bulk_string("entity_id"),
266                        RespValue::bulk_string(e.entity_id_str()),
267                        RespValue::bulk_string("tenant_id"),
268                        RespValue::bulk_string(e.tenant_id_str()),
269                        RespValue::bulk_string("payload"),
270                        RespValue::bulk_string(&e.payload.to_string()),
271                        RespValue::bulk_string("timestamp"),
272                        RespValue::bulk_string(&e.timestamp.to_rfc3339()),
273                    ];
274                    RespValue::Array(vec![
275                        RespValue::bulk_string(&stream_id),
276                        RespValue::Array(fields),
277                    ])
278                })
279                .collect();
280            (RespValue::Array(entries), None)
281        }
282        Err(e) => (RespValue::err(format!("query failed: {e}")), None),
283    }
284}
285
286// ── XLEN ────────────────────────────────────────────────────────────────────
287// Usage: XLEN <stream_key>
288
289fn handle_xlen(
290    args: &[RespValue],
291    store: &Arc<EventStore>,
292) -> (RespValue, Option<SubscriptionInfo>) {
293    if args.is_empty() {
294        return (
295            RespValue::err("wrong number of arguments for 'XLEN' command"),
296            None,
297        );
298    }
299
300    let Some(stream_key) = args[0].as_str() else {
301        return (RespValue::err("invalid stream key"), None);
302    };
303
304    // Query all events for this stream and return count
305    let mut request = QueryEventsRequest {
306        entity_id: None,
307        event_type: None,
308        tenant_id: None,
309        as_of: None,
310        since: None,
311        until: None,
312        limit: None,
313        event_type_prefix: None,
314        payload_filter: None,
315    };
316    if let Some(rest) = stream_key.strip_prefix("entity:") {
317        request.entity_id = Some(rest.to_string());
318    } else if let Some(rest) = stream_key.strip_prefix("type:") {
319        request.event_type = Some(rest.to_string());
320    } else {
321        request.tenant_id = Some(stream_key.to_string());
322    }
323
324    match store.query(&request) {
325        Ok(events) => (RespValue::Integer(events.len() as i64), None),
326        Err(e) => (RespValue::err(format!("query failed: {e}")), None),
327    }
328}
329
330// ── SUBSCRIBE ───────────────────────────────────────────────────────────────
331// Usage: SUBSCRIBE <pattern> [<pattern> ...]
332//
333// Subscribes to real-time event broadcasts with server-side prefix filtering.
334// Patterns use `prefix.*` syntax (e.g. `scheduler.*` matches `scheduler.started`).
335// `SUBSCRIBE *` receives all events (backwards compatible).
336
337fn handle_subscribe(
338    args: &[RespValue],
339    store: &Arc<EventStore>,
340) -> (RespValue, Option<SubscriptionInfo>) {
341    if args.is_empty() {
342        return (
343            RespValue::err("wrong number of arguments for 'SUBSCRIBE' command"),
344            None,
345        );
346    }
347
348    let ws_manager = store.websocket_manager();
349
350    // Get a broadcast receiver from the WebSocket manager's event channel
351    let rx = ws_manager.subscribe_events();
352
353    // Collect channel patterns as prefix filters
354    let filters: Vec<String> = args
355        .iter()
356        .filter_map(|a| a.as_str().map(String::from))
357        .filter(|f| f != "*") // "*" means all events — no filter needed
358        .collect();
359
360    // Build subscription confirmation per Redis protocol (one per channel)
361    let mut confirmations = Vec::new();
362    for (i, arg) in args.iter().enumerate() {
363        let channel = arg.as_str().unwrap_or("unknown");
364        confirmations.push(RespValue::Array(vec![
365            RespValue::bulk_string("subscribe"),
366            RespValue::bulk_string(channel),
367            RespValue::Integer((i + 1) as i64),
368        ]));
369    }
370
371    let sub_info = SubscriptionInfo { rx, filters };
372
373    // Return the first confirmation; the server loop will send the rest
374    // and then enter subscription mode
375    if confirmations.len() == 1 {
376        (confirmations.into_iter().next().unwrap(), Some(sub_info))
377    } else {
378        // For multiple channels, wrap all confirmations
379        // The server loop handles writing each one
380        (RespValue::Array(confirmations), Some(sub_info))
381    }
382}
383
384// ── COMMAND ─────────────────────────────────────────────────────────────────
385
386fn handle_command(args: &[RespValue]) -> (RespValue, Option<SubscriptionInfo>) {
387    // redis-cli sends `COMMAND DOCS` on connect — return empty array
388    (RespValue::Array(vec![]), None)
389}
390
391// ── INFO ────────────────────────────────────────────────────────────────────
392
393fn handle_info(store: &Arc<EventStore>) -> (RespValue, Option<SubscriptionInfo>) {
394    let info = format!(
395        "# Server\r\n\
396         redis_version:7.0.0-allsource\r\n\
397         allsource_version:{}\r\n\
398         # Keyspace\r\n",
399        env!("CARGO_PKG_VERSION"),
400    );
401    (RespValue::bulk_string(&info), None)
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    fn make_store() -> Arc<EventStore> {
409        Arc::new(EventStore::new())
410    }
411
412    fn cmd(parts: &[&str]) -> Vec<RespValue> {
413        parts.iter().map(|s| RespValue::bulk_string(s)).collect()
414    }
415
416    #[test]
417    fn test_ping() {
418        let store = make_store();
419        let (resp, sub) = execute(&cmd(&["PING"]), &store);
420        assert_eq!(resp, RespValue::SimpleString("PONG".to_string()));
421        assert!(sub.is_none());
422    }
423
424    #[test]
425    fn test_ping_with_message() {
426        let store = make_store();
427        let (resp, _) = execute(&cmd(&["PING", "hello"]), &store);
428        assert_eq!(resp, RespValue::bulk_string("hello"));
429    }
430
431    #[test]
432    fn test_xadd_and_xrange() {
433        let store = make_store();
434
435        // XADD
436        let (resp, _) = execute(
437            &cmd(&[
438                "XADD",
439                "default",
440                "*",
441                "event_type",
442                "user.created",
443                "entity_id",
444                "user-1",
445                "payload",
446                r#"{"name":"Alice"}"#,
447            ]),
448            &store,
449        );
450        // Should return a stream ID like "<timestamp>-0"
451        assert!(resp.as_str().unwrap().ends_with("-0"), "got: {resp:?}");
452
453        // XRANGE
454        let (resp, _) = execute(&cmd(&["XRANGE", "default", "-", "+"]), &store);
455        match resp {
456            RespValue::Array(entries) => {
457                assert_eq!(entries.len(), 1);
458                // Each entry is [stream_id, [field, value, ...]]
459                if let RespValue::Array(ref inner) = entries[0] {
460                    assert_eq!(inner.len(), 2);
461                    // Check the fields array
462                    if let RespValue::Array(ref fields) = inner[1] {
463                        // Find event_type field
464                        let et_idx = fields
465                            .iter()
466                            .position(|f| f.as_str() == Some("event_type"))
467                            .unwrap();
468                        assert_eq!(fields[et_idx + 1].as_str(), Some("user.created"));
469                    }
470                }
471            }
472            _ => panic!("expected array, got {resp:?}"),
473        }
474    }
475
476    #[test]
477    fn test_xadd_missing_fields() {
478        let store = make_store();
479        let (resp, _) = execute(&cmd(&["XADD", "default", "*"]), &store);
480        match resp {
481            RespValue::Error(_) => {}
482            _ => panic!("expected error"),
483        }
484    }
485
486    #[test]
487    fn test_xlen() {
488        let store = make_store();
489
490        // Ingest 3 events
491        for i in 0..3 {
492            execute(
493                &cmd(&[
494                    "XADD",
495                    "default",
496                    "*",
497                    "event_type",
498                    "test.event",
499                    "entity_id",
500                    &format!("entity-{i}"),
501                ]),
502                &store,
503            );
504        }
505
506        let (resp, _) = execute(&cmd(&["XLEN", "default"]), &store);
507        assert_eq!(resp, RespValue::Integer(3));
508    }
509
510    #[test]
511    fn test_xrange_with_count() {
512        let store = make_store();
513
514        for i in 0..5 {
515            execute(
516                &cmd(&[
517                    "XADD",
518                    "default",
519                    "*",
520                    "event_type",
521                    "test.event",
522                    "entity_id",
523                    &format!("entity-{i}"),
524                ]),
525                &store,
526            );
527        }
528
529        let (resp, _) = execute(&cmd(&["XRANGE", "default", "-", "+", "COUNT", "2"]), &store);
530        match resp {
531            RespValue::Array(entries) => assert_eq!(entries.len(), 2),
532            _ => panic!("expected array"),
533        }
534    }
535
536    #[test]
537    fn test_xrange_entity_filter() {
538        let store = make_store();
539
540        execute(
541            &cmd(&[
542                "XADD",
543                "default",
544                "*",
545                "event_type",
546                "user.created",
547                "entity_id",
548                "user-1",
549            ]),
550            &store,
551        );
552        execute(
553            &cmd(&[
554                "XADD",
555                "default",
556                "*",
557                "event_type",
558                "order.placed",
559                "entity_id",
560                "order-1",
561            ]),
562            &store,
563        );
564
565        // Query by entity
566        let (resp, _) = execute(&cmd(&["XRANGE", "entity:user-1", "-", "+"]), &store);
567        match resp {
568            RespValue::Array(entries) => assert_eq!(entries.len(), 1),
569            _ => panic!("expected array"),
570        }
571    }
572
573    #[test]
574    fn test_subscribe() {
575        let store = make_store();
576        let (resp, sub) = execute(&cmd(&["SUBSCRIBE", "events"]), &store);
577        assert!(sub.is_some(), "subscribe should return a receiver");
578        // Confirmation message
579        match resp {
580            RespValue::Array(items) => {
581                assert_eq!(items.len(), 3);
582                assert_eq!(items[0].as_str(), Some("subscribe"));
583                assert_eq!(items[1].as_str(), Some("events"));
584            }
585            _ => panic!("expected array confirmation"),
586        }
587    }
588
589    #[test]
590    fn test_subscribe_with_prefix_filters() {
591        let store = make_store();
592        let (_, sub) = execute(&cmd(&["SUBSCRIBE", "scheduler.*", "index.*"]), &store);
593        let sub_info = sub.unwrap();
594        assert_eq!(sub_info.filters, vec!["scheduler.*", "index.*"]);
595    }
596
597    #[test]
598    fn test_subscribe_wildcard_has_no_filters() {
599        let store = make_store();
600        let (_, sub) = execute(&cmd(&["SUBSCRIBE", "*"]), &store);
601        let sub_info = sub.unwrap();
602        assert!(
603            sub_info.filters.is_empty(),
604            "wildcard should produce no filters"
605        );
606    }
607
608    #[test]
609    fn test_unknown_command() {
610        let store = make_store();
611        let (resp, _) = execute(&cmd(&["FLUSHALL"]), &store);
612        match resp {
613            RespValue::Error(e) => assert!(e.contains("unknown command")),
614            _ => panic!("expected error"),
615        }
616    }
617
618    #[test]
619    fn test_info() {
620        let store = make_store();
621        let (resp, _) = execute(&cmd(&["INFO"]), &store);
622        let s = resp.as_str().unwrap();
623        assert!(s.contains("allsource_version"));
624    }
625
626    #[test]
627    fn test_xadd_extra_fields_go_to_payload() {
628        let store = make_store();
629        let (resp, _) = execute(
630            &cmd(&[
631                "XADD",
632                "default",
633                "*",
634                "event_type",
635                "user.created",
636                "entity_id",
637                "user-1",
638                "name",
639                "Alice",
640                "age",
641                "30",
642            ]),
643            &store,
644        );
645        assert!(resp.as_str().unwrap().ends_with("-0"));
646
647        // Query and check payload contains extra fields
648        let (resp, _) = execute(&cmd(&["XRANGE", "default", "-", "+"]), &store);
649        if let RespValue::Array(entries) = resp
650            && let RespValue::Array(ref inner) = entries[0]
651            && let RespValue::Array(ref fields) = inner[1]
652        {
653            let payload_idx = fields
654                .iter()
655                .position(|f| f.as_str() == Some("payload"))
656                .unwrap();
657            let payload_str = fields[payload_idx + 1].as_str().unwrap();
658            let payload: serde_json::Value = serde_json::from_str(payload_str).unwrap();
659            assert_eq!(payload["name"], "Alice");
660            assert_eq!(payload["age"], "30");
661        }
662    }
663}