longport_wscli/
client.rs

1use std::{
2    collections::HashMap,
3    str::FromStr,
4    sync::Arc,
5    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
6};
7
8use futures_util::{
9    SinkExt, StreamExt, TryFutureExt,
10    stream::{SplitSink, SplitStream},
11};
12use leaky_bucket::RateLimiter;
13use longport_proto::control::{AuthRequest, AuthResponse, ReconnectRequest, ReconnectResponse};
14use num_enum::IntoPrimitive;
15use prost::Message as _;
16use tokio::{
17    net::TcpStream,
18    sync::{mpsc, oneshot},
19};
20use tokio_tungstenite::{
21    MaybeTlsStream, WebSocketStream,
22    tungstenite::{Message, client::IntoClientRequest, http::Uri},
23};
24use url::Url;
25
26use crate::{
27    WsClientError, WsClientResult, WsCloseReason, WsEvent, WsResponseErrorDetail, codec::Packet,
28};
29
30const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
32const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(120);
33const AUTH_TIMEOUT: Duration = Duration::from_secs(5);
34const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
35
36const COMMAND_CODE_AUTH: u8 = 2;
37const COMMAND_CODE_RECONNECT: u8 = 3;
38
39/// LongPort websocket protocol version
40#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
41#[repr(i32)]
42pub enum ProtocolVersion {
43    /// Version 1
44    Version1 = 1,
45}
46
47/// LongPort websocket codec type
48#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
49#[repr(i32)]
50pub enum CodecType {
51    /// Protobuf
52    Protobuf = 1,
53}
54
55/// LongPort websocket platform type
56#[derive(Debug, IntoPrimitive, Copy, Clone, Eq, PartialEq, Hash)]
57#[repr(i32)]
58pub enum Platform {
59    /// OpenAPI
60    OpenAPI = 9,
61}
62
63enum Command {
64    Request {
65        command_code: u8,
66        timeout_millis: u16,
67        body: Vec<u8>,
68        reply_tx: oneshot::Sender<WsClientResult<Vec<u8>>>,
69    },
70}
71
72/// Rate limiter config
73#[derive(Debug, Copy, Clone)]
74pub struct RateLimit {
75    /// The time duration between which we add refill number to the bucket
76    pub interval: Duration,
77    /// The initial number of tokens
78    pub initial: usize,
79    /// The max number of tokens to use
80    pub max: usize,
81    /// The number of tokens to add at each interval interval
82    pub refill: usize,
83}
84
85impl From<RateLimit> for RateLimiter {
86    fn from(config: RateLimit) -> Self {
87        RateLimiter::builder()
88            .interval(config.interval)
89            .refill(config.refill)
90            .max(config.max)
91            .initial(0)
92            .build()
93    }
94}
95
96struct Context<'a> {
97    request_id: u32,
98    inflight_requests: HashMap<u32, oneshot::Sender<WsClientResult<Vec<u8>>>>,
99    sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
100    stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
101    command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
102    event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
103}
104
105impl<'a> Context<'a> {
106    fn new(
107        conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
108        command_rx: &'a mut mpsc::UnboundedReceiver<Command>,
109        event_sender: &'a mut mpsc::UnboundedSender<WsEvent>,
110    ) -> Self {
111        let (sink, stream) = conn.split();
112        Context {
113            request_id: 0,
114            inflight_requests: Default::default(),
115            sink,
116            stream,
117            command_rx,
118            event_sender,
119        }
120    }
121
122    #[inline]
123    fn get_request_id(&mut self) -> u32 {
124        self.request_id += 1;
125        self.request_id
126    }
127
128    fn send_event(&mut self, event: WsEvent) {
129        let _ = self.event_sender.send(event);
130    }
131
132    async fn process_loop(&mut self) -> WsClientResult<()> {
133        let mut ping_time = Instant::now();
134        let mut checkout_timeout = tokio::time::interval(Duration::from_secs(1));
135
136        loop {
137            tokio::select! {
138                item = self.stream.next() => {
139                    match item.transpose()? {
140                        Some(msg) => {
141                            if msg.is_ping() {
142                                tracing::debug!("ping");
143                                ping_time = Instant::now();
144                            }
145                            self.handle_message(msg).await?;
146                        },
147                        None => return Err(WsClientError::ConnectionClosed { reason: None }),
148                    }
149                }
150                item = self.command_rx.recv() => {
151                    match item {
152                        Some(command) => self.handle_command(command).await?,
153                        None => return Ok(()),
154                    }
155                }
156                _ = checkout_timeout.tick() => {
157                    if (Instant::now() - ping_time) > HEARTBEAT_TIMEOUT {
158                        tracing::info!("heartbeat timeout");
159                        return Err(WsClientError::ConnectionClosed { reason: None });
160                    }
161                }
162            }
163        }
164    }
165
166    async fn handle_command(&mut self, command: Command) -> WsClientResult<()> {
167        match command {
168            Command::Request {
169                command_code,
170                timeout_millis: timeout,
171                body,
172                reply_tx,
173            } => {
174                let request_id = self.get_request_id();
175                let msg = Message::Binary(
176                    Packet::Request {
177                        command_code,
178                        request_id,
179                        timeout_millis: timeout,
180                        body,
181                        signature: None,
182                    }
183                    .encode()
184                    .into(),
185                );
186                self.inflight_requests.insert(request_id, reply_tx);
187                self.sink.send(msg).await?;
188                Ok(())
189            }
190        }
191    }
192
193    async fn handle_message(&mut self, msg: Message) -> WsClientResult<()> {
194        match msg {
195            Message::Ping(data) => {
196                self.sink.send(Message::Pong(data)).await?;
197            }
198            Message::Binary(data) => match Packet::decode(&data)? {
199                Packet::Response {
200                    request_id,
201                    status,
202                    body,
203                    ..
204                } => {
205                    if let Some(sender) = self.inflight_requests.remove(&request_id) {
206                        if status == 0 {
207                            let _ = sender.send(Ok(body));
208                        } else {
209                            let detail = longport_proto::Error::decode(&*body).ok().map(
210                                |longport_proto::Error { code, msg }| WsResponseErrorDetail {
211                                    code,
212                                    msg,
213                                },
214                            );
215                            let _ =
216                                sender.send(Err(WsClientError::ResponseError { status, detail }));
217                        }
218                    }
219                }
220                Packet::Push {
221                    command_code, body, ..
222                } => {
223                    let _ = self.event_sender.send(WsEvent::Push { command_code, body });
224                }
225                _ => return Err(WsClientError::UnexpectedResponse),
226            },
227            Message::Close(Some(close_frame)) => {
228                return Err(WsClientError::ConnectionClosed {
229                    reason: Some(WsCloseReason {
230                        code: close_frame.code,
231                        message: close_frame.reason.to_string(),
232                    }),
233                });
234            }
235            _ => return Err(WsClientError::UnexpectedResponse),
236        }
237
238        Ok(())
239    }
240}
241
242/// The session for the Websocket connection
243#[derive(Debug)]
244pub struct WsSession {
245    /// Session id
246    pub session_id: String,
247    /// The expiration time of the session id.
248    pub deadline: SystemTime,
249}
250
251impl WsSession {
252    /// Returns `true` if the session id is expired, otherwise returns `false
253    #[inline]
254    pub fn is_expired(&self) -> bool {
255        self.deadline < SystemTime::now()
256    }
257}
258
259/// LongPort Websocket client
260pub struct WsClient {
261    command_tx: mpsc::UnboundedSender<Command>,
262    rate_limit: Arc<HashMap<u8, RateLimiter>>,
263}
264
265impl WsClient {
266    /// Connect to `url` and returns a `WsClient` object
267    pub async fn open(
268        request: impl IntoClientRequest,
269        version: ProtocolVersion,
270        codec: CodecType,
271        platform: Platform,
272        event_sender: mpsc::UnboundedSender<WsEvent>,
273        rate_limit: Vec<(u8, RateLimit)>,
274    ) -> WsClientResult<Self> {
275        let (command_tx, command_rx) = mpsc::unbounded_channel();
276        let conn = do_connect(request, version, codec, platform).await?;
277        tokio::spawn(client_loop(conn, command_rx, event_sender));
278        Ok(Self {
279            command_tx,
280            rate_limit: Arc::new(
281                rate_limit
282                    .into_iter()
283                    .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
284                    .collect(),
285            ),
286        })
287    }
288
289    /// Set the rate limit
290    pub fn set_rate_limit(&mut self, rate_limit: Vec<(u8, RateLimit)>) {
291        self.rate_limit = Arc::new(
292            rate_limit
293                .into_iter()
294                .map(|(cmd, rate_limit)| (cmd, rate_limit.into()))
295                .collect(),
296        );
297    }
298
299    /// Send an authentication request to get a [`WsSession`]
300    ///
301    /// Reference: <https://open.longportapp.com/en/docs/socket-token-api>
302    /// Reference: <https://open.longportapp.com/en/docs/socket/control-command#auth>
303    pub async fn request_auth(
304        &self,
305        otp: impl Into<String>,
306        metadata: HashMap<String, String>,
307    ) -> WsClientResult<WsSession> {
308        let resp: AuthResponse = self
309            .request(
310                COMMAND_CODE_AUTH,
311                Some(AUTH_TIMEOUT),
312                AuthRequest {
313                    token: otp.into(),
314                    metadata,
315                },
316            )
317            .await?;
318        let expires_mills = resp.expires.saturating_sub(
319            SystemTime::now()
320                .duration_since(UNIX_EPOCH)
321                .unwrap()
322                .as_millis() as i64,
323        ) as u64;
324        let deadline = SystemTime::now() + Duration::from_millis(expires_mills);
325        Ok(WsSession {
326            session_id: resp.session_id,
327            deadline,
328        })
329    }
330
331    /// Send a reconnect request to get a [`WsSession`]
332    ///
333    /// Reference: <https://open.longportapp.com/en/docs/socket/control-command#reconnect>
334    pub async fn request_reconnect(
335        &self,
336        session_id: impl Into<String>,
337        metadata: HashMap<String, String>,
338    ) -> WsClientResult<WsSession> {
339        let resp: ReconnectResponse = self
340            .request(
341                COMMAND_CODE_RECONNECT,
342                Some(RECONNECT_TIMEOUT),
343                ReconnectRequest {
344                    session_id: session_id.into(),
345                    metadata,
346                },
347            )
348            .await?;
349        Ok(WsSession {
350            session_id: resp.session_id,
351            deadline: SystemTime::now() + Duration::from_millis(resp.expires as u64),
352        })
353    }
354
355    /// Send a raw request
356    pub async fn request_raw(
357        &self,
358        command_code: u8,
359        timeout: Option<Duration>,
360        body: Vec<u8>,
361    ) -> WsClientResult<Vec<u8>> {
362        if let Some(rate_limit) = self.rate_limit.get(&command_code) {
363            rate_limit.acquire_one().await;
364        }
365
366        let (reply_tx, reply_rx) = oneshot::channel();
367        self.command_tx
368            .send(Command::Request {
369                command_code,
370                timeout_millis: timeout.unwrap_or(REQUEST_TIMEOUT).as_millis().min(60000) as u16,
371                body,
372                reply_tx,
373            })
374            .map_err(|_| WsClientError::ClientClosed)?;
375        let resp = tokio::time::timeout(
376            REQUEST_TIMEOUT,
377            reply_rx.map_err(|_| WsClientError::ClientClosed),
378        )
379        .map_err(|_| WsClientError::RequestTimeout)
380        .await???;
381        Ok(resp)
382    }
383
384    /// Send a request `T` to get a response `R`
385    pub async fn request<T, R>(
386        &self,
387        command_code: u8,
388        timeout: Option<Duration>,
389        req: T,
390    ) -> WsClientResult<R>
391    where
392        T: prost::Message,
393        R: prost::Message + Default,
394    {
395        tracing::info!(message = ?req, "ws request");
396        let resp = self
397            .request_raw(command_code, timeout, req.encode_to_vec())
398            .await?;
399        let resp = R::decode(&*resp)?;
400        tracing::info!(message = ?resp, "ws response");
401        Ok(resp)
402    }
403}
404
405async fn do_connect(
406    request: impl IntoClientRequest,
407    version: ProtocolVersion,
408    codec: CodecType,
409    platform: Platform,
410) -> WsClientResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
411    let mut request = request.into_client_request()?;
412    let mut url_obj = Url::parse(&request.uri().to_string())?;
413    url_obj.query_pairs_mut().extend_pairs(&[
414        ("version", i32::from(version).to_string()),
415        ("codec", i32::from(codec).to_string()),
416        ("platform", i32::from(platform).to_string()),
417    ]);
418    *request.uri_mut() = Uri::from_str(url_obj.as_ref()).expect("valid url");
419
420    let conn = match tokio::time::timeout(
421        CONNECT_TIMEOUT,
422        tokio_tungstenite::connect_async(request).map_err(WsClientError::from),
423    )
424    .map_err(|_| WsClientError::ConnectTimeout)
425    .await
426    .and_then(std::convert::identity)
427    {
428        Ok((conn, _)) => conn,
429        Err(err) => return Err(err),
430    };
431
432    Ok(conn)
433}
434
435async fn client_loop(
436    conn: WebSocketStream<MaybeTlsStream<TcpStream>>,
437    mut command_tx: mpsc::UnboundedReceiver<Command>,
438    mut event_sender: mpsc::UnboundedSender<WsEvent>,
439) {
440    let mut ctx = Context::new(conn, &mut command_tx, &mut event_sender);
441
442    let res = ctx.process_loop().await;
443    match res {
444        Ok(()) => return,
445        Err(err) => {
446            ctx.send_event(WsEvent::Error(err));
447        }
448    };
449
450    for sender in ctx.inflight_requests.into_values() {
451        let _ = sender.send(Err(WsClientError::Cancelled));
452    }
453}