use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use prost::Message;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time;
use uuid::Uuid;
use crate::client_helper::{dispatch_loop, keepalive};
use crate::error::Error;
use crate::payload;
use crate::proto::common::{ProtoMessage, ProtoOaErrorRes, ProtoOaVersionReq, ProtoOaVersionRes};
use crate::transport::Transport;
pub type Registry = Arc<Mutex<HashMap<String, oneshot::Sender<ProtoMessage>>>>;
#[derive(Debug, Clone)]
pub struct Config {
pub client_id: String,
pub client_secret: String,
pub live: bool,
pub deadline: Duration,
}
impl Config {
pub fn new(client_id: impl Into<String>, client_secret: impl Into<String>) -> Self {
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
live: false,
deadline: Duration::from_secs(5),
}
}
pub fn live(mut self) -> Self {
self.live = true;
self
}
pub fn deadline(mut self, d: Duration) -> Self {
self.deadline = d;
self
}
}
pub struct Client {
pub transport: Arc<Transport>,
pub registry: Registry,
pub config: Config,
pub event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>>,
}
impl Client {
fn host(live: bool) -> &'static str {
if live {
"live.ctraderapi.com"
} else {
"demo.ctraderapi.com"
}
}
pub async fn start(config: Config) -> Result<Self, Error> {
Self::start_with_handler(config, None::<fn(ProtoMessage)>).await
}
pub async fn start_with_handler(
config: Config,
handler: Option<impl Fn(ProtoMessage) + Send + Sync + 'static>,
) -> Result<Self, Error> {
let host = Self::host(config.live);
let registry: Registry = Arc::new(Mutex::new(HashMap::new()));
let (frame_tx, frame_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let transport = Arc::new(Transport::connect(host, 5035, frame_tx).await?);
let event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>> =
handler.map(|h| Arc::new(h) as _);
{
let registry = registry.clone();
let event_handler = event_handler.clone();
tokio::spawn(dispatch_loop(frame_rx, registry, event_handler));
}
let client = Self {
transport,
registry,
config: config.clone(),
event_handler,
};
client.application_auth().await?;
{
let transport = client.transport.clone();
tokio::spawn(async move {
keepalive(transport).await;
});
}
Ok(client)
}
pub async fn command<Q, R>(&self, req_type: u32, req: Q, res_type: u32) -> Result<R, Error>
where
Q: Message,
R: Message + Default,
{
let id = Uuid::new_v4().to_string();
let mut payload_bytes = Vec::new();
req.encode(&mut payload_bytes)?;
let envelope = ProtoMessage {
payload_type: req_type,
payload: Some(payload_bytes),
client_msg_id: Some(id.clone()),
};
let mut frame = Vec::new();
envelope.encode(&mut frame)?;
let (tx, rx) = oneshot::channel::<ProtoMessage>();
{
let mut reg = self.registry.lock().await;
reg.insert(id.clone(), tx);
}
self.transport.send(&frame).await?;
let response_envelope = time::timeout(self.config.deadline, rx)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::Disconnected)?;
{
let mut reg = self.registry.lock().await;
reg.remove(&id);
}
let pt = response_envelope.payload_type;
if pt == payload::OA_ERROR_RES || pt == payload::ERROR_RES {
let err =
ProtoOaErrorRes::decode(response_envelope.payload.as_deref().unwrap_or_default())?;
return Err(Error::Api {
error_code: err.error_code,
description: err.description.clone().unwrap_or_default(),
});
}
if pt != res_type {
return Err(Error::UnexpectedPayload(pt));
}
Ok(R::decode(
response_envelope.payload.as_deref().unwrap_or_default(),
)?)
}
pub async fn version(&self) -> Result<ProtoOaVersionRes, Error> {
let req = ProtoOaVersionReq {
payload_type: Some(payload::OA_VERSION_REQ as i32),
};
self.command(payload::OA_VERSION_REQ, req, payload::OA_VERSION_RES)
.await
}
}
#[cfg(test)]
mod tests {
#[async_std::test]
async fn test() {}
}