Skip to main content

acuity_index_api_rs/
client.rs

1use crate::error::{IndexerApiError, ServerError, SubscriptionTerminated};
2use crate::types::{
3    EmptyPayload, Envelope, ErrorPayload, EventNotification, EventNotificationPayload, EventsResponse,
4    GetEventsPayload, Key, PalletMeta, RequestMessage, Span, StatusUpdate,
5    SubscribeEventsPayload, SubscriptionStatusPayload,
6    SubscriptionTerminatedPayload,
7};
8use futures::{SinkExt, StreamExt};
9use std::collections::HashMap;
10use std::sync::{
11    Arc,
12    atomic::{AtomicU64, Ordering},
13};
14use tokio::sync::{Mutex, mpsc, oneshot};
15use tokio_tungstenite::{connect_async, tungstenite::Message};
16
17type PendingSender = oneshot::Sender<Result<Envelope, IndexerApiError>>;
18type StatusSubscribers = Arc<Mutex<HashMap<u64, mpsc::Sender<Result<StatusUpdate, IndexerApiError>>>>>;
19type EventSubscribers = Arc<Mutex<HashMap<u64, EventSubscriber>>>;
20
21#[derive(Clone)]
22struct EventSubscriber {
23    key: Key,
24    sender: mpsc::Sender<Result<EventNotification, IndexerApiError>>,
25}
26
27#[derive(Clone)]
28pub struct IndexerClient {
29    writer: Arc<Mutex<futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
30    pending: Arc<Mutex<HashMap<u64, PendingSender>>>,
31    status_subscribers: StatusSubscribers,
32    event_subscribers: EventSubscribers,
33    next_id: Arc<AtomicU64>,
34}
35
36pub struct StatusSubscription {
37    client: IndexerClient,
38    id: u64,
39    receiver: mpsc::Receiver<Result<StatusUpdate, IndexerApiError>>,
40}
41
42pub struct EventSubscription {
43    client: IndexerClient,
44    id: u64,
45    key: Key,
46    receiver: mpsc::Receiver<Result<EventNotification, IndexerApiError>>,
47}
48
49impl IndexerClient {
50    pub async fn connect(url: &str) -> Result<Self, IndexerApiError> {
51        let (stream, _) = connect_async(url).await?;
52        let (writer, reader) = stream.split();
53
54        let client = Self {
55            writer: Arc::new(Mutex::new(writer)),
56            pending: Arc::new(Mutex::new(HashMap::new())),
57            status_subscribers: Arc::new(Mutex::new(HashMap::new())),
58            event_subscribers: Arc::new(Mutex::new(HashMap::new())),
59            next_id: Arc::new(AtomicU64::new(1)),
60        };
61
62        tokio::spawn(run_reader(
63            reader,
64            Arc::clone(&client.pending),
65            Arc::clone(&client.status_subscribers),
66            Arc::clone(&client.event_subscribers),
67        ));
68
69        Ok(client)
70    }
71
72    pub async fn close(&self) -> Result<(), IndexerApiError> {
73        self.writer.lock().await.close().await?;
74        Ok(())
75    }
76
77    pub async fn status(&self) -> Result<Vec<Span>, IndexerApiError> {
78        let envelope = self.request("Status", EmptyPayload::default()).await?;
79        expect_payload::<Vec<Span>>(envelope, "status")
80    }
81
82    pub async fn variants(&self) -> Result<Vec<PalletMeta>, IndexerApiError> {
83        let envelope = self.request("Variants", EmptyPayload::default()).await?;
84        expect_payload::<Vec<PalletMeta>>(envelope, "variants")
85    }
86
87    pub async fn size_on_disk(&self) -> Result<u64, IndexerApiError> {
88        let envelope = self.request("SizeOnDisk", EmptyPayload::default()).await?;
89        expect_payload::<u64>(envelope, "sizeOnDisk")
90    }
91
92    pub async fn get_events(
93        &self,
94        key: Key,
95        limit: Option<u16>,
96        before: Option<crate::types::EventRef>,
97    ) -> Result<EventsResponse, IndexerApiError> {
98        let envelope = self
99            .request("GetEvents", GetEventsPayload { key, limit, before })
100            .await?;
101        expect_payload::<EventsResponse>(envelope, "events")
102    }
103
104    pub async fn subscribe_status(&self) -> Result<StatusSubscription, IndexerApiError> {
105        let (tx, rx) = mpsc::channel(32);
106        let subscription_id = self.next_id.fetch_add(1, Ordering::Relaxed);
107        self.status_subscribers.lock().await.insert(subscription_id, tx);
108
109        let envelope = self.request("SubscribeStatus", EmptyPayload::default()).await?;
110        let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
111
112        Ok(StatusSubscription {
113            client: self.clone(),
114            id: subscription_id,
115            receiver: rx,
116        })
117    }
118
119    pub async fn unsubscribe_status(&self) -> Result<(), IndexerApiError> {
120        let envelope = self.request("UnsubscribeStatus", EmptyPayload::default()).await?;
121        let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
122        self.status_subscribers.lock().await.clear();
123        Ok(())
124    }
125
126    pub async fn subscribe_events(&self, key: Key) -> Result<EventSubscription, IndexerApiError> {
127        let (tx, rx) = mpsc::channel(32);
128        let subscription_id = self.next_id.fetch_add(1, Ordering::Relaxed);
129        self.event_subscribers.lock().await.insert(
130            subscription_id,
131            EventSubscriber {
132                key: key.clone(),
133                sender: tx,
134            },
135        );
136
137        let envelope = self
138            .request("SubscribeEvents", SubscribeEventsPayload { key: key.clone() })
139            .await?;
140        let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
141
142        Ok(EventSubscription {
143            client: self.clone(),
144            id: subscription_id,
145            key,
146            receiver: rx,
147        })
148    }
149
150    pub async fn unsubscribe_events(&self, key: Key) -> Result<(), IndexerApiError> {
151        let envelope = self
152            .request("UnsubscribeEvents", SubscribeEventsPayload { key: key.clone() })
153            .await?;
154        let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
155        self.event_subscribers
156            .lock()
157            .await
158            .retain(|_, subscriber| subscriber.key != key);
159        Ok(())
160    }
161
162    async fn unregister_status_subscription(&self, id: u64) {
163        self.status_subscribers.lock().await.remove(&id);
164    }
165
166    async fn unregister_event_subscription(&self, id: u64) {
167        self.event_subscribers.lock().await.remove(&id);
168    }
169
170    async fn request<T>(&self, message_type: &'static str, payload: T) -> Result<Envelope, IndexerApiError>
171    where
172        T: serde::Serialize,
173    {
174        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
175        let request = RequestMessage {
176            id,
177            message_type,
178            payload,
179        };
180        let json = serde_json::to_string(&request)?;
181        let (tx, rx) = oneshot::channel();
182        self.pending.lock().await.insert(id, tx);
183
184        let send_result = self.writer.lock().await.send(Message::Text(json.into())).await;
185        if let Err(error) = send_result {
186            self.pending.lock().await.remove(&id);
187            return Err(error.into());
188        }
189
190        match rx.await {
191            Ok(result) => result,
192            Err(_) => Err(IndexerApiError::ResponseChannelClosed { request_id: id }),
193        }
194    }
195}
196
197impl StatusSubscription {
198    pub async fn next(&mut self) -> Option<Result<StatusUpdate, IndexerApiError>> {
199        self.receiver.recv().await
200    }
201
202    pub async fn unsubscribe(self) -> Result<(), IndexerApiError> {
203        let client = self.client.clone();
204        let id = self.id;
205        let result = client.unsubscribe_status().await;
206        client.unregister_status_subscription(id).await;
207        result
208    }
209}
210
211impl EventSubscription {
212    pub async fn next(&mut self) -> Option<Result<EventNotification, IndexerApiError>> {
213        self.receiver.recv().await
214    }
215
216    pub async fn unsubscribe(self) -> Result<(), IndexerApiError> {
217        let client = self.client.clone();
218        let id = self.id;
219        let key = self.key.clone();
220        let result = client.unsubscribe_events(key).await;
221        client.unregister_event_subscription(id).await;
222        result
223    }
224}
225
226impl Drop for StatusSubscription {
227    fn drop(&mut self) {
228        let client = self.client.clone();
229        let id = self.id;
230        tokio::spawn(async move {
231            client.unregister_status_subscription(id).await;
232        });
233    }
234}
235
236impl Drop for EventSubscription {
237    fn drop(&mut self) {
238        let client = self.client.clone();
239        let id = self.id;
240        tokio::spawn(async move {
241            client.unregister_event_subscription(id).await;
242        });
243    }
244}
245
246async fn run_reader(
247    mut reader: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
248    pending: Arc<Mutex<HashMap<u64, PendingSender>>>,
249    status_subscribers: StatusSubscribers,
250    event_subscribers: EventSubscribers,
251) {
252    while let Some(message) = reader.next().await {
253        match handle_message(message, &pending, &status_subscribers, &event_subscribers).await {
254            Ok(()) => {}
255            Err(error) => {
256                fail_all_pending(&pending, &error).await;
257                broadcast_status_error(&status_subscribers, &error).await;
258                broadcast_event_error(&event_subscribers, &error).await;
259                return;
260            }
261        }
262    }
263
264    let error = IndexerApiError::ConnectionClosed;
265    fail_all_pending(&pending, &error).await;
266    broadcast_status_error(&status_subscribers, &error).await;
267    broadcast_event_error(&event_subscribers, &error).await;
268}
269
270async fn handle_message(
271    message: Result<Message, tokio_tungstenite::tungstenite::Error>,
272    pending: &Arc<Mutex<HashMap<u64, PendingSender>>>,
273    status_subscribers: &StatusSubscribers,
274    event_subscribers: &EventSubscribers,
275) -> Result<(), IndexerApiError> {
276    let payload = match message? {
277        Message::Text(text) => text.to_string(),
278        Message::Binary(bytes) => {
279            String::from_utf8(bytes.to_vec()).map_err(|_| IndexerApiError::NonUtf8Binary)?
280        }
281        Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => return Ok(()),
282        Message::Close(_) => return Err(IndexerApiError::ConnectionClosed),
283    };
284
285    let envelope: Envelope = serde_json::from_str(&payload)?;
286
287    if let Some(id) = envelope.id {
288        if let Some(sender) = pending.lock().await.remove(&id) {
289            let result = if envelope.message_type == "error" {
290                let error = parse_server_error(&envelope)?;
291                Err(error.into())
292            } else {
293                Ok(envelope)
294            };
295            let _ = sender.send(result);
296            return Ok(());
297        }
298    }
299
300    match envelope.message_type.as_str() {
301        "status" => {
302            let spans = envelope_data::<Vec<Span>>(&envelope)?;
303            broadcast_status_update(status_subscribers, StatusUpdate { spans }).await;
304        }
305        "eventNotification" => {
306            let payload = envelope_data::<EventNotificationPayload>(&envelope)?;
307            broadcast_event_update(
308                event_subscribers,
309                EventNotification {
310                    key: payload.key,
311                    event: payload.event,
312                    decoded_event: payload.decoded_event,
313                },
314            )
315            .await;
316        }
317        "subscriptionTerminated" => {
318            let termination = envelope_data::<SubscriptionTerminatedPayload>(&envelope)?;
319            let subscription_error = SubscriptionTerminated {
320                reason: termination.reason,
321                message: termination.message,
322            };
323            let status_error = IndexerApiError::StatusSubscriptionTerminated {
324                reason: subscription_error.reason.clone(),
325                message: subscription_error.message.clone(),
326            };
327            let event_error = IndexerApiError::EventSubscriptionTerminated {
328                reason: subscription_error.reason,
329                message: subscription_error.message,
330            };
331            broadcast_status_error(status_subscribers, &status_error).await;
332            broadcast_event_error(event_subscribers, &event_error).await;
333        }
334        "error" => {
335            let error = parse_server_error(&envelope)?;
336            let error = IndexerApiError::from(error);
337            broadcast_status_error(status_subscribers, &error).await;
338            broadcast_event_error(event_subscribers, &error).await;
339        }
340        _ => {}
341    }
342
343    Ok(())
344}
345
346fn expect_payload<T>(envelope: Envelope, expected_type: &'static str) -> Result<T, IndexerApiError>
347where
348    T: for<'de> serde::Deserialize<'de>,
349{
350    if envelope.message_type != expected_type {
351        return Err(IndexerApiError::UnexpectedResponseType {
352            request_id: envelope.id.unwrap_or_default(),
353            message_type: envelope.message_type,
354        });
355    }
356
357    envelope_data(&envelope)
358}
359
360fn envelope_data<T>(envelope: &Envelope) -> Result<T, IndexerApiError>
361where
362    T: for<'de> serde::Deserialize<'de>,
363{
364    serde_json::from_value(
365        envelope
366            .data
367            .clone()
368            .ok_or(IndexerApiError::Json(serde_json::Error::io(std::io::Error::new(
369                std::io::ErrorKind::InvalidData,
370                "missing data field",
371            ))))?,
372    )
373    .map_err(IndexerApiError::from)
374}
375
376fn parse_server_error(envelope: &Envelope) -> Result<ServerError, IndexerApiError> {
377    let payload = envelope_data::<ErrorPayload>(envelope)?;
378    Ok(ServerError {
379        code: payload.code,
380        message: payload.message,
381    })
382}
383
384async fn fail_all_pending(
385    pending: &Arc<Mutex<HashMap<u64, PendingSender>>>,
386    error: &IndexerApiError,
387) {
388    let mut pending = pending.lock().await;
389    for (request_id, sender) in pending.drain() {
390        let _ = sender.send(Err(match error {
391            IndexerApiError::ConnectionClosed => IndexerApiError::RequestCancelled { request_id },
392            _ => IndexerApiError::BackgroundTaskEnded,
393        }));
394    }
395}
396
397async fn broadcast_status_update(
398    subscribers: &StatusSubscribers,
399    update: StatusUpdate,
400) {
401    let mut subscribers = subscribers.lock().await;
402    let ids: Vec<u64> = subscribers.keys().copied().collect();
403    for id in ids {
404        let Some(subscriber) = subscribers.get(&id).cloned() else {
405            continue;
406        };
407        if subscriber.send(Ok(update.clone())).await.is_err() {
408            subscribers.remove(&id);
409        }
410    }
411}
412
413async fn broadcast_event_update(
414    subscribers: &EventSubscribers,
415    update: EventNotification,
416) {
417    let mut subscribers = subscribers.lock().await;
418    let ids: Vec<u64> = subscribers.keys().copied().collect();
419    for id in ids {
420        let Some(subscriber) = subscribers.get(&id).cloned() else {
421            continue;
422        };
423        if subscriber.key == update.key && subscriber.sender.send(Ok(update.clone())).await.is_err() {
424            subscribers.remove(&id);
425        }
426    }
427}
428
429async fn broadcast_status_error(
430    subscribers: &StatusSubscribers,
431    error: &IndexerApiError,
432) {
433    let mut subscribers = subscribers.lock().await;
434    let ids: Vec<u64> = subscribers.keys().copied().collect();
435    for id in ids {
436        let Some(subscriber) = subscribers.get(&id).cloned() else {
437            continue;
438        };
439        if subscriber.send(Err(clone_error(error))).await.is_err() {
440            subscribers.remove(&id);
441        }
442    }
443}
444
445async fn broadcast_event_error(
446    subscribers: &EventSubscribers,
447    error: &IndexerApiError,
448) {
449    let mut subscribers = subscribers.lock().await;
450    let ids: Vec<u64> = subscribers.keys().copied().collect();
451    for id in ids {
452        let Some(subscriber) = subscribers.get(&id).cloned() else {
453            continue;
454        };
455        if subscriber.sender.send(Err(clone_error(error))).await.is_err() {
456            subscribers.remove(&id);
457        }
458    }
459}
460
461    fn clone_error(error: &IndexerApiError) -> IndexerApiError {
462    match error {
463        IndexerApiError::Url(error) => IndexerApiError::Url(*error),
464        IndexerApiError::WebSocket(_) => IndexerApiError::BackgroundTaskEnded,
465        IndexerApiError::Json(error) => IndexerApiError::Json(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, error.to_string()))),
466        IndexerApiError::RequestCancelled { request_id } => IndexerApiError::RequestCancelled { request_id: *request_id },
467        IndexerApiError::ResponseChannelClosed { request_id } => IndexerApiError::ResponseChannelClosed { request_id: *request_id },
468        IndexerApiError::Server { code, message } => IndexerApiError::Server { code: code.clone(), message: message.clone() },
469        IndexerApiError::StatusSubscriptionTerminated { reason, message } => IndexerApiError::StatusSubscriptionTerminated { reason: reason.clone(), message: message.clone() },
470        IndexerApiError::EventSubscriptionTerminated { reason, message } => IndexerApiError::EventSubscriptionTerminated { reason: reason.clone(), message: message.clone() },
471        IndexerApiError::UnexpectedResponseType { request_id, message_type } => IndexerApiError::UnexpectedResponseType { request_id: *request_id, message_type: message_type.clone() },
472        IndexerApiError::NonUtf8Binary => IndexerApiError::NonUtf8Binary,
473        IndexerApiError::ConnectionClosed => IndexerApiError::ConnectionClosed,
474        IndexerApiError::BackgroundTaskEnded => IndexerApiError::BackgroundTaskEnded,
475    }
476    }
477
478    #[cfg(test)]
479    mod tests {
480    use super::*;
481    use crate::types::{
482        CustomKey, CustomScalarValue, CustomValue, DecodedEvent, Envelope, EventRef,
483    };
484    use serde_json::json;
485    use tokio::net::TcpListener;
486    use tokio::sync::mpsc;
487    use tokio_tungstenite::accept_async;
488
489    fn custom_u32_key(name: &str, value: u32) -> Key {
490        Key::Custom(CustomKey {
491            name: name.into(),
492            value: CustomValue::U32(value),
493        })
494    }
495
496    fn composite_key(name: &str, bytes: u8, value: u32) -> Key {
497        Key::Custom(CustomKey {
498            name: name.into(),
499            value: CustomValue::Composite(vec![
500                CustomScalarValue::Bytes32(crate::types::Bytes32([bytes; 32])),
501                CustomScalarValue::U32(value),
502            ]),
503        })
504    }
505
506    #[test]
507    fn parses_status_payload() {
508        let envelope = Envelope {
509            id: Some(2),
510            message_type: "status".into(),
511            data: Some(json!([{"start": 1, "end": 8}])),
512        };
513
514        let spans = expect_payload::<Vec<Span>>(envelope, "status").unwrap();
515        assert_eq!(spans, vec![Span { start: 1, end: 8 }]);
516    }
517
518    #[test]
519    fn expect_payload_rejects_unexpected_response_type() {
520        let envelope = Envelope {
521            id: Some(2),
522            message_type: "variants".into(),
523            data: Some(json!([])),
524        };
525
526        let error = expect_payload::<Vec<Span>>(envelope, "status").unwrap_err();
527        match error {
528            IndexerApiError::UnexpectedResponseType {
529                request_id,
530                message_type,
531            } => {
532                assert_eq!(request_id, 2);
533                assert_eq!(message_type, "variants");
534            }
535            _ => panic!("unexpected error variant"),
536        }
537    }
538
539    #[test]
540    fn envelope_data_rejects_missing_data() {
541        let envelope = Envelope {
542            id: Some(2),
543            message_type: "status".into(),
544            data: None,
545        };
546
547        let error = envelope_data::<Vec<Span>>(&envelope).unwrap_err();
548        assert!(error.to_string().contains("missing data field"));
549    }
550
551    #[test]
552    fn parses_events_payload() {
553        let envelope = Envelope {
554            id: Some(3),
555            message_type: "events".into(),
556            data: Some(json!({
557                "key": {"type": "Custom", "value": {"name": "ref_index", "kind": "u32", "value": 42}},
558                "events": [{"blockNumber": 50, "eventIndex": 3}],
559                "decodedEvents": [{
560                    "blockNumber": 50,
561                    "eventIndex": 3,
562                    "event": {
563                        "specVersion": 1234,
564                        "palletName": "Referenda",
565                        "eventName": "Submitted",
566                        "palletIndex": 42,
567                        "variantIndex": 0,
568                        "eventIndex": 3,
569                        "fields": {"index": 42}
570                    }
571                }]
572            })),
573        };
574
575        let response = expect_payload::<EventsResponse>(envelope, "events").unwrap();
576        assert_eq!(response.events.len(), 1);
577        assert_eq!(response.decoded_events.len(), 1);
578    }
579
580    #[test]
581    fn parses_server_error_payload() {
582        let envelope = Envelope {
583            id: Some(9),
584            message_type: "error".into(),
585            data: Some(json!({"code": "invalid_request", "message": "missing field `id`"})),
586        };
587
588        let error = parse_server_error(&envelope).unwrap();
589        assert_eq!(error.code, "invalid_request");
590        assert_eq!(error.message, "missing field `id`");
591    }
592
593    #[test]
594    fn handle_message_routes_response_to_matching_pending_request() {
595        let runtime = tokio::runtime::Builder::new_current_thread()
596            .enable_all()
597            .build()
598            .unwrap();
599
600        runtime.block_on(async {
601            let pending = Arc::new(Mutex::new(HashMap::new()));
602            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
603            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
604            let (tx, rx) = oneshot::channel();
605            pending.lock().await.insert(7, tx);
606
607            handle_message(
608                Ok(Message::Text(
609                    serde_json::to_string(&json!({
610                        "id": 7,
611                        "type": "status",
612                        "data": [{"start": 1, "end": 9}]
613                    }))
614                    .unwrap()
615                    .into(),
616                )),
617                &pending,
618                &status_subscribers,
619                &event_subscribers,
620            )
621            .await
622            .unwrap();
623
624            let response = rx.await.unwrap().unwrap();
625            assert_eq!(response.id, Some(7));
626            assert_eq!(response.message_type, "status");
627        });
628    }
629
630    #[test]
631    fn handle_message_routes_server_error_to_matching_pending_request() {
632        let runtime = tokio::runtime::Builder::new_current_thread()
633            .enable_all()
634            .build()
635            .unwrap();
636
637        runtime.block_on(async {
638            let pending = Arc::new(Mutex::new(HashMap::new()));
639            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
640            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
641            let (tx, rx) = oneshot::channel();
642            pending.lock().await.insert(9, tx);
643
644            handle_message(
645                Ok(Message::Text(
646                    serde_json::to_string(&json!({
647                        "id": 9,
648                        "type": "error",
649                        "data": {"code": "invalid_request", "message": "missing field `id`"}
650                    }))
651                    .unwrap()
652                    .into(),
653                )),
654                &pending,
655                &status_subscribers,
656                &event_subscribers,
657            )
658            .await
659            .unwrap();
660
661            let error = rx.await.unwrap().unwrap_err();
662            match error {
663                IndexerApiError::Server { code, message } => {
664                    assert_eq!(code, "invalid_request");
665                    assert_eq!(message, "missing field `id`");
666                }
667                _ => panic!("unexpected error variant"),
668            }
669        });
670    }
671
672    #[test]
673    fn handle_message_broadcasts_status_update_to_subscribers() {
674        let runtime = tokio::runtime::Builder::new_current_thread()
675            .enable_all()
676            .build()
677            .unwrap();
678
679        runtime.block_on(async {
680            let pending = Arc::new(Mutex::new(HashMap::new()));
681            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
682            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
683            let (tx, mut rx) = mpsc::channel(1);
684            status_subscribers.lock().await.insert(1, tx);
685
686            handle_message(
687                Ok(Message::Text(
688                    serde_json::to_string(&json!({
689                        "type": "status",
690                        "data": [{"start": 1, "end": 8}]
691                    }))
692                    .unwrap()
693                    .into(),
694                )),
695                &pending,
696                &status_subscribers,
697                &event_subscribers,
698            )
699            .await
700            .unwrap();
701
702            let update = rx.recv().await.unwrap().unwrap();
703            assert_eq!(update, StatusUpdate { spans: vec![Span { start: 1, end: 8 }] });
704        });
705    }
706
707    #[test]
708    fn handle_message_broadcasts_event_notification_to_subscribers() {
709        let runtime = tokio::runtime::Builder::new_current_thread()
710            .enable_all()
711            .build()
712            .unwrap();
713
714        runtime.block_on(async {
715            let pending = Arc::new(Mutex::new(HashMap::new()));
716            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
717            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
718            let (tx, mut rx) = mpsc::channel(1);
719            event_subscribers.lock().await.insert(1, EventSubscriber { key: custom_u32_key("ref_index", 42), sender: tx });
720
721            handle_message(
722                Ok(Message::Text(
723                    serde_json::to_string(&json!({
724                        "type": "eventNotification",
725                        "data": {
726                            "key": {"type": "Custom", "value": {"name": "ref_index", "kind": "u32", "value": 42}},
727                            "event": {"blockNumber": 50, "eventIndex": 3},
728                            "decodedEvent": null
729                        }
730                    }))
731                    .unwrap()
732                    .into(),
733                )),
734                &pending,
735                &status_subscribers,
736                &event_subscribers,
737            )
738            .await
739            .unwrap();
740
741            let update = rx.recv().await.unwrap().unwrap();
742            assert_eq!(update.key, custom_u32_key("ref_index", 42));
743            assert_eq!(update.event, EventRef { block_number: 50, event_index: 3 });
744            assert!(update.decoded_event.is_none());
745        });
746    }
747
748    #[test]
749    fn handle_message_broadcasts_subscription_termination_to_subscribers() {
750        let runtime = tokio::runtime::Builder::new_current_thread()
751            .enable_all()
752            .build()
753            .unwrap();
754
755        runtime.block_on(async {
756            let pending = Arc::new(Mutex::new(HashMap::new()));
757            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
758            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
759            let (status_tx, mut status_rx) = mpsc::channel(1);
760            let (event_tx, mut event_rx) = mpsc::channel(1);
761            status_subscribers.lock().await.insert(1, status_tx);
762            event_subscribers.lock().await.insert(2, EventSubscriber { key: custom_u32_key("ref_index", 42), sender: event_tx });
763
764            handle_message(
765                Ok(Message::Text(
766                    serde_json::to_string(&json!({
767                        "type": "subscriptionTerminated",
768                        "data": {
769                            "reason": "backpressure",
770                            "message": "subscriber disconnected due to backpressure"
771                        }
772                    }))
773                    .unwrap()
774                    .into(),
775                )),
776                &pending,
777                &status_subscribers,
778                &event_subscribers,
779            )
780            .await
781            .unwrap();
782
783            match status_rx.recv().await.unwrap().unwrap_err() {
784                IndexerApiError::StatusSubscriptionTerminated { reason, message } => {
785                    assert_eq!(reason, "backpressure");
786                    assert_eq!(message, "subscriber disconnected due to backpressure");
787                }
788                _ => panic!("unexpected status error variant"),
789            }
790
791            match event_rx.recv().await.unwrap().unwrap_err() {
792                IndexerApiError::EventSubscriptionTerminated { reason, message } => {
793                    assert_eq!(reason, "backpressure");
794                    assert_eq!(message, "subscriber disconnected due to backpressure");
795                }
796                _ => panic!("unexpected event error variant"),
797            }
798        });
799    }
800
801    #[test]
802    fn handle_message_rejects_invalid_binary_payload() {
803        let runtime = tokio::runtime::Builder::new_current_thread()
804            .enable_all()
805            .build()
806            .unwrap();
807
808        runtime.block_on(async {
809            let pending = Arc::new(Mutex::new(HashMap::new()));
810            let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
811            let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
812
813            let error = handle_message(
814                Ok(Message::Binary(vec![0xFF, 0xFE].into())),
815                &pending,
816                &status_subscribers,
817                &event_subscribers,
818            )
819            .await
820            .unwrap_err();
821
822            assert!(matches!(error, IndexerApiError::NonUtf8Binary));
823        });
824    }
825
826    #[test]
827    fn fail_all_pending_marks_connection_closed_requests_as_cancelled() {
828        let runtime = tokio::runtime::Builder::new_current_thread()
829            .enable_all()
830            .build()
831            .unwrap();
832
833        runtime.block_on(async {
834            let pending = Arc::new(Mutex::new(HashMap::new()));
835            let (tx, rx) = oneshot::channel();
836            pending.lock().await.insert(12, tx);
837
838            fail_all_pending(&pending, &IndexerApiError::ConnectionClosed).await;
839
840            match rx.await.unwrap().unwrap_err() {
841                IndexerApiError::RequestCancelled { request_id } => assert_eq!(request_id, 12),
842                _ => panic!("unexpected error variant"),
843            }
844        });
845    }
846
847    #[test]
848    fn fail_all_pending_marks_other_failures_as_background_task_ended() {
849        let runtime = tokio::runtime::Builder::new_current_thread()
850            .enable_all()
851            .build()
852            .unwrap();
853
854        runtime.block_on(async {
855            let pending = Arc::new(Mutex::new(HashMap::new()));
856            let (tx, rx) = oneshot::channel();
857            pending.lock().await.insert(13, tx);
858
859            fail_all_pending(
860                &pending,
861                &IndexerApiError::Server {
862                    code: "internal_error".into(),
863                    message: "boom".into(),
864                },
865            )
866            .await;
867
868            assert!(matches!(
869                rx.await.unwrap().unwrap_err(),
870                IndexerApiError::BackgroundTaskEnded
871            ));
872        });
873    }
874
875    #[test]
876    fn clone_error_preserves_server_payload() {
877        let cloned = clone_error(&IndexerApiError::Server {
878            code: "invalid_request".into(),
879            message: "missing field `id`".into(),
880        });
881
882        match cloned {
883            IndexerApiError::Server { code, message } => {
884                assert_eq!(code, "invalid_request");
885                assert_eq!(message, "missing field `id`");
886            }
887            _ => panic!("unexpected error variant"),
888        }
889    }
890
891    #[test]
892    fn event_notification_payload_matches_server_shape() {
893        let payload = serde_json::from_value::<EventNotificationPayload>(json!({
894            "key": {"type": "Custom", "value": {"name": "item_id", "kind": "bytes32", "value": format!("0x{}", "11".repeat(32))}},
895            "event": {"blockNumber": 50, "eventIndex": 3},
896            "decodedEvent": {
897                "blockNumber": 50,
898                "eventIndex": 3,
899                "event": {
900                    "specVersion": 1234,
901                    "palletName": "Content",
902                    "eventName": "PublishRevision",
903                    "palletIndex": 42,
904                    "variantIndex": 1,
905                    "eventIndex": 3,
906                    "fields": {}
907                }
908            }
909        }))
910        .unwrap();
911
912        assert_eq!(payload.event, EventRef { block_number: 50, event_index: 3 });
913        assert_eq!(payload.key, Key::Custom(CustomKey { name: "item_id".into(), value: CustomValue::Bytes32(crate::types::Bytes32([0x11; 32])) }));
914        assert_eq!(
915            payload.decoded_event,
916            Some(DecodedEvent {
917                block_number: 50,
918                event_index: 3,
919                event: crate::types::StoredEvent {
920                    spec_version: 1234,
921                    pallet_name: "Content".into(),
922                    event_name: "PublishRevision".into(),
923                    pallet_index: 42,
924                    variant_index: 1,
925                    event_index: 3,
926                    fields: json!({}),
927                },
928            })
929        );
930    }
931
932    #[test]
933    fn broadcast_event_update_only_notifies_matching_keys() {
934        let runtime = tokio::runtime::Builder::new_current_thread()
935            .enable_all()
936            .build()
937            .unwrap();
938
939        runtime.block_on(async {
940            let subscribers = Arc::new(Mutex::new(HashMap::new()));
941            let (match_tx, mut match_rx) = mpsc::channel(1);
942            let (other_tx, mut other_rx) = mpsc::channel(1);
943
944            subscribers.lock().await.insert(
945                1,
946                EventSubscriber {
947                    key: custom_u32_key("ref_index", 42),
948                    sender: match_tx,
949                },
950            );
951            subscribers.lock().await.insert(
952                2,
953                EventSubscriber {
954                    key: custom_u32_key("ref_index", 7),
955                    sender: other_tx,
956                },
957            );
958
959            broadcast_event_update(
960                &subscribers,
961                EventNotification {
962                    key: custom_u32_key("ref_index", 42),
963                    event: EventRef {
964                        block_number: 10,
965                        event_index: 1,
966                    },
967                    decoded_event: None,
968                },
969            )
970            .await;
971
972            assert!(match_rx.recv().await.is_some());
973            assert!(other_rx.try_recv().is_err());
974        });
975    }
976
977    #[test]
978    fn broadcast_event_update_matches_composite_keys() {
979        let runtime = tokio::runtime::Builder::new_current_thread()
980            .enable_all()
981            .build()
982            .unwrap();
983
984        runtime.block_on(async {
985            let subscribers = Arc::new(Mutex::new(HashMap::new()));
986            let (match_tx, mut match_rx) = mpsc::channel(1);
987            let (other_tx, mut other_rx) = mpsc::channel(1);
988
989            subscribers.lock().await.insert(
990                1,
991                EventSubscriber {
992                    key: composite_key("item_revision", 0x11, 7),
993                    sender: match_tx,
994                },
995            );
996            subscribers.lock().await.insert(
997                2,
998                EventSubscriber {
999                    key: composite_key("item_revision", 0x11, 8),
1000                    sender: other_tx,
1001                },
1002            );
1003
1004            broadcast_event_update(
1005                &subscribers,
1006                EventNotification {
1007                    key: composite_key("item_revision", 0x11, 7),
1008                    event: EventRef {
1009                        block_number: 10,
1010                        event_index: 1,
1011                    },
1012                    decoded_event: None,
1013                },
1014            )
1015            .await;
1016
1017            assert!(match_rx.recv().await.is_some());
1018            assert!(other_rx.try_recv().is_err());
1019        });
1020    }
1021
1022    #[test]
1023    fn clone_error_preserves_response_channel_closed_payload() {
1024        let cloned = clone_error(&IndexerApiError::ResponseChannelClosed { request_id: 44 });
1025
1026        match cloned {
1027            IndexerApiError::ResponseChannelClosed { request_id } => assert_eq!(request_id, 44),
1028            _ => panic!("unexpected error variant"),
1029        }
1030    }
1031
1032    #[tokio::test(flavor = "current_thread")]
1033    async fn close_sends_websocket_close_frame() {
1034        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1035        let addr = listener.local_addr().unwrap();
1036
1037        let server = tokio::spawn(async move {
1038            let (stream, _) = listener.accept().await.unwrap();
1039            let mut websocket = accept_async(stream).await.unwrap();
1040
1041            match websocket.next().await {
1042                Some(Ok(Message::Close(_))) => {}
1043                other => panic!("expected websocket close frame, got {other:?}"),
1044            }
1045        });
1046
1047        let client = IndexerClient::connect(&format!("ws://{addr}")).await.unwrap();
1048        client.close().await.unwrap();
1049
1050        server.await.unwrap();
1051    }
1052}