tws_rs/
client.rs

1use std::cell::OnceCell;
2use std::io::Write;
3use std::sync::atomic::{AtomicI32, Ordering};
4use std::time::Duration;
5
6use byteorder::{BigEndian, WriteBytesExt};
7use derivative::Derivative;
8use time::macros::format_description;
9use time::OffsetDateTime;
10use time_tz::{OffsetResult, PrimitiveDateTimeExt, timezones, Tz};
11use tokio::sync::mpsc::{channel, Receiver, Sender, unbounded_channel, UnboundedSender};
12use tracing::{error, info};
13
14use crate::{Error, server_versions};
15use crate::client::transport::{Item, MessageBus, TcpMessageBus};
16use crate::client::transport::message_bus::{ResponseStream, Signal};
17use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage};
18
19pub mod market_data;
20mod transport;
21
22const MIN_SERVER_VERSION: i32 = 100;
23const MAX_SERVER_VERSION: i32 = server_versions::HISTORICAL_SCHEDULE;
24
25#[derive(Derivative)]
26#[derivative(Debug)]
27pub struct Client {
28    /// IB server version
29    pub(crate) server_version: i32,
30    /// IB Server time
31    //    pub server_time: OffsetDateTime,
32    pub(crate) connection_time: Option<OffsetDateTime>,
33    #[derivative(Debug = "ignore")]
34    pub(crate) time_zone: Option<&'static Tz>,
35
36    address: String,
37    managed_accounts: String,
38    client_id: i32, // ID of client.
39    #[derivative(Debug = "ignore")]
40    pub(crate) message_bus: Option<TcpMessageBus>,
41    order_id: i32, // Next available order_id. Starts with value returned on connection.
42    #[derivative(Debug = "ignore")]
43    receiver: Option<Receiver<Item>>,
44}
45
46pub struct ClientRef {
47    sender: Sender<Item>,
48    signals_send: Option<UnboundedSender<Signal>>,
49    time_zone: Option<&'static Tz>,
50    pub(crate) server_version: i32,
51    next_request_id: AtomicI32, // Next available request_id.
52}
53
54impl ClientRef {
55    pub async fn send(
56        &self,
57        request_id: i32,
58        msg: RequestMessage,
59    ) -> Result<ResponseStream, Error> {
60        let (sender, receiver) = unbounded_channel();
61        self.sender
62            .send((sender, request_id, msg))
63            .await
64            .map_err(|e| Error::NotImplemented)?;
65        Ok(ResponseStream::new(
66            receiver,
67            self.signals_send.clone().unwrap(),
68            Some(request_id),
69            None,
70            Some(Duration::from_secs(10)),
71        ))
72    }
73
74    pub(crate) fn check_server_version(&self, version: i32, message: &str) -> Result<(), Error> {
75        if version <= self.server_version {
76            Ok(())
77        } else {
78            Err(Error::ServerVersion(
79                version,
80                self.server_version,
81                message.into(),
82            ))
83        }
84    }
85
86    pub(crate) fn server_version(&self) -> i32 {
87        self.server_version
88    }
89
90    pub(crate) fn next_request_id(&self) -> i32 {
91        self.next_request_id.fetch_add(1, Ordering::Relaxed)
92    }
93}
94
95impl Client {
96    pub fn new(address: &str, client_id: i32) -> Self {
97        Self {
98            server_version: 0,
99            connection_time: None,
100            time_zone: None,
101            address: String::from(address),
102            managed_accounts: "".to_string(),
103            client_id,
104            message_bus: None,
105            order_id: -1,
106            receiver: None,
107        }
108    }
109
110    pub async fn connect(&mut self) -> Result<ClientRef, Error> {
111        let bus = TcpMessageBus::connect(self.address.as_str()).await?;
112        let signals_send = bus.signals_send.clone();
113        self.message_bus = Some(bus);
114        self.handshake().await?;
115        self.start_api().await?;
116        self.receive_account_info().await?;
117
118        let (sender, receiver) = channel(2048);
119        self.receiver = Some(receiver);
120        info!("{:?}", self);
121        Ok(ClientRef {
122            sender,
123            signals_send: Some(signals_send),
124            time_zone: self.time_zone.clone(),
125            server_version: self.server_version,
126            next_request_id: AtomicI32::from(9000),
127        })
128    }
129
130    pub async fn blocking_process(&mut self) -> Result<(), Error> {
131        self.message_bus
132            .as_mut()
133            .unwrap()
134            .process_messages(self.receiver.take().unwrap(), self.server_version)
135            .await?;
136        Ok(())
137    }
138    async fn handshake(&mut self) -> Result<(), Error> {
139        let prefix = "API\0";
140        let version = format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}");
141
142        let packet = prefix.to_owned() + &encode_packet(&version);
143        self.message_bus.as_mut().unwrap().write(&packet).await?;
144
145        let ack = self.message_bus.as_mut().unwrap().read_message().await;
146
147        return match ack {
148            Ok(mut response_message) => {
149                self.server_version = response_message.next_int()?;
150
151                let time = response_message.next_string()?;
152                (self.connection_time, self.time_zone) = parse_connection_time(time.as_str());
153                Ok(())
154            }
155            Err(Error::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
156                Err(Error::Simple(format!(
157                    "The server may be rejecting connections from this host: {err}"
158                )))
159            }
160            Err(err) => Err(err),
161        };
162    }
163
164    // asks server to start processing messages
165    async fn start_api(&mut self) -> Result<(), Error> {
166        const VERSION: i32 = 2;
167
168        let prelude = &mut RequestMessage::default();
169
170        prelude.push_field(&OutgoingMessages::StartApi);
171        prelude.push_field(&VERSION);
172        prelude.push_field(&self.client_id);
173
174        if self.server_version > server_versions::OPTIONAL_CAPABILITIES {
175            prelude.push_field(&"");
176        }
177
178        self.message_bus
179            .as_mut()
180            .unwrap()
181            .write_message(prelude)
182            .await?;
183
184        Ok(())
185    }
186
187    // Fetches next order id and managed accounts.
188    async fn receive_account_info(&mut self) -> Result<(), Error> {
189        let mut saw_next_order_id: bool = false;
190        let mut saw_managed_accounts: bool = false;
191
192        let mut attempts = 0;
193        const MAX_ATTEMPTS: i32 = 100;
194        loop {
195            let mut message = self.message_bus.as_mut().unwrap().read_message().await?;
196
197            match message.message_type() {
198                IncomingMessages::NextValidId => {
199                    saw_next_order_id = true;
200
201                    message.skip(); // message type
202                    message.skip(); // message version
203
204                    self.order_id = message.next_int()?;
205                }
206                IncomingMessages::ManagedAccounts => {
207                    saw_managed_accounts = true;
208
209                    message.skip(); // message type
210                    message.skip(); // message version
211
212                    self.managed_accounts = message.next_string()?;
213                }
214                IncomingMessages::Error => {
215                    error!("message: {message:?}")
216                }
217                _ => info!("message: {message:?}"),
218            }
219
220            attempts += 1;
221            if (saw_next_order_id && saw_managed_accounts) || attempts > MAX_ATTEMPTS {
222                break;
223            }
224        }
225
226        Ok(())
227    }
228}
229
230// Parses following format: 20230405 22:20:39 PST
231fn parse_connection_time(connection_time: &str) -> (Option<OffsetDateTime>, Option<&'static Tz>) {
232    let parts: Vec<&str> = connection_time.split(' ').collect();
233
234    let mut zones = timezones::find_by_name(parts[2]);
235    if zones.is_empty() {
236        if parts[2] == "中国标准时间" {
237            zones = timezones::find_by_name("China Standard Time")
238        } else {
239            error!("time zone not found for {}", parts[2]);
240            return (None, None);
241        }
242    }
243
244    let timezone = zones[0];
245
246    let format = format_description!("[year][month][day] [hour]:[minute]:[second]");
247    let date_str = format!("{} {}", parts[0], parts[1]);
248    let date = time::PrimitiveDateTime::parse(date_str.as_str(), format);
249    match date {
250        Ok(connected_at) => match connected_at.assume_timezone(timezone) {
251            OffsetResult::Some(date) => (Some(date), Some(timezone)),
252            _ => {
253                error!("error setting timezone");
254                (None, Some(timezone))
255            }
256        },
257        Err(err) => {
258            error!("could not parse connection time from {date_str}: {err}");
259            return (None, Some(timezone));
260        }
261    }
262}
263
264fn encode_packet(message: &str) -> String {
265    let data = message.as_bytes();
266
267    let mut packet: Vec<u8> = Vec::with_capacity(data.len() + 4);
268
269    packet.write_u32::<BigEndian>(data.len() as u32).unwrap();
270    packet.write_all(data).unwrap();
271
272    std::str::from_utf8(&packet).unwrap().into()
273}
274
275#[cfg(test)]
276mod test {
277    use std::time::Duration;
278
279    use tokio::task::LocalSet;
280    use tokio::time::Instant;
281    use tokio::time::sleep;
282    use tracing::info;
283    use tracing_test::traced_test;
284
285    use crate::client::Client;
286    use crate::client::market_data::historical::{
287        BarSize, historical_data, TWSDuration, WhatToShow,
288    };
289    use crate::contracts::Contract;
290    use crate::Error;
291
292    #[traced_test]
293    #[tokio::test]
294    async fn it_works() -> Result<(), Error> {
295        let mut client = Client::new("127.0.0.1:14001", 4322);
296        let client_ref = client.connect().await?;
297
298        tokio::spawn(async move {
299            sleep(Duration::from_secs(5)).await;
300            let now = Instant::now();
301            let bars = historical_data(
302                &client_ref,
303                &Contract::stock("TSLA"),
304                None,
305                TWSDuration::days(3),
306                BarSize::Min3,
307                Some(WhatToShow::Trades),
308                true,
309                true,
310            )
311            .await
312            .unwrap();
313            info!("cost {:?}, bars: {:?}", now.elapsed(), bars);
314        });
315        let local = LocalSet::new();
316        let res = local
317            .run_until(async move {
318                client.blocking_process().await?;
319                sleep(Duration::from_secs(5)).await;
320                Result::<(), Error>::Ok(())
321            })
322            .await;
323        info!("{:?}", res);
324        Ok(())
325    }
326}