deribit_api/
lib.rs

1use futures_util::{SinkExt, Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
8use tokio::sync::{broadcast, mpsc, oneshot};
9use tokio_stream::wrappers::BroadcastStream;
10use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
11use tokio_tungstenite::connect_async;
12use tokio_tungstenite::tungstenite::Error as WSError;
13use tokio_tungstenite::tungstenite::Message;
14
15// Include the generated client code
16pub mod prod {
17    use serde::{Deserialize, Serialize};
18    use serde_json::Value;
19    include!(concat!(env!("OUT_DIR"), "/deribit_client_prod.rs"));
20}
21
22#[cfg(feature = "testnet")]
23pub mod testnet {
24    use serde::{Deserialize, Serialize};
25    use serde_json::Value;
26    include!(concat!(env!("OUT_DIR"), "/deribit_client_testnet.rs"));
27}
28
29// Default to prod at crate root
30pub use prod::*;
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RpcError {
33    pub code: i32,
34    pub message: String,
35    pub data: Option<Value>,
36}
37
38impl std::fmt::Display for RpcError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "RPC Error {}: {}", self.code, self.message)
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45enum JsonRpcVersion {
46    #[serde(rename = "2.0")]
47    V2,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51struct RpcRequest {
52    jsonrpc: JsonRpcVersion,
53    id: u64,
54    method: String,
55    params: Value,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct RpcResponseBase {
60    jsonrpc: JsonRpcVersion,
61    id: u64,
62    testnet: bool,
63    #[serde(rename = "usIn")]
64    us_in: u64,
65    #[serde(rename = "usOut")]
66    us_out: u64,
67    #[serde(rename = "usDiff")]
68    us_diff: u64,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct RpcOkResponse {
73    #[serde(flatten)]
74    base: RpcResponseBase,
75    result: Value,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct RpcErrorResponse {
80    #[serde(flatten)]
81    base: RpcResponseBase,
82    error: RpcError,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86struct SubscriptionParams {
87    channel: String,
88    data: Value,
89    label: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93enum SubscriptionMethod {
94    #[serde(rename = "subscription")]
95    Subscription,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99struct SubscriptionNotification {
100    jsonrpc: JsonRpcVersion,
101    method: SubscriptionMethod,
102    params: SubscriptionParams,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106enum HeartbeatType {
107    #[serde(rename = "heartbeat")]
108    Heartbeat,
109    #[serde(rename = "test_request")]
110    TestRequest,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct HeartbeatParams {
115    r#type: HeartbeatType,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119enum HeartbeatMethod {
120    #[serde(rename = "heartbeat")]
121    Heartbeat,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct Heartbeat {
126    jsonrpc: JsonRpcVersion,
127    method: HeartbeatMethod,
128    params: HeartbeatParams,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(untagged)]
133enum JsonRPCMessage {
134    Heartbeat(Heartbeat),
135    Notification(SubscriptionNotification),
136    OkResponse(RpcOkResponse),
137    ErrorResponse(RpcErrorResponse),
138}
139
140#[derive(Debug, thiserror::Error)]
141pub enum Error {
142    #[error("RPC error: {0}")]
143    RpcError(RpcError),
144    #[error("WebSocket error: {0}")]
145    WebSocketError(#[from] WSError),
146    #[error("JSON decode error: {0}")]
147    JsonError(#[from] serde_json::Error),
148    #[error("Invalid subscription channel: {0}")]
149    InvalidSubscriptionChannel(String),
150    #[error("Subscription messages lagged: {0}")]
151    SubscriptionLagged(u64),
152}
153
154type Result<T> = std::result::Result<T, Error>;
155
156// ApiRequest trait for all request types
157pub trait ApiRequest: serde::Serialize {
158    type Response: DeserializeOwned + Serialize;
159    fn method_name(&self) -> &'static str;
160
161    fn is_private(&self) -> bool {
162        self.method_name().starts_with("private/")
163    }
164
165    fn to_params(&self) -> Value {
166        serde_json::to_value(self).unwrap_or_default()
167    }
168}
169
170// Subscription trait implemented by generated channel structs
171pub trait Subscription {
172    type Data: DeserializeOwned + Serialize + Send + 'static;
173    fn channel_string(&self) -> String;
174}
175
176// Helper used by generated code to stringify subscription path parameters
177pub(crate) fn sub_param_to_string<T: Serialize>(value: &T) -> String {
178    let json = serde_json::to_value(value).unwrap_or(Value::Null);
179    match json {
180        Value::String(s) => s,
181        Value::Number(n) => n.to_string(),
182        Value::Bool(b) => b.to_string(),
183        _ => json.to_string(),
184    }
185}
186
187#[derive(Debug)]
188pub enum Env {
189    Production,
190    Testnet,
191}
192
193#[derive(Debug)]
194pub struct DeribitClient {
195    authenticated: AtomicBool,
196    id_counter: Arc<AtomicU64>,
197    request_channel: mpsc::Sender<(RpcRequest, oneshot::Sender<Result<Value>>)>,
198    subscription_channel: mpsc::Sender<(String, oneshot::Sender<broadcast::Receiver<Value>>)>,
199}
200
201impl DeribitClient {
202    pub async fn connect(env: Env) -> Result<Self> {
203        let ws_url = match env {
204            Env::Production => "wss://www.deribit.com/ws/api/v2",
205            Env::Testnet => "wss://test.deribit.com/ws/api/v2",
206        };
207
208        let (mut ws_stream, _) = connect_async(ws_url).await?;
209        let (request_tx, mut request_rx) =
210            mpsc::channel::<(RpcRequest, oneshot::Sender<Result<Value>>)>(100);
211        let (subscription_tx, mut subscription_rx) =
212            mpsc::channel::<(String, oneshot::Sender<broadcast::Receiver<Value>>)>(100);
213
214        let id_counter = Arc::new(AtomicU64::new(0));
215        let id_counter_clone = id_counter.clone();
216
217        tokio::spawn(async move {
218            let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value>>> = HashMap::new();
219            let mut subscribers: HashMap<String, broadcast::Sender<Value>> = HashMap::new();
220
221            loop {
222                tokio::select! {
223                    msg = ws_stream.next() => {
224                        match msg {
225                            Some(Ok(Message::Text(text))) => {
226                                match serde_json::from_str::<JsonRPCMessage>(&text) {
227                                    Ok(JsonRPCMessage::Heartbeat(heartbeat)) => {
228                                        if heartbeat.params.r#type == HeartbeatType::TestRequest {
229                                            let test_request = RpcRequest {
230                                                jsonrpc: JsonRpcVersion::V2,
231                                                id: id_counter_clone.fetch_add(1, Ordering::Relaxed),
232                                                method: "public/test".to_string(),
233                                                params: Value::Null,
234                                            };
235                                            ws_stream
236                                                .send(Message::Text(
237                                                    serde_json::to_string(&test_request).unwrap().into(),
238                                                ))
239                                                .await
240                                                .unwrap();
241                                        }
242                                    }
243                                    Ok(JsonRPCMessage::Notification(notification)) => {
244                                        if let Some(tx) = subscribers.get(&notification.params.channel)
245                                            && tx.send(notification.params.data.clone()).is_err()
246                                        {
247                                            subscribers.remove(&notification.params.channel);
248                                        }
249                                    }
250                                    Ok(JsonRPCMessage::OkResponse(response)) => {
251                                        let result = Ok(response.result);
252                                        if let Some(tx) = pending_requests.remove(&response.base.id) {
253                                            let _ = tx.send(result);
254                                        }
255                                    }
256                                    Ok(JsonRPCMessage::ErrorResponse(response)) => {
257                                        let error = Err(Error::RpcError(response.error));
258                                        if let Some(tx) = pending_requests.remove(&response.base.id) {
259                                            let _ = tx.send(error);
260                                        }
261                                    }
262                                    Err(e) => {
263                                        panic!("Received invalid json message: {e}\nOriginal message: {text}");
264                                    }
265                                }
266                            }
267                            Some(Ok(msg)) => {
268                                panic!("Received non-text message: {msg:?}");
269                            }
270                            Some(Err(e)) => {
271                                panic!("WebSocket error: {e:?}");
272                            }
273                            None => {
274                                panic!("WebSocket connection closed");
275                            }
276                        }
277                    }
278                    Some((request, tx)) = request_rx.recv() => {
279                        pending_requests.insert(request.id, tx);
280                        ws_stream
281                            .send(Message::Text(
282                                serde_json::to_string(&request).unwrap().into(),
283                            ))
284                            .await
285                            .unwrap();
286                    }
287                    Some((channel, oneshot_tx)) = subscription_rx.recv() => {
288                        if let Some(broadcast_tx) = subscribers.get(&channel) {
289                            let _ = oneshot_tx.send(broadcast_tx.subscribe());
290                        } else {
291                            let (broadcast_tx, broadcast_rx) = broadcast::channel(100);
292                            subscribers.insert(channel, broadcast_tx);
293                            let _ = oneshot_tx.send(broadcast_rx);
294                        }
295                    }
296                }
297            }
298        });
299
300        Ok(Self {
301            authenticated: AtomicBool::new(false),
302            id_counter,
303            request_channel: request_tx,
304            subscription_channel: subscription_tx,
305        })
306    }
307
308    fn next_id(&self) -> u64 {
309        self.id_counter.fetch_add(1, Ordering::Relaxed)
310    }
311
312    pub async fn call_raw(&self, method: &str, params: Value) -> Result<Value> {
313        let request = RpcRequest {
314            jsonrpc: JsonRpcVersion::V2,
315            id: self.next_id(),
316            method: method.to_string(),
317            params,
318        };
319
320        let (tx, rx) = oneshot::channel();
321
322        self.request_channel
323            .send((request, tx))
324            .await
325            .map_err(|_| WSError::ConnectionClosed)?;
326
327        let value = rx.await.map_err(|_| WSError::ConnectionClosed)??;
328
329        if method == "public/auth" {
330            self.authenticated.store(true, Ordering::Release);
331        }
332
333        Ok(value)
334    }
335
336    pub async fn call<T: ApiRequest>(&self, req: T) -> Result<T::Response> {
337        let value = self.call_raw(req.method_name(), req.to_params()).await?;
338        let typed: T::Response = serde_json::from_value(value)?;
339        Ok(typed)
340    }
341
342    pub async fn subscribe_raw(
343        &self,
344        channel: &str,
345    ) -> Result<impl Stream<Item = Result<Value>> + Send + 'static + use<>> {
346        let channels = vec![channel.to_string()];
347        let subscribed_channels = if self.authenticated.load(Ordering::Acquire) {
348            self.call(PrivateSubscribeRequest {
349                channels,
350                label: None,
351            })
352            .await?
353        } else {
354            self.call(PublicSubscribeRequest { channels }).await?
355        };
356        if let Some(channel) = subscribed_channels.first() {
357            let (tx, rx) = oneshot::channel();
358            self.subscription_channel
359                .send((channel.clone(), tx))
360                .await
361                .map_err(|_| WSError::ConnectionClosed)?;
362            let channel_rx = rx.await.map_err(|_| WSError::ConnectionClosed)?;
363            let stream = BroadcastStream::new(channel_rx).map(|msg| match msg {
364                Ok(msg) => Ok(msg),
365                Err(BroadcastStreamRecvError::Lagged(lag)) => Err(Error::SubscriptionLagged(lag)),
366            });
367            Ok(stream)
368        } else {
369            Err(Error::InvalidSubscriptionChannel(channel.to_string()))
370        }
371    }
372
373    // Typed subscription: accepts a generated Subscription and returns a typed broadcast receiver
374    pub async fn subscribe<S: Subscription + Send + 'static>(
375        &self,
376        subscription: S,
377    ) -> Result<impl Stream<Item = Result<S::Data>> + Send + 'static> {
378        let channel = subscription.channel_string();
379        let raw_stream = self.subscribe_raw(&channel).await?;
380        let typed_stream = raw_stream.map(|msg| match msg {
381            Ok(msg) => serde_json::from_value::<S::Data>(msg).map_err(Error::JsonError),
382            Err(e) => Err(e),
383        });
384        Ok(typed_stream)
385    }
386}