1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 str::FromStr,
5 sync::Arc,
6 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
7};
8
9use futures_util::{
10 SinkExt, StreamExt, TryFutureExt,
11 stream::{SplitSink, SplitStream},
12};
13use leaky_bucket::RateLimiter;
14use longport_proto::control::{AuthRequest, AuthResponse, ReconnectRequest, ReconnectResponse};
15use num_enum::IntoPrimitive;
16use prost::Message as _;
17use tokio::{
18 net::TcpStream,
19 sync::{mpsc, oneshot},
20};
21use tokio_tungstenite::{
22 MaybeTlsStream, WebSocketStream,
23 tungstenite::{Message, client::IntoClientRequest, http::Uri},
24};
25use url::Url;
26
27use crate::{
28 WsClientError, WsClientResult, WsCloseReason, WsEvent, WsResponseErrorDetail, codec::Packet,
29};
30
31const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
32const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
33const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(120);
34const AUTH_TIMEOUT: Duration = Duration::from_secs(5);
35const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
36
37const COMMAND_CODE_AUTH: u8 = 2;
38const COMMAND_CODE_RECONNECT: u8 = 3;
39
40#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
42#[repr(i32)]
43pub enum ProtocolVersion {
44 Version1 = 1,
46}
47
48#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
50#[repr(i32)]
51pub enum CodecType {
52 Protobuf = 1,
54}
55
56#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
58#[repr(i32)]
59pub enum Platform {
60 OpenAPI = 9,
62}
63
64enum Command {
65 Request {
66 command_code: u8,
67 timeout_millis: u16,
68 body: Vec<u8>,
69 reply_tx: oneshot::Sender<WsClientResult<Vec<u8>>>,
70 },
71}
72
73#[derive(Debug, Copy, Clone)]
75pub struct RateLimit {
76 pub interval: Duration,
78 pub initial: usize,
80 pub max: usize,
82 pub refill: usize,
84}
85
86impl From<RateLimit> for RateLimiter {
87 fn from(config: RateLimit) -> Self {
88 RateLimiter::builder()
89 .interval(config.interval)
90 .refill(config.refill)
91 .max(config.max)
92 .initial(0)
93 .build()
94 }
95}
96
97struct Context<'a> {
98 request_id: u32,
99 inflight_requests: HashMap<u32, oneshot::Sender<WsClientResult<Vec<u8>>>>,
100 sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
101 stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
102 command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
103 event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
104}
105
106impl<'a> Context<'a> {
107 fn new(
108 conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
109 command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
110 event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
111 ) -> Self {
112 let (sink, stream) = conn.split();
113 Context {
114 request_id: 0,
115 inflight_requests: Default::default(),
116 sink,
117 stream,
118 command_rx,
119 event_sender,
120 }
121 }
122
123 #[inline]
124 fn get_request_id(&mut self) -> u32 {
125 self.request_id += 1;
126 self.request_id
127 }
128
129 fn send_event(&mut self, event: WsEvent) {
130 let _ = self.event_sender.send(event);
131 }
132
133 async fn process_loop(&mut self) -> WsClientResult<()> {
134 let mut ping_time = Instant::now();
135 let mut checkout_timeout = tokio::time::interval(Duration::from_secs(1));
136
137 loop {
138 tokio::select! {
139 item = self.stream.next() => {
140 match item.transpose()? {
141 Some(msg) => {
142 if msg.is_ping() {
143 tracing::debug!("ping");
144 ping_time = Instant::now();
145 }
146 self.handle_message(msg).await?;
147 },
148 None => return Err(WsClientError::ConnectionClosed { reason: None }),
149 }
150 }
151 item = self.command_rx.recv() => {
152 match item {
153 Some(command) => self.handle_command(command).await?,
154 None => return Ok(()),
155 }
156 }
157 _ = checkout_timeout.tick() => {
158 if (Instant::now() - ping_time) > HEARTBEAT_TIMEOUT {
159 tracing::info!("heartbeat timeout");
160 return Err(WsClientError::ConnectionClosed { reason: None });
161 }
162 }
163 }
164 }
165 }
166
167 async fn handle_command(&mut self, command: Command) -> WsClientResult<()> {
168 match command {
169 Command::Request {
170 command_code,
171 timeout_millis: timeout,
172 body,
173 reply_tx,
174 } => {
175 let request_id = self.get_request_id();
176 let msg = Message::Binary(
177 Packet::Request {
178 command_code,
179 request_id,
180 timeout_millis: timeout,
181 body,
182 signature: None,
183 }
184 .encode()
185 .into(),
186 );
187 self.inflight_requests.insert(request_id, reply_tx);
188 self.sink.send(msg).await?;
189 Ok(())
190 }
191 }
192 }
193
194 async fn handle_message(&mut self, msg: Message) -> WsClientResult<()> {
195 match msg {
196 Message::Ping(data) => {
197 self.sink.send(Message::Pong(data)).await?;
198 }
199 Message::Binary(data) => match Packet::decode(&data)? {
200 Packet::Response {
201 request_id,
202 status,
203 body,
204 ..
205 } => {
206 if let Some(sender) = self.inflight_requests.remove(&request_id) {
207 if status == 0 {
208 let _ = sender.send(Ok(body));
209 } else {
210 let detail = longport_proto::Error::decode(&*body).ok().map(
211 |longport_proto::Error { code, msg }| WsResponseErrorDetail {
212 code,
213 msg,
214 },
215 );
216 let _ =
217 sender.send(Err(WsClientError::ResponseError { status, detail }));
218 }
219 }
220 }
221 Packet::Push {
222 command_code, body, ..
223 } => {
224 let _ = self.event_sender.send(WsEvent::Push { command_code, body });
225 }
226 _ => return Err(WsClientError::UnexpectedResponse),
227 },
228 Message::Close(Some(close_frame)) => {
229 return Err(WsClientError::ConnectionClosed {
230 reason: Some(WsCloseReason {
231 code: close_frame.code,
232 message: close_frame.reason.to_string(),
233 }),
234 });
235 }
236 _ => return Err(WsClientError::UnexpectedResponse),
237 }
238
239 Ok(())
240 }
241}
242
243#[derive(Debug)]
245pub struct WsSession {
246 pub session_id: String,
248 pub deadline: SystemTime,
250}
251
252impl WsSession {
253 #[inline]
255 pub fn is_expired(&self) -> bool {
256 self.deadline < SystemTime::now()
257 }
258}
259
260pub struct WsClient {
262 command_tx: mpsc::UnboundedSender<Command>,
263 rate_limit: Arc<HashMap<u8, RateLimiter>>,
264}
265
266impl WsClient {
267 pub async fn open(
269 request: impl IntoClientRequest,
270 version: ProtocolVersion,
271 codec: CodecType,
272 platform: Platform,
273 event_sender: mpsc::UnboundedSender<WsEvent>,
274 rate_limit: Vec<(u8, RateLimit)>,
275 ) -> WsClientResult<Self> {
276 let (command_tx, command_rx) = mpsc::unbounded_channel();
277 let conn = do_connect(request, version, codec, platform).await?;
278 tokio::spawn(client_loop(conn, command_rx, event_sender));
279 Ok(Self {
280 command_tx,
281 rate_limit: Arc::new(
282 rate_limit
283 .into_iter()
284 .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
285 .collect(),
286 ),
287 })
288 }
289
290 pub fn set_rate_limit(&mut self, rate_limit: Vec<(u8, RateLimit)>) {
292 self.rate_limit = Arc::new(
293 rate_limit
294 .into_iter()
295 .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
296 .collect(),
297 );
298 }
299
300 pub async fn request_auth(
305 &self,
306 otp: impl Into<String>,
307 metadata: HashMap<String, String>,
308 ) -> WsClientResult<WsSession> {
309 let resp: AuthResponse = self
310 .request(
311 COMMAND_CODE_AUTH,
312 Some(AUTH_TIMEOUT),
313 AuthRequest {
314 token: otp.into(),
315 metadata,
316 },
317 )
318 .await?;
319 let expires_mills = resp.expires.saturating_sub(
320 SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap()
323 .as_millis() as i64,
324 ) as u64;
325 let deadline = SystemTime::now() + Duration::from_millis(expires_mills);
326 Ok(WsSession {
327 session_id: resp.session_id,
328 deadline,
329 })
330 }
331
332 pub async fn request_reconnect(
336 &self,
337 session_id: impl Into<String>,
338 metadata: HashMap<String, String>,
339 ) -> WsClientResult<WsSession> {
340 let resp: ReconnectResponse = self
341 .request(
342 COMMAND_CODE_RECONNECT,
343 Some(RECONNECT_TIMEOUT),
344 ReconnectRequest {
345 session_id: session_id.into(),
346 metadata,
347 },
348 )
349 .await?;
350 Ok(WsSession {
351 session_id: resp.session_id,
352 deadline: SystemTime::now() + Duration::from_millis(resp.expires as u64),
353 })
354 }
355
356 pub async fn request_raw(
358 &self,
359 command_code: u8,
360 timeout: Option<Duration>,
361 body: Vec<u8>,
362 ) -> WsClientResult<Vec<u8>> {
363 if let Some(rate_limit) = self.rate_limit.get(&command_code) {
364 rate_limit.acquire_one().await;
365 }
366
367 let (reply_tx, reply_rx) = oneshot::channel();
368 self.command_tx
369 .send(Command::Request {
370 command_code,
371 timeout_millis: timeout.unwrap_or(REQUEST_TIMEOUT).as_millis().min(60000) as u16,
372 body,
373 reply_tx,
374 })
375 .map_err(|_| WsClientError::ClientClosed)?;
376 let resp = tokio::time::timeout(
377 REQUEST_TIMEOUT,
378 reply_rx.map_err(|_| WsClientError::ClientClosed),
379 )
380 .map_err(|_| WsClientError::RequestTimeout)
381 .await???;
382 Ok(resp)
383 }
384
385 pub async fn request<T, R>(
387 &self,
388 command_code: u8,
389 timeout: Option<Duration>,
390 req: T,
391 ) -> WsClientResult<R>
392 where
393 T: prost::Message + Debug,
394 R: prost::Message + Default + Debug,
395 {
396 tracing::info!(message = ?req, "ws request");
397 let resp = self
398 .request_raw(command_code, timeout, req.encode_to_vec())
399 .await?;
400 let resp = R::decode(&*resp)?;
401 tracing::info!(message = ?resp, "ws response");
402 Ok(resp)
403 }
404}
405
406async fn do_connect(
407 request: impl IntoClientRequest,
408 version: ProtocolVersion,
409 codec: CodecType,
410 platform: Platform,
411) -> WsClientResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
412 let mut request = request.into_client_request()?;
413 let mut url_obj = Url::parse(&request.uri().to_string())?;
414 url_obj.query_pairs_mut().extend_pairs(&[
415 ("version", i32::from(version).to_string()),
416 ("codec", i32::from(codec).to_string()),
417 ("platform", i32::from(platform).to_string()),
418 ]);
419 *request.uri_mut() = Uri::from_str(url_obj.as_ref()).expect("valid url");
420
421 let conn = match tokio::time::timeout(
422 CONNECT_TIMEOUT,
423 tokio_tungstenite::connect_async(request).map_err(WsClientError::from),
424 )
425 .map_err(|_| WsClientError::ConnectTimeout)
426 .await
427 .and_then(std::convert::identity)
428 {
429 Ok((conn, _)) => conn,
430 Err(err) => return Err(err),
431 };
432
433 Ok(conn)
434}
435
436async fn client_loop(
437 conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
438 mut command_tx: mpsc::UnboundedReceiver<Command>,
439 mut event_sender: mpsc::UnboundedSender<WsEvent>,
440) {
441 let mut ctx = Context::new(conn, &mut command_tx, &mut event_sender);
442
443 let res = ctx.process_loop().await;
444 match res {
445 Ok(()) => return,
446 Err(err) => {
447 ctx.send_event(WsEvent::Error(err));
448 }
449 };
450
451 for sender in ctx.inflight_requests.into_values() {
452 let _ = sender.send(Err(WsClientError::Cancelled));
453 }
454}