use std::collections::HashMap;
use serde_json::Value;
use crate::core::types::{AccountType, StreamEvent, WebSocketError, WebSocketResult};
use super::stream_kind::StreamKind;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TopicKey(pub String);
impl TopicKey {
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for TopicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TopicPattern(pub String);
impl TopicPattern {
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
pub fn matches(&self, key: &TopicKey) -> bool {
topic_pattern_matches(&self.0, &key.0)
}
}
impl std::fmt::Display for TopicPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
pub fn topic_pattern_matches(pattern: &str, key: &str) -> bool {
match pattern.find('*') {
None => pattern == key,
Some(star_pos) => {
let prefix = &pattern[..star_pos];
let suffix = &pattern[star_pos + 1..];
if !key.starts_with(prefix) {
return false;
}
let rest = &key[prefix.len()..];
if suffix.is_empty() {
true
} else {
rest.ends_with(suffix) && rest.len() >= suffix.len()
}
}
}
}
pub type ParserFn = fn(&Value) -> WebSocketResult<StreamEvent>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RegistryKey {
pub kind: StreamKind,
pub account_type: AccountType,
}
#[derive(Clone)]
pub struct RegistryEntry {
pub pattern: TopicPattern,
pub parser: ParserFn,
}
pub struct TopicRegistry {
entries: HashMap<RegistryKey, RegistryEntry>,
dispatch: Vec<(TopicPattern, ParserFn)>,
}
impl TopicRegistry {
pub fn builder() -> TopicRegistryBuilder {
TopicRegistryBuilder::default()
}
pub fn dispatch(&self, key: &TopicKey) -> Option<ParserFn> {
for (pattern, parser) in &self.dispatch {
if pattern.matches(key) {
return Some(*parser);
}
}
None
}
pub fn supports(&self, kind: &StreamKind, account: AccountType) -> bool {
let key = RegistryKey {
kind: kind.clone(),
account_type: account,
};
self.entries.contains_key(&key)
}
pub fn native_pairs(&self) -> impl Iterator<Item = (&StreamKind, AccountType)> + '_ {
self.entries
.keys()
.map(|k| (&k.kind, k.account_type))
}
pub fn entries(&self) -> &HashMap<RegistryKey, RegistryEntry> {
&self.entries
}
}
#[derive(Default)]
pub struct TopicRegistryBuilder {
entries: Vec<(RegistryKey, RegistryEntry)>,
}
impl TopicRegistryBuilder {
pub fn register(
mut self,
kind: StreamKind,
account_type: AccountType,
pattern: impl Into<String>,
parser: ParserFn,
) -> Self {
let key = RegistryKey { kind, account_type };
let entry = RegistryEntry {
pattern: TopicPattern::new(pattern),
parser,
};
self.entries.push((key, entry));
self
}
pub fn build(self) -> TopicRegistry {
let mut dispatch: Vec<(TopicPattern, ParserFn)> = Vec::new();
let mut map: HashMap<RegistryKey, RegistryEntry> = HashMap::new();
for (key, entry) in self.entries {
dispatch.push((entry.pattern.clone(), entry.parser));
map.insert(key, entry);
}
TopicRegistry {
entries: map,
dispatch,
}
}
}
pub fn missing_field(field: &str) -> WebSocketError {
WebSocketError::Parse(format!("missing field: {}", field))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exact_match() {
assert!(topic_pattern_matches("spot.trades", "spot.trades"));
assert!(!topic_pattern_matches("spot.trades", "spot.trade"));
}
#[test]
fn suffix_wildcard() {
assert!(topic_pattern_matches("*@trade", "btcusdt@trade"));
assert!(topic_pattern_matches("*@trade", "ethusdt@trade"));
assert!(!topic_pattern_matches("*@trade", "btcusdt@kline_1m"));
}
#[test]
fn prefix_wildcard() {
assert!(topic_pattern_matches("publicTrade.*", "publicTrade.BTCUSDT"));
assert!(topic_pattern_matches("publicTrade.*", "publicTrade.ETHUSDT"));
assert!(!topic_pattern_matches("publicTrade.*", "orderbook.BTCUSDT"));
}
#[test]
fn mid_wildcard() {
assert!(topic_pattern_matches(
"market.*.trade.detail",
"market.BTC-USDT.trade.detail"
));
assert!(!topic_pattern_matches(
"market.*.trade.detail",
"market.BTC-USDT.depth"
));
}
#[test]
fn topic_key_display() {
let key = TopicKey::new("btcusdt@trade");
assert_eq!(key.to_string(), "btcusdt@trade");
}
#[test]
fn registry_dispatch() {
fn dummy_parser(_v: &Value) -> WebSocketResult<StreamEvent> {
Err(WebSocketError::Parse("test".into()))
}
let registry = TopicRegistry::builder()
.register(
StreamKind::Trade,
AccountType::Spot,
"*@trade",
dummy_parser,
)
.build();
let key = TopicKey::new("btcusdt@trade");
assert!(registry.dispatch(&key).is_some());
let miss = TopicKey::new("btcusdt@kline_1m");
assert!(registry.dispatch(&miss).is_none());
}
#[test]
fn registry_supports() {
fn dummy_parser(_v: &Value) -> WebSocketResult<StreamEvent> {
Err(WebSocketError::Parse("test".into()))
}
let registry = TopicRegistry::builder()
.register(
StreamKind::Trade,
AccountType::Spot,
"*@trade",
dummy_parser,
)
.build();
assert!(registry.supports(&StreamKind::Trade, AccountType::Spot));
assert!(!registry.supports(&StreamKind::Ticker, AccountType::Spot));
}
}