1use crate::error::{DerivError, Result};
2use crate::utils::{validate_app_id, validate_language, validate_schema};
3use futures_util::{SinkExt, StreamExt};
4use log::debug;
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use serde::de::Error as SerdeError;
8use std::future::Future;
9use std::sync::atomic::{AtomicI64, Ordering};
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot, RwLock};
12use tokio_tungstenite::connect_async;
13use tokio_tungstenite::tungstenite::protocol::Message;
14use url::Url;
15use crate::subscription::Subscription;
16
17const DEFAULT_BUFFER_SIZE: usize = 1024;
18
19#[derive(Debug, Clone)]
21pub struct ClientConfig {
22 pub keep_alive: bool,
23 pub debug: bool,
24}
25
26impl Default for ClientConfig {
27 fn default() -> Self {
28 Self {
29 keep_alive: false,
30 debug: false,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct DerivClient {
38 endpoint: Url,
39 origin: Url,
40 app_id: i32,
41 language: String,
42 config: ClientConfig,
43 last_request_id: Arc<AtomicI64>,
44 request_sender: Option<mpsc::Sender<ApiRequest>>,
45 pending_request_registrar: Option<mpsc::Sender<PendingRequestInfo>>,
46 connection_status: Arc<RwLock<bool>>,
47}
48
49#[derive(Debug)]
50struct ApiRequest {
51 message: Vec<u8>,
52 request_id: i32,
53}
54
55#[derive(Debug)]
56struct PendingRequestInfo {
57 req_id: i32,
58 response_sender: oneshot::Sender<Result<Vec<u8>>>,
59}
60
61#[derive(Debug, Deserialize)]
62struct ApiResponseReqId {
63 req_id: Option<i32>,
64 error: Option<serde_json::Value>,
65}
66
67impl DerivClient {
68 pub fn new(
70 endpoint: &str,
71 app_id: i32,
72 language: &str,
73 origin: &str,
74 config: Option<ClientConfig>,
75 ) -> Result<Self> {
76 let endpoint_url = Url::parse(endpoint)?;
77 let origin_url = Url::parse(origin)?;
78
79 validate_schema(&endpoint_url)?;
80 validate_app_id(app_id)?;
81 validate_language(language)?;
82
83 let mut endpoint_url = endpoint_url;
85 endpoint_url
86 .query_pairs_mut()
87 .append_pair("app_id", &app_id.to_string())
88 .append_pair("l", language);
89
90 Ok(Self {
91 endpoint: endpoint_url,
92 origin: origin_url,
93 app_id,
94 language: language.to_string(),
95 config: config.unwrap_or_default(),
96 last_request_id: Arc::new(AtomicI64::new(0)),
97 request_sender: None,
98 pending_request_registrar: None,
99 connection_status: Arc::new(RwLock::new(false)),
100 })
101 }
102
103 pub async fn connect(&mut self) -> Result<()> {
105 if *self.connection_status.read().await {
106 return Ok(());
107 }
108
109 debug!("Connecting to {}", self.endpoint);
110
111 let ws_stream = connect_async(&self.endpoint).await?.0;
112 let (mut write, mut read) = ws_stream.split();
113
114 let (request_sender, mut request_receiver) = mpsc::channel(DEFAULT_BUFFER_SIZE);
115 let (incoming_msg_sender, mut incoming_msg_receiver) = mpsc::channel(DEFAULT_BUFFER_SIZE);
116 let (pending_request_sender, mut pending_request_receiver) = mpsc::channel::<PendingRequestInfo>(DEFAULT_BUFFER_SIZE);
117
118 self.request_sender = Some(request_sender);
119 self.pending_request_registrar = Some(pending_request_sender);
120 *self.connection_status.write().await = true;
121 let connection_status_write = self.connection_status.clone();
122 let _connection_status_read = self.connection_status.clone();
123
124 let _write_task = tokio::spawn(async move {
125 while let Some(request) = request_receiver.recv().await {
126 let msg_str = String::from_utf8_lossy(&request.message).into_owned();
127 debug!("Sending message: {}", msg_str);
128 if let Err(e) = write.send(Message::Text(msg_str)).await {
129 debug!("Failed to send message for req_id {}: {}", request.request_id, e);
130 break;
131 }
132 debug!("Message sent successfully for req_id: {}", request.request_id);
133 }
134 debug!("Sender task finished.");
135 });
136
137 let _read_task = tokio::spawn(async move {
138 while let Some(message_result) = read.next().await {
139 debug!("Received raw message: {:?}", message_result);
140 match message_result {
141 Ok(Message::Text(text)) => {
142 debug!("Received text message: {}", text);
143 if incoming_msg_sender.send(text.into_bytes()).await.is_err() {
144 debug!("Failed to forward response to handler, receiver dropped.");
145 break;
146 }
147 }
148 Ok(Message::Close(close_frame)) => {
149 debug!("Received Close frame: {:?}", close_frame);
150 break;
151 }
152 Ok(Message::Ping(ping_data)) => {
153 debug!("Received Ping: {:?}", ping_data);
154 }
155 Ok(_) => {
156 debug!("Received other message type");
157 }
158 Err(e) => {
159 debug!("WebSocket read error: {}", e);
160 break;
161 }
162 }
163 }
164 debug!("Receiver task finished.");
165 *connection_status_write.write().await = false;
166 });
167
168 let _response_handler = tokio::spawn(async move {
169 let mut pending_requests: std::collections::HashMap<i32, oneshot::Sender<Result<Vec<u8>>>> =
170 std::collections::HashMap::new();
171
172 loop {
173 tokio::select! {
174 Some(pending_info) = pending_request_receiver.recv() => {
175 debug!("Registering pending request: {}", pending_info.req_id);
176 pending_requests.insert(pending_info.req_id, pending_info.response_sender);
177 }
178 Some(response_bytes) = incoming_msg_receiver.recv() => {
179 match serde_json::from_slice::<ApiResponseReqId>(&response_bytes) {
180 Ok(api_response) => {
181 if let Some(req_id) = api_response.req_id {
182 if let Some(sender) = pending_requests.remove(&req_id) {
183 debug!("Routing response for req_id: {}", req_id);
184 if api_response.error.is_some() {
185 debug!("API Error found for req_id: {}", req_id);
186 match crate::error::parse_error(&response_bytes) {
187 Ok(_) => {
188 debug!("Warning: API error flag set, but parse_error succeeded for req_id: {}", req_id);
189 let _ = sender.send(Ok(response_bytes));
190 }
191 Err(e) => {
192 debug!("Sending parsed error for req_id: {}: {:?}", req_id, e);
193 let _ = sender.send(Err(e));
194 }
195 }
196 } else {
197 debug!("Success response for req_id: {}", req_id);
198 let _ = sender.send(Ok(response_bytes));
199 }
200 } else {
201 debug!("Received response for unknown req_id: {}", req_id);
202 }
203 } else {
204 debug!("Received message without req_id (likely subscription): {:?}", String::from_utf8_lossy(&response_bytes));
205 crate::subscription::handle_subscription_message(&response_bytes);
207 }
208 }
209 Err(e) => {
210 debug!("Failed to parse incoming message JSON for req_id: {}", e);
211 }
212 }
213 }
214 else => {
215 debug!("Response handler loop exiting.");
216 break;
217 }
218 }
219 }
220 debug!("Response handler task finished.");
221 for (_, sender) in pending_requests.drain() {
222 let _ = sender.send(Err(DerivError::ConnectionClosed));
223 }
224 });
225
226 Ok(())
227 }
228
229 pub async fn disconnect(&mut self) {
231 if !*self.connection_status.read().await {
232 return;
233 }
234
235 debug!("Disconnecting from {}", self.endpoint);
236
237 self.request_sender = None;
238 self.pending_request_registrar = None;
239 *self.connection_status.write().await = false;
240 }
241
242 pub async fn send_request<T, R>(&self, request: &T) -> Result<R>
244 where
245 T: Serialize + std::fmt::Debug,
246 R: DeserializeOwned,
247 {
248 if !*self.connection_status.read().await {
249 debug!("Attempted send_request while not connected.");
250 return Err(DerivError::ConnectionClosed);
251 }
252
253 let request_id = self.get_next_request_id();
254 let (response_sender, response_receiver) = oneshot::channel();
255
256 let mut request_value = serde_json::to_value(request)?;
257 if let Some(obj) = request_value.as_object_mut() {
258 obj.insert("req_id".to_string(), serde_json::json!(request_id));
259 } else {
260 return Err(DerivError::SerializationError(serde_json::Error::custom("Request is not a JSON object")));
261 }
262 let message = serde_json::to_vec(&request_value)?;
263
264 debug!("Serialized JSON being sent: {}", String::from_utf8_lossy(&message));
265
266 debug!("Preparing request req_id: {}, payload: {:?}", request_id, request);
267
268 if let Some(registrar) = &self.pending_request_registrar {
269 let pending_info = PendingRequestInfo {
270 req_id: request_id,
271 response_sender,
272 };
273 if registrar.send(pending_info).await.is_err() {
274 debug!("Failed to register pending request {}, response handler likely dead.", request_id);
275 return Err(DerivError::ConnectionClosed);
276 }
277 debug!("Pending request {} registered.", request_id);
278 } else {
279 debug!("Attempted send_request but registrar is None (not connected?).");
280 return Err(DerivError::ConnectionClosed);
281 }
282
283 let api_request = ApiRequest {
284 message,
285 request_id,
286 };
287
288 if let Some(sender) = &self.request_sender {
289 debug!("Sending request {} to writer task.", request_id);
290 if sender.send(api_request).await.is_err() {
291 debug!("Failed to send request {} to writer task (channel closed).", request_id);
292 return Err(DerivError::ConnectionClosed);
293 }
294 debug!("Request {} sent to writer task.", request_id);
295 } else {
296 debug!("Attempted send_request but sender is None (not connected?).");
297 return Err(DerivError::ConnectionClosed);
298 }
299
300 debug!("Waiting for response for req_id: {}", request_id);
301 match response_receiver.await {
302 Ok(Ok(response_bytes)) => {
303 debug!("Received successful response bytes for req_id: {}", request_id);
304 crate::error::parse_error(&response_bytes)?;
305 debug!("Deserializing successful response for req_id: {}", request_id);
306 Ok(serde_json::from_slice(&response_bytes)?)
307 }
308 Ok(Err(e)) => {
309 debug!("Received error from response handler for req_id: {}: {:?}", request_id, e);
310 Err(e)
311 }
312 Err(_) => {
313 debug!("Oneshot channel closed for req_id: {} (handler died?).", request_id);
314 Err(DerivError::ConnectionClosed)
315 }
316 }
317 }
318
319 fn get_next_request_id(&self) -> i32 {
320 self.last_request_id.fetch_add(1, Ordering::SeqCst) as i32
321 }
322
323 pub async fn create_subscription<T, R, S>(&self, request: &mut T, msg_type: &str) -> Result<(R, Subscription<S>)>
325 where
326 T: Serialize + std::fmt::Debug,
327 R: DeserializeOwned + Serialize,
328 S: DeserializeOwned + Send + 'static,
329 {
330 if !*self.connection_status.read().await {
331 debug!("Attempted create_subscription while not connected.");
332 return Err(DerivError::ConnectionClosed);
333 }
334
335 let initial_response: R = self.send_request(request).await?;
337
338 let response_value = serde_json::to_vec(&initial_response)?;
340 let subscription_id = crate::subscription::parse_subscription_response(&response_value)?;
341
342 let (_sender, receiver) = mpsc::channel::<S>(100);
344
345 let client_arc = Arc::new(self.clone());
347 let subscription = Subscription::new(receiver, subscription_id, client_arc, msg_type);
348
349 Ok((initial_response, subscription))
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[tokio::test]
358 async fn test_client_creation() {
359 let client = DerivClient::new(
360 "wss://ws.binaryws.com/websockets/v3",
361 1234,
362 "en",
363 "https://binary.com",
364 None,
365 );
366 assert!(client.is_ok());
367 }
368
369 #[tokio::test]
370 async fn test_invalid_schema() {
371 let client = DerivClient::new(
372 "http://ws.binaryws.com/websockets/v3",
373 1234,
374 "en",
375 "https://binary.com",
376 None,
377 );
378 assert!(matches!(client, Err(DerivError::InvalidSchema(_))));
379 }
380
381 #[tokio::test]
382 async fn test_invalid_app_id() {
383 let client = DerivClient::new(
384 "wss://ws.binaryws.com/websockets/v3",
385 0,
386 "en",
387 "https://binary.com",
388 None,
389 );
390 assert!(matches!(client, Err(DerivError::InvalidAppId(_))));
391 }
392
393 #[tokio::test]
394 async fn test_invalid_language() {
395 let client = DerivClient::new(
396 "wss://ws.binaryws.com/websockets/v3",
397 1234,
398 "eng",
399 "https://binary.com",
400 None,
401 );
402 assert!(matches!(client, Err(DerivError::InvalidLanguage(_))));
403 }
404}
405
406trait ReceiverExt<T> {
407 fn recv_next(&mut self) -> impl Future<Output = Option<T>>;
408}
409
410impl<T> ReceiverExt<T> for mpsc::Receiver<T> {
411 async fn recv_next(&mut self) -> Option<T> {
412 self.recv().await
413 }
414}
415