1use std::{
2 collections::HashMap,
3 str::FromStr,
4 sync::Arc,
5 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
6};
7
8use futures_util::{
9 SinkExt, StreamExt, TryFutureExt,
10 stream::{SplitSink, SplitStream},
11};
12use leaky_bucket::RateLimiter;
13use longport_proto::control::{AuthRequest, AuthResponse, ReconnectRequest, ReconnectResponse};
14use num_enum::IntoPrimitive;
15use prost::Message as _;
16use tokio::{
17 net::TcpStream,
18 sync::{mpsc, oneshot},
19};
20use tokio_tungstenite::{
21 MaybeTlsStream, WebSocketStream,
22 tungstenite::{Message, client::IntoClientRequest, http::Uri},
23};
24use url::Url;
25
26use crate::{
27 WsClientError, WsClientResult, WsCloseReason, WsEvent, WsResponseErrorDetail, codec::Packet,
28};
29
30const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
32const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(120);
33const AUTH_TIMEOUT: Duration = Duration::from_secs(5);
34const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
35
36const COMMAND_CODE_AUTH: u8 = 2;
37const COMMAND_CODE_RECONNECT: u8 = 3;
38
39#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
41#[repr(i32)]
42pub enum ProtocolVersion {
43 Version1 = 1,
45}
46
47#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
49#[repr(i32)]
50pub enum CodecType {
51 Protobuf = 1,
53}
54
55#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
57#[repr(i32)]
58pub enum Platform {
59 OpenAPI = 9,
61}
62
63enum Command {
64 Request {
65 command_code: u8,
66 timeout_millis: u16,
67 body: Vec<u8>,
68 reply_tx: oneshot::Sender<WsClientResult<Vec<u8>>>,
69 },
70}
71
72#[derive(Debug, Copy, Clone)]
74pub struct RateLimit {
75 pub interval: Duration,
77 pub initial: usize,
79 pub max: usize,
81 pub refill: usize,
83}
84
85impl From<RateLimit> for RateLimiter {
86 fn from(config: RateLimit) -> Self {
87 RateLimiter::builder()
88 .interval(config.interval)
89 .refill(config.refill)
90 .max(config.max)
91 .initial(0)
92 .build()
93 }
94}
95
96struct Context<'a> {
97 request_id: u32,
98 inflight_requests: HashMap<u32, oneshot::Sender<WsClientResult<Vec<u8>>>>,
99 sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
100 stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
101 command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
102 event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
103}
104
105impl<'a> Context<'a> {
106 fn new(
107 conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
108 command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
109 event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
110 ) -> Self {
111 let (sink, stream) = conn.split();
112 Context {
113 request_id: 0,
114 inflight_requests: Default::default(),
115 sink,
116 stream,
117 command_rx,
118 event_sender,
119 }
120 }
121
122 #[inline]
123 fn get_request_id(&mut self) -> u32 {
124 self.request_id += 1;
125 self.request_id
126 }
127
128 fn send_event(&mut self, event: WsEvent) {
129 let _ = self.event_sender.send(event);
130 }
131
132 async fn process_loop(&mut self) -> WsClientResult<()> {
133 let mut ping_time = Instant::now();
134 let mut checkout_timeout = tokio::time::interval(Duration::from_secs(1));
135
136 loop {
137 tokio::select! {
138 item = self.stream.next() => {
139 match item.transpose()? {
140 Some(msg) => {
141 if msg.is_ping() {
142 tracing::debug!("ping");
143 ping_time = Instant::now();
144 }
145 self.handle_message(msg).await?;
146 },
147 None => return Err(WsClientError::ConnectionClosed { reason: None }),
148 }
149 }
150 item = self.command_rx.recv() => {
151 match item {
152 Some(command) => self.handle_command(command).await?,
153 None => return Ok(()),
154 }
155 }
156 _ = checkout_timeout.tick() => {
157 if (Instant::now() - ping_time) > HEARTBEAT_TIMEOUT {
158 tracing::info!("heartbeat timeout");
159 return Err(WsClientError::ConnectionClosed { reason: None });
160 }
161 }
162 }
163 }
164 }
165
166 async fn handle_command(&mut self, command: Command) -> WsClientResult<()> {
167 match command {
168 Command::Request {
169 command_code,
170 timeout_millis: timeout,
171 body,
172 reply_tx,
173 } => {
174 let request_id = self.get_request_id();
175 let msg = Message::Binary(
176 Packet::Request {
177 command_code,
178 request_id,
179 timeout_millis: timeout,
180 body,
181 signature: None,
182 }
183 .encode()
184 .into(),
185 );
186 self.inflight_requests.insert(request_id, reply_tx);
187 self.sink.send(msg).await?;
188 Ok(())
189 }
190 }
191 }
192
193 async fn handle_message(&mut self, msg: Message) -> WsClientResult<()> {
194 match msg {
195 Message::Ping(data) => {
196 self.sink.send(Message::Pong(data)).await?;
197 }
198 Message::Binary(data) => match Packet::decode(&data)? {
199 Packet::Response {
200 request_id,
201 status,
202 body,
203 ..
204 } => {
205 if let Some(sender) = self.inflight_requests.remove(&request_id) {
206 if status == 0 {
207 let _ = sender.send(Ok(body));
208 } else {
209 let detail = longport_proto::Error::decode(&*body).ok().map(
210 |longport_proto::Error { code, msg }| WsResponseErrorDetail {
211 code,
212 msg,
213 },
214 );
215 let _ =
216 sender.send(Err(WsClientError::ResponseError { status, detail }));
217 }
218 }
219 }
220 Packet::Push {
221 command_code, body, ..
222 } => {
223 let _ = self.event_sender.send(WsEvent::Push { command_code, body });
224 }
225 _ => return Err(WsClientError::UnexpectedResponse),
226 },
227 Message::Close(Some(close_frame)) => {
228 return Err(WsClientError::ConnectionClosed {
229 reason: Some(WsCloseReason {
230 code: close_frame.code,
231 message: close_frame.reason.to_string(),
232 }),
233 });
234 }
235 _ => return Err(WsClientError::UnexpectedResponse),
236 }
237
238 Ok(())
239 }
240}
241
242#[derive(Debug)]
244pub struct WsSession {
245 pub session_id: String,
247 pub deadline: SystemTime,
249}
250
251impl WsSession {
252 #[inline]
254 pub fn is_expired(&self) -> bool {
255 self.deadline < SystemTime::now()
256 }
257}
258
259pub struct WsClient {
261 command_tx: mpsc::UnboundedSender<Command>,
262 rate_limit: Arc<HashMap<u8, RateLimiter>>,
263}
264
265impl WsClient {
266 pub async fn open(
268 request: impl IntoClientRequest,
269 version: ProtocolVersion,
270 codec: CodecType,
271 platform: Platform,
272 event_sender: mpsc::UnboundedSender<WsEvent>,
273 rate_limit: Vec<(u8, RateLimit)>,
274 ) -> WsClientResult<Self> {
275 let (command_tx, command_rx) = mpsc::unbounded_channel();
276 let conn = do_connect(request, version, codec, platform).await?;
277 tokio::spawn(client_loop(conn, command_rx, event_sender));
278 Ok(Self {
279 command_tx,
280 rate_limit: Arc::new(
281 rate_limit
282 .into_iter()
283 .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
284 .collect(),
285 ),
286 })
287 }
288
289 pub fn set_rate_limit(&mut self, rate_limit: Vec<(u8, RateLimit)>) {
291 self.rate_limit = Arc::new(
292 rate_limit
293 .into_iter()
294 .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
295 .collect(),
296 );
297 }
298
299 pub async fn request_auth(
304 &self,
305 otp: impl Into<String>,
306 metadata: HashMap<String, String>,
307 ) -> WsClientResult<WsSession> {
308 let resp: AuthResponse = self
309 .request(
310 COMMAND_CODE_AUTH,
311 Some(AUTH_TIMEOUT),
312 AuthRequest {
313 token: otp.into(),
314 metadata,
315 },
316 )
317 .await?;
318 let expires_mills = resp.expires.saturating_sub(
319 SystemTime::now()
320 .duration_since(UNIX_EPOCH)
321 .unwrap()
322 .as_millis() as i64,
323 ) as u64;
324 let deadline = SystemTime::now() + Duration::from_millis(expires_mills);
325 Ok(WsSession {
326 session_id: resp.session_id,
327 deadline,
328 })
329 }
330
331 pub async fn request_reconnect(
335 &self,
336 session_id: impl Into<String>,
337 metadata: HashMap<String, String>,
338 ) -> WsClientResult<WsSession> {
339 let resp: ReconnectResponse = self
340 .request(
341 COMMAND_CODE_RECONNECT,
342 Some(RECONNECT_TIMEOUT),
343 ReconnectRequest {
344 session_id: session_id.into(),
345 metadata,
346 },
347 )
348 .await?;
349 Ok(WsSession {
350 session_id: resp.session_id,
351 deadline: SystemTime::now() + Duration::from_millis(resp.expires as u64),
352 })
353 }
354
355 pub async fn request_raw(
357 &self,
358 command_code: u8,
359 timeout: Option<Duration>,
360 body: Vec<u8>,
361 ) -> WsClientResult<Vec<u8>> {
362 if let Some(rate_limit) = self.rate_limit.get(&command_code) {
363 rate_limit.acquire_one().await;
364 }
365
366 let (reply_tx, reply_rx) = oneshot::channel();
367 self.command_tx
368 .send(Command::Request {
369 command_code,
370 timeout_millis: timeout.unwrap_or(REQUEST_TIMEOUT).as_millis().min(60000) as u16,
371 body,
372 reply_tx,
373 })
374 .map_err(|_| WsClientError::ClientClosed)?;
375 let resp = tokio::time::timeout(
376 REQUEST_TIMEOUT,
377 reply_rx.map_err(|_| WsClientError::ClientClosed),
378 )
379 .map_err(|_| WsClientError::RequestTimeout)
380 .await???;
381 Ok(resp)
382 }
383
384 pub async fn request<T, R>(
386 &self,
387 command_code: u8,
388 timeout: Option<Duration>,
389 req: T,
390 ) -> WsClientResult<R>
391 where
392 T: prost::Message,
393 R: prost::Message + Default,
394 {
395 tracing::info!(message = ?req, "ws request");
396 let resp = self
397 .request_raw(command_code, timeout, req.encode_to_vec())
398 .await?;
399 let resp = R::decode(&*resp)?;
400 tracing::info!(message = ?resp, "ws response");
401 Ok(resp)
402 }
403}
404
405async fn do_connect(
406 request: impl IntoClientRequest,
407 version: ProtocolVersion,
408 codec: CodecType,
409 platform: Platform,
410) -> WsClientResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
411 let mut request = request.into_client_request()?;
412 let mut url_obj = Url::parse(&request.uri().to_string())?;
413 url_obj.query_pairs_mut().extend_pairs(&[
414 ("version", i32::from(version).to_string()),
415 ("codec", i32::from(codec).to_string()),
416 ("platform", i32::from(platform).to_string()),
417 ]);
418 *request.uri_mut() = Uri::from_str(url_obj.as_ref()).expect("valid url");
419
420 let conn = match tokio::time::timeout(
421 CONNECT_TIMEOUT,
422 tokio_tungstenite::connect_async(request).map_err(WsClientError::from),
423 )
424 .map_err(|_| WsClientError::ConnectTimeout)
425 .await
426 .and_then(std::convert::identity)
427 {
428 Ok((conn, _)) => conn,
429 Err(err) => return Err(err),
430 };
431
432 Ok(conn)
433}
434
435async fn client_loop(
436 conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
437 mut command_tx: mpsc::UnboundedReceiver<Command>,
438 mut event_sender: mpsc::UnboundedSender<WsEvent>,
439) {
440 let mut ctx = Context::new(conn, &mut command_tx, &mut event_sender);
441
442 let res = ctx.process_loop().await;
443 match res {
444 Ok(()) => return,
445 Err(err) => {
446 ctx.send_event(WsEvent::Error(err));
447 }
448 };
449
450 for sender in ctx.inflight_requests.into_values() {
451 let _ = sender.send(Err(WsClientError::Cancelled));
452 }
453}