use kraken_types::{Channel, Depth, SubscribeParams, SubscribeRequest};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct Subscription {
pub channel: Channel,
pub symbols: Vec<String>,
pub depth: Option<Depth>,
pub snapshot: bool,
}
impl Subscription {
pub fn new(channel: Channel, symbols: Vec<String>) -> Self {
Self {
channel,
symbols,
depth: None,
snapshot: true,
}
}
pub fn orderbook(symbols: Vec<String>, depth: Depth) -> Self {
Self {
channel: Channel::Book,
symbols,
depth: Some(depth),
snapshot: true,
}
}
pub fn ticker(symbols: Vec<String>) -> Self {
Self {
channel: Channel::Ticker,
symbols,
depth: None,
snapshot: true,
}
}
pub fn trade(symbols: Vec<String>) -> Self {
Self {
channel: Channel::Trade,
symbols,
depth: None,
snapshot: true,
}
}
pub fn level3(symbols: Vec<String>) -> Self {
Self {
channel: Channel::Level3,
symbols,
depth: None,
snapshot: true,
}
}
pub fn to_request(&self, req_id: Option<u64>) -> SubscribeRequest {
let params = match self.channel {
Channel::Book => SubscribeParams::book(self.symbols.clone(), self.depth.unwrap_or(Depth::D10)),
Channel::Ticker => SubscribeParams::ticker(self.symbols.clone()),
Channel::Trade => SubscribeParams::trade(self.symbols.clone()),
_ => SubscribeParams {
channel: self.channel,
symbol: self.symbols.clone(),
depth: self.depth.map(|d| d.as_u32()),
snapshot: Some(self.snapshot),
interval: None,
event_trigger: None,
token: None,
},
};
SubscribeRequest {
method: "subscribe",
params,
req_id,
}
}
}
#[derive(Debug, Default)]
pub struct SubscriptionManager {
subscriptions: Vec<Subscription>,
pending: HashSet<u64>,
next_req_id: u64,
}
impl SubscriptionManager {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, sub: Subscription) -> u64 {
let req_id = self.next_req_id;
self.next_req_id += 1;
self.pending.insert(req_id);
self.subscriptions.push(sub);
req_id
}
pub fn confirm(&mut self, req_id: u64) {
self.pending.remove(&req_id);
}
pub fn reject(&mut self, req_id: u64) {
self.pending.remove(&req_id);
}
pub fn all(&self) -> &[Subscription] {
&self.subscriptions
}
pub fn count(&self) -> usize {
self.subscriptions.len()
}
pub fn clear(&mut self) {
self.subscriptions.clear();
self.pending.clear();
}
pub fn has_pending(&self) -> bool {
!self.pending.is_empty()
}
pub fn restoration_requests(&mut self) -> Vec<(u64, SubscribeRequest)> {
let mut requests = Vec::new();
for sub in &self.subscriptions {
let req_id = self.next_req_id;
self.next_req_id += 1;
self.pending.insert(req_id);
requests.push((req_id, sub.to_request(Some(req_id))));
}
requests
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_creation() {
let sub = Subscription::orderbook(vec!["BTC/USD".to_string()], Depth::D10);
assert_eq!(sub.channel, Channel::Book);
assert_eq!(sub.depth, Some(Depth::D10));
assert!(sub.snapshot);
}
#[test]
fn test_subscription_manager() {
let mut manager = SubscriptionManager::new();
let sub1 = Subscription::ticker(vec!["BTC/USD".to_string()]);
let req_id1 = manager.add(sub1);
let sub2 = Subscription::orderbook(vec!["ETH/USD".to_string()], Depth::D10);
let req_id2 = manager.add(sub2);
assert_eq!(manager.count(), 2);
assert!(manager.has_pending());
manager.confirm(req_id1);
manager.confirm(req_id2);
assert!(!manager.has_pending());
}
}