tqsdk_rs/
client.rs

1//! 客户端模块
2//!
3//! 统一的客户端入口
4
5use crate::auth::{Authenticator, TqAuth};
6use crate::datamanager::{DataManager, DataManagerConfig};
7use crate::errors::{Result, TqError};
8use crate::quote::QuoteSubscription;
9use crate::series::SeriesAPI;
10use crate::trade_session::TradeSession;
11use crate::websocket::{TqQuoteWebsocket, WebSocketConfig};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16/// 客户端配置
17#[derive(Debug, Clone)]
18pub struct ClientConfig {
19    /// 日志级别
20    pub log_level: String,
21    /// 默认视图宽度
22    pub view_width: usize,
23    /// 开发模式
24    pub development: bool,
25}
26
27impl Default for ClientConfig {
28    fn default() -> Self {
29        ClientConfig {
30            log_level: "info".to_string(),
31            view_width: 10000,
32            development: false,
33        }
34    }
35}
36
37/// 客户端选项
38pub type ClientOption = Box<dyn Fn(&mut ClientConfig)>;
39
40/// 客户端构建器
41pub struct ClientBuilder {
42    username: String,
43    password: String,
44    config: ClientConfig,
45    auth: Option<Arc<RwLock<dyn Authenticator>>>,
46}
47
48impl ClientBuilder {
49    /// 创建新的客户端构建器
50    ///
51    /// # 参数
52    ///
53    /// * `username` - 用户名
54    /// * `password` - 密码
55    ///
56    /// # 示例
57    ///
58    /// ```no_run
59    /// # use tqsdk_rs::*;
60    /// # async fn example() -> Result<()> {
61    /// let client = ClientBuilder::new("username", "password")
62    ///     .log_level("debug")
63    ///     .view_width(5000)
64    ///     .build()
65    ///     .await?;
66    /// # Ok(())
67    /// # }
68    /// ```
69    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
70        ClientBuilder {
71            username: username.into(),
72            password: password.into(),
73            config: ClientConfig::default(),
74            auth: None,
75        }
76    }
77
78    /// 设置日志级别
79    pub fn log_level(mut self, level: impl Into<String>) -> Self {
80        self.config.log_level = level.into();
81        self
82    }
83
84    /// 设置默认视图宽度
85    pub fn view_width(mut self, width: usize) -> Self {
86        self.config.view_width = width;
87        self
88    }
89
90    /// 设置开发模式
91    pub fn development(mut self, dev: bool) -> Self {
92        self.config.development = dev;
93        self
94    }
95
96    /// 设置完整配置
97    pub fn config(mut self, config: ClientConfig) -> Self {
98        self.config = config;
99        self
100    }
101
102    /// 使用自定义认证器(高级用法)
103    ///
104    /// 如果不设置,将使用默认的 TqAuth
105    ///
106    /// # 示例
107    ///
108    /// ```no_run
109    /// # use tqsdk_rs::*;
110    /// # use tqsdk_rs::auth::TqAuth;
111    /// # async fn example() -> Result<()> {
112    /// let mut auth = TqAuth::new("username".to_string(), "password".to_string());
113    /// auth.login().await?;
114    ///
115    /// let client = Client::builder("username", "password")
116    ///     .auth(auth)
117    ///     .build()
118    ///     .await?;
119    /// # Ok(())
120    /// # }
121    /// ```
122    pub fn auth<A: Authenticator + 'static>(mut self, auth: A) -> Self {
123        self.auth = Some(Arc::new(RwLock::new(auth)));
124        self
125    }
126
127    /// 构建客户端
128    ///
129    /// # 错误
130    ///
131    /// 如果认证失败,返回错误
132    pub async fn build(self) -> Result<Client> {
133        // 初始化日志
134        crate::logger::init_logger(&self.config.log_level, true);
135
136        // 创建或使用提供的认证器
137        let auth: Arc<RwLock<dyn Authenticator>> = if let Some(custom_auth) = self.auth {
138            custom_auth
139        } else {
140            let mut auth = TqAuth::new(self.username.clone(), self.password.clone());
141            auth.login().await?;
142            Arc::new(RwLock::new(auth))
143        };
144
145        // 创建数据管理器
146        let dm_config = DataManagerConfig {
147            default_view_width: self.config.view_width,
148            enable_auto_cleanup: true,
149        };
150        let initial_data = HashMap::new();
151        let dm = Arc::new(DataManager::new(initial_data, dm_config));
152
153        Ok(Client {
154            _username: self.username,
155            _config: self.config,
156            auth,
157            dm,
158            quotes_ws: None,
159            series_api: None,
160            trade_sessions: Arc::new(RwLock::new(HashMap::new())),
161        })
162    }
163}
164
165/// 客户端
166pub struct Client {
167    _username: String,
168    _config: ClientConfig,
169    auth: Arc<RwLock<dyn Authenticator>>,
170    dm: Arc<DataManager>,
171    quotes_ws: Option<Arc<TqQuoteWebsocket>>,
172    series_api: Option<Arc<SeriesAPI>>,
173    trade_sessions: Arc<RwLock<HashMap<String, Arc<TradeSession>>>>,
174}
175
176impl Client {
177    /// 创建新的客户端(使用默认配置)
178    ///
179    /// 这是一个便捷方法,等同于:
180    /// ```no_run
181    /// # use tqsdk_rs::*;
182    /// # async fn example() -> Result<()> {
183    /// let client = ClientBuilder::new("username", "password")
184    ///     .config(ClientConfig::default())
185    ///     .build()
186    ///     .await?;
187    /// # Ok(())
188    /// # }
189    /// ```
190    ///
191    /// 如需更多配置选项,请使用 `ClientBuilder`
192    pub async fn new(username: &str, password: &str, config: ClientConfig) -> Result<Self> {
193        ClientBuilder::new(username, password)
194            .config(config)
195            .build()
196            .await
197    }
198
199    /// 创建客户端构建器
200    ///
201    /// # 示例
202    ///
203    /// ```no_run
204    /// # use tqsdk_rs::*;
205    /// # async fn example() -> Result<()> {
206    /// let client = Client::builder("username", "password")
207    ///     .log_level("debug")
208    ///     .view_width(5000)
209    ///     .development(true)
210    ///     .build()
211    ///     .await?;
212    /// # Ok(())
213    /// # }
214    /// ```
215    pub fn builder(username: impl Into<String>, password: impl Into<String>) -> ClientBuilder {
216        ClientBuilder::new(username, password)
217    }
218
219    /// 初始化行情功能
220    pub async fn init_market(&mut self) -> Result<()> {
221        let auth = self.auth.read().await;
222        let md_url = auth.get_md_url(false, false).await?;
223
224        let mut ws_config = WebSocketConfig::default();
225        ws_config.headers = auth.base_header();
226
227        let quotes_ws = Arc::new(TqQuoteWebsocket::new(
228            md_url,
229            Arc::clone(&self.dm),
230            ws_config,
231        ));
232
233        quotes_ws.init(false).await?;
234
235        self.quotes_ws = Some(Arc::clone(&quotes_ws));
236
237        // 创建 SeriesAPI(传入 auth)
238        let series_api = Arc::new(SeriesAPI::new(
239            Arc::clone(&self.dm),
240            quotes_ws,
241            Arc::clone(&self.auth),
242        ));
243        self.series_api = Some(series_api);
244
245        Ok(())
246    }
247
248    /// 设置认证器
249    ///
250    /// 允许在运行时更换认证器,例如切换账号或更新 token
251    ///
252    /// # 注意
253    ///
254    /// - 更换认证器后,需要重新调用 `init_market()` 来使用新的认证信息
255    /// - 已创建的 `SeriesAPI` 和 `TradeSession` 仍使用旧的认证器
256    ///
257    /// # 示例
258    ///
259    /// ```no_run
260    /// # use tqsdk_rs::*;
261    /// # use tqsdk_rs::auth::TqAuth;
262    /// # async fn example() -> Result<()> {
263    /// let mut client = Client::new("user1", "pass1", ClientConfig::default()).await?;
264    ///
265    /// // 切换到另一个账号
266    /// let mut new_auth = TqAuth::new("user2".to_string(), "pass2".to_string());
267    /// new_auth.login().await?;
268    /// client.set_auth(new_auth).await;
269    ///
270    /// // 重新初始化行情功能
271    /// client.init_market().await?;
272    /// # Ok(())
273    /// # }
274    /// ```
275    pub async fn set_auth<A: Authenticator + 'static>(&mut self, auth: A) {
276        self.auth = Arc::new(RwLock::new(auth));
277    }
278
279    /// 获取认证器的只读引用
280    ///
281    /// 用于检查当前的认证状态或权限
282    ///
283    /// # 示例
284    ///
285    /// ```no_run
286    /// # use tqsdk_rs::*;
287    /// # async fn example(client: &Client) -> Result<()> {
288    /// let auth = client.get_auth().await;
289    /// if auth.has_feature("futr") {
290    ///     println!("有期货权限");
291    /// }
292    /// # Ok(())
293    /// # }
294    /// ```
295    pub async fn get_auth(&self) -> tokio::sync::RwLockReadGuard<'_, dyn Authenticator> {
296        self.auth.read().await
297    }
298
299    /// 获取 Series API
300    pub fn series(&self) -> Result<Arc<SeriesAPI>> {
301        self.series_api
302            .clone()
303            .ok_or_else(|| TqError::InternalError("Series API 未初始化".to_string()))
304    }
305
306    /// 订阅 Quote
307    pub async fn subscribe_quote(&self, symbols: &[&str]) -> Result<Arc<QuoteSubscription>> {
308        if self.quotes_ws.is_none() {
309            return Err(TqError::InternalError(
310                "行情 WebSocket 未初始化".to_string(),
311            ));
312        }
313        {
314            let auth = self.auth.read().await;
315            auth.has_md_grants(symbols)? 
316        }
317        let symbol_list: Vec<String> = symbols.iter().map(|s| s.to_string()).collect();
318        let qs = Arc::new(QuoteSubscription::new(
319            Arc::clone(&self.dm),
320            self.quotes_ws.as_ref().unwrap().clone(),
321            symbol_list,
322        ));
323
324        // 启动订阅
325        qs.start().await?;
326
327        Ok(qs)
328    }
329
330    /// 创建交易会话(不自动连接)
331    ///
332    /// # 重要提示
333    ///
334    /// 由于 TradeSession 使用 broadcast 队列,建议按以下顺序使用:
335    ///
336    /// ```no_run
337    /// # use tqsdk_rs::*;
338    /// # async fn example() -> Result<()> {
339    /// # let client = Client::new("user", "pass", ClientConfig::default()).await?;
340    /// // 1. 创建会话(不连接)
341    /// let session = client.create_trade_session("simnow", "user_id", "password").await?;
342    ///
343    /// // 2. 先注册回调或订阅 channel(避免丢失消息)
344    /// session.on_account(|account| {
345    ///     println!("账户: {}", account.balance);
346    /// }).await;
347    ///
348    /// // 3. 最后连接
349    /// session.connect().await?;
350    /// # Ok(())
351    /// # }
352    /// ```
353    ///
354    /// # 参数
355    ///
356    /// * `broker` - 期货公司代码(如 "simnow")
357    /// * `user_id` - 用户账号
358    /// * `password` - 密码
359    ///
360    /// # 返回
361    ///
362    /// 返回 `TradeSession` 实例,需要手动调用 `connect()` 连接
363    pub async fn create_trade_session(
364        &self,
365        broker: &str,
366        user_id: &str,
367        password: &str,
368    ) -> Result<Arc<TradeSession>> {
369        // 获取交易服务器地址
370        let auth = self.auth.read().await;
371        let broker_info = auth.get_td_url(broker, user_id).await?;
372
373        let mut ws_config = WebSocketConfig::default();
374        ws_config.headers = auth.base_header();
375        drop(auth);
376
377        // 创建交易会话(不自动连接)
378        let session = Arc::new(TradeSession::new(
379            broker.to_string(),
380            user_id.to_string(),
381            password.to_string(),
382            Arc::clone(&self.dm),
383            broker_info.url,
384            ws_config,
385        ));
386
387        // 保存会话
388        let key = format!("{}:{}", broker, user_id);
389        let mut sessions = self.trade_sessions.write().await;
390        sessions.insert(key, Arc::clone(&session));
391
392        Ok(session)
393    }
394
395    /// 注册交易会话
396    pub async fn register_trade_session(&self, key: &str, session: Arc<TradeSession>) {
397        let mut sessions = self.trade_sessions.write().await;
398        sessions.insert(key.to_string(), session);
399    }
400
401    /// 获取交易会话
402    pub async fn get_trade_session(&self, key: &str) -> Option<Arc<TradeSession>> {
403        let sessions = self.trade_sessions.read().await;
404        sessions.get(key).cloned()
405    }
406
407    /// 关闭客户端
408    pub async fn close(&self) -> Result<()> {
409        if let Some(ws) = &self.quotes_ws {
410            ws.close().await?;
411        }
412
413        let sessions = self.trade_sessions.read().await;
414        for (_key, trader) in sessions.iter() {
415            trader.close().await?;
416        }
417
418        Ok(())
419    }
420}