deriv_api/
subscription.rs1use crate::error::{DerivError, Result};
2use futures_util::Stream;
3use log::{debug, error, warn};
4use serde::de::DeserializeOwned;
5use serde_json::Value;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11
12lazy_static::lazy_static! {
14 static ref SUBSCRIPTION_REGISTRY: Arc<Mutex<SubscriptionRegistry>> =
15 Arc::new(Mutex::new(SubscriptionRegistry::new()));
16}
17
18struct SubscriptionRegistry {
20 subscriptions: HashMap<String, SubscriptionSender>,
21 msg_type_map: HashMap<String, String>, }
23
24impl SubscriptionRegistry {
25 fn new() -> Self {
26 Self {
27 subscriptions: HashMap::new(),
28 msg_type_map: HashMap::new(),
29 }
30 }
31
32 fn register<T>(&mut self, subscription_id: String, sender: mpsc::Sender<T>, msg_type: &str)
33 where
34 T: DeserializeOwned + Send + 'static
35 {
36 let sender_box = Box::new(move |data: &[u8]| {
37 match serde_json::from_slice::<T>(data) {
38 Ok(parsed) => {
39 if sender.try_send(parsed).is_err() {
41 debug!("Failed to send subscription update - receiver dropped");
42 return false; }
44 true
45 }
46 Err(e) => {
47 error!("Failed to parse subscription data: {}", e);
48 true }
50 }
51 }) as SubscriptionSender;
52
53 self.subscriptions.insert(subscription_id.clone(), sender_box);
54 self.msg_type_map.insert(msg_type.to_string(), subscription_id.clone());
55
56 debug!("Registered subscription: {} for msg_type: {}", subscription_id, msg_type);
57 }
58
59 fn dispatch(&mut self, data: &[u8]) -> bool {
60 if let Ok(json) = serde_json::from_slice::<Value>(data) {
62 if let Some(id) = json.get("id").and_then(|v| v.as_str()).or_else(|| {
64 json.get("subscription")
66 .and_then(|s| s.get("id"))
67 .and_then(|id| id.as_str())
68 }) {
69 if let Some(sender) = self.subscriptions.get_mut(id) {
70 debug!("Found subscription handler for ID: {}", id);
71 return sender(data);
72 }
73 }
74
75 if let Some(msg_type) = json.get("msg_type").and_then(|v| v.as_str()) {
77 if let Some(subscription_id) = self.msg_type_map.get(msg_type) {
78 if let Some(sender) = self.subscriptions.get_mut(subscription_id) {
79 debug!("Found subscription handler for msg_type: {}", msg_type);
80 return sender(data);
81 }
82 }
83 }
84 }
85
86 debug!("No handler found for subscription update");
87 true }
89
90 fn unregister(&mut self, subscription_id: &str) -> bool {
91 let msg_types_to_remove: Vec<String> = self.msg_type_map
93 .iter()
94 .filter_map(|(msg_type, id)| {
95 if id == subscription_id {
96 Some(msg_type.clone())
97 } else {
98 None
99 }
100 })
101 .collect();
102
103 for msg_type in msg_types_to_remove {
104 self.msg_type_map.remove(&msg_type);
105 }
106
107 self.subscriptions.remove(subscription_id).is_some()
109 }
110}
111
112type SubscriptionSender = Box<dyn FnMut(&[u8]) -> bool + Send>;
114
115pub(crate) fn handle_subscription_message(data: &[u8]) {
117 let mut registry = SUBSCRIPTION_REGISTRY.lock().unwrap();
118 if !registry.dispatch(data) {
119 if let Ok(json) = serde_json::from_slice::<Value>(data) {
121 if let Some(id) = json.get("id")
122 .and_then(|v| v.as_str())
123 .or_else(|| json.get("subscription")
124 .and_then(|s| s.get("id"))
125 .and_then(|id| id.as_str()))
126 {
127 debug!("Removing dropped subscription: {}", id);
128 registry.unregister(id);
129 }
130 }
131 }
132}
133
134pub struct Subscription<T> {
135 receiver: mpsc::Receiver<T>,
136 subscription_id: String,
137 client: Arc<crate::client::DerivClient>,
138}
139
140impl<T> Subscription<T>
141where
142 T: DeserializeOwned + Send + 'static
143{
144 pub(crate) fn new(
145 receiver: mpsc::Receiver<T>,
146 subscription_id: String,
147 client: Arc<crate::client::DerivClient>,
148 msg_type: &str
149 ) -> Self {
150 let (tx, _) = mpsc::channel::<T>(100); SUBSCRIPTION_REGISTRY.lock().unwrap().register(
153 subscription_id.clone(),
154 tx, msg_type
156 );
157
158 Self {
159 receiver,
160 subscription_id,
161 client,
162 }
163 }
164
165 pub fn subscription_id(&self) -> &str {
166 &self.subscription_id
167 }
168
169 pub async fn forget(&mut self) -> Result<()> {
170 let forget_request = deriv_api_schema::ForgetRequest {
172 forget: self.subscription_id.clone(),
173 passthrough: None,
174 req_id: None,
175 };
176
177 let _forget_response = self.client.forget(forget_request).await?;
179
180 SUBSCRIPTION_REGISTRY.lock().unwrap().unregister(&self.subscription_id);
182
183 Ok(())
184 }
185}
186
187impl<T> Stream for Subscription<T> {
188 type Item = T;
189
190 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191 Pin::new(&mut self.receiver).poll_recv(cx)
192 }
193}
194
195#[derive(serde::Deserialize, Debug)]
196pub(crate) struct SubscriptionResponse {
197 pub subscription: Option<SubscriptionInfo>,
198}
199
200#[derive(serde::Deserialize, Debug)]
201pub(crate) struct SubscriptionInfo {
202 pub id: String,
203}
204
205pub(crate) fn parse_subscription_response(response: &[u8]) -> Result<String> {
206 let subscription_response: SubscriptionResponse = serde_json::from_slice(response)?;
207
208 subscription_response
209 .subscription
210 .map(|s| s.id)
211 .ok_or(DerivError::EmptySubscriptionId)
212}