1use std::cell::OnceCell;
2use std::io::Write;
3use std::sync::atomic::{AtomicI32, Ordering};
4use std::time::Duration;
5
6use byteorder::{BigEndian, WriteBytesExt};
7use derivative::Derivative;
8use time::macros::format_description;
9use time::OffsetDateTime;
10use time_tz::{OffsetResult, PrimitiveDateTimeExt, timezones, Tz};
11use tokio::sync::mpsc::{channel, Receiver, Sender, unbounded_channel, UnboundedSender};
12use tracing::{error, info};
13
14use crate::{Error, server_versions};
15use crate::client::transport::{Item, MessageBus, TcpMessageBus};
16use crate::client::transport::message_bus::{ResponseStream, Signal};
17use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage};
18
19pub mod market_data;
20mod transport;
21
22const MIN_SERVER_VERSION: i32 = 100;
23const MAX_SERVER_VERSION: i32 = server_versions::HISTORICAL_SCHEDULE;
24
25#[derive(Derivative)]
26#[derivative(Debug)]
27pub struct Client {
28 pub(crate) server_version: i32,
30 pub(crate) connection_time: Option<OffsetDateTime>,
33 #[derivative(Debug = "ignore")]
34 pub(crate) time_zone: Option<&'static Tz>,
35
36 address: String,
37 managed_accounts: String,
38 client_id: i32, #[derivative(Debug = "ignore")]
40 pub(crate) message_bus: Option<TcpMessageBus>,
41 order_id: i32, #[derivative(Debug = "ignore")]
43 receiver: Option<Receiver<Item>>,
44}
45
46pub struct ClientRef {
47 sender: Sender<Item>,
48 signals_send: Option<UnboundedSender<Signal>>,
49 time_zone: Option<&'static Tz>,
50 pub(crate) server_version: i32,
51 next_request_id: AtomicI32, }
53
54impl ClientRef {
55 pub async fn send(
56 &self,
57 request_id: i32,
58 msg: RequestMessage,
59 ) -> Result<ResponseStream, Error> {
60 let (sender, receiver) = unbounded_channel();
61 self.sender
62 .send((sender, request_id, msg))
63 .await
64 .map_err(|e| Error::NotImplemented)?;
65 Ok(ResponseStream::new(
66 receiver,
67 self.signals_send.clone().unwrap(),
68 Some(request_id),
69 None,
70 Some(Duration::from_secs(10)),
71 ))
72 }
73
74 pub(crate) fn check_server_version(&self, version: i32, message: &str) -> Result<(), Error> {
75 if version <= self.server_version {
76 Ok(())
77 } else {
78 Err(Error::ServerVersion(
79 version,
80 self.server_version,
81 message.into(),
82 ))
83 }
84 }
85
86 pub(crate) fn server_version(&self) -> i32 {
87 self.server_version
88 }
89
90 pub(crate) fn next_request_id(&self) -> i32 {
91 self.next_request_id.fetch_add(1, Ordering::Relaxed)
92 }
93}
94
95impl Client {
96 pub fn new(address: &str, client_id: i32) -> Self {
97 Self {
98 server_version: 0,
99 connection_time: None,
100 time_zone: None,
101 address: String::from(address),
102 managed_accounts: "".to_string(),
103 client_id,
104 message_bus: None,
105 order_id: -1,
106 receiver: None,
107 }
108 }
109
110 pub async fn connect(&mut self) -> Result<ClientRef, Error> {
111 let bus = TcpMessageBus::connect(self.address.as_str()).await?;
112 let signals_send = bus.signals_send.clone();
113 self.message_bus = Some(bus);
114 self.handshake().await?;
115 self.start_api().await?;
116 self.receive_account_info().await?;
117
118 let (sender, receiver) = channel(2048);
119 self.receiver = Some(receiver);
120 info!("{:?}", self);
121 Ok(ClientRef {
122 sender,
123 signals_send: Some(signals_send),
124 time_zone: self.time_zone.clone(),
125 server_version: self.server_version,
126 next_request_id: AtomicI32::from(9000),
127 })
128 }
129
130 pub async fn blocking_process(&mut self) -> Result<(), Error> {
131 self.message_bus
132 .as_mut()
133 .unwrap()
134 .process_messages(self.receiver.take().unwrap(), self.server_version)
135 .await?;
136 Ok(())
137 }
138 async fn handshake(&mut self) -> Result<(), Error> {
139 let prefix = "API\0";
140 let version = format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}");
141
142 let packet = prefix.to_owned() + &encode_packet(&version);
143 self.message_bus.as_mut().unwrap().write(&packet).await?;
144
145 let ack = self.message_bus.as_mut().unwrap().read_message().await;
146
147 return match ack {
148 Ok(mut response_message) => {
149 self.server_version = response_message.next_int()?;
150
151 let time = response_message.next_string()?;
152 (self.connection_time, self.time_zone) = parse_connection_time(time.as_str());
153 Ok(())
154 }
155 Err(Error::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
156 Err(Error::Simple(format!(
157 "The server may be rejecting connections from this host: {err}"
158 )))
159 }
160 Err(err) => Err(err),
161 };
162 }
163
164 async fn start_api(&mut self) -> Result<(), Error> {
166 const VERSION: i32 = 2;
167
168 let prelude = &mut RequestMessage::default();
169
170 prelude.push_field(&OutgoingMessages::StartApi);
171 prelude.push_field(&VERSION);
172 prelude.push_field(&self.client_id);
173
174 if self.server_version > server_versions::OPTIONAL_CAPABILITIES {
175 prelude.push_field(&"");
176 }
177
178 self.message_bus
179 .as_mut()
180 .unwrap()
181 .write_message(prelude)
182 .await?;
183
184 Ok(())
185 }
186
187 async fn receive_account_info(&mut self) -> Result<(), Error> {
189 let mut saw_next_order_id: bool = false;
190 let mut saw_managed_accounts: bool = false;
191
192 let mut attempts = 0;
193 const MAX_ATTEMPTS: i32 = 100;
194 loop {
195 let mut message = self.message_bus.as_mut().unwrap().read_message().await?;
196
197 match message.message_type() {
198 IncomingMessages::NextValidId => {
199 saw_next_order_id = true;
200
201 message.skip(); message.skip(); self.order_id = message.next_int()?;
205 }
206 IncomingMessages::ManagedAccounts => {
207 saw_managed_accounts = true;
208
209 message.skip(); message.skip(); self.managed_accounts = message.next_string()?;
213 }
214 IncomingMessages::Error => {
215 error!("message: {message:?}")
216 }
217 _ => info!("message: {message:?}"),
218 }
219
220 attempts += 1;
221 if (saw_next_order_id && saw_managed_accounts) || attempts > MAX_ATTEMPTS {
222 break;
223 }
224 }
225
226 Ok(())
227 }
228}
229
230fn parse_connection_time(connection_time: &str) -> (Option<OffsetDateTime>, Option<&'static Tz>) {
232 let parts: Vec<&str> = connection_time.split(' ').collect();
233
234 let mut zones = timezones::find_by_name(parts[2]);
235 if zones.is_empty() {
236 if parts[2] == "中国标准时间" {
237 zones = timezones::find_by_name("China Standard Time")
238 } else {
239 error!("time zone not found for {}", parts[2]);
240 return (None, None);
241 }
242 }
243
244 let timezone = zones[0];
245
246 let format = format_description!("[year][month][day] [hour]:[minute]:[second]");
247 let date_str = format!("{} {}", parts[0], parts[1]);
248 let date = time::PrimitiveDateTime::parse(date_str.as_str(), format);
249 match date {
250 Ok(connected_at) => match connected_at.assume_timezone(timezone) {
251 OffsetResult::Some(date) => (Some(date), Some(timezone)),
252 _ => {
253 error!("error setting timezone");
254 (None, Some(timezone))
255 }
256 },
257 Err(err) => {
258 error!("could not parse connection time from {date_str}: {err}");
259 return (None, Some(timezone));
260 }
261 }
262}
263
264fn encode_packet(message: &str) -> String {
265 let data = message.as_bytes();
266
267 let mut packet: Vec<u8> = Vec::with_capacity(data.len() + 4);
268
269 packet.write_u32::<BigEndian>(data.len() as u32).unwrap();
270 packet.write_all(data).unwrap();
271
272 std::str::from_utf8(&packet).unwrap().into()
273}
274
275#[cfg(test)]
276mod test {
277 use std::time::Duration;
278
279 use tokio::task::LocalSet;
280 use tokio::time::Instant;
281 use tokio::time::sleep;
282 use tracing::info;
283 use tracing_test::traced_test;
284
285 use crate::client::Client;
286 use crate::client::market_data::historical::{
287 BarSize, historical_data, TWSDuration, WhatToShow,
288 };
289 use crate::contracts::Contract;
290 use crate::Error;
291
292 #[traced_test]
293 #[tokio::test]
294 async fn it_works() -> Result<(), Error> {
295 let mut client = Client::new("127.0.0.1:14001", 4322);
296 let client_ref = client.connect().await?;
297
298 tokio::spawn(async move {
299 sleep(Duration::from_secs(5)).await;
300 let now = Instant::now();
301 let bars = historical_data(
302 &client_ref,
303 &Contract::stock("TSLA"),
304 None,
305 TWSDuration::days(3),
306 BarSize::Min3,
307 Some(WhatToShow::Trades),
308 true,
309 true,
310 )
311 .await
312 .unwrap();
313 info!("cost {:?}, bars: {:?}", now.elapsed(), bars);
314 });
315 let local = LocalSet::new();
316 let res = local
317 .run_until(async move {
318 client.blocking_process().await?;
319 sleep(Duration::from_secs(5)).await;
320 Result::<(), Error>::Ok(())
321 })
322 .await;
323 info!("{:?}", res);
324 Ok(())
325 }
326}