Skip to main content

ctrader_rs/
client.rs

1/// cTrader Open API client.
2///
3///   - `Client::start()` connects, authenticates the *application*, and spawns
4///     the keepalive heartbeat goroutine (here: Tokio task).
5///   - `command::<Req, Res>()` is the generic helper that sends a request and
6///     awaits the matched response
7///   - Unsolicited events (spot prices, execution events, …) are delivered via
8///     the optional `event_handler` closure you pass at construction time.
9///
10///
11///
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16use prost::Message;
17use tokio::sync::{mpsc, oneshot, Mutex};
18use tokio::time;
19use uuid::Uuid;
20
21use crate::client_helper::{dispatch_loop, keepalive};
22use crate::error::Error;
23use crate::payload;
24use crate::proto::common::{ProtoMessage, ProtoOaErrorRes, ProtoOaVersionReq, ProtoOaVersionRes};
25use crate::transport::Transport;
26
27/// Request registry
28///
29///
30///
31///
32///
33pub type Registry = Arc<Mutex<HashMap<String, oneshot::Sender<ProtoMessage>>>>;
34
35/// Configuration for the cTrader client.
36///
37///
38///
39///
40///
41///
42///
43///
44#[derive(Debug, Clone)]
45pub struct Config {
46    /// `clientId` from openapi.ctrader.com → your app → Credentials.
47    pub client_id: String,
48    /// `clientSecret` from the same location.
49    pub client_secret: String,
50    /// Use live servers (`live.ctraderapi.com`) when `true`, demo otherwise.
51    pub live: bool,
52    /// Per-request deadline. Defaults to 5 seconds.
53    pub deadline: Duration,
54}
55
56impl Config {
57    ///
58    ///
59    ///
60    ///
61    ///
62    ///
63    ///
64    ///
65    pub fn new(client_id: impl Into<String>, client_secret: impl Into<String>) -> Self {
66        Self {
67            client_id: client_id.into(),
68            client_secret: client_secret.into(),
69            live: false,
70            deadline: Duration::from_secs(5),
71        }
72    }
73
74    ///
75    ///
76    ///
77    ///
78    ///
79    ///
80    ///
81    ///
82    ///
83    ///
84    pub fn live(mut self) -> Self {
85        self.live = true;
86        self
87    }
88
89    pub fn deadline(mut self, d: Duration) -> Self {
90        self.deadline = d;
91        self
92    }
93}
94
95/// The main client.  Construct with [`Client::start`].
96///
97///
98///
99///
100///
101///
102///
103pub struct Client {
104    pub transport: Arc<Transport>,
105    pub registry: Registry,
106    pub config: Config,
107    /// Called for every unsolicited event (spot, execution, …).
108    pub event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>>,
109}
110
111impl Client {
112    /// Build the host string based on live/demo setting.
113    ///
114    ///
115    ///
116    ///
117    ///
118    ///
119    fn host(live: bool) -> &'static str {
120        if live {
121            "live.ctraderapi.com"
122        } else {
123            "demo.ctraderapi.com"
124        }
125    }
126
127    /// Connect, authenticate the application, and start the keepalive task.
128    ///
129    ///
130    ///
131    ///
132    ///
133    ///
134    ///
135    pub async fn start(config: Config) -> Result<Self, Error> {
136        Self::start_with_handler(config, None::<fn(ProtoMessage)>).await
137    }
138
139    /// Same as [`start`] but with an event handler for unsolicited messages.
140    ///
141    ///
142    ///
143    ///
144    ///
145    ///
146    ///
147    ///
148    ///
149    ///
150    ///
151    pub async fn start_with_handler(
152        config: Config,
153        handler: Option<impl Fn(ProtoMessage) + Send + Sync + 'static>,
154    ) -> Result<Self, Error> {
155        let host = Self::host(config.live);
156        let registry: Registry = Arc::new(Mutex::new(HashMap::new()));
157
158        let (frame_tx, frame_rx) = mpsc::unbounded_channel::<Vec<u8>>();
159        let transport = Arc::new(Transport::connect(host, 5035, frame_tx).await?);
160
161        let event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>> =
162            handler.map(|h| Arc::new(h) as _);
163
164        // Spawn message dispatcher
165        {
166            let registry = registry.clone();
167            let event_handler = event_handler.clone();
168            tokio::spawn(dispatch_loop(frame_rx, registry, event_handler));
169        }
170
171        let client = Self {
172            transport,
173            registry,
174            config: config.clone(),
175            event_handler,
176        };
177
178        // Authenticate the application (blocking — must succeed before use)
179        client.application_auth().await?;
180
181        // Heartbeat keepalive every 10 s (matches Go SDK)
182        {
183            let transport = client.transport.clone();
184            tokio::spawn(async move {
185                keepalive(transport).await;
186            });
187        }
188
189        Ok(client)
190    }
191
192    // ── Generic request/response helper ──────────────────────────────────────
193
194    /// Encode `req`, send it, await the response envelope whose `payloadType`
195    /// matches `res_type`, and decode it as `R`.
196    ///
197    ///
198    ///
199    ///
200    ///
201    ///
202    pub async fn command<Q, R>(&self, req_type: u32, req: Q, res_type: u32) -> Result<R, Error>
203    where
204        Q: Message,
205        R: Message + Default,
206    {
207        let id = Uuid::new_v4().to_string();
208
209        // Encode inner message
210        let mut payload_bytes = Vec::new();
211        req.encode(&mut payload_bytes)?;
212
213        // Wrap in ProtoMessage envelope
214        let envelope = ProtoMessage {
215            payload_type: req_type,
216            payload: Some(payload_bytes),
217            client_msg_id: Some(id.clone()),
218        };
219        let mut frame = Vec::new();
220        envelope.encode(&mut frame)?;
221
222        // Register callback channel before sending (avoid race)
223        let (tx, rx) = oneshot::channel::<ProtoMessage>();
224        {
225            let mut reg = self.registry.lock().await;
226            reg.insert(id.clone(), tx);
227        }
228
229        self.transport.send(&frame).await?;
230
231        // Await response with deadline
232        let response_envelope = time::timeout(self.config.deadline, rx)
233            .await
234            .map_err(|_| Error::Timeout)?
235            .map_err(|_| Error::Disconnected)?;
236
237        // Cleanup registry
238        {
239            let mut reg = self.registry.lock().await;
240            reg.remove(&id);
241        }
242
243        let pt = response_envelope.payload_type;
244
245        // Check for OA-level tracing::error response
246        if pt == payload::OA_ERROR_RES || pt == payload::ERROR_RES {
247            let err =
248                ProtoOaErrorRes::decode(response_envelope.payload.as_deref().unwrap_or_default())?;
249            return Err(Error::Api {
250                error_code: err.error_code,
251                description: err.description.clone().unwrap_or_default(),
252            });
253        }
254
255        if pt != res_type {
256            return Err(Error::UnexpectedPayload(pt));
257        }
258
259        Ok(R::decode(
260            response_envelope.payload.as_deref().unwrap_or_default(),
261        )?)
262    }
263
264    /// Get the API version from the server.
265    ///
266    ///
267    ///
268    ///
269    ///
270    ///
271    ///
272    ///
273    ///
274    ///
275    ///
276    ///
277    pub async fn version(&self) -> Result<ProtoOaVersionRes, Error> {
278        let req = ProtoOaVersionReq {
279            payload_type: Some(payload::OA_VERSION_REQ as i32),
280        };
281        self.command(payload::OA_VERSION_REQ, req, payload::OA_VERSION_RES)
282            .await
283    }
284}
285
286#[cfg(test)]
287mod tests {
288
289    #[async_std::test]
290    async fn test() {}
291}