1#![allow(unused_imports)]
20use anyhow::Context;
21use async_trait::async_trait;
22use derive_builder::Builder;
23use rust_decimal::prelude::*;
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26use std::{collections::BTreeMap, sync::Arc};
27
28use crate::common::{
29 errors::WebsocketError,
30 models::{ParamBuildError, WebsocketApiResponse},
31 utils::remove_empty_value,
32 websocket::{WebsocketApi, WebsocketMessageSendOptions},
33};
34use crate::spot::websocket_api::models;
35
36#[async_trait]
37pub trait AuthApi: Send + Sync {
38 async fn session_logon(
39 &self,
40 params: SessionLogonParams,
41 ) -> anyhow::Result<Vec<WebsocketApiResponse<Box<models::SessionLogonResponseResult>>>>;
42 async fn session_logout(
43 &self,
44 params: SessionLogoutParams,
45 ) -> anyhow::Result<Vec<WebsocketApiResponse<Box<models::SessionLogoutResponseResult>>>>;
46 async fn session_status(
47 &self,
48 params: SessionStatusParams,
49 ) -> anyhow::Result<WebsocketApiResponse<Box<models::SessionStatusResponseResult>>>;
50}
51
52#[derive(Clone)]
53pub struct AuthApiClient {
54 websocket_api_base: Arc<WebsocketApi>,
55}
56
57impl AuthApiClient {
58 pub fn new(websocket_api_base: Arc<WebsocketApi>) -> Self {
59 Self { websocket_api_base }
60 }
61}
62
63#[derive(Clone, Debug, Builder, Default)]
68#[builder(pattern = "owned", build_fn(error = "ParamBuildError"))]
69pub struct SessionLogonParams {
70 #[builder(setter(into), default)]
74 pub id: Option<String>,
75 #[builder(setter(into), default)]
79 pub recv_window: Option<rust_decimal::Decimal>,
80}
81
82impl SessionLogonParams {
83 #[must_use]
86 pub fn builder() -> SessionLogonParamsBuilder {
87 SessionLogonParamsBuilder::default()
88 }
89}
90#[derive(Clone, Debug, Builder, Default)]
95#[builder(pattern = "owned", build_fn(error = "ParamBuildError"))]
96pub struct SessionLogoutParams {
97 #[builder(setter(into), default)]
101 pub id: Option<String>,
102}
103
104impl SessionLogoutParams {
105 #[must_use]
108 pub fn builder() -> SessionLogoutParamsBuilder {
109 SessionLogoutParamsBuilder::default()
110 }
111}
112#[derive(Clone, Debug, Builder, Default)]
117#[builder(pattern = "owned", build_fn(error = "ParamBuildError"))]
118pub struct SessionStatusParams {
119 #[builder(setter(into), default)]
123 pub id: Option<String>,
124}
125
126impl SessionStatusParams {
127 #[must_use]
130 pub fn builder() -> SessionStatusParamsBuilder {
131 SessionStatusParamsBuilder::default()
132 }
133}
134
135#[async_trait]
136impl AuthApi for AuthApiClient {
137 async fn session_logon(
138 &self,
139 params: SessionLogonParams,
140 ) -> anyhow::Result<Vec<WebsocketApiResponse<Box<models::SessionLogonResponseResult>>>> {
141 let SessionLogonParams { id, recv_window } = params;
142
143 let mut payload: BTreeMap<String, Value> = BTreeMap::new();
144 if let Some(value) = id {
145 payload.insert("id".to_string(), serde_json::json!(value));
146 }
147 if let Some(value) = recv_window {
148 payload.insert("recvWindow".to_string(), serde_json::json!(value));
149 }
150 let payload = remove_empty_value(payload);
151
152 let response = self
153 .websocket_api_base
154 .send_message::<Box<models::SessionLogonResponseResult>>(
155 "/session.logon".trim_start_matches('/'),
156 payload,
157 WebsocketMessageSendOptions::new().signed().session_logon(),
158 )
159 .await
160 .map_err(anyhow::Error::from)?
161 .into_iter()
162 .collect();
163
164 Ok(response)
165 }
166
167 async fn session_logout(
168 &self,
169 params: SessionLogoutParams,
170 ) -> anyhow::Result<Vec<WebsocketApiResponse<Box<models::SessionLogoutResponseResult>>>> {
171 let SessionLogoutParams { id } = params;
172
173 let mut payload: BTreeMap<String, Value> = BTreeMap::new();
174 if let Some(value) = id {
175 payload.insert("id".to_string(), serde_json::json!(value));
176 }
177 let payload = remove_empty_value(payload);
178
179 let response = self
180 .websocket_api_base
181 .send_message::<Box<models::SessionLogoutResponseResult>>(
182 "/session.logout".trim_start_matches('/'),
183 payload,
184 WebsocketMessageSendOptions::new().session_logout(),
185 )
186 .await
187 .map_err(anyhow::Error::from)?
188 .into_iter()
189 .collect();
190
191 Ok(response)
192 }
193
194 async fn session_status(
195 &self,
196 params: SessionStatusParams,
197 ) -> anyhow::Result<WebsocketApiResponse<Box<models::SessionStatusResponseResult>>> {
198 let SessionStatusParams { id } = params;
199
200 let mut payload: BTreeMap<String, Value> = BTreeMap::new();
201 if let Some(value) = id {
202 payload.insert("id".to_string(), serde_json::json!(value));
203 }
204 let payload = remove_empty_value(payload);
205
206 self.websocket_api_base
207 .send_message::<Box<models::SessionStatusResponseResult>>(
208 "/session.status".trim_start_matches('/'),
209 payload,
210 WebsocketMessageSendOptions::new(),
211 )
212 .await
213 .map_err(anyhow::Error::from)?
214 .into_iter()
215 .next()
216 .ok_or(WebsocketError::NoResponse)
217 .map_err(anyhow::Error::from)
218 }
219}
220
221#[cfg(all(test, feature = "spot"))]
222mod tests {
223 use super::*;
224 use crate::TOKIO_SHARED_RT;
225 use crate::common::websocket::{WebsocketApi, WebsocketConnection, WebsocketHandler};
226 use crate::config::ConfigurationWebsocketApi;
227 use crate::errors::WebsocketError;
228 use crate::models::WebsocketApiRateLimit;
229 use serde_json::{Value, json};
230 use tokio::spawn;
231 use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
232 use tokio::time::{Duration, timeout};
233 use tokio_tungstenite::tungstenite::Message;
234
235 async fn setup() -> (
236 Arc<WebsocketApi>,
237 Arc<WebsocketConnection>,
238 UnboundedReceiver<Message>,
239 ) {
240 let conn = WebsocketConnection::new("test-conn");
241 let (tx, rx) = unbounded_channel::<Message>();
242 {
243 let mut conn_state = conn.state.lock().await;
244 conn_state.ws_write_tx = Some(tx);
245 }
246
247 let config = ConfigurationWebsocketApi::builder()
248 .api_key("key")
249 .api_secret("secret")
250 .build()
251 .expect("Failed to build configuration");
252 let ws_api = WebsocketApi::new(config, vec![conn.clone()]);
253 conn.set_handler(ws_api.clone() as Arc<dyn WebsocketHandler>)
254 .await;
255 ws_api.clone().connect().await.unwrap();
256
257 (ws_api, conn, rx)
258 }
259
260 #[test]
261 fn session_logon_success() {
262 TOKIO_SHARED_RT.block_on(async {
263 let (ws_api, conn, mut rx) = setup().await;
264 let client = AuthApiClient::new(ws_api.clone());
265
266 let handle = spawn(async move {
267 let params = SessionLogonParams::builder().build().unwrap();
268 client.session_logon(params).await
269 });
270
271 let sent = timeout(Duration::from_secs(1), rx.recv()).await.expect("send should occur").expect("channel closed");
272 let Message::Text(text) = sent else { panic!() };
273 let v: Value = serde_json::from_str(&text).unwrap();
274 let id = v["id"].as_str().unwrap();
275 assert_eq!(v["method"], "/session.logon".trim_start_matches('/'));
276
277 let mut resp_json: Value = serde_json::from_str(r#"{"id":"c174a2b1-3f51-4580-b200-8528bd237cb7","status":200,"result":{"apiKey":"vmPUZE6mv9SD5VNHk4HlWFsOr6aKE2zvsw0MuIgwCIPy6utIco14y7Ju91duEh8A","authorizedSince":1649729878532,"connectedSince":1649729873021,"returnRateLimits":false,"serverTime":1649729878630,"userDataStream":false}}"#).unwrap();
278 resp_json["id"] = id.into();
279
280 let raw_data = resp_json.get("result").or_else(|| resp_json.get("response")).expect("no response in JSON");
281 let expected_data: Box<models::SessionLogonResponseResult> = serde_json::from_value(raw_data.clone()).expect("should parse raw response");
282 let empty_array = Value::Array(vec![]);
283 let raw_rate_limits = resp_json.get("rateLimits").unwrap_or(&empty_array);
284 let expected_rate_limits: Option<Vec<WebsocketApiRateLimit>> =
285 match raw_rate_limits.as_array() {
286 Some(arr) if arr.is_empty() => None,
287 Some(_) => Some(serde_json::from_value(raw_rate_limits.clone()).expect("should parse rateLimits array")),
288 None => None,
289 };
290
291 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
292
293 let response = timeout(Duration::from_secs(1), handle).await.expect("task done").expect("no panic").expect("no error");
294let response = response.into_iter().next().expect("should have response");
295
296 let response_rate_limits = response.rate_limits.clone();
297 let response_data = response.data().expect("deserialize data");
298
299 assert_eq!(response_rate_limits, expected_rate_limits);
300 assert_eq!(response_data, expected_data);
301 });
302 }
303
304 #[test]
305 fn session_logon_error_response() {
306 TOKIO_SHARED_RT.block_on(async {
307 let (ws_api, conn, mut rx) = setup().await;
308 let client = AuthApiClient::new(ws_api.clone());
309
310 let handle = tokio::spawn(async move {
311 let params = SessionLogonParams::builder().build().unwrap();
312 client.session_logon(params).await
313 });
314
315 let sent = timeout(Duration::from_secs(1), rx.recv()).await.unwrap().unwrap();
316 let Message::Text(text) = sent else { panic!() };
317 let v: Value = serde_json::from_str(&text).unwrap();
318 let id = v["id"].as_str().unwrap().to_string();
319
320 let resp_json = json!({
321 "id": id,
322 "status": 400,
323 "error": {
324 "code": -2010,
325 "msg": "Account has insufficient balance for requested action.",
326 },
327 "rateLimits": [
328 {
329 "rateLimitType": "ORDERS",
330 "interval": "SECOND",
331 "intervalNum": 10,
332 "limit": 50,
333 "count": 13
334 },
335 ],
336 });
337 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
338
339 let join = timeout(Duration::from_secs(1), handle).await.unwrap();
340 match join {
341 Ok(Err(e)) => {
342 let msg = e.to_string();
343 assert!(
344 msg.contains("Server‐side response error (code -2010): Account has insufficient balance for requested action."),
345 "Expected error msg to contain server error, got: {msg}"
346 );
347 }
348 Ok(Ok(_)) => panic!("Expected error"),
349 Err(_) => panic!("Task panicked"),
350 }
351 });
352 }
353
354 #[test]
355 fn session_logon_request_timeout() {
356 TOKIO_SHARED_RT.block_on(async {
357 let (ws_api, _conn, mut rx) = setup().await;
358 let client = AuthApiClient::new(ws_api.clone());
359
360 let handle = spawn(async move {
361 let params = SessionLogonParams::builder().build().unwrap();
362 client.session_logon(params).await
363 });
364
365 let sent = timeout(Duration::from_secs(1), rx.recv())
366 .await
367 .expect("send should occur")
368 .expect("channel closed");
369 let Message::Text(text) = sent else {
370 panic!("expected Message Text")
371 };
372
373 let _: Value = serde_json::from_str(&text).unwrap();
374
375 let result = handle.await.expect("task completed");
376 match result {
377 Err(e) => {
378 if let Some(inner) = e.downcast_ref::<WebsocketError>() {
379 assert!(matches!(inner, WebsocketError::Timeout));
380 } else {
381 panic!("Unexpected error type: {:?}", e);
382 }
383 }
384 Ok(_) => panic!("Expected timeout error"),
385 }
386 });
387 }
388
389 #[test]
390 fn session_logout_success() {
391 TOKIO_SHARED_RT.block_on(async {
392 let (ws_api, conn, mut rx) = setup().await;
393 let client = AuthApiClient::new(ws_api.clone());
394
395 let handle = spawn(async move {
396 let params = SessionLogoutParams::builder().build().unwrap();
397 client.session_logout(params).await
398 });
399
400 let sent = timeout(Duration::from_secs(1), rx.recv()).await.expect("send should occur").expect("channel closed");
401 let Message::Text(text) = sent else { panic!() };
402 let v: Value = serde_json::from_str(&text).unwrap();
403 let id = v["id"].as_str().unwrap();
404 assert_eq!(v["method"], "/session.logout".trim_start_matches('/'));
405
406 let mut resp_json: Value = serde_json::from_str(r#"{"id":"c174a2b1-3f51-4580-b200-8528bd237cb7","status":200,"result":{"apiKey":"CAvIjXy3F44yW6Pou5k8Dy1swsYDWJZLeoK2r8G4cFDnE9nosRppc2eKc1T8TRTQ","authorizedSince":1649729878532,"connectedSince":1649729873021,"returnRateLimits":false,"serverTime":1649730611671,"userDataStream":false}}"#).unwrap();
407 resp_json["id"] = id.into();
408
409 let raw_data = resp_json.get("result").or_else(|| resp_json.get("response")).expect("no response in JSON");
410 let expected_data: Box<models::SessionLogoutResponseResult> = serde_json::from_value(raw_data.clone()).expect("should parse raw response");
411 let empty_array = Value::Array(vec![]);
412 let raw_rate_limits = resp_json.get("rateLimits").unwrap_or(&empty_array);
413 let expected_rate_limits: Option<Vec<WebsocketApiRateLimit>> =
414 match raw_rate_limits.as_array() {
415 Some(arr) if arr.is_empty() => None,
416 Some(_) => Some(serde_json::from_value(raw_rate_limits.clone()).expect("should parse rateLimits array")),
417 None => None,
418 };
419
420 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
421
422 let response = timeout(Duration::from_secs(1), handle).await.expect("task done").expect("no panic").expect("no error");
423let response = response.into_iter().next().expect("should have response");
424
425 let response_rate_limits = response.rate_limits.clone();
426 let response_data = response.data().expect("deserialize data");
427
428 assert_eq!(response_rate_limits, expected_rate_limits);
429 assert_eq!(response_data, expected_data);
430 });
431 }
432
433 #[test]
434 fn session_logout_error_response() {
435 TOKIO_SHARED_RT.block_on(async {
436 let (ws_api, conn, mut rx) = setup().await;
437 let client = AuthApiClient::new(ws_api.clone());
438
439 let handle = tokio::spawn(async move {
440 let params = SessionLogoutParams::builder().build().unwrap();
441 client.session_logout(params).await
442 });
443
444 let sent = timeout(Duration::from_secs(1), rx.recv()).await.unwrap().unwrap();
445 let Message::Text(text) = sent else { panic!() };
446 let v: Value = serde_json::from_str(&text).unwrap();
447 let id = v["id"].as_str().unwrap().to_string();
448
449 let resp_json = json!({
450 "id": id,
451 "status": 400,
452 "error": {
453 "code": -2010,
454 "msg": "Account has insufficient balance for requested action.",
455 },
456 "rateLimits": [
457 {
458 "rateLimitType": "ORDERS",
459 "interval": "SECOND",
460 "intervalNum": 10,
461 "limit": 50,
462 "count": 13
463 },
464 ],
465 });
466 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
467
468 let join = timeout(Duration::from_secs(1), handle).await.unwrap();
469 match join {
470 Ok(Err(e)) => {
471 let msg = e.to_string();
472 assert!(
473 msg.contains("Server‐side response error (code -2010): Account has insufficient balance for requested action."),
474 "Expected error msg to contain server error, got: {msg}"
475 );
476 }
477 Ok(Ok(_)) => panic!("Expected error"),
478 Err(_) => panic!("Task panicked"),
479 }
480 });
481 }
482
483 #[test]
484 fn session_logout_request_timeout() {
485 TOKIO_SHARED_RT.block_on(async {
486 let (ws_api, _conn, mut rx) = setup().await;
487 let client = AuthApiClient::new(ws_api.clone());
488
489 let handle = spawn(async move {
490 let params = SessionLogoutParams::builder().build().unwrap();
491 client.session_logout(params).await
492 });
493
494 let sent = timeout(Duration::from_secs(1), rx.recv())
495 .await
496 .expect("send should occur")
497 .expect("channel closed");
498 let Message::Text(text) = sent else {
499 panic!("expected Message Text")
500 };
501
502 let _: Value = serde_json::from_str(&text).unwrap();
503
504 let result = handle.await.expect("task completed");
505 match result {
506 Err(e) => {
507 if let Some(inner) = e.downcast_ref::<WebsocketError>() {
508 assert!(matches!(inner, WebsocketError::Timeout));
509 } else {
510 panic!("Unexpected error type: {:?}", e);
511 }
512 }
513 Ok(_) => panic!("Expected timeout error"),
514 }
515 });
516 }
517
518 #[test]
519 fn session_status_success() {
520 TOKIO_SHARED_RT.block_on(async {
521 let (ws_api, conn, mut rx) = setup().await;
522 let client = AuthApiClient::new(ws_api.clone());
523
524 let handle = spawn(async move {
525 let params = SessionStatusParams::builder().build().unwrap();
526 client.session_status(params).await
527 });
528
529 let sent = timeout(Duration::from_secs(1), rx.recv()).await.expect("send should occur").expect("channel closed");
530 let Message::Text(text) = sent else { panic!() };
531 let v: Value = serde_json::from_str(&text).unwrap();
532 let id = v["id"].as_str().unwrap();
533 assert_eq!(v["method"], "/session.status".trim_start_matches('/'));
534
535 let mut resp_json: Value = serde_json::from_str(r#"{"id":"b50c16cd-62c9-4e29-89e4-37f10111f5bf","status":200,"result":{"apiKey":"vmPUZE6mv9SD5VNHk4HlWFsOr6aKE2zvsw0MuIgwCIPy6utIco14y7Ju91duEh8A","authorizedSince":1649729878532,"connectedSince":1649729873021,"returnRateLimits":false,"serverTime":1649730611671,"userDataStream":true}}"#).unwrap();
536 resp_json["id"] = id.into();
537
538 let raw_data = resp_json.get("result").or_else(|| resp_json.get("response")).expect("no response in JSON");
539 let expected_data: Box<models::SessionStatusResponseResult> = serde_json::from_value(raw_data.clone()).expect("should parse raw response");
540 let empty_array = Value::Array(vec![]);
541 let raw_rate_limits = resp_json.get("rateLimits").unwrap_or(&empty_array);
542 let expected_rate_limits: Option<Vec<WebsocketApiRateLimit>> =
543 match raw_rate_limits.as_array() {
544 Some(arr) if arr.is_empty() => None,
545 Some(_) => Some(serde_json::from_value(raw_rate_limits.clone()).expect("should parse rateLimits array")),
546 None => None,
547 };
548
549 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
550
551 let response = timeout(Duration::from_secs(1), handle).await.expect("task done").expect("no panic").expect("no error");
552
553
554 let response_rate_limits = response.rate_limits.clone();
555 let response_data = response.data().expect("deserialize data");
556
557 assert_eq!(response_rate_limits, expected_rate_limits);
558 assert_eq!(response_data, expected_data);
559 });
560 }
561
562 #[test]
563 fn session_status_error_response() {
564 TOKIO_SHARED_RT.block_on(async {
565 let (ws_api, conn, mut rx) = setup().await;
566 let client = AuthApiClient::new(ws_api.clone());
567
568 let handle = tokio::spawn(async move {
569 let params = SessionStatusParams::builder().build().unwrap();
570 client.session_status(params).await
571 });
572
573 let sent = timeout(Duration::from_secs(1), rx.recv()).await.unwrap().unwrap();
574 let Message::Text(text) = sent else { panic!() };
575 let v: Value = serde_json::from_str(&text).unwrap();
576 let id = v["id"].as_str().unwrap().to_string();
577
578 let resp_json = json!({
579 "id": id,
580 "status": 400,
581 "error": {
582 "code": -2010,
583 "msg": "Account has insufficient balance for requested action.",
584 },
585 "rateLimits": [
586 {
587 "rateLimitType": "ORDERS",
588 "interval": "SECOND",
589 "intervalNum": 10,
590 "limit": 50,
591 "count": 13
592 },
593 ],
594 });
595 WebsocketHandler::on_message(&*ws_api, resp_json.to_string(), conn.clone()).await;
596
597 let join = timeout(Duration::from_secs(1), handle).await.unwrap();
598 match join {
599 Ok(Err(e)) => {
600 let msg = e.to_string();
601 assert!(
602 msg.contains("Server‐side response error (code -2010): Account has insufficient balance for requested action."),
603 "Expected error msg to contain server error, got: {msg}"
604 );
605 }
606 Ok(Ok(_)) => panic!("Expected error"),
607 Err(_) => panic!("Task panicked"),
608 }
609 });
610 }
611
612 #[test]
613 fn session_status_request_timeout() {
614 TOKIO_SHARED_RT.block_on(async {
615 let (ws_api, _conn, mut rx) = setup().await;
616 let client = AuthApiClient::new(ws_api.clone());
617
618 let handle = spawn(async move {
619 let params = SessionStatusParams::builder().build().unwrap();
620 client.session_status(params).await
621 });
622
623 let sent = timeout(Duration::from_secs(1), rx.recv())
624 .await
625 .expect("send should occur")
626 .expect("channel closed");
627 let Message::Text(text) = sent else {
628 panic!("expected Message Text")
629 };
630
631 let _: Value = serde_json::from_str(&text).unwrap();
632
633 let result = handle.await.expect("task completed");
634 match result {
635 Err(e) => {
636 if let Some(inner) = e.downcast_ref::<WebsocketError>() {
637 assert!(matches!(inner, WebsocketError::Timeout));
638 } else {
639 panic!("Unexpected error type: {:?}", e);
640 }
641 }
642 Ok(_) => panic!("Expected timeout error"),
643 }
644 });
645 }
646}