use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::types::candles::{CandleData, Timeframe};
use crate::types::market::{L2BookUpdate, MarketStatsUpdate};
use crate::types::trader::TraderStateServerMessage;
use crate::types::trades::{TradesMessage, TradesSubscriptionRequest};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct FundingRateSubscriptionRequest {
pub symbol: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct OrderbookSubscriptionRequest {
pub symbol: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bypass_execution_band: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct TraderStateSubscriptionRequest {
pub authority: String,
pub trader_pda_index: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct MarketSubscriptionRequest {
pub symbol: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct CandlesSubscriptionRequest {
pub symbol: String,
pub timeframe: Timeframe,
}
#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
#[serde(tag = "channel")]
pub enum SubscriptionRequest {
#[serde(rename = "allMids")]
AllMids,
#[serde(rename = "fundingRate")]
FundingRate(FundingRateSubscriptionRequest),
#[serde(rename = "orderbook")]
Orderbook(OrderbookSubscriptionRequest),
#[serde(rename = "traderState")]
TraderState(TraderStateSubscriptionRequest),
#[serde(rename = "market")]
Market(MarketSubscriptionRequest),
#[serde(rename = "trades")]
Trades(TradesSubscriptionRequest),
#[serde(rename = "candles")]
Candles(CandlesSubscriptionRequest),
#[serde(other)]
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum ClientMessage {
#[serde(rename = "subscribe")]
Subscribe { subscription: SubscriptionRequest },
#[serde(rename = "unsubscribe")]
Unsubscribe { subscription: SubscriptionRequest },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AllMidsData {
pub mids: HashMap<String, f64>,
pub slot: u64,
pub slot_index: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FundingRateMessage {
pub symbol: String,
pub funding: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "channel")]
#[serde(rename_all = "camelCase")]
pub enum ServerMessage {
#[serde(rename = "allMids")]
AllMids(AllMidsData),
#[serde(rename = "fundingRate")]
FundingRate(FundingRateMessage),
#[serde(rename = "orderbook")]
Orderbook(L2BookUpdate),
#[serde(rename = "traderState")]
TraderState(TraderStateServerMessage),
#[serde(rename = "market")]
Market(MarketStatsUpdate),
#[serde(rename = "trades")]
Trades(TradesMessage),
#[serde(rename = "candle", alias = "candles")]
Candles(CandleData),
#[serde(rename = "error")]
Error(ErrorMessage),
#[serde(other)]
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename = "subscriptionConfirmed")]
pub struct SubscriptionConfirmedMessage {
pub subscription: SubscriptionRequest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename = "subscriptionError")]
pub struct SubscriptionErrorMessage {
pub subscription: SubscriptionRequest,
pub code: String,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ErrorMessage {
pub error: String,
pub code: u16,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_client_message() {
let json = r#"{
"type": "subscribe",
"subscription": {
"channel": "traderState",
"authority": "ABC123",
"traderPdaIndex": 0
}
}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Subscribe { .. }));
}
#[test]
fn test_serialize_client_message() {
let msg = ClientMessage::Subscribe {
subscription: SubscriptionRequest::TraderState(TraderStateSubscriptionRequest {
authority: "ABC123".to_string(),
trader_pda_index: 0,
}),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("subscribe"));
assert!(json.contains("traderState"));
}
#[test]
fn test_orderbook_subscription_request() {
let msg = ClientMessage::Subscribe {
subscription: SubscriptionRequest::Orderbook(OrderbookSubscriptionRequest {
symbol: "SOL".to_string(),
bypass_execution_band: None,
}),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("subscribe"));
assert!(json.contains("orderbook"));
assert!(json.contains("SOL"));
}
#[test]
fn test_deserialize_orderbook_server_message() {
let json = r#"{
"channel": "orderbook",
"symbol": "SOL",
"orderbook": {
"bids": [[150.25, 100.0], [150.20, 200.0]],
"asks": [[150.30, 150.0], [150.35, 250.0]],
"mid": 150.275
}
}"#;
let msg: ServerMessage = serde_json::from_str(json).unwrap();
if let ServerMessage::Orderbook(update) = msg {
assert_eq!(update.symbol, "SOL");
assert_eq!(update.orderbook.bids.len(), 2);
assert_eq!(update.orderbook.asks.len(), 2);
assert_eq!(update.orderbook.mid, Some(150.275));
} else {
panic!("Expected Orderbook message");
}
}
#[test]
fn test_deserialize_funding_rate_server_message() {
let json = r#"{
"channel": "fundingRate",
"symbol": "SOL",
"funding": 0.0125
}"#;
let msg: ServerMessage = serde_json::from_str(json).unwrap();
if let ServerMessage::FundingRate(update) = msg {
assert_eq!(update.symbol, "SOL");
assert_eq!(update.funding, 0.0125);
} else {
panic!("Expected FundingRate message");
}
}
#[test]
fn test_deserialize_trades_server_message() {
let json = r#"{
"channel": "trades",
"symbol": "SOL",
"trades": [{
"slot": "123456789",
"slotIndex": 5,
"timestamp": "1775578550",
"symbol": "SOL",
"taker": "ABC123pubkey",
"tradeSequenceNumber": "100",
"side": "bid",
"baseLotsFilled": "1000",
"quoteLotsFilled": "150000",
"feeInQuoteLots": "30",
"baseAmount": 10.0,
"quoteAmount": 1500.0,
"numFills": 2
}]
}"#;
let msg: ServerMessage = serde_json::from_str(json).unwrap();
if let ServerMessage::Trades(update) = msg {
assert_eq!(update.symbol, "SOL");
assert_eq!(update.trades.len(), 1);
} else {
panic!("Expected Trades message");
}
}
#[test]
fn test_serialize_candles_subscription_request() {
let req = CandlesSubscriptionRequest {
symbol: "SOL".to_string(),
timeframe: Timeframe::Minute1,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"symbol\":\"SOL\""));
assert!(json.contains("\"timeframe\":\"1m\""));
}
#[test]
fn test_deserialize_candle_server_message() {
let json = r#"{
"channel": "candle",
"symbol": "SOL",
"timeframe": "1m",
"candle": {
"time": 1776801600,
"open": 85.0,
"high": 85.5,
"low": 84.9,
"close": 85.2
}
}"#;
let msg: ServerMessage = serde_json::from_str(json).unwrap();
if let ServerMessage::Candles(update) = msg {
assert_eq!(update.symbol, "SOL");
assert_eq!(update.timeframe, "1m");
assert_eq!(update.candle.close, 85.2);
} else {
panic!("Expected Candles message");
}
}
}