ctrader-rs 0.1.2

Rust SDK for the cTrader Open API
Documentation
/// cTrader Open API client.
///
///   - `Client::start()` connects, authenticates the *application*, and spawns
///     the keepalive heartbeat goroutine (here: Tokio task).
///   - `command::<Req, Res>()` is the generic helper that sends a request and
///     awaits the matched response
///   - Unsolicited events (spot prices, execution events, …) are delivered via
///     the optional `event_handler` closure you pass at construction time.
///
///
///
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use prost::Message;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time;
use uuid::Uuid;

use crate::client_helper::{dispatch_loop, keepalive};
use crate::error::Error;
use crate::payload;
use crate::proto::common::{ProtoMessage, ProtoOaErrorRes, ProtoOaVersionReq, ProtoOaVersionRes};
use crate::transport::Transport;

/// Request registry
///
///
///
///
///
pub type Registry = Arc<Mutex<HashMap<String, oneshot::Sender<ProtoMessage>>>>;

/// Configuration for the cTrader client.
///
///
///
///
///
///
///
///
#[derive(Debug, Clone)]
pub struct Config {
    /// `clientId` from openapi.ctrader.com → your app → Credentials.
    pub client_id: String,
    /// `clientSecret` from the same location.
    pub client_secret: String,
    /// Use live servers (`live.ctraderapi.com`) when `true`, demo otherwise.
    pub live: bool,
    /// Per-request deadline. Defaults to 5 seconds.
    pub deadline: Duration,
}

impl Config {
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    pub fn new(client_id: impl Into<String>, client_secret: impl Into<String>) -> Self {
        Self {
            client_id: client_id.into(),
            client_secret: client_secret.into(),
            live: false,
            deadline: Duration::from_secs(5),
        }
    }

    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    pub fn live(mut self) -> Self {
        self.live = true;
        self
    }

    pub fn deadline(mut self, d: Duration) -> Self {
        self.deadline = d;
        self
    }
}

/// The main client.  Construct with [`Client::start`].
///
///
///
///
///
///
///
pub struct Client {
    pub transport: Arc<Transport>,
    pub registry: Registry,
    pub config: Config,
    /// Called for every unsolicited event (spot, execution, …).
    pub event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>>,
}

impl Client {
    /// Build the host string based on live/demo setting.
    ///
    ///
    ///
    ///
    ///
    ///
    fn host(live: bool) -> &'static str {
        if live {
            "live.ctraderapi.com"
        } else {
            "demo.ctraderapi.com"
        }
    }

    /// Connect, authenticate the application, and start the keepalive task.
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    pub async fn start(config: Config) -> Result<Self, Error> {
        Self::start_with_handler(config, None::<fn(ProtoMessage)>).await
    }

    /// Same as [`start`] but with an event handler for unsolicited messages.
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    pub async fn start_with_handler(
        config: Config,
        handler: Option<impl Fn(ProtoMessage) + Send + Sync + 'static>,
    ) -> Result<Self, Error> {
        let host = Self::host(config.live);
        let registry: Registry = Arc::new(Mutex::new(HashMap::new()));

        let (frame_tx, frame_rx) = mpsc::unbounded_channel::<Vec<u8>>();
        let transport = Arc::new(Transport::connect(host, 5035, frame_tx).await?);

        let event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>> =
            handler.map(|h| Arc::new(h) as _);

        // Spawn message dispatcher
        {
            let registry = registry.clone();
            let event_handler = event_handler.clone();
            tokio::spawn(dispatch_loop(frame_rx, registry, event_handler));
        }

        let client = Self {
            transport,
            registry,
            config: config.clone(),
            event_handler,
        };

        // Authenticate the application (blocking — must succeed before use)
        client.application_auth().await?;

        // Heartbeat keepalive every 10 s (matches Go SDK)
        {
            let transport = client.transport.clone();
            tokio::spawn(async move {
                keepalive(transport).await;
            });
        }

        Ok(client)
    }

    // ── Generic request/response helper ──────────────────────────────────────

    /// Encode `req`, send it, await the response envelope whose `payloadType`
    /// matches `res_type`, and decode it as `R`.
    ///
    ///
    ///
    ///
    ///
    ///
    pub async fn command<Q, R>(&self, req_type: u32, req: Q, res_type: u32) -> Result<R, Error>
    where
        Q: Message,
        R: Message + Default,
    {
        let id = Uuid::new_v4().to_string();

        // Encode inner message
        let mut payload_bytes = Vec::new();
        req.encode(&mut payload_bytes)?;

        // Wrap in ProtoMessage envelope
        let envelope = ProtoMessage {
            payload_type: req_type,
            payload: Some(payload_bytes),
            client_msg_id: Some(id.clone()),
        };
        let mut frame = Vec::new();
        envelope.encode(&mut frame)?;

        // Register callback channel before sending (avoid race)
        let (tx, rx) = oneshot::channel::<ProtoMessage>();
        {
            let mut reg = self.registry.lock().await;
            reg.insert(id.clone(), tx);
        }

        self.transport.send(&frame).await?;

        // Await response with deadline
        let response_envelope = time::timeout(self.config.deadline, rx)
            .await
            .map_err(|_| Error::Timeout)?
            .map_err(|_| Error::Disconnected)?;

        // Cleanup registry
        {
            let mut reg = self.registry.lock().await;
            reg.remove(&id);
        }

        let pt = response_envelope.payload_type;

        // Check for OA-level tracing::error response
        if pt == payload::OA_ERROR_RES || pt == payload::ERROR_RES {
            let err =
                ProtoOaErrorRes::decode(response_envelope.payload.as_deref().unwrap_or_default())?;
            return Err(Error::Api {
                error_code: err.error_code,
                description: err.description.clone().unwrap_or_default(),
            });
        }

        if pt != res_type {
            return Err(Error::UnexpectedPayload(pt));
        }

        Ok(R::decode(
            response_envelope.payload.as_deref().unwrap_or_default(),
        )?)
    }

    /// Get the API version from the server.
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    ///
    pub async fn version(&self) -> Result<ProtoOaVersionRes, Error> {
        let req = ProtoOaVersionReq {
            payload_type: Some(payload::OA_VERSION_REQ as i32),
        };
        self.command(payload::OA_VERSION_REQ, req, payload::OA_VERSION_RES)
            .await
    }
}

#[cfg(test)]
mod tests {

    #[async_std::test]
    async fn test() {}
}