1use futures_util::{SinkExt, Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
8use tokio::sync::{broadcast, mpsc, oneshot};
9use tokio_stream::wrappers::BroadcastStream;
10use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
11use tokio_tungstenite::connect_async;
12use tokio_tungstenite::tungstenite::Error as WSError;
13use tokio_tungstenite::tungstenite::Message;
14
15pub mod prod {
17 use serde::{Deserialize, Serialize};
18 use serde_json::Value;
19 include!(concat!(env!("OUT_DIR"), "/deribit_client_prod.rs"));
20}
21
22#[cfg(feature = "testnet")]
23pub mod testnet {
24 use serde::{Deserialize, Serialize};
25 use serde_json::Value;
26 include!(concat!(env!("OUT_DIR"), "/deribit_client_testnet.rs"));
27}
28
29pub use prod::*;
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RpcError {
33 pub code: i32,
34 pub message: String,
35 pub data: Option<Value>,
36}
37
38impl std::fmt::Display for RpcError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "RPC Error {}: {}", self.code, self.message)
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45enum JsonRpcVersion {
46 #[serde(rename = "2.0")]
47 V2,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51struct RpcRequest {
52 jsonrpc: JsonRpcVersion,
53 id: u64,
54 method: String,
55 params: Value,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct RpcResponseBase {
60 jsonrpc: JsonRpcVersion,
61 id: u64,
62 testnet: bool,
63 #[serde(rename = "usIn")]
64 us_in: u64,
65 #[serde(rename = "usOut")]
66 us_out: u64,
67 #[serde(rename = "usDiff")]
68 us_diff: u64,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct RpcOkResponse {
73 #[serde(flatten)]
74 base: RpcResponseBase,
75 result: Value,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct RpcErrorResponse {
80 #[serde(flatten)]
81 base: RpcResponseBase,
82 error: RpcError,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86struct SubscriptionParams {
87 channel: String,
88 data: Value,
89 label: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93enum SubscriptionMethod {
94 #[serde(rename = "subscription")]
95 Subscription,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99struct SubscriptionNotification {
100 jsonrpc: JsonRpcVersion,
101 method: SubscriptionMethod,
102 params: SubscriptionParams,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106enum HeartbeatType {
107 #[serde(rename = "heartbeat")]
108 Heartbeat,
109 #[serde(rename = "test_request")]
110 TestRequest,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct HeartbeatParams {
115 r#type: HeartbeatType,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119enum HeartbeatMethod {
120 #[serde(rename = "heartbeat")]
121 Heartbeat,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct Heartbeat {
126 jsonrpc: JsonRpcVersion,
127 method: HeartbeatMethod,
128 params: HeartbeatParams,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(untagged)]
133enum JsonRPCMessage {
134 Heartbeat(Heartbeat),
135 Notification(SubscriptionNotification),
136 OkResponse(RpcOkResponse),
137 ErrorResponse(RpcErrorResponse),
138}
139
140#[derive(Debug, thiserror::Error)]
141pub enum Error {
142 #[error("RPC error: {0}")]
143 RpcError(RpcError),
144 #[error("WebSocket error: {0}")]
145 WebSocketError(#[from] WSError),
146 #[error("JSON decode error: {0}")]
147 JsonError(#[from] serde_json::Error),
148 #[error("Invalid subscription channel: {0}")]
149 InvalidSubscriptionChannel(String),
150 #[error("Subscription messages lagged: {0}")]
151 SubscriptionLagged(u64),
152}
153
154type Result<T> = std::result::Result<T, Error>;
155
156pub trait ApiRequest: serde::Serialize {
158 type Response: DeserializeOwned + Serialize;
159 fn method_name(&self) -> &'static str;
160
161 fn is_private(&self) -> bool {
162 self.method_name().starts_with("private/")
163 }
164
165 fn to_params(&self) -> Value {
166 serde_json::to_value(self).unwrap_or_default()
167 }
168}
169
170pub trait Subscription {
172 type Data: DeserializeOwned + Serialize + Send + 'static;
173 fn channel_string(&self) -> String;
174}
175
176pub(crate) fn sub_param_to_string<T: Serialize>(value: &T) -> String {
178 let json = serde_json::to_value(value).unwrap_or(Value::Null);
179 match json {
180 Value::String(s) => s,
181 Value::Number(n) => n.to_string(),
182 Value::Bool(b) => b.to_string(),
183 _ => json.to_string(),
184 }
185}
186
187#[derive(Debug)]
188pub enum Env {
189 Production,
190 Testnet,
191}
192
193#[derive(Debug)]
194pub struct DeribitClient {
195 authenticated: AtomicBool,
196 id_counter: Arc<AtomicU64>,
197 request_channel: mpsc::Sender<(RpcRequest, oneshot::Sender<Result<Value>>)>,
198 subscription_channel: mpsc::Sender<(String, oneshot::Sender<broadcast::Receiver<Value>>)>,
199}
200
201impl DeribitClient {
202 pub async fn connect(env: Env) -> Result<Self> {
203 let ws_url = match env {
204 Env::Production => "wss://www.deribit.com/ws/api/v2",
205 Env::Testnet => "wss://test.deribit.com/ws/api/v2",
206 };
207
208 let (mut ws_stream, _) = connect_async(ws_url).await?;
209 let (request_tx, mut request_rx) =
210 mpsc::channel::<(RpcRequest, oneshot::Sender<Result<Value>>)>(100);
211 let (subscription_tx, mut subscription_rx) =
212 mpsc::channel::<(String, oneshot::Sender<broadcast::Receiver<Value>>)>(100);
213
214 let id_counter = Arc::new(AtomicU64::new(0));
215 let id_counter_clone = id_counter.clone();
216
217 tokio::spawn(async move {
218 let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value>>> = HashMap::new();
219 let mut subscribers: HashMap<String, broadcast::Sender<Value>> = HashMap::new();
220
221 loop {
222 tokio::select! {
223 msg = ws_stream.next() => {
224 match msg {
225 Some(Ok(Message::Text(text))) => {
226 match serde_json::from_str::<JsonRPCMessage>(&text) {
227 Ok(JsonRPCMessage::Heartbeat(heartbeat)) => {
228 if heartbeat.params.r#type == HeartbeatType::TestRequest {
229 let test_request = RpcRequest {
230 jsonrpc: JsonRpcVersion::V2,
231 id: id_counter_clone.fetch_add(1, Ordering::Relaxed),
232 method: "public/test".to_string(),
233 params: Value::Null,
234 };
235 ws_stream
236 .send(Message::Text(
237 serde_json::to_string(&test_request).unwrap().into(),
238 ))
239 .await
240 .unwrap();
241 }
242 }
243 Ok(JsonRPCMessage::Notification(notification)) => {
244 if let Some(tx) = subscribers.get(¬ification.params.channel)
245 && tx.send(notification.params.data.clone()).is_err()
246 {
247 subscribers.remove(¬ification.params.channel);
248 }
249 }
250 Ok(JsonRPCMessage::OkResponse(response)) => {
251 let result = Ok(response.result);
252 if let Some(tx) = pending_requests.remove(&response.base.id) {
253 let _ = tx.send(result);
254 }
255 }
256 Ok(JsonRPCMessage::ErrorResponse(response)) => {
257 let error = Err(Error::RpcError(response.error));
258 if let Some(tx) = pending_requests.remove(&response.base.id) {
259 let _ = tx.send(error);
260 }
261 }
262 Err(e) => {
263 panic!("Received invalid json message: {e}\nOriginal message: {text}");
264 }
265 }
266 }
267 Some(Ok(msg)) => {
268 panic!("Received non-text message: {msg:?}");
269 }
270 Some(Err(e)) => {
271 panic!("WebSocket error: {e:?}");
272 }
273 None => {
274 panic!("WebSocket connection closed");
275 }
276 }
277 }
278 Some((request, tx)) = request_rx.recv() => {
279 pending_requests.insert(request.id, tx);
280 ws_stream
281 .send(Message::Text(
282 serde_json::to_string(&request).unwrap().into(),
283 ))
284 .await
285 .unwrap();
286 }
287 Some((channel, oneshot_tx)) = subscription_rx.recv() => {
288 if let Some(broadcast_tx) = subscribers.get(&channel) {
289 let _ = oneshot_tx.send(broadcast_tx.subscribe());
290 } else {
291 let (broadcast_tx, broadcast_rx) = broadcast::channel(100);
292 subscribers.insert(channel, broadcast_tx);
293 let _ = oneshot_tx.send(broadcast_rx);
294 }
295 }
296 }
297 }
298 });
299
300 Ok(Self {
301 authenticated: AtomicBool::new(false),
302 id_counter,
303 request_channel: request_tx,
304 subscription_channel: subscription_tx,
305 })
306 }
307
308 fn next_id(&self) -> u64 {
309 self.id_counter.fetch_add(1, Ordering::Relaxed)
310 }
311
312 pub async fn call_raw(&self, method: &str, params: Value) -> Result<Value> {
313 let request = RpcRequest {
314 jsonrpc: JsonRpcVersion::V2,
315 id: self.next_id(),
316 method: method.to_string(),
317 params,
318 };
319
320 let (tx, rx) = oneshot::channel();
321
322 self.request_channel
323 .send((request, tx))
324 .await
325 .map_err(|_| WSError::ConnectionClosed)?;
326
327 let value = rx.await.map_err(|_| WSError::ConnectionClosed)??;
328
329 if method == "public/auth" {
330 self.authenticated.store(true, Ordering::Release);
331 }
332
333 Ok(value)
334 }
335
336 pub async fn call<T: ApiRequest>(&self, req: T) -> Result<T::Response> {
337 let value = self.call_raw(req.method_name(), req.to_params()).await?;
338 let typed: T::Response = serde_json::from_value(value)?;
339 Ok(typed)
340 }
341
342 pub async fn subscribe_raw(
343 &self,
344 channel: &str,
345 ) -> Result<impl Stream<Item = Result<Value>> + Send + 'static + use<>> {
346 let channels = vec![channel.to_string()];
347 let subscribed_channels = if self.authenticated.load(Ordering::Acquire) {
348 self.call(PrivateSubscribeRequest {
349 channels,
350 label: None,
351 })
352 .await?
353 } else {
354 self.call(PublicSubscribeRequest { channels }).await?
355 };
356 if let Some(channel) = subscribed_channels.first() {
357 let (tx, rx) = oneshot::channel();
358 self.subscription_channel
359 .send((channel.clone(), tx))
360 .await
361 .map_err(|_| WSError::ConnectionClosed)?;
362 let channel_rx = rx.await.map_err(|_| WSError::ConnectionClosed)?;
363 let stream = BroadcastStream::new(channel_rx).map(|msg| match msg {
364 Ok(msg) => Ok(msg),
365 Err(BroadcastStreamRecvError::Lagged(lag)) => Err(Error::SubscriptionLagged(lag)),
366 });
367 Ok(stream)
368 } else {
369 Err(Error::InvalidSubscriptionChannel(channel.to_string()))
370 }
371 }
372
373 pub async fn subscribe<S: Subscription + Send + 'static>(
375 &self,
376 subscription: S,
377 ) -> Result<impl Stream<Item = Result<S::Data>> + Send + 'static> {
378 let channel = subscription.channel_string();
379 let raw_stream = self.subscribe_raw(&channel).await?;
380 let typed_stream = raw_stream.map(|msg| match msg {
381 Ok(msg) => serde_json::from_value::<S::Data>(msg).map_err(Error::JsonError),
382 Err(e) => Err(e),
383 });
384 Ok(typed_stream)
385 }
386}