1use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16use prost::Message;
17use tokio::sync::{mpsc, oneshot, Mutex};
18use tokio::time;
19use uuid::Uuid;
20
21use crate::client_helper::{dispatch_loop, keepalive};
22use crate::error::Error;
23use crate::payload;
24use crate::proto::common::{ProtoMessage, ProtoOaErrorRes, ProtoOaVersionReq, ProtoOaVersionRes};
25use crate::transport::Transport;
26
27pub type Registry = Arc<Mutex<HashMap<String, oneshot::Sender<ProtoMessage>>>>;
34
35#[derive(Debug, Clone)]
45pub struct Config {
46 pub client_id: String,
48 pub client_secret: String,
50 pub live: bool,
52 pub deadline: Duration,
54}
55
56impl Config {
57 pub fn new(client_id: impl Into<String>, client_secret: impl Into<String>) -> Self {
66 Self {
67 client_id: client_id.into(),
68 client_secret: client_secret.into(),
69 live: false,
70 deadline: Duration::from_secs(5),
71 }
72 }
73
74 pub fn live(mut self) -> Self {
85 self.live = true;
86 self
87 }
88
89 pub fn deadline(mut self, d: Duration) -> Self {
90 self.deadline = d;
91 self
92 }
93}
94
95pub struct Client {
104 pub transport: Arc<Transport>,
105 pub registry: Registry,
106 pub config: Config,
107 pub event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>>,
109}
110
111impl Client {
112 fn host(live: bool) -> &'static str {
120 if live {
121 "live.ctraderapi.com"
122 } else {
123 "demo.ctraderapi.com"
124 }
125 }
126
127 pub async fn start(config: Config) -> Result<Self, Error> {
136 Self::start_with_handler(config, None::<fn(ProtoMessage)>).await
137 }
138
139 pub async fn start_with_handler(
152 config: Config,
153 handler: Option<impl Fn(ProtoMessage) + Send + Sync + 'static>,
154 ) -> Result<Self, Error> {
155 let host = Self::host(config.live);
156 let registry: Registry = Arc::new(Mutex::new(HashMap::new()));
157
158 let (frame_tx, frame_rx) = mpsc::unbounded_channel::<Vec<u8>>();
159 let transport = Arc::new(Transport::connect(host, 5035, frame_tx).await?);
160
161 let event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>> =
162 handler.map(|h| Arc::new(h) as _);
163
164 {
166 let registry = registry.clone();
167 let event_handler = event_handler.clone();
168 tokio::spawn(dispatch_loop(frame_rx, registry, event_handler));
169 }
170
171 let client = Self {
172 transport,
173 registry,
174 config: config.clone(),
175 event_handler,
176 };
177
178 client.application_auth().await?;
180
181 {
183 let transport = client.transport.clone();
184 tokio::spawn(async move {
185 keepalive(transport).await;
186 });
187 }
188
189 Ok(client)
190 }
191
192 pub async fn command<Q, R>(&self, req_type: u32, req: Q, res_type: u32) -> Result<R, Error>
203 where
204 Q: Message,
205 R: Message + Default,
206 {
207 let id = Uuid::new_v4().to_string();
208
209 let mut payload_bytes = Vec::new();
211 req.encode(&mut payload_bytes)?;
212
213 let envelope = ProtoMessage {
215 payload_type: req_type,
216 payload: Some(payload_bytes),
217 client_msg_id: Some(id.clone()),
218 };
219 let mut frame = Vec::new();
220 envelope.encode(&mut frame)?;
221
222 let (tx, rx) = oneshot::channel::<ProtoMessage>();
224 {
225 let mut reg = self.registry.lock().await;
226 reg.insert(id.clone(), tx);
227 }
228
229 self.transport.send(&frame).await?;
230
231 let response_envelope = time::timeout(self.config.deadline, rx)
233 .await
234 .map_err(|_| Error::Timeout)?
235 .map_err(|_| Error::Disconnected)?;
236
237 {
239 let mut reg = self.registry.lock().await;
240 reg.remove(&id);
241 }
242
243 let pt = response_envelope.payload_type;
244
245 if pt == payload::OA_ERROR_RES || pt == payload::ERROR_RES {
247 let err =
248 ProtoOaErrorRes::decode(response_envelope.payload.as_deref().unwrap_or_default())?;
249 return Err(Error::Api {
250 error_code: err.error_code,
251 description: err.description.clone().unwrap_or_default(),
252 });
253 }
254
255 if pt != res_type {
256 return Err(Error::UnexpectedPayload(pt));
257 }
258
259 Ok(R::decode(
260 response_envelope.payload.as_deref().unwrap_or_default(),
261 )?)
262 }
263
264 pub async fn version(&self) -> Result<ProtoOaVersionRes, Error> {
278 let req = ProtoOaVersionReq {
279 payload_type: Some(payload::OA_VERSION_REQ as i32),
280 };
281 self.command(payload::OA_VERSION_REQ, req, payload::OA_VERSION_RES)
282 .await
283 }
284}
285
286#[cfg(test)]
287mod tests {
288
289 #[async_std::test]
290 async fn test() {}
291}