use std::time::Duration;
use std::{marker::PhantomData, sync::Arc, task::Poll};
use crate::core::ExcService;
use crate::{core::types::instrument::SubscribeInstruments, ExchangeError};
use exc_core::types::instrument::FetchInstruments;
use exc_core::{ExcLayer, ExcServiceExt};
use futures::{
future::{ready, BoxFuture},
FutureExt, TryFutureExt,
};
use tokio::task::JoinHandle;
use tower::util::BoxCloneService;
use tower::{util::BoxService, Layer, Service, ServiceBuilder, ServiceExt};
use self::{state::State, worker::Worker};
use self::options::InstrumentsOptions;
use super::{
request::{InstrumentsRequest, Kind},
response::InstrumentsResponse,
};
type SubscribeInstrumentSvc = BoxService<
SubscribeInstruments,
<SubscribeInstruments as crate::Request>::Response,
ExchangeError,
>;
type FetchInstrumentSvc =
BoxService<FetchInstruments, <FetchInstruments as crate::Request>::Response, ExchangeError>;
mod state;
mod worker;
pub mod options;
#[derive(Default)]
enum ServiceState {
Init(Worker),
Running(JoinHandle<Result<(), ExchangeError>>),
Closing(JoinHandle<Result<(), ExchangeError>>),
#[default]
Failed,
}
struct Inner {
state: Arc<State>,
svc_state: ServiceState,
}
impl Inner {
fn new(
opts: &InstrumentsOptions,
inst: SubscribeInstrumentSvc,
fetch: FetchInstrumentSvc,
) -> Self {
let state = Arc::default();
Self {
svc_state: ServiceState::Init(Worker::new(&state, opts, inst, fetch)),
state,
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
if let ServiceState::Running(handle) = std::mem::take(&mut self.svc_state) {
handle.abort();
}
}
}
impl Service<InstrumentsRequest> for Inner {
type Response = InstrumentsResponse;
type Error = ExchangeError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
match &mut self.svc_state {
ServiceState::Init(worker) => {
tracing::trace!("init; wait init");
futures::ready!(worker.poll_init(cx))?;
tracing::trace!("init; spawn worker task");
let ServiceState::Init(worker) = std::mem::take(&mut self.svc_state) else {
unreachable!();
};
let handle = tokio::spawn(
worker
.start()
.inspect_err(|err| tracing::error!(%err, "market worker error")),
);
self.svc_state = ServiceState::Running(handle);
break;
}
ServiceState::Running(handle) => {
if handle.is_finished() {
tracing::trace!("running; found finished");
let ServiceState::Running(handle) = std::mem::take(&mut self.svc_state)
else {
unreachable!()
};
self.svc_state = ServiceState::Closing(handle);
} else {
tracing::trace!("running; ready");
break;
}
}
ServiceState::Closing(handle) => {
tracing::trace!("closing; closing");
match handle.try_poll_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(res) => {
self.svc_state = ServiceState::Failed;
res.map_err(|err| ExchangeError::Other(err.into()))
.and_then(|res| res)?;
}
}
}
ServiceState::Failed => {
tracing::trace!("failed; failed");
return Poll::Ready(Err(ExchangeError::Other(anyhow::anyhow!(
"market worker dead"
))));
}
}
}
Poll::Ready(Ok(()))
}
fn call(&mut self, req: InstrumentsRequest) -> Self::Future {
match req.kind() {
Kind::GetInstrument(req) => {
let meta = self.state.clone().get_instrument(req);
ready(Ok(InstrumentsResponse::from(meta))).boxed()
}
}
}
}
#[derive(Debug, Clone)]
pub struct Instruments {
inner: BoxCloneService<InstrumentsRequest, InstrumentsResponse, ExchangeError>,
}
impl Service<InstrumentsRequest> for Instruments {
type Response = InstrumentsResponse;
type Error = ExchangeError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
#[inline]
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Service::poll_ready(&mut self.inner, cx)
}
#[inline]
fn call(&mut self, req: InstrumentsRequest) -> Self::Future {
Service::call(&mut self.inner, req)
}
}
#[derive(Debug, Clone)]
pub struct InstrumentsLayer<Req, L1 = ExcLayer<Req>, L2 = ExcLayer<Req>> {
opts: InstrumentsOptions,
fetch_instruments: L1,
subscribe_instruments: L2,
_req: PhantomData<fn() -> Req>,
}
impl<Req> InstrumentsLayer<Req> {
pub fn new(inst_tags: &[&str]) -> Self {
let opts = InstrumentsOptions::default().tags(inst_tags);
Self {
opts,
fetch_instruments: ExcLayer::default(),
subscribe_instruments: ExcLayer::default(),
_req: PhantomData,
}
}
pub fn with_buffer_bound(inst_tags: &[&str], bound: usize) -> Self {
let opts = InstrumentsOptions::default()
.tags(inst_tags)
.buffer_bound(bound);
Self {
opts,
fetch_instruments: ExcLayer::default(),
subscribe_instruments: ExcLayer::default(),
_req: PhantomData,
}
}
}
impl<Req, L1, L2> InstrumentsLayer<Req, L1, L2> {
pub fn fetch_instruments<L>(self, layer: L) -> InstrumentsLayer<Req, L, L2> {
InstrumentsLayer {
opts: self.opts,
fetch_instruments: layer,
subscribe_instruments: self.subscribe_instruments,
_req: PhantomData,
}
}
pub fn subscribe_instruments<L>(self, layer: L) -> InstrumentsLayer<Req, L1, L> {
InstrumentsLayer {
opts: self.opts,
fetch_instruments: self.fetch_instruments,
subscribe_instruments: layer,
_req: PhantomData,
}
}
pub fn set_fetch_rate_limit(&mut self, num: u64, dur: Duration) -> &mut Self {
self.opts.fetch_rate_limit = (num, dur);
self
}
pub fn set_subscribe_rate_limit(&mut self, num: u64, dur: Duration) -> &mut Self {
self.opts.subscribe_rate_limit = (num, dur);
self
}
}
impl<S, Req, L1: Layer<S>, L2: Layer<S>> Layer<S> for InstrumentsLayer<Req, L1, L2>
where
S: ExcService<Req> + Send + 'static + Clone,
S::Future: Send + 'static,
Req: crate::Request + 'static,
L1::Service: ExcService<FetchInstruments> + Send + 'static,
<L1::Service as ExcService<FetchInstruments>>::Future: Send,
L2::Service: ExcService<SubscribeInstruments> + Send + 'static,
<L2::Service as ExcService<SubscribeInstruments>>::Future: Send,
{
type Service = Instruments;
fn layer(&self, svc: S) -> Self::Service {
let fetch = ServiceBuilder::default()
.rate_limit(self.opts.fetch_rate_limit.0, self.opts.fetch_rate_limit.1)
.layer_fn(|svc: L1::Service| svc.into_service())
.layer(&self.fetch_instruments)
.service(svc.clone());
let subscribe = ServiceBuilder::default()
.rate_limit(
self.opts.subscribe_rate_limit.0,
self.opts.subscribe_rate_limit.1,
)
.layer_fn(|svc: L2::Service| svc.into_service())
.layer(&self.subscribe_instruments)
.service(svc);
let svc = Inner::new(
&self.opts,
ServiceExt::boxed(subscribe),
ServiceExt::boxed(fetch),
);
let inner = ServiceBuilder::default()
.buffer(self.opts.buffer_bound)
.service(svc)
.map_err(|err| ExchangeError::from(err).flatten())
.boxed_clone();
Instruments { inner }
}
}