deriv_api/
client.rs

1use crate::error::{DerivError, Result};
2use crate::utils::{validate_app_id, validate_language, validate_schema};
3use futures_util::{SinkExt, StreamExt};
4use log::debug;
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use serde::de::Error as SerdeError;
8use std::future::Future;
9use std::sync::atomic::{AtomicI64, Ordering};
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot, RwLock};
12use tokio_tungstenite::connect_async;
13use tokio_tungstenite::tungstenite::protocol::Message;
14use url::Url;
15use crate::subscription::Subscription;
16
17const DEFAULT_BUFFER_SIZE: usize = 1024;
18
19/// Configuration options for the Deriv API client
20#[derive(Debug, Clone)]
21pub struct ClientConfig {
22    pub keep_alive: bool,
23    pub debug: bool,
24}
25
26impl Default for ClientConfig {
27    fn default() -> Self {
28        Self {
29            keep_alive: false,
30            debug: false,
31        }
32    }
33}
34
35/// The main client for interacting with the Deriv API
36#[derive(Debug, Clone)]
37pub struct DerivClient {
38    endpoint: Url,
39    origin: Url,
40    app_id: i32,
41    language: String,
42    config: ClientConfig,
43    last_request_id: Arc<AtomicI64>,
44    request_sender: Option<mpsc::Sender<ApiRequest>>,
45    pending_request_registrar: Option<mpsc::Sender<PendingRequestInfo>>,
46    connection_status: Arc<RwLock<bool>>,
47}
48
49#[derive(Debug)]
50struct ApiRequest {
51    message: Vec<u8>,
52    request_id: i32,
53}
54
55#[derive(Debug)]
56struct PendingRequestInfo {
57    req_id: i32,
58    response_sender: oneshot::Sender<Result<Vec<u8>>>,
59}
60
61#[derive(Debug, Deserialize)]
62struct ApiResponseReqId {
63    req_id: Option<i32>,
64    error: Option<serde_json::Value>,
65}
66
67impl DerivClient {
68    /// Creates a new instance of DerivClient
69    pub fn new(
70        endpoint: &str,
71        app_id: i32,
72        language: &str,
73        origin: &str,
74        config: Option<ClientConfig>,
75    ) -> Result<Self> {
76        let endpoint_url = Url::parse(endpoint)?;
77        let origin_url = Url::parse(origin)?;
78
79        validate_schema(&endpoint_url)?;
80        validate_app_id(app_id)?;
81        validate_language(language)?;
82
83        // Build the endpoint URL with query parameters
84        let mut endpoint_url = endpoint_url;
85        endpoint_url
86            .query_pairs_mut()
87            .append_pair("app_id", &app_id.to_string())
88            .append_pair("l", language);
89
90        Ok(Self {
91            endpoint: endpoint_url,
92            origin: origin_url,
93            app_id,
94            language: language.to_string(),
95            config: config.unwrap_or_default(),
96            last_request_id: Arc::new(AtomicI64::new(0)),
97            request_sender: None,
98            pending_request_registrar: None,
99            connection_status: Arc::new(RwLock::new(false)),
100        })
101    }
102
103    /// Connects to the Deriv WebSocket API
104    pub async fn connect(&mut self) -> Result<()> {
105        if *self.connection_status.read().await {
106            return Ok(());
107        }
108
109        debug!("Connecting to {}", self.endpoint);
110
111        let ws_stream = connect_async(&self.endpoint).await?.0;
112        let (mut write, mut read) = ws_stream.split();
113
114        let (request_sender, mut request_receiver) = mpsc::channel(DEFAULT_BUFFER_SIZE);
115        let (incoming_msg_sender, mut incoming_msg_receiver) = mpsc::channel(DEFAULT_BUFFER_SIZE);
116        let (pending_request_sender, mut pending_request_receiver) = mpsc::channel::<PendingRequestInfo>(DEFAULT_BUFFER_SIZE);
117
118        self.request_sender = Some(request_sender);
119        self.pending_request_registrar = Some(pending_request_sender);
120        *self.connection_status.write().await = true;
121        let connection_status_write = self.connection_status.clone();
122        let _connection_status_read = self.connection_status.clone();
123
124        let _write_task = tokio::spawn(async move {
125            while let Some(request) = request_receiver.recv().await {
126                let msg_str = String::from_utf8_lossy(&request.message).into_owned();
127                debug!("Sending message: {}", msg_str);
128                if let Err(e) = write.send(Message::Text(msg_str)).await {
129                    debug!("Failed to send message for req_id {}: {}", request.request_id, e);
130                    break;
131                }
132                debug!("Message sent successfully for req_id: {}", request.request_id);
133            }
134            debug!("Sender task finished.");
135        });
136
137        let _read_task = tokio::spawn(async move {
138            while let Some(message_result) = read.next().await {
139                debug!("Received raw message: {:?}", message_result);
140                match message_result {
141                    Ok(Message::Text(text)) => {
142                        debug!("Received text message: {}", text);
143                        if incoming_msg_sender.send(text.into_bytes()).await.is_err() {
144                            debug!("Failed to forward response to handler, receiver dropped.");
145                            break;
146                        }
147                    }
148                    Ok(Message::Close(close_frame)) => {
149                        debug!("Received Close frame: {:?}", close_frame);
150                        break;
151                    }
152                    Ok(Message::Ping(ping_data)) => {
153                        debug!("Received Ping: {:?}", ping_data);
154                    }
155                    Ok(_) => {
156                        debug!("Received other message type");
157                    }
158                    Err(e) => {
159                        debug!("WebSocket read error: {}", e);
160                        break;
161                    }
162                }
163            }
164            debug!("Receiver task finished.");
165            *connection_status_write.write().await = false;
166        });
167
168        let _response_handler = tokio::spawn(async move {
169            let mut pending_requests: std::collections::HashMap<i32, oneshot::Sender<Result<Vec<u8>>>> =
170                std::collections::HashMap::new();
171
172            loop {
173                tokio::select! {
174                    Some(pending_info) = pending_request_receiver.recv() => {
175                        debug!("Registering pending request: {}", pending_info.req_id);
176                        pending_requests.insert(pending_info.req_id, pending_info.response_sender);
177                    }
178                    Some(response_bytes) = incoming_msg_receiver.recv() => {
179                        match serde_json::from_slice::<ApiResponseReqId>(&response_bytes) {
180                            Ok(api_response) => {
181                                if let Some(req_id) = api_response.req_id {
182                                    if let Some(sender) = pending_requests.remove(&req_id) {
183                                        debug!("Routing response for req_id: {}", req_id);
184                                        if api_response.error.is_some() {
185                                            debug!("API Error found for req_id: {}", req_id);
186                                            match crate::error::parse_error(&response_bytes) {
187                                                Ok(_) => {
188                                                    debug!("Warning: API error flag set, but parse_error succeeded for req_id: {}", req_id);
189                                                    let _ = sender.send(Ok(response_bytes));
190                                                }
191                                                Err(e) => {
192                                                    debug!("Sending parsed error for req_id: {}: {:?}", req_id, e);
193                                                    let _ = sender.send(Err(e));
194                                                }
195                                            }
196                                        } else {
197                                            debug!("Success response for req_id: {}", req_id);
198                                            let _ = sender.send(Ok(response_bytes));
199                                        }
200                                    } else {
201                                        debug!("Received response for unknown req_id: {}", req_id);
202                                    }
203                                } else {
204                                    debug!("Received message without req_id (likely subscription): {:?}", String::from_utf8_lossy(&response_bytes));
205                                    // Process subscription message
206                                    crate::subscription::handle_subscription_message(&response_bytes);
207                                }
208                            }
209                            Err(e) => {
210                                debug!("Failed to parse incoming message JSON for req_id: {}", e);
211                            }
212                        }
213                    }
214                    else => {
215                        debug!("Response handler loop exiting.");
216                        break;
217                    }
218                }
219            }
220            debug!("Response handler task finished.");
221            for (_, sender) in pending_requests.drain() {
222                let _ = sender.send(Err(DerivError::ConnectionClosed));
223            }
224        });
225
226        Ok(())
227    }
228
229    /// Disconnects from the Deriv WebSocket API
230    pub async fn disconnect(&mut self) {
231        if !*self.connection_status.read().await {
232            return;
233        }
234
235        debug!("Disconnecting from {}", self.endpoint);
236
237        self.request_sender = None;
238        self.pending_request_registrar = None;
239        *self.connection_status.write().await = false;
240    }
241
242    /// Sends a request to the Deriv API and receives a response
243    pub async fn send_request<T, R>(&self, request: &T) -> Result<R>
244    where
245        T: Serialize + std::fmt::Debug,
246        R: DeserializeOwned,
247    {
248        if !*self.connection_status.read().await {
249            debug!("Attempted send_request while not connected.");
250            return Err(DerivError::ConnectionClosed);
251        }
252
253        let request_id = self.get_next_request_id();
254        let (response_sender, response_receiver) = oneshot::channel();
255
256        let mut request_value = serde_json::to_value(request)?;
257        if let Some(obj) = request_value.as_object_mut() {
258            obj.insert("req_id".to_string(), serde_json::json!(request_id));
259        } else {
260            return Err(DerivError::SerializationError(serde_json::Error::custom("Request is not a JSON object")));
261        }
262        let message = serde_json::to_vec(&request_value)?;
263
264        debug!("Serialized JSON being sent: {}", String::from_utf8_lossy(&message));
265
266        debug!("Preparing request req_id: {}, payload: {:?}", request_id, request);
267
268        if let Some(registrar) = &self.pending_request_registrar {
269            let pending_info = PendingRequestInfo {
270                req_id: request_id,
271                response_sender,
272            };
273            if registrar.send(pending_info).await.is_err() {
274                debug!("Failed to register pending request {}, response handler likely dead.", request_id);
275                return Err(DerivError::ConnectionClosed);
276            }
277            debug!("Pending request {} registered.", request_id);
278        } else {
279            debug!("Attempted send_request but registrar is None (not connected?).");
280            return Err(DerivError::ConnectionClosed);
281        }
282
283        let api_request = ApiRequest {
284            message,
285            request_id,
286        };
287
288        if let Some(sender) = &self.request_sender {
289            debug!("Sending request {} to writer task.", request_id);
290            if sender.send(api_request).await.is_err() {
291                debug!("Failed to send request {} to writer task (channel closed).", request_id);
292                return Err(DerivError::ConnectionClosed);
293            }
294            debug!("Request {} sent to writer task.", request_id);
295        } else {
296            debug!("Attempted send_request but sender is None (not connected?).");
297            return Err(DerivError::ConnectionClosed);
298        }
299
300        debug!("Waiting for response for req_id: {}", request_id);
301        match response_receiver.await {
302            Ok(Ok(response_bytes)) => {
303                debug!("Received successful response bytes for req_id: {}", request_id);
304                crate::error::parse_error(&response_bytes)?;
305                debug!("Deserializing successful response for req_id: {}", request_id);
306                Ok(serde_json::from_slice(&response_bytes)?)
307            }
308            Ok(Err(e)) => {
309                debug!("Received error from response handler for req_id: {}: {:?}", request_id, e);
310                Err(e)
311            }
312            Err(_) => {
313                debug!("Oneshot channel closed for req_id: {} (handler died?).", request_id);
314                Err(DerivError::ConnectionClosed)
315            }
316        }
317    }
318
319    fn get_next_request_id(&self) -> i32 {
320        self.last_request_id.fetch_add(1, Ordering::SeqCst) as i32
321    }
322    
323    /// Creates a subscription from a subscription-enabled API
324    pub async fn create_subscription<T, R, S>(&self, request: &mut T, msg_type: &str) -> Result<(R, Subscription<S>)>
325    where
326        T: Serialize + std::fmt::Debug,
327        R: DeserializeOwned + Serialize,
328        S: DeserializeOwned + Send + 'static,
329    {
330        if !*self.connection_status.read().await {
331            debug!("Attempted create_subscription while not connected.");
332            return Err(DerivError::ConnectionClosed);
333        }
334        
335        // Send the initial request to start the subscription
336        let initial_response: R = self.send_request(request).await?;
337        
338        // Extract the subscription ID from the response
339        let response_value = serde_json::to_vec(&initial_response)?;
340        let subscription_id = crate::subscription::parse_subscription_response(&response_value)?;
341        
342        // Create a channel for the subscription updates
343        let (_sender, receiver) = mpsc::channel::<S>(100);
344        
345        // Create the subscription with an Arc to self
346        let client_arc = Arc::new(self.clone());
347        let subscription = Subscription::new(receiver, subscription_id, client_arc, msg_type);
348        
349        Ok((initial_response, subscription))
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[tokio::test]
358    async fn test_client_creation() {
359        let client = DerivClient::new(
360            "wss://ws.binaryws.com/websockets/v3",
361            1234,
362            "en",
363            "https://binary.com",
364            None,
365        );
366        assert!(client.is_ok());
367    }
368
369    #[tokio::test]
370    async fn test_invalid_schema() {
371        let client = DerivClient::new(
372            "http://ws.binaryws.com/websockets/v3",
373            1234,
374            "en",
375            "https://binary.com",
376            None,
377        );
378        assert!(matches!(client, Err(DerivError::InvalidSchema(_))));
379    }
380
381    #[tokio::test]
382    async fn test_invalid_app_id() {
383        let client = DerivClient::new(
384            "wss://ws.binaryws.com/websockets/v3",
385            0,
386            "en",
387            "https://binary.com",
388            None,
389        );
390        assert!(matches!(client, Err(DerivError::InvalidAppId(_))));
391    }
392
393    #[tokio::test]
394    async fn test_invalid_language() {
395        let client = DerivClient::new(
396            "wss://ws.binaryws.com/websockets/v3",
397            1234,
398            "eng",
399            "https://binary.com",
400            None,
401        );
402        assert!(matches!(client, Err(DerivError::InvalidLanguage(_))));
403    }
404}
405
406trait ReceiverExt<T> {
407    fn recv_next(&mut self) -> impl Future<Output = Option<T>>;
408}
409
410impl<T> ReceiverExt<T> for mpsc::Receiver<T> {
411    async fn recv_next(&mut self) -> Option<T> {
412        self.recv().await
413    }
414}
415