starknet_devnet_server/
subscribe.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::ws::{Message, WebSocket};
5use futures::SinkExt;
6use futures::stream::SplitSink;
7use serde::{self, Deserialize, Serialize};
8use starknet_core::starknet::events::check_if_filter_applies_for_event;
9use starknet_rs_core::types::Felt;
10use starknet_types::contract_address::ContractAddress;
11use starknet_types::emitted_event::SubscriptionEmittedEvent;
12use starknet_types::felt::TransactionHash;
13use starknet_types::rpc::block::{BlockHeader, ReorgData};
14use starknet_types::rpc::transaction_receipt::TransactionReceipt;
15use starknet_types::rpc::transactions::{
16    TransactionFinalityStatus, TransactionStatus, TransactionWithHash,
17};
18use tokio::sync::Mutex;
19
20use crate::api::error::ApiError;
21use crate::api::models::SubscriptionId;
22use crate::rpc_core::request::Id;
23
24pub type SocketId = u64;
25
26#[derive(Default)]
27pub struct SocketCollection {
28    sockets: HashMap<SocketId, SocketContext>,
29}
30
31impl SocketCollection {
32    pub fn get_mut(&mut self, socket_id: &SocketId) -> Result<&mut SocketContext, ApiError> {
33        self.sockets.get_mut(socket_id).ok_or(ApiError::StarknetDevnetError(
34            starknet_core::error::Error::UnexpectedInternalError {
35                msg: format!("Unregistered socket ID: {socket_id}"),
36            },
37        ))
38    }
39
40    /// Assigns a random socket ID to the socket whose `socket_writer` is provided. Returns the ID.
41    pub fn insert(&mut self, socket_writer: Arc<Mutex<SplitSink<WebSocket, Message>>>) -> SocketId {
42        let socket_id = rand::random();
43        self.sockets.insert(socket_id, SocketContext::from_sender(socket_writer));
44        socket_id
45    }
46
47    pub fn remove(&mut self, socket_id: &SocketId) {
48        self.sockets.remove(socket_id);
49    }
50
51    pub async fn notify_subscribers(&self, notifications: &[NotificationData]) {
52        for (_, socket_context) in self.sockets.iter() {
53            for notification in notifications {
54                socket_context.notify_subscribers(notification).await;
55            }
56        }
57    }
58
59    pub fn clear(&mut self) {
60        self.sockets
61            .iter_mut()
62            .for_each(|(_, socket_context)| socket_context.subscriptions.clear());
63        tracing::info!("Websocket memory cleared. No subscribers.");
64    }
65}
66
67#[derive(Debug)]
68pub struct AddressFilter {
69    address_container: Vec<ContractAddress>,
70}
71
72impl AddressFilter {
73    pub(crate) fn new(address_container: Vec<ContractAddress>) -> Self {
74        Self { address_container }
75    }
76    pub(crate) fn passes(&self, address: &ContractAddress) -> bool {
77        self.address_container.is_empty() || self.address_container.contains(address)
78    }
79}
80
81#[derive(Debug, Clone)]
82pub struct StatusFilter {
83    status_container: Vec<TransactionFinalityStatus>,
84}
85
86impl StatusFilter {
87    pub(crate) fn new(status_container: Vec<TransactionFinalityStatus>) -> Self {
88        Self { status_container }
89    }
90
91    pub(crate) fn passes(&self, status: &TransactionFinalityStatus) -> bool {
92        self.status_container.is_empty() || self.status_container.contains(status)
93    }
94}
95
96#[derive(Debug)]
97pub enum Subscription {
98    NewHeads,
99    TransactionStatus {
100        transaction_hash: TransactionHash,
101    },
102    NewTransactions {
103        address_filter: AddressFilter,
104        status_filter: StatusFilter,
105    },
106    NewTransactionReceipts {
107        address_filter: AddressFilter,
108        status_filter: StatusFilter,
109    },
110    Events {
111        address: Option<ContractAddress>,
112        keys_filter: Option<Vec<Vec<Felt>>>,
113        status_filter: StatusFilter,
114    },
115}
116
117impl Subscription {
118    fn confirm(&self, id: SubscriptionId) -> SubscriptionConfirmation {
119        match self {
120            Subscription::NewHeads => SubscriptionConfirmation::NewSubscription(id),
121            Subscription::TransactionStatus { .. } => SubscriptionConfirmation::NewSubscription(id),
122            Subscription::NewTransactions { .. } => SubscriptionConfirmation::NewSubscription(id),
123            Subscription::NewTransactionReceipts { .. } => {
124                SubscriptionConfirmation::NewSubscription(id)
125            }
126            Subscription::Events { .. } => SubscriptionConfirmation::NewSubscription(id),
127        }
128    }
129
130    pub fn matches(&self, notification: &NotificationData) -> bool {
131        match (self, notification) {
132            (Subscription::NewHeads, NotificationData::NewHeads(_)) => true,
133            (
134                Subscription::TransactionStatus { transaction_hash: subscription_hash },
135                NotificationData::TransactionStatus(notification),
136            ) => subscription_hash == &notification.transaction_hash,
137            (
138                Subscription::NewTransactions { address_filter, status_filter },
139                NotificationData::NewTransaction(NewTransactionNotification {
140                    tx,
141                    finality_status,
142                }),
143            ) => match tx.get_sender_address() {
144                Some(address) => {
145                    address_filter.passes(&address) && status_filter.passes(finality_status)
146                }
147                None => true,
148            },
149            (
150                Subscription::NewTransactionReceipts { address_filter, status_filter },
151                NotificationData::NewTransactionReceipt(NewTransactionReceiptNotification {
152                    tx_receipt,
153                    sender_address,
154                }),
155            ) => {
156                status_filter.passes(tx_receipt.finality_status())
157                    && match sender_address {
158                        Some(address) => address_filter.passes(address),
159                        None => true,
160                    }
161            }
162            (
163                Subscription::Events { address, keys_filter, status_filter },
164                NotificationData::Event(event_with_finality_status),
165            ) => {
166                let event = (&event_with_finality_status.emitted_event).into();
167                check_if_filter_applies_for_event(address, keys_filter, &event)
168                    && status_filter.passes(&event_with_finality_status.finality_status)
169            }
170            (
171                Subscription::NewHeads
172                | Subscription::TransactionStatus { .. }
173                | Subscription::Events { .. }
174                | Subscription::NewTransactions { .. }
175                | Subscription::NewTransactionReceipts { .. },
176                NotificationData::Reorg(_),
177            ) => true, // All subscriptions require a reorg notification
178            _ => false,
179        }
180    }
181}
182
183#[derive(Debug, Serialize)]
184#[serde(untagged)]
185#[cfg_attr(test, derive(Deserialize))]
186pub(crate) enum SubscriptionConfirmation {
187    NewSubscription(SubscriptionId),
188    Unsubscription(bool),
189}
190
191#[derive(Debug, Clone, Serialize)]
192#[cfg_attr(test, derive(Deserialize))]
193pub struct NewTransactionStatus {
194    pub transaction_hash: TransactionHash,
195    pub status: TransactionStatus,
196}
197
198#[derive(Debug, Clone)]
199pub struct TransactionHashWrapper {
200    pub hash: TransactionHash,
201    pub sender_address: Option<ContractAddress>,
202}
203
204impl Serialize for TransactionHashWrapper {
205    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: serde::Serializer,
208    {
209        self.hash.serialize(serializer)
210    }
211}
212
213impl<'de> Deserialize<'de> for TransactionHashWrapper {
214    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215    where
216        D: serde::Deserializer<'de>,
217    {
218        let hash = Felt::deserialize(deserializer)?;
219
220        Ok(TransactionHashWrapper { hash, sender_address: None })
221    }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
225#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
226pub enum TransactionFinalityStatusWithoutL1 {
227    PreConfirmed,
228    AcceptedOnL2,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
232#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
233pub enum TransactionStatusWithoutL1 {
234    Received,
235    Candidate,
236    PreConfirmed,
237    AcceptedOnL2,
238}
239
240impl From<TransactionFinalityStatusWithoutL1> for TransactionFinalityStatus {
241    fn from(status: TransactionFinalityStatusWithoutL1) -> Self {
242        match status {
243            TransactionFinalityStatusWithoutL1::PreConfirmed => Self::PreConfirmed,
244            TransactionFinalityStatusWithoutL1::AcceptedOnL2 => Self::AcceptedOnL2,
245        }
246    }
247}
248
249impl From<TransactionStatusWithoutL1> for TransactionFinalityStatus {
250    fn from(status: TransactionStatusWithoutL1) -> Self {
251        match status {
252            TransactionStatusWithoutL1::Received => Self::Received,
253            TransactionStatusWithoutL1::Candidate => Self::Candidate,
254            TransactionStatusWithoutL1::PreConfirmed => Self::PreConfirmed,
255            TransactionStatusWithoutL1::AcceptedOnL2 => Self::AcceptedOnL2,
256        }
257    }
258}
259
260#[derive(Debug, Clone, Serialize)]
261#[cfg_attr(test, derive(Deserialize))]
262pub struct NewTransactionNotification {
263    #[serde(flatten)]
264    pub tx: TransactionWithHash,
265    pub finality_status: TransactionFinalityStatus,
266}
267
268#[derive(Debug, Clone)]
269pub struct NewTransactionReceiptNotification {
270    pub tx_receipt: TransactionReceipt,
271    pub sender_address: Option<ContractAddress>,
272}
273
274#[derive(Debug, Clone)]
275pub enum NotificationData {
276    NewHeads(BlockHeader),
277    TransactionStatus(NewTransactionStatus),
278    NewTransaction(NewTransactionNotification),
279    NewTransactionReceipt(NewTransactionReceiptNotification),
280    Event(SubscriptionEmittedEvent),
281    Reorg(ReorgData),
282}
283
284#[derive(Debug, Serialize)]
285#[serde(untagged)]
286#[cfg_attr(test, derive(Deserialize))]
287pub(crate) enum SubscriptionResponse {
288    Confirmation {
289        #[serde(rename = "id")]
290        rpc_request_id: Id,
291        result: SubscriptionConfirmation,
292    },
293    Notification(Box<SubscriptionNotification>),
294}
295
296#[derive(Serialize, Debug)]
297#[cfg_attr(test, derive(Deserialize))]
298#[serde(tag = "method", content = "params")]
299pub(crate) enum SubscriptionNotification {
300    #[serde(rename = "starknet_subscriptionNewHeads")]
301    NewHeads { subscription_id: SubscriptionId, result: BlockHeader },
302    #[serde(rename = "starknet_subscriptionTransactionStatus")]
303    TransactionStatus { subscription_id: SubscriptionId, result: NewTransactionStatus },
304    #[serde(rename = "starknet_subscriptionNewTransaction")]
305    NewTransaction { subscription_id: SubscriptionId, result: NewTransactionNotification },
306    #[serde(rename = "starknet_subscriptionNewTransactionReceipts")]
307    NewTransactionReceipt { subscription_id: SubscriptionId, result: TransactionReceipt },
308    #[serde(rename = "starknet_subscriptionEvents")]
309    Event { subscription_id: SubscriptionId, result: SubscriptionEmittedEvent },
310    #[serde(rename = "starknet_subscriptionReorg")]
311    Reorg { subscription_id: SubscriptionId, result: ReorgData },
312}
313
314impl SubscriptionResponse {
315    fn to_serialized_rpc_response(&self) -> serde_json::Value {
316        let mut resp = serde_json::json!(self);
317
318        resp["jsonrpc"] = "2.0".into();
319        resp
320    }
321}
322
323pub struct SocketContext {
324    /// The sender part of the socket's own channel
325    sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
326    subscriptions: HashMap<SubscriptionId, Subscription>,
327}
328
329impl SocketContext {
330    pub fn from_sender(sender: Arc<Mutex<SplitSink<WebSocket, Message>>>) -> Self {
331        Self { sender, subscriptions: HashMap::new() }
332    }
333
334    async fn send_serialized(&self, resp: String) {
335        if let Err(e) = self.sender.lock().await.send(Message::Text(resp.into())).await {
336            tracing::error!("Failed writing to socket: {}", e.to_string());
337        }
338    }
339
340    pub async fn send_rpc_response(&self, result: serde_json::Value, id: Id) {
341        let resp_serialized = serde_json::json!({
342            "jsonrpc": "2.0",
343            "id": id,
344            "result": result,
345        })
346        .to_string();
347
348        tracing::trace!(target: "ws.json-rpc-api", response = %resp_serialized, "JSON-RPC response");
349        self.send_serialized(resp_serialized).await;
350    }
351
352    async fn send_subscription_response(&self, resp: SubscriptionResponse) {
353        let resp_serialized = resp.to_serialized_rpc_response().to_string();
354
355        tracing::trace!(target: "ws.subscriptions", response = %resp_serialized, "subscription response");
356        self.send_serialized(resp_serialized).await;
357    }
358
359    pub async fn subscribe(
360        &mut self,
361        rpc_request_id: Id,
362        subscription: Subscription,
363    ) -> SubscriptionId {
364        loop {
365            let subscription_id: SubscriptionId = rand::random::<u64>().into();
366            if self.subscriptions.contains_key(&subscription_id) {
367                continue;
368            }
369
370            let confirmation = subscription.confirm(subscription_id);
371            self.subscriptions.insert(subscription_id, subscription);
372
373            self.send_subscription_response(SubscriptionResponse::Confirmation {
374                rpc_request_id,
375                result: confirmation,
376            })
377            .await;
378
379            return subscription_id;
380        }
381    }
382
383    pub async fn unsubscribe(
384        &mut self,
385        rpc_request_id: Id,
386        subscription_id: SubscriptionId,
387    ) -> Result<(), ApiError> {
388        self.subscriptions.remove(&subscription_id).ok_or(ApiError::InvalidSubscriptionId)?;
389        self.send_subscription_response(SubscriptionResponse::Confirmation {
390            rpc_request_id,
391            result: SubscriptionConfirmation::Unsubscription(true),
392        })
393        .await;
394        Ok(())
395    }
396
397    pub async fn notify(&self, subscription_id: SubscriptionId, data: NotificationData) {
398        let notification_data = match data {
399            NotificationData::NewHeads(block_header) => {
400                SubscriptionNotification::NewHeads { subscription_id, result: block_header }
401            }
402
403            NotificationData::TransactionStatus(new_transaction_status) => {
404                SubscriptionNotification::TransactionStatus {
405                    subscription_id,
406                    result: new_transaction_status,
407                }
408            }
409
410            NotificationData::NewTransaction(tx_notification) => {
411                SubscriptionNotification::NewTransaction {
412                    subscription_id,
413                    result: tx_notification,
414                }
415            }
416
417            NotificationData::NewTransactionReceipt(tx_receipt_notification) => {
418                SubscriptionNotification::NewTransactionReceipt {
419                    subscription_id,
420                    result: tx_receipt_notification.tx_receipt,
421                }
422            }
423
424            NotificationData::Event(emitted_event) => {
425                SubscriptionNotification::Event { subscription_id, result: emitted_event }
426            }
427
428            NotificationData::Reorg(reorg_data) => {
429                SubscriptionNotification::Reorg { subscription_id, result: reorg_data }
430            }
431        };
432
433        self.send_subscription_response(SubscriptionResponse::Notification(Box::new(
434            notification_data,
435        )))
436        .await;
437    }
438
439    pub async fn notify_subscribers(&self, notification: &NotificationData) {
440        for (subscription_id, subscription) in self.subscriptions.iter() {
441            if subscription.matches(notification) {
442                self.notify(*subscription_id, notification.clone()).await;
443            }
444        }
445    }
446}