deriv_api/
subscription.rs

1use crate::error::{DerivError, Result};
2use futures_util::Stream;
3use log::{debug, error, warn};
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11
12// Global subscription registry to track active subscriptions
13lazy_static::lazy_static! {
14    static ref SUBSCRIPTION_REGISTRY: Arc<Mutex<SubscriptionRegistry>> =
15        Arc::new(Mutex::new(SubscriptionRegistry::new()));
16}
17
18// Registry to map subscription IDs to subscription channels
19struct SubscriptionRegistry {
20    subscriptions: HashMap<String, SubscriptionSender>,
21    msg_type_map: HashMap<String, String>, // Maps msg_type to subscription_id
22}
23
24impl SubscriptionRegistry {
25    fn new() -> Self {
26        Self {
27            subscriptions: HashMap::new(),
28            msg_type_map: HashMap::new(),
29        }
30    }
31
32    fn register<T>(&mut self, subscription_id: String, sender: mpsc::Sender<T>, msg_type: &str) 
33    where 
34        T: DeserializeOwned + Send + 'static 
35    {
36        let sender_box = Box::new(move |data: &[u8]| {
37            match serde_json::from_slice::<T>(data) {
38                Ok(parsed) => {
39                    // Try sending the parsed data
40                    if sender.try_send(parsed).is_err() {
41                        debug!("Failed to send subscription update - receiver dropped");
42                        return false; // Return false to indicate we should remove this subscription
43                    }
44                    true
45                }
46                Err(e) => {
47                    error!("Failed to parse subscription data: {}", e);
48                    true // Keep subscription even if parsing fails
49                }
50            }
51        }) as SubscriptionSender;
52
53        self.subscriptions.insert(subscription_id.clone(), sender_box);
54        self.msg_type_map.insert(msg_type.to_string(), subscription_id.clone());
55        
56        debug!("Registered subscription: {} for msg_type: {}", subscription_id, msg_type);
57    }
58
59    fn dispatch(&mut self, data: &[u8]) -> bool {
60        // Extract the subscription ID or msg_type from the data
61        if let Ok(json) = serde_json::from_slice::<Value>(data) {
62            // First try to find subscription ID directly
63            if let Some(id) = json.get("id").and_then(|v| v.as_str()).or_else(|| {
64                // If not, look for subscription.id
65                json.get("subscription")
66                    .and_then(|s| s.get("id"))
67                    .and_then(|id| id.as_str())
68            }) {
69                if let Some(sender) = self.subscriptions.get_mut(id) {
70                    debug!("Found subscription handler for ID: {}", id);
71                    return sender(data);
72                }
73            }
74
75            // If no ID found or no handler for that ID, try using msg_type
76            if let Some(msg_type) = json.get("msg_type").and_then(|v| v.as_str()) {
77                if let Some(subscription_id) = self.msg_type_map.get(msg_type) {
78                    if let Some(sender) = self.subscriptions.get_mut(subscription_id) {
79                        debug!("Found subscription handler for msg_type: {}", msg_type);
80                        return sender(data);
81                    }
82                }
83            }
84        }
85        
86        debug!("No handler found for subscription update");
87        true // Keep checking future messages
88    }
89
90    fn unregister(&mut self, subscription_id: &str) -> bool {
91        // Find and remove any msg_type mappings for this subscription
92        let msg_types_to_remove: Vec<String> = self.msg_type_map
93            .iter()
94            .filter_map(|(msg_type, id)| {
95                if id == subscription_id {
96                    Some(msg_type.clone())
97                } else {
98                    None
99                }
100            })
101            .collect();
102            
103        for msg_type in msg_types_to_remove {
104            self.msg_type_map.remove(&msg_type);
105        }
106        
107        // Remove the subscription itself
108        self.subscriptions.remove(subscription_id).is_some()
109    }
110}
111
112// Type for storing subscription senders that can handle any type
113type SubscriptionSender = Box<dyn FnMut(&[u8]) -> bool + Send>;
114
115// Public function to handle incoming subscription messages
116pub(crate) fn handle_subscription_message(data: &[u8]) {
117    let mut registry = SUBSCRIPTION_REGISTRY.lock().unwrap();
118    if !registry.dispatch(data) {
119        // If dispatch returns false, the receiver was dropped, so remove the subscription
120        if let Ok(json) = serde_json::from_slice::<Value>(data) {
121            if let Some(id) = json.get("id")
122                .and_then(|v| v.as_str())
123                .or_else(|| json.get("subscription")
124                    .and_then(|s| s.get("id"))
125                    .and_then(|id| id.as_str())) 
126            {
127                debug!("Removing dropped subscription: {}", id);
128                registry.unregister(id);
129            }
130        }
131    }
132}
133
134pub struct Subscription<T> {
135    receiver: mpsc::Receiver<T>,
136    subscription_id: String,
137    client: Arc<crate::client::DerivClient>,
138}
139
140impl<T> Subscription<T> 
141where
142    T: DeserializeOwned + Send + 'static
143{
144    pub(crate) fn new(
145        receiver: mpsc::Receiver<T>, 
146        subscription_id: String,
147        client: Arc<crate::client::DerivClient>,
148        msg_type: &str
149    ) -> Self {
150        // Register the subscription in the registry
151        let (tx, _) = mpsc::channel::<T>(100); // Create a dummy sender just for type inference
152        SUBSCRIPTION_REGISTRY.lock().unwrap().register(
153            subscription_id.clone(),
154            tx, // This won't be used since we're providing the receiver directly
155            msg_type
156        );
157        
158        Self {
159            receiver,
160            subscription_id,
161            client,
162        }
163    }
164
165    pub fn subscription_id(&self) -> &str {
166        &self.subscription_id
167    }
168
169    pub async fn forget(&mut self) -> Result<()> {
170        // Send a forget request to the server
171        let forget_request = deriv_api_schema::ForgetRequest {
172            forget: self.subscription_id.clone(),
173            passthrough: None,
174            req_id: None,
175        };
176        
177        // Send the forget request
178        let _forget_response = self.client.forget(forget_request).await?;
179        
180        // Remove from local registry
181        SUBSCRIPTION_REGISTRY.lock().unwrap().unregister(&self.subscription_id);
182        
183        Ok(())
184    }
185}
186
187impl<T> Stream for Subscription<T> {
188    type Item = T;
189
190    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191        Pin::new(&mut self.receiver).poll_recv(cx)
192    }
193}
194
195#[derive(serde::Deserialize, Debug)]
196pub(crate) struct SubscriptionResponse {
197    pub subscription: Option<SubscriptionInfo>,
198}
199
200#[derive(serde::Deserialize, Debug)]
201pub(crate) struct SubscriptionInfo {
202    pub id: String,
203}
204
205pub(crate) fn parse_subscription_response(response: &[u8]) -> Result<String> {
206    let subscription_response: SubscriptionResponse = serde_json::from_slice(response)?;
207
208    subscription_response
209        .subscription
210        .map(|s| s.id)
211        .ok_or(DerivError::EmptySubscriptionId)
212}