longport_wscli/
client.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    str::FromStr,
5    sync::Arc,
6    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
7};
8
9use futures_util::{
10    SinkExt, StreamExt, TryFutureExt,
11    stream::{SplitSink, SplitStream},
12};
13use leaky_bucket::RateLimiter;
14use longport_proto::control::{AuthRequest, AuthResponse, ReconnectRequest, ReconnectResponse};
15use num_enum::IntoPrimitive;
16use prost::Message as _;
17use tokio::{
18    net::TcpStream,
19    sync::{mpsc, oneshot},
20};
21use tokio_tungstenite::{
22    MaybeTlsStream, WebSocketStream,
23    tungstenite::{Message, client::IntoClientRequest, http::Uri},
24};
25use url::Url;
26
27use crate::{
28    WsClientError, WsClientResult, WsCloseReason, WsEvent, WsResponseErrorDetail, codec::Packet,
29};
30
31const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
32const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
33const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(120);
34const AUTH_TIMEOUT: Duration = Duration::from_secs(5);
35const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
36
37const COMMAND_CODE_AUTH: u8 = 2;
38const COMMAND_CODE_RECONNECT: u8 = 3;
39
40/// LongPort websocket protocol version
41#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
42#[repr(i32)]
43pub enum ProtocolVersion {
44    /// Version 1
45    Version1 = 1,
46}
47
48/// LongPort websocket codec type
49#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
50#[repr(i32)]
51pub enum CodecType {
52    /// Protobuf
53    Protobuf = 1,
54}
55
56/// LongPort websocket platform type
57#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
58#[repr(i32)]
59pub enum Platform {
60    /// OpenAPI
61    OpenAPI = 9,
62}
63
64enum Command {
65    Request {
66        command_code: u8,
67        timeout_millis: u16,
68        body: Vec<u8>,
69        reply_tx: oneshot::Sender<WsClientResult<Vec<u8>>>,
70    },
71}
72
73/// Rate limiter config
74#[derive(Debug, Copy, Clone)]
75pub struct RateLimit {
76    /// The time duration between which we add refill number to the bucket
77    pub interval: Duration,
78    /// The initial number of tokens
79    pub initial: usize,
80    /// The max number of tokens to use
81    pub max: usize,
82    /// The number of tokens to add at each interval interval
83    pub refill: usize,
84}
85
86impl From<RateLimit> for RateLimiter {
87    fn from(config: RateLimit) -> Self {
88        RateLimiter::builder()
89            .interval(config.interval)
90            .refill(config.refill)
91            .max(config.max)
92            .initial(0)
93            .build()
94    }
95}
96
97struct Context<'a> {
98    request_id: u32,
99    inflight_requests: HashMap<u32, oneshot::Sender<WsClientResult<Vec<u8>>>>,
100    sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
101    stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
102    command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
103    event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
104}
105
106impl<'a> Context<'a> {
107    fn new(
108        conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
109        command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
110        event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
111    ) -> Self {
112        let (sink, stream) = conn.split();
113        Context {
114            request_id: 0,
115            inflight_requests: Default::default(),
116            sink,
117            stream,
118            command_rx,
119            event_sender,
120        }
121    }
122
123    #[inline]
124    fn get_request_id(&mut self) -> u32 {
125        self.request_id += 1;
126        self.request_id
127    }
128
129    fn send_event(&mut self, event: WsEvent) {
130        let _ = self.event_sender.send(event);
131    }
132
133    async fn process_loop(&mut self) -> WsClientResult<()> {
134        let mut ping_time = Instant::now();
135        let mut checkout_timeout = tokio::time::interval(Duration::from_secs(1));
136
137        loop {
138            tokio::select! {
139                item = self.stream.next() => {
140                    match item.transpose()? {
141                        Some(msg) => {
142                            if msg.is_ping() {
143                                tracing::debug!("ping");
144                                ping_time = Instant::now();
145                            }
146                            self.handle_message(msg).await?;
147                        },
148                        None => return Err(WsClientError::ConnectionClosed { reason: None }),
149                    }
150                }
151                item = self.command_rx.recv() => {
152                    match item {
153                        Some(command) => self.handle_command(command).await?,
154                        None => return Ok(()),
155                    }
156                }
157                _ = checkout_timeout.tick() => {
158                    if (Instant::now() - ping_time) > HEARTBEAT_TIMEOUT {
159                        tracing::info!("heartbeat timeout");
160                        return Err(WsClientError::ConnectionClosed { reason: None });
161                    }
162                }
163            }
164        }
165    }
166
167    async fn handle_command(&mut self, command: Command) -> WsClientResult<()> {
168        match command {
169            Command::Request {
170                command_code,
171                timeout_millis: timeout,
172                body,
173                reply_tx,
174            } => {
175                let request_id = self.get_request_id();
176                let msg = Message::Binary(
177                    Packet::Request {
178                        command_code,
179                        request_id,
180                        timeout_millis: timeout,
181                        body,
182                        signature: None,
183                    }
184                    .encode()
185                    .into(),
186                );
187                self.inflight_requests.insert(request_id, reply_tx);
188                self.sink.send(msg).await?;
189                Ok(())
190            }
191        }
192    }
193
194    async fn handle_message(&mut self, msg: Message) -> WsClientResult<()> {
195        match msg {
196            Message::Ping(data) => {
197                self.sink.send(Message::Pong(data)).await?;
198            }
199            Message::Binary(data) => match Packet::decode(&data)? {
200                Packet::Response {
201                    request_id,
202                    status,
203                    body,
204                    ..
205                } => {
206                    if let Some(sender) = self.inflight_requests.remove(&request_id) {
207                        if status == 0 {
208                            let _ = sender.send(Ok(body));
209                        } else {
210                            let detail = longport_proto::Error::decode(&*body).ok().map(
211                                |longport_proto::Error { code, msg }| WsResponseErrorDetail {
212                                    code,
213                                    msg,
214                                },
215                            );
216                            let _ =
217                                sender.send(Err(WsClientError::ResponseError { status, detail }));
218                        }
219                    }
220                }
221                Packet::Push {
222                    command_code, body, ..
223                } => {
224                    let _ = self.event_sender.send(WsEvent::Push { command_code, body });
225                }
226                _ => return Err(WsClientError::UnexpectedResponse),
227            },
228            Message::Close(Some(close_frame)) => {
229                return Err(WsClientError::ConnectionClosed {
230                    reason: Some(WsCloseReason {
231                        code: close_frame.code,
232                        message: close_frame.reason.to_string(),
233                    }),
234                });
235            }
236            _ => return Err(WsClientError::UnexpectedResponse),
237        }
238
239        Ok(())
240    }
241}
242
243/// The session for the Websocket connection
244#[derive(Debug)]
245pub struct WsSession {
246    /// Session id
247    pub session_id: String,
248    /// The expiration time of the session id.
249    pub deadline: SystemTime,
250}
251
252impl WsSession {
253    /// Returns `true` if the session id is expired, otherwise returns `false
254    #[inline]
255    pub fn is_expired(&self) -> bool {
256        self.deadline < SystemTime::now()
257    }
258}
259
260/// LongPort Websocket client
261pub struct WsClient {
262    command_tx: mpsc::UnboundedSender<Command>,
263    rate_limit: Arc<HashMap<u8, RateLimiter>>,
264}
265
266impl WsClient {
267    /// Connect to `url` and returns a `WsClient` object
268    pub async fn open(
269        request: impl IntoClientRequest,
270        version: ProtocolVersion,
271        codec: CodecType,
272        platform: Platform,
273        event_sender: mpsc::UnboundedSender<WsEvent>,
274        rate_limit: Vec<(u8, RateLimit)>,
275    ) -> WsClientResult<Self> {
276        let (command_tx, command_rx) = mpsc::unbounded_channel();
277        let conn = do_connect(request, version, codec, platform).await?;
278        tokio::spawn(client_loop(conn, command_rx, event_sender));
279        Ok(Self {
280            command_tx,
281            rate_limit: Arc::new(
282                rate_limit
283                    .into_iter()
284                    .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
285                    .collect(),
286            ),
287        })
288    }
289
290    /// Set the rate limit
291    pub fn set_rate_limit(&mut self, rate_limit: Vec<(u8, RateLimit)>) {
292        self.rate_limit = Arc::new(
293            rate_limit
294                .into_iter()
295                .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
296                .collect(),
297        );
298    }
299
300    /// Send an authentication request to get a [`WsSession`]
301    ///
302    /// Reference: <https://open.longportapp.com/en/docs/socket-token-api>
303    /// Reference: <https://open.longportapp.com/en/docs/socket/control-command#auth>
304    pub async fn request_auth(
305        &self,
306        otp: impl Into<String>,
307        metadata: HashMap<String, String>,
308    ) -> WsClientResult<WsSession> {
309        let resp: AuthResponse = self
310            .request(
311                COMMAND_CODE_AUTH,
312                Some(AUTH_TIMEOUT),
313                AuthRequest {
314                    token: otp.into(),
315                    metadata,
316                },
317            )
318            .await?;
319        let expires_mills = resp.expires.saturating_sub(
320            SystemTime::now()
321                .duration_since(UNIX_EPOCH)
322                .unwrap()
323                .as_millis() as i64,
324        ) as u64;
325        let deadline = SystemTime::now() + Duration::from_millis(expires_mills);
326        Ok(WsSession {
327            session_id: resp.session_id,
328            deadline,
329        })
330    }
331
332    /// Send a reconnect request to get a [`WsSession`]
333    ///
334    /// Reference: <https://open.longportapp.com/en/docs/socket/control-command#reconnect>
335    pub async fn request_reconnect(
336        &self,
337        session_id: impl Into<String>,
338        metadata: HashMap<String, String>,
339    ) -> WsClientResult<WsSession> {
340        let resp: ReconnectResponse = self
341            .request(
342                COMMAND_CODE_RECONNECT,
343                Some(RECONNECT_TIMEOUT),
344                ReconnectRequest {
345                    session_id: session_id.into(),
346                    metadata,
347                },
348            )
349            .await?;
350        Ok(WsSession {
351            session_id: resp.session_id,
352            deadline: SystemTime::now() + Duration::from_millis(resp.expires as u64),
353        })
354    }
355
356    /// Send a raw request
357    pub async fn request_raw(
358        &self,
359        command_code: u8,
360        timeout: Option<Duration>,
361        body: Vec<u8>,
362    ) -> WsClientResult<Vec<u8>> {
363        if let Some(rate_limit) = self.rate_limit.get(&command_code) {
364            rate_limit.acquire_one().await;
365        }
366
367        let (reply_tx, reply_rx) = oneshot::channel();
368        self.command_tx
369            .send(Command::Request {
370                command_code,
371                timeout_millis: timeout.unwrap_or(REQUEST_TIMEOUT).as_millis().min(60000) as u16,
372                body,
373                reply_tx,
374            })
375            .map_err(|_| WsClientError::ClientClosed)?;
376        let resp = tokio::time::timeout(
377            REQUEST_TIMEOUT,
378            reply_rx.map_err(|_| WsClientError::ClientClosed),
379        )
380        .map_err(|_| WsClientError::RequestTimeout)
381        .await???;
382        Ok(resp)
383    }
384
385    /// Send a request `T` to get a response `R`
386    pub async fn request<T, R>(
387        &self,
388        command_code: u8,
389        timeout: Option<Duration>,
390        req: T,
391    ) -> WsClientResult<R>
392    where
393        T: prost::Message + Debug,
394        R: prost::Message + Default + Debug,
395    {
396        tracing::info!(message = ?req, "ws request");
397        let resp = self
398            .request_raw(command_code, timeout, req.encode_to_vec())
399            .await?;
400        let resp = R::decode(&*resp)?;
401        tracing::info!(message = ?resp, "ws response");
402        Ok(resp)
403    }
404}
405
406async fn do_connect(
407    request: impl IntoClientRequest,
408    version: ProtocolVersion,
409    codec: CodecType,
410    platform: Platform,
411) -> WsClientResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
412    let mut request = request.into_client_request()?;
413    let mut url_obj = Url::parse(&request.uri().to_string())?;
414    url_obj.query_pairs_mut().extend_pairs(&[
415        ("version", i32::from(version).to_string()),
416        ("codec", i32::from(codec).to_string()),
417        ("platform", i32::from(platform).to_string()),
418    ]);
419    *request.uri_mut() = Uri::from_str(url_obj.as_ref()).expect("valid url");
420
421    let conn = match tokio::time::timeout(
422        CONNECT_TIMEOUT,
423        tokio_tungstenite::connect_async(request).map_err(WsClientError::from),
424    )
425    .map_err(|_| WsClientError::ConnectTimeout)
426    .await
427    .and_then(std::convert::identity)
428    {
429        Ok((conn, _)) => conn,
430        Err(err) => return Err(err),
431    };
432
433    Ok(conn)
434}
435
436async fn client_loop(
437    conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
438    mut command_tx: mpsc::UnboundedReceiver<Command>,
439    mut event_sender: mpsc::UnboundedSender<WsEvent>,
440) {
441    let mut ctx = Context::new(conn, &mut command_tx, &mut event_sender);
442
443    let res = ctx.process_loop().await;
444    match res {
445        Ok(()) => return,
446        Err(err) => {
447            ctx.send_event(WsEvent::Error(err));
448        }
449    };
450
451    for sender in ctx.inflight_requests.into_values() {
452        let _ = sender.send(Err(WsClientError::Cancelled));
453    }
454}