1use crate::error::{IndexerApiError, ServerError, SubscriptionTerminated};
2use crate::types::{
3 EmptyPayload, Envelope, ErrorPayload, EventNotification, EventNotificationPayload, EventsResponse,
4 GetEventsPayload, Key, PalletMeta, RequestMessage, Span, StatusUpdate,
5 SubscribeEventsPayload, SubscriptionStatusPayload,
6 SubscriptionTerminatedPayload,
7};
8use futures::{SinkExt, StreamExt};
9use std::collections::HashMap;
10use std::sync::{
11 Arc,
12 atomic::{AtomicU64, Ordering},
13};
14use tokio::sync::{Mutex, mpsc, oneshot};
15use tokio_tungstenite::{connect_async, tungstenite::Message};
16
17type PendingSender = oneshot::Sender<Result<Envelope, IndexerApiError>>;
18type StatusSubscribers = Arc<Mutex<HashMap<u64, mpsc::Sender<Result<StatusUpdate, IndexerApiError>>>>>;
19type EventSubscribers = Arc<Mutex<HashMap<u64, EventSubscriber>>>;
20
21#[derive(Clone)]
22struct EventSubscriber {
23 key: Key,
24 sender: mpsc::Sender<Result<EventNotification, IndexerApiError>>,
25}
26
27#[derive(Clone)]
28pub struct IndexerClient {
29 writer: Arc<Mutex<futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
30 pending: Arc<Mutex<HashMap<u64, PendingSender>>>,
31 status_subscribers: StatusSubscribers,
32 event_subscribers: EventSubscribers,
33 next_id: Arc<AtomicU64>,
34}
35
36pub struct StatusSubscription {
37 client: IndexerClient,
38 id: u64,
39 receiver: mpsc::Receiver<Result<StatusUpdate, IndexerApiError>>,
40}
41
42pub struct EventSubscription {
43 client: IndexerClient,
44 id: u64,
45 key: Key,
46 receiver: mpsc::Receiver<Result<EventNotification, IndexerApiError>>,
47}
48
49impl IndexerClient {
50 pub async fn connect(url: &str) -> Result<Self, IndexerApiError> {
51 let (stream, _) = connect_async(url).await?;
52 let (writer, reader) = stream.split();
53
54 let client = Self {
55 writer: Arc::new(Mutex::new(writer)),
56 pending: Arc::new(Mutex::new(HashMap::new())),
57 status_subscribers: Arc::new(Mutex::new(HashMap::new())),
58 event_subscribers: Arc::new(Mutex::new(HashMap::new())),
59 next_id: Arc::new(AtomicU64::new(1)),
60 };
61
62 tokio::spawn(run_reader(
63 reader,
64 Arc::clone(&client.pending),
65 Arc::clone(&client.status_subscribers),
66 Arc::clone(&client.event_subscribers),
67 ));
68
69 Ok(client)
70 }
71
72 pub async fn close(&self) -> Result<(), IndexerApiError> {
73 self.writer.lock().await.close().await?;
74 Ok(())
75 }
76
77 pub async fn status(&self) -> Result<Vec<Span>, IndexerApiError> {
78 let envelope = self.request("Status", EmptyPayload::default()).await?;
79 expect_payload::<Vec<Span>>(envelope, "status")
80 }
81
82 pub async fn variants(&self) -> Result<Vec<PalletMeta>, IndexerApiError> {
83 let envelope = self.request("Variants", EmptyPayload::default()).await?;
84 expect_payload::<Vec<PalletMeta>>(envelope, "variants")
85 }
86
87 pub async fn size_on_disk(&self) -> Result<u64, IndexerApiError> {
88 let envelope = self.request("SizeOnDisk", EmptyPayload::default()).await?;
89 expect_payload::<u64>(envelope, "sizeOnDisk")
90 }
91
92 pub async fn get_events(
93 &self,
94 key: Key,
95 limit: Option<u16>,
96 before: Option<crate::types::EventRef>,
97 ) -> Result<EventsResponse, IndexerApiError> {
98 let envelope = self
99 .request("GetEvents", GetEventsPayload { key, limit, before })
100 .await?;
101 expect_payload::<EventsResponse>(envelope, "events")
102 }
103
104 pub async fn subscribe_status(&self) -> Result<StatusSubscription, IndexerApiError> {
105 let (tx, rx) = mpsc::channel(32);
106 let subscription_id = self.next_id.fetch_add(1, Ordering::Relaxed);
107 self.status_subscribers.lock().await.insert(subscription_id, tx);
108
109 let envelope = self.request("SubscribeStatus", EmptyPayload::default()).await?;
110 let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
111
112 Ok(StatusSubscription {
113 client: self.clone(),
114 id: subscription_id,
115 receiver: rx,
116 })
117 }
118
119 pub async fn unsubscribe_status(&self) -> Result<(), IndexerApiError> {
120 let envelope = self.request("UnsubscribeStatus", EmptyPayload::default()).await?;
121 let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
122 self.status_subscribers.lock().await.clear();
123 Ok(())
124 }
125
126 pub async fn subscribe_events(&self, key: Key) -> Result<EventSubscription, IndexerApiError> {
127 let (tx, rx) = mpsc::channel(32);
128 let subscription_id = self.next_id.fetch_add(1, Ordering::Relaxed);
129 self.event_subscribers.lock().await.insert(
130 subscription_id,
131 EventSubscriber {
132 key: key.clone(),
133 sender: tx,
134 },
135 );
136
137 let envelope = self
138 .request("SubscribeEvents", SubscribeEventsPayload { key: key.clone() })
139 .await?;
140 let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
141
142 Ok(EventSubscription {
143 client: self.clone(),
144 id: subscription_id,
145 key,
146 receiver: rx,
147 })
148 }
149
150 pub async fn unsubscribe_events(&self, key: Key) -> Result<(), IndexerApiError> {
151 let envelope = self
152 .request("UnsubscribeEvents", SubscribeEventsPayload { key: key.clone() })
153 .await?;
154 let _ = expect_payload::<SubscriptionStatusPayload>(envelope, "subscriptionStatus")?;
155 self.event_subscribers
156 .lock()
157 .await
158 .retain(|_, subscriber| subscriber.key != key);
159 Ok(())
160 }
161
162 async fn unregister_status_subscription(&self, id: u64) {
163 self.status_subscribers.lock().await.remove(&id);
164 }
165
166 async fn unregister_event_subscription(&self, id: u64) {
167 self.event_subscribers.lock().await.remove(&id);
168 }
169
170 async fn request<T>(&self, message_type: &'static str, payload: T) -> Result<Envelope, IndexerApiError>
171 where
172 T: serde::Serialize,
173 {
174 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
175 let request = RequestMessage {
176 id,
177 message_type,
178 payload,
179 };
180 let json = serde_json::to_string(&request)?;
181 let (tx, rx) = oneshot::channel();
182 self.pending.lock().await.insert(id, tx);
183
184 let send_result = self.writer.lock().await.send(Message::Text(json.into())).await;
185 if let Err(error) = send_result {
186 self.pending.lock().await.remove(&id);
187 return Err(error.into());
188 }
189
190 match rx.await {
191 Ok(result) => result,
192 Err(_) => Err(IndexerApiError::ResponseChannelClosed { request_id: id }),
193 }
194 }
195}
196
197impl StatusSubscription {
198 pub async fn next(&mut self) -> Option<Result<StatusUpdate, IndexerApiError>> {
199 self.receiver.recv().await
200 }
201
202 pub async fn unsubscribe(self) -> Result<(), IndexerApiError> {
203 let client = self.client.clone();
204 let id = self.id;
205 let result = client.unsubscribe_status().await;
206 client.unregister_status_subscription(id).await;
207 result
208 }
209}
210
211impl EventSubscription {
212 pub async fn next(&mut self) -> Option<Result<EventNotification, IndexerApiError>> {
213 self.receiver.recv().await
214 }
215
216 pub async fn unsubscribe(self) -> Result<(), IndexerApiError> {
217 let client = self.client.clone();
218 let id = self.id;
219 let key = self.key.clone();
220 let result = client.unsubscribe_events(key).await;
221 client.unregister_event_subscription(id).await;
222 result
223 }
224}
225
226impl Drop for StatusSubscription {
227 fn drop(&mut self) {
228 let client = self.client.clone();
229 let id = self.id;
230 tokio::spawn(async move {
231 client.unregister_status_subscription(id).await;
232 });
233 }
234}
235
236impl Drop for EventSubscription {
237 fn drop(&mut self) {
238 let client = self.client.clone();
239 let id = self.id;
240 tokio::spawn(async move {
241 client.unregister_event_subscription(id).await;
242 });
243 }
244}
245
246async fn run_reader(
247 mut reader: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
248 pending: Arc<Mutex<HashMap<u64, PendingSender>>>,
249 status_subscribers: StatusSubscribers,
250 event_subscribers: EventSubscribers,
251) {
252 while let Some(message) = reader.next().await {
253 match handle_message(message, &pending, &status_subscribers, &event_subscribers).await {
254 Ok(()) => {}
255 Err(error) => {
256 fail_all_pending(&pending, &error).await;
257 broadcast_status_error(&status_subscribers, &error).await;
258 broadcast_event_error(&event_subscribers, &error).await;
259 return;
260 }
261 }
262 }
263
264 let error = IndexerApiError::ConnectionClosed;
265 fail_all_pending(&pending, &error).await;
266 broadcast_status_error(&status_subscribers, &error).await;
267 broadcast_event_error(&event_subscribers, &error).await;
268}
269
270async fn handle_message(
271 message: Result<Message, tokio_tungstenite::tungstenite::Error>,
272 pending: &Arc<Mutex<HashMap<u64, PendingSender>>>,
273 status_subscribers: &StatusSubscribers,
274 event_subscribers: &EventSubscribers,
275) -> Result<(), IndexerApiError> {
276 let payload = match message? {
277 Message::Text(text) => text.to_string(),
278 Message::Binary(bytes) => {
279 String::from_utf8(bytes.to_vec()).map_err(|_| IndexerApiError::NonUtf8Binary)?
280 }
281 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => return Ok(()),
282 Message::Close(_) => return Err(IndexerApiError::ConnectionClosed),
283 };
284
285 let envelope: Envelope = serde_json::from_str(&payload)?;
286
287 if let Some(id) = envelope.id {
288 if let Some(sender) = pending.lock().await.remove(&id) {
289 let result = if envelope.message_type == "error" {
290 let error = parse_server_error(&envelope)?;
291 Err(error.into())
292 } else {
293 Ok(envelope)
294 };
295 let _ = sender.send(result);
296 return Ok(());
297 }
298 }
299
300 match envelope.message_type.as_str() {
301 "status" => {
302 let spans = envelope_data::<Vec<Span>>(&envelope)?;
303 broadcast_status_update(status_subscribers, StatusUpdate { spans }).await;
304 }
305 "eventNotification" => {
306 let payload = envelope_data::<EventNotificationPayload>(&envelope)?;
307 broadcast_event_update(
308 event_subscribers,
309 EventNotification {
310 key: payload.key,
311 event: payload.event,
312 decoded_event: payload.decoded_event,
313 },
314 )
315 .await;
316 }
317 "subscriptionTerminated" => {
318 let termination = envelope_data::<SubscriptionTerminatedPayload>(&envelope)?;
319 let subscription_error = SubscriptionTerminated {
320 reason: termination.reason,
321 message: termination.message,
322 };
323 let status_error = IndexerApiError::StatusSubscriptionTerminated {
324 reason: subscription_error.reason.clone(),
325 message: subscription_error.message.clone(),
326 };
327 let event_error = IndexerApiError::EventSubscriptionTerminated {
328 reason: subscription_error.reason,
329 message: subscription_error.message,
330 };
331 broadcast_status_error(status_subscribers, &status_error).await;
332 broadcast_event_error(event_subscribers, &event_error).await;
333 }
334 "error" => {
335 let error = parse_server_error(&envelope)?;
336 let error = IndexerApiError::from(error);
337 broadcast_status_error(status_subscribers, &error).await;
338 broadcast_event_error(event_subscribers, &error).await;
339 }
340 _ => {}
341 }
342
343 Ok(())
344}
345
346fn expect_payload<T>(envelope: Envelope, expected_type: &'static str) -> Result<T, IndexerApiError>
347where
348 T: for<'de> serde::Deserialize<'de>,
349{
350 if envelope.message_type != expected_type {
351 return Err(IndexerApiError::UnexpectedResponseType {
352 request_id: envelope.id.unwrap_or_default(),
353 message_type: envelope.message_type,
354 });
355 }
356
357 envelope_data(&envelope)
358}
359
360fn envelope_data<T>(envelope: &Envelope) -> Result<T, IndexerApiError>
361where
362 T: for<'de> serde::Deserialize<'de>,
363{
364 serde_json::from_value(
365 envelope
366 .data
367 .clone()
368 .ok_or(IndexerApiError::Json(serde_json::Error::io(std::io::Error::new(
369 std::io::ErrorKind::InvalidData,
370 "missing data field",
371 ))))?,
372 )
373 .map_err(IndexerApiError::from)
374}
375
376fn parse_server_error(envelope: &Envelope) -> Result<ServerError, IndexerApiError> {
377 let payload = envelope_data::<ErrorPayload>(envelope)?;
378 Ok(ServerError {
379 code: payload.code,
380 message: payload.message,
381 })
382}
383
384async fn fail_all_pending(
385 pending: &Arc<Mutex<HashMap<u64, PendingSender>>>,
386 error: &IndexerApiError,
387) {
388 let mut pending = pending.lock().await;
389 for (request_id, sender) in pending.drain() {
390 let _ = sender.send(Err(match error {
391 IndexerApiError::ConnectionClosed => IndexerApiError::RequestCancelled { request_id },
392 _ => IndexerApiError::BackgroundTaskEnded,
393 }));
394 }
395}
396
397async fn broadcast_status_update(
398 subscribers: &StatusSubscribers,
399 update: StatusUpdate,
400) {
401 let mut subscribers = subscribers.lock().await;
402 let ids: Vec<u64> = subscribers.keys().copied().collect();
403 for id in ids {
404 let Some(subscriber) = subscribers.get(&id).cloned() else {
405 continue;
406 };
407 if subscriber.send(Ok(update.clone())).await.is_err() {
408 subscribers.remove(&id);
409 }
410 }
411}
412
413async fn broadcast_event_update(
414 subscribers: &EventSubscribers,
415 update: EventNotification,
416) {
417 let mut subscribers = subscribers.lock().await;
418 let ids: Vec<u64> = subscribers.keys().copied().collect();
419 for id in ids {
420 let Some(subscriber) = subscribers.get(&id).cloned() else {
421 continue;
422 };
423 if subscriber.key == update.key && subscriber.sender.send(Ok(update.clone())).await.is_err() {
424 subscribers.remove(&id);
425 }
426 }
427}
428
429async fn broadcast_status_error(
430 subscribers: &StatusSubscribers,
431 error: &IndexerApiError,
432) {
433 let mut subscribers = subscribers.lock().await;
434 let ids: Vec<u64> = subscribers.keys().copied().collect();
435 for id in ids {
436 let Some(subscriber) = subscribers.get(&id).cloned() else {
437 continue;
438 };
439 if subscriber.send(Err(clone_error(error))).await.is_err() {
440 subscribers.remove(&id);
441 }
442 }
443}
444
445async fn broadcast_event_error(
446 subscribers: &EventSubscribers,
447 error: &IndexerApiError,
448) {
449 let mut subscribers = subscribers.lock().await;
450 let ids: Vec<u64> = subscribers.keys().copied().collect();
451 for id in ids {
452 let Some(subscriber) = subscribers.get(&id).cloned() else {
453 continue;
454 };
455 if subscriber.sender.send(Err(clone_error(error))).await.is_err() {
456 subscribers.remove(&id);
457 }
458 }
459}
460
461 fn clone_error(error: &IndexerApiError) -> IndexerApiError {
462 match error {
463 IndexerApiError::Url(error) => IndexerApiError::Url(*error),
464 IndexerApiError::WebSocket(_) => IndexerApiError::BackgroundTaskEnded,
465 IndexerApiError::Json(error) => IndexerApiError::Json(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, error.to_string()))),
466 IndexerApiError::RequestCancelled { request_id } => IndexerApiError::RequestCancelled { request_id: *request_id },
467 IndexerApiError::ResponseChannelClosed { request_id } => IndexerApiError::ResponseChannelClosed { request_id: *request_id },
468 IndexerApiError::Server { code, message } => IndexerApiError::Server { code: code.clone(), message: message.clone() },
469 IndexerApiError::StatusSubscriptionTerminated { reason, message } => IndexerApiError::StatusSubscriptionTerminated { reason: reason.clone(), message: message.clone() },
470 IndexerApiError::EventSubscriptionTerminated { reason, message } => IndexerApiError::EventSubscriptionTerminated { reason: reason.clone(), message: message.clone() },
471 IndexerApiError::UnexpectedResponseType { request_id, message_type } => IndexerApiError::UnexpectedResponseType { request_id: *request_id, message_type: message_type.clone() },
472 IndexerApiError::NonUtf8Binary => IndexerApiError::NonUtf8Binary,
473 IndexerApiError::ConnectionClosed => IndexerApiError::ConnectionClosed,
474 IndexerApiError::BackgroundTaskEnded => IndexerApiError::BackgroundTaskEnded,
475 }
476 }
477
478 #[cfg(test)]
479 mod tests {
480 use super::*;
481 use crate::types::{
482 CustomKey, CustomScalarValue, CustomValue, DecodedEvent, Envelope, EventRef,
483 };
484 use serde_json::json;
485 use tokio::net::TcpListener;
486 use tokio::sync::mpsc;
487 use tokio_tungstenite::accept_async;
488
489 fn custom_u32_key(name: &str, value: u32) -> Key {
490 Key::Custom(CustomKey {
491 name: name.into(),
492 value: CustomValue::U32(value),
493 })
494 }
495
496 fn composite_key(name: &str, bytes: u8, value: u32) -> Key {
497 Key::Custom(CustomKey {
498 name: name.into(),
499 value: CustomValue::Composite(vec![
500 CustomScalarValue::Bytes32(crate::types::Bytes32([bytes; 32])),
501 CustomScalarValue::U32(value),
502 ]),
503 })
504 }
505
506 #[test]
507 fn parses_status_payload() {
508 let envelope = Envelope {
509 id: Some(2),
510 message_type: "status".into(),
511 data: Some(json!([{"start": 1, "end": 8}])),
512 };
513
514 let spans = expect_payload::<Vec<Span>>(envelope, "status").unwrap();
515 assert_eq!(spans, vec![Span { start: 1, end: 8 }]);
516 }
517
518 #[test]
519 fn expect_payload_rejects_unexpected_response_type() {
520 let envelope = Envelope {
521 id: Some(2),
522 message_type: "variants".into(),
523 data: Some(json!([])),
524 };
525
526 let error = expect_payload::<Vec<Span>>(envelope, "status").unwrap_err();
527 match error {
528 IndexerApiError::UnexpectedResponseType {
529 request_id,
530 message_type,
531 } => {
532 assert_eq!(request_id, 2);
533 assert_eq!(message_type, "variants");
534 }
535 _ => panic!("unexpected error variant"),
536 }
537 }
538
539 #[test]
540 fn envelope_data_rejects_missing_data() {
541 let envelope = Envelope {
542 id: Some(2),
543 message_type: "status".into(),
544 data: None,
545 };
546
547 let error = envelope_data::<Vec<Span>>(&envelope).unwrap_err();
548 assert!(error.to_string().contains("missing data field"));
549 }
550
551 #[test]
552 fn parses_events_payload() {
553 let envelope = Envelope {
554 id: Some(3),
555 message_type: "events".into(),
556 data: Some(json!({
557 "key": {"type": "Custom", "value": {"name": "ref_index", "kind": "u32", "value": 42}},
558 "events": [{"blockNumber": 50, "eventIndex": 3}],
559 "decodedEvents": [{
560 "blockNumber": 50,
561 "eventIndex": 3,
562 "event": {
563 "specVersion": 1234,
564 "palletName": "Referenda",
565 "eventName": "Submitted",
566 "palletIndex": 42,
567 "variantIndex": 0,
568 "eventIndex": 3,
569 "fields": {"index": 42}
570 }
571 }]
572 })),
573 };
574
575 let response = expect_payload::<EventsResponse>(envelope, "events").unwrap();
576 assert_eq!(response.events.len(), 1);
577 assert_eq!(response.decoded_events.len(), 1);
578 }
579
580 #[test]
581 fn parses_server_error_payload() {
582 let envelope = Envelope {
583 id: Some(9),
584 message_type: "error".into(),
585 data: Some(json!({"code": "invalid_request", "message": "missing field `id`"})),
586 };
587
588 let error = parse_server_error(&envelope).unwrap();
589 assert_eq!(error.code, "invalid_request");
590 assert_eq!(error.message, "missing field `id`");
591 }
592
593 #[test]
594 fn handle_message_routes_response_to_matching_pending_request() {
595 let runtime = tokio::runtime::Builder::new_current_thread()
596 .enable_all()
597 .build()
598 .unwrap();
599
600 runtime.block_on(async {
601 let pending = Arc::new(Mutex::new(HashMap::new()));
602 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
603 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
604 let (tx, rx) = oneshot::channel();
605 pending.lock().await.insert(7, tx);
606
607 handle_message(
608 Ok(Message::Text(
609 serde_json::to_string(&json!({
610 "id": 7,
611 "type": "status",
612 "data": [{"start": 1, "end": 9}]
613 }))
614 .unwrap()
615 .into(),
616 )),
617 &pending,
618 &status_subscribers,
619 &event_subscribers,
620 )
621 .await
622 .unwrap();
623
624 let response = rx.await.unwrap().unwrap();
625 assert_eq!(response.id, Some(7));
626 assert_eq!(response.message_type, "status");
627 });
628 }
629
630 #[test]
631 fn handle_message_routes_server_error_to_matching_pending_request() {
632 let runtime = tokio::runtime::Builder::new_current_thread()
633 .enable_all()
634 .build()
635 .unwrap();
636
637 runtime.block_on(async {
638 let pending = Arc::new(Mutex::new(HashMap::new()));
639 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
640 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
641 let (tx, rx) = oneshot::channel();
642 pending.lock().await.insert(9, tx);
643
644 handle_message(
645 Ok(Message::Text(
646 serde_json::to_string(&json!({
647 "id": 9,
648 "type": "error",
649 "data": {"code": "invalid_request", "message": "missing field `id`"}
650 }))
651 .unwrap()
652 .into(),
653 )),
654 &pending,
655 &status_subscribers,
656 &event_subscribers,
657 )
658 .await
659 .unwrap();
660
661 let error = rx.await.unwrap().unwrap_err();
662 match error {
663 IndexerApiError::Server { code, message } => {
664 assert_eq!(code, "invalid_request");
665 assert_eq!(message, "missing field `id`");
666 }
667 _ => panic!("unexpected error variant"),
668 }
669 });
670 }
671
672 #[test]
673 fn handle_message_broadcasts_status_update_to_subscribers() {
674 let runtime = tokio::runtime::Builder::new_current_thread()
675 .enable_all()
676 .build()
677 .unwrap();
678
679 runtime.block_on(async {
680 let pending = Arc::new(Mutex::new(HashMap::new()));
681 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
682 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
683 let (tx, mut rx) = mpsc::channel(1);
684 status_subscribers.lock().await.insert(1, tx);
685
686 handle_message(
687 Ok(Message::Text(
688 serde_json::to_string(&json!({
689 "type": "status",
690 "data": [{"start": 1, "end": 8}]
691 }))
692 .unwrap()
693 .into(),
694 )),
695 &pending,
696 &status_subscribers,
697 &event_subscribers,
698 )
699 .await
700 .unwrap();
701
702 let update = rx.recv().await.unwrap().unwrap();
703 assert_eq!(update, StatusUpdate { spans: vec![Span { start: 1, end: 8 }] });
704 });
705 }
706
707 #[test]
708 fn handle_message_broadcasts_event_notification_to_subscribers() {
709 let runtime = tokio::runtime::Builder::new_current_thread()
710 .enable_all()
711 .build()
712 .unwrap();
713
714 runtime.block_on(async {
715 let pending = Arc::new(Mutex::new(HashMap::new()));
716 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
717 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
718 let (tx, mut rx) = mpsc::channel(1);
719 event_subscribers.lock().await.insert(1, EventSubscriber { key: custom_u32_key("ref_index", 42), sender: tx });
720
721 handle_message(
722 Ok(Message::Text(
723 serde_json::to_string(&json!({
724 "type": "eventNotification",
725 "data": {
726 "key": {"type": "Custom", "value": {"name": "ref_index", "kind": "u32", "value": 42}},
727 "event": {"blockNumber": 50, "eventIndex": 3},
728 "decodedEvent": null
729 }
730 }))
731 .unwrap()
732 .into(),
733 )),
734 &pending,
735 &status_subscribers,
736 &event_subscribers,
737 )
738 .await
739 .unwrap();
740
741 let update = rx.recv().await.unwrap().unwrap();
742 assert_eq!(update.key, custom_u32_key("ref_index", 42));
743 assert_eq!(update.event, EventRef { block_number: 50, event_index: 3 });
744 assert!(update.decoded_event.is_none());
745 });
746 }
747
748 #[test]
749 fn handle_message_broadcasts_subscription_termination_to_subscribers() {
750 let runtime = tokio::runtime::Builder::new_current_thread()
751 .enable_all()
752 .build()
753 .unwrap();
754
755 runtime.block_on(async {
756 let pending = Arc::new(Mutex::new(HashMap::new()));
757 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
758 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
759 let (status_tx, mut status_rx) = mpsc::channel(1);
760 let (event_tx, mut event_rx) = mpsc::channel(1);
761 status_subscribers.lock().await.insert(1, status_tx);
762 event_subscribers.lock().await.insert(2, EventSubscriber { key: custom_u32_key("ref_index", 42), sender: event_tx });
763
764 handle_message(
765 Ok(Message::Text(
766 serde_json::to_string(&json!({
767 "type": "subscriptionTerminated",
768 "data": {
769 "reason": "backpressure",
770 "message": "subscriber disconnected due to backpressure"
771 }
772 }))
773 .unwrap()
774 .into(),
775 )),
776 &pending,
777 &status_subscribers,
778 &event_subscribers,
779 )
780 .await
781 .unwrap();
782
783 match status_rx.recv().await.unwrap().unwrap_err() {
784 IndexerApiError::StatusSubscriptionTerminated { reason, message } => {
785 assert_eq!(reason, "backpressure");
786 assert_eq!(message, "subscriber disconnected due to backpressure");
787 }
788 _ => panic!("unexpected status error variant"),
789 }
790
791 match event_rx.recv().await.unwrap().unwrap_err() {
792 IndexerApiError::EventSubscriptionTerminated { reason, message } => {
793 assert_eq!(reason, "backpressure");
794 assert_eq!(message, "subscriber disconnected due to backpressure");
795 }
796 _ => panic!("unexpected event error variant"),
797 }
798 });
799 }
800
801 #[test]
802 fn handle_message_rejects_invalid_binary_payload() {
803 let runtime = tokio::runtime::Builder::new_current_thread()
804 .enable_all()
805 .build()
806 .unwrap();
807
808 runtime.block_on(async {
809 let pending = Arc::new(Mutex::new(HashMap::new()));
810 let status_subscribers = Arc::new(Mutex::new(HashMap::new()));
811 let event_subscribers = Arc::new(Mutex::new(HashMap::new()));
812
813 let error = handle_message(
814 Ok(Message::Binary(vec![0xFF, 0xFE].into())),
815 &pending,
816 &status_subscribers,
817 &event_subscribers,
818 )
819 .await
820 .unwrap_err();
821
822 assert!(matches!(error, IndexerApiError::NonUtf8Binary));
823 });
824 }
825
826 #[test]
827 fn fail_all_pending_marks_connection_closed_requests_as_cancelled() {
828 let runtime = tokio::runtime::Builder::new_current_thread()
829 .enable_all()
830 .build()
831 .unwrap();
832
833 runtime.block_on(async {
834 let pending = Arc::new(Mutex::new(HashMap::new()));
835 let (tx, rx) = oneshot::channel();
836 pending.lock().await.insert(12, tx);
837
838 fail_all_pending(&pending, &IndexerApiError::ConnectionClosed).await;
839
840 match rx.await.unwrap().unwrap_err() {
841 IndexerApiError::RequestCancelled { request_id } => assert_eq!(request_id, 12),
842 _ => panic!("unexpected error variant"),
843 }
844 });
845 }
846
847 #[test]
848 fn fail_all_pending_marks_other_failures_as_background_task_ended() {
849 let runtime = tokio::runtime::Builder::new_current_thread()
850 .enable_all()
851 .build()
852 .unwrap();
853
854 runtime.block_on(async {
855 let pending = Arc::new(Mutex::new(HashMap::new()));
856 let (tx, rx) = oneshot::channel();
857 pending.lock().await.insert(13, tx);
858
859 fail_all_pending(
860 &pending,
861 &IndexerApiError::Server {
862 code: "internal_error".into(),
863 message: "boom".into(),
864 },
865 )
866 .await;
867
868 assert!(matches!(
869 rx.await.unwrap().unwrap_err(),
870 IndexerApiError::BackgroundTaskEnded
871 ));
872 });
873 }
874
875 #[test]
876 fn clone_error_preserves_server_payload() {
877 let cloned = clone_error(&IndexerApiError::Server {
878 code: "invalid_request".into(),
879 message: "missing field `id`".into(),
880 });
881
882 match cloned {
883 IndexerApiError::Server { code, message } => {
884 assert_eq!(code, "invalid_request");
885 assert_eq!(message, "missing field `id`");
886 }
887 _ => panic!("unexpected error variant"),
888 }
889 }
890
891 #[test]
892 fn event_notification_payload_matches_server_shape() {
893 let payload = serde_json::from_value::<EventNotificationPayload>(json!({
894 "key": {"type": "Custom", "value": {"name": "item_id", "kind": "bytes32", "value": format!("0x{}", "11".repeat(32))}},
895 "event": {"blockNumber": 50, "eventIndex": 3},
896 "decodedEvent": {
897 "blockNumber": 50,
898 "eventIndex": 3,
899 "event": {
900 "specVersion": 1234,
901 "palletName": "Content",
902 "eventName": "PublishRevision",
903 "palletIndex": 42,
904 "variantIndex": 1,
905 "eventIndex": 3,
906 "fields": {}
907 }
908 }
909 }))
910 .unwrap();
911
912 assert_eq!(payload.event, EventRef { block_number: 50, event_index: 3 });
913 assert_eq!(payload.key, Key::Custom(CustomKey { name: "item_id".into(), value: CustomValue::Bytes32(crate::types::Bytes32([0x11; 32])) }));
914 assert_eq!(
915 payload.decoded_event,
916 Some(DecodedEvent {
917 block_number: 50,
918 event_index: 3,
919 event: crate::types::StoredEvent {
920 spec_version: 1234,
921 pallet_name: "Content".into(),
922 event_name: "PublishRevision".into(),
923 pallet_index: 42,
924 variant_index: 1,
925 event_index: 3,
926 fields: json!({}),
927 },
928 })
929 );
930 }
931
932 #[test]
933 fn broadcast_event_update_only_notifies_matching_keys() {
934 let runtime = tokio::runtime::Builder::new_current_thread()
935 .enable_all()
936 .build()
937 .unwrap();
938
939 runtime.block_on(async {
940 let subscribers = Arc::new(Mutex::new(HashMap::new()));
941 let (match_tx, mut match_rx) = mpsc::channel(1);
942 let (other_tx, mut other_rx) = mpsc::channel(1);
943
944 subscribers.lock().await.insert(
945 1,
946 EventSubscriber {
947 key: custom_u32_key("ref_index", 42),
948 sender: match_tx,
949 },
950 );
951 subscribers.lock().await.insert(
952 2,
953 EventSubscriber {
954 key: custom_u32_key("ref_index", 7),
955 sender: other_tx,
956 },
957 );
958
959 broadcast_event_update(
960 &subscribers,
961 EventNotification {
962 key: custom_u32_key("ref_index", 42),
963 event: EventRef {
964 block_number: 10,
965 event_index: 1,
966 },
967 decoded_event: None,
968 },
969 )
970 .await;
971
972 assert!(match_rx.recv().await.is_some());
973 assert!(other_rx.try_recv().is_err());
974 });
975 }
976
977 #[test]
978 fn broadcast_event_update_matches_composite_keys() {
979 let runtime = tokio::runtime::Builder::new_current_thread()
980 .enable_all()
981 .build()
982 .unwrap();
983
984 runtime.block_on(async {
985 let subscribers = Arc::new(Mutex::new(HashMap::new()));
986 let (match_tx, mut match_rx) = mpsc::channel(1);
987 let (other_tx, mut other_rx) = mpsc::channel(1);
988
989 subscribers.lock().await.insert(
990 1,
991 EventSubscriber {
992 key: composite_key("item_revision", 0x11, 7),
993 sender: match_tx,
994 },
995 );
996 subscribers.lock().await.insert(
997 2,
998 EventSubscriber {
999 key: composite_key("item_revision", 0x11, 8),
1000 sender: other_tx,
1001 },
1002 );
1003
1004 broadcast_event_update(
1005 &subscribers,
1006 EventNotification {
1007 key: composite_key("item_revision", 0x11, 7),
1008 event: EventRef {
1009 block_number: 10,
1010 event_index: 1,
1011 },
1012 decoded_event: None,
1013 },
1014 )
1015 .await;
1016
1017 assert!(match_rx.recv().await.is_some());
1018 assert!(other_rx.try_recv().is_err());
1019 });
1020 }
1021
1022 #[test]
1023 fn clone_error_preserves_response_channel_closed_payload() {
1024 let cloned = clone_error(&IndexerApiError::ResponseChannelClosed { request_id: 44 });
1025
1026 match cloned {
1027 IndexerApiError::ResponseChannelClosed { request_id } => assert_eq!(request_id, 44),
1028 _ => panic!("unexpected error variant"),
1029 }
1030 }
1031
1032 #[tokio::test(flavor = "current_thread")]
1033 async fn close_sends_websocket_close_frame() {
1034 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1035 let addr = listener.local_addr().unwrap();
1036
1037 let server = tokio::spawn(async move {
1038 let (stream, _) = listener.accept().await.unwrap();
1039 let mut websocket = accept_async(stream).await.unwrap();
1040
1041 match websocket.next().await {
1042 Some(Ok(Message::Close(_))) => {}
1043 other => panic!("expected websocket close frame, got {other:?}"),
1044 }
1045 });
1046
1047 let client = IndexerClient::connect(&format!("ws://{addr}")).await.unwrap();
1048 client.close().await.unwrap();
1049
1050 server.await.unwrap();
1051 }
1052}