deribit_api/
lib.rs

1use futures_util::{SinkExt, Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tokio_stream::wrappers::BroadcastStream;
9use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
10use tokio_tungstenite::connect_async;
11use tokio_tungstenite::tungstenite::Error as WSError;
12use tokio_tungstenite::tungstenite::Message;
13
14// Include the generated client code
15pub mod prod {
16    use serde::{Deserialize, Serialize};
17    use serde_json::Value;
18    include!(concat!(env!("OUT_DIR"), "/deribit_client_prod.rs"));
19}
20
21#[cfg(feature = "testnet")]
22pub mod testnet {
23    use serde::{Deserialize, Serialize};
24    use serde_json::Value;
25    include!(concat!(env!("OUT_DIR"), "/deribit_client_testnet.rs"));
26}
27
28// Default to prod at crate root
29pub use prod::*;
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RpcError {
32    pub code: i32,
33    pub message: String,
34    pub data: Option<Value>,
35}
36
37impl std::fmt::Display for RpcError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "RPC Error {}: {}", self.code, self.message)
40    }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct RpcRequest {
45    jsonrpc: String,
46    id: u64,
47    method: String,
48    params: Value,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct RpcResponseBase {
53    jsonrpc: String,
54    id: u64,
55    testnet: bool,
56    #[serde(rename = "usIn")]
57    us_in: u64,
58    #[serde(rename = "usOut")]
59    us_out: u64,
60    #[serde(rename = "usDiff")]
61    us_diff: u64,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65struct RpcOkResponse {
66    #[serde(flatten)]
67    base: RpcResponseBase,
68    result: Value,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct RpcErrorResponse {
73    #[serde(flatten)]
74    base: RpcResponseBase,
75    error: RpcError,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct SubscriptionParams {
80    channel: String,
81    data: Value,
82    label: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86struct SubscriptionNotification {
87    jsonrpc: String,
88    method: String,
89    params: SubscriptionParams,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(untagged)]
94enum JsonRPCMessage {
95    Notification(SubscriptionNotification),
96    OkResponse(RpcOkResponse),
97    ErrorResponse(RpcErrorResponse),
98}
99
100#[derive(Debug, thiserror::Error)]
101pub enum Error {
102    #[error("RPC error: {0}")]
103    RpcError(RpcError),
104    #[error("WebSocket error: {0}")]
105    WebSocketError(#[from] WSError),
106    #[error("JSON decode error: {0}")]
107    JsonError(#[from] serde_json::Error),
108    #[error("Invalid subscription channel: {0}")]
109    InvalidSubscriptionChannel(String),
110    #[error("Subscription messages lagged: {0}")]
111    SubscriptionLagged(u64),
112}
113
114pub type Result<T> = std::result::Result<T, Error>;
115
116// ApiRequest trait for all request types
117pub trait ApiRequest: serde::Serialize {
118    type Response: DeserializeOwned + Serialize;
119    fn method_name(&self) -> &'static str;
120
121    fn is_private(&self) -> bool {
122        self.method_name().starts_with("private/")
123    }
124
125    fn to_params(&self) -> Value {
126        serde_json::to_value(self).unwrap_or_default()
127    }
128}
129
130// Subscription trait implemented by generated channel structs
131pub trait Subscription {
132    type Data: DeserializeOwned + Serialize + Send + 'static;
133    fn channel_string(&self) -> String;
134}
135
136// Helper used by generated code to stringify subscription path parameters
137pub(crate) fn sub_param_to_string<T: Serialize>(value: &T) -> String {
138    // For enums and basic scalars we can serialize to JSON and strip quotes if needed
139    // This keeps parity with channel naming used by Deribit (e.g., "BTC-...", "agg2", numbers)
140    let json = serde_json::to_value(value).unwrap_or(Value::Null);
141    match json {
142        Value::String(s) => s,
143        Value::Number(n) => n.to_string(),
144        Value::Bool(b) => b.to_string(),
145        // Fallback: compact JSON string
146        _ => json.to_string(),
147    }
148}
149
150#[derive(Debug)]
151pub enum Env {
152    Production,
153    Testnet,
154}
155
156#[derive(Debug)]
157pub struct DeribitClient {
158    authenticated: AtomicBool,
159    id_counter: AtomicU64,
160    request_channel: mpsc::Sender<(RpcRequest, oneshot::Sender<Result<Value>>)>,
161    subscription_channel: mpsc::Sender<(String, oneshot::Sender<broadcast::Receiver<Value>>)>,
162    // typed_subscription_channel: mpsc::Sender<(String, oneshot::Sender<broadcast::Receiver<Value>>)>,
163}
164
165impl DeribitClient {
166    pub async fn connect(env: Env) -> Result<Self> {
167        let ws_url = match env {
168            Env::Production => "wss://www.deribit.com/ws/api/v2",
169            Env::Testnet => "wss://test.deribit.com/ws/api/v2",
170        };
171
172        let (mut ws_stream, _) = connect_async(ws_url).await?;
173        let (request_tx, mut request_rx) =
174            mpsc::channel::<(RpcRequest, oneshot::Sender<Result<Value>>)>(100);
175        let (subscription_tx, mut subscription_rx) =
176            mpsc::channel::<(String, oneshot::Sender<broadcast::Receiver<Value>>)>(100);
177
178        tokio::spawn(async move {
179            let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value>>> = HashMap::new();
180            let mut subscribers: HashMap<String, broadcast::Sender<Value>> = HashMap::new();
181
182            loop {
183                tokio::select! {
184                    msg = ws_stream.next() => {
185                        match msg {
186                            Some(Ok(Message::Text(text))) => {
187                                match serde_json::from_str::<JsonRPCMessage>(&text) {
188                                    Ok(JsonRPCMessage::Notification(notification)) => {
189                                        if let Some(tx) = subscribers.get(&notification.params.channel)
190                                            && tx.send(notification.params.data.clone()).is_err()
191                                        {
192                                            subscribers.remove(&notification.params.channel);
193                                        }
194                                    }
195                                    Ok(JsonRPCMessage::OkResponse(response)) => {
196                                        let result = Ok(response.result);
197                                        if let Some(tx) = pending_requests.remove(&response.base.id) {
198                                            let _ = tx.send(result);
199                                        }
200                                    }
201                                    Ok(JsonRPCMessage::ErrorResponse(response)) => {
202                                        let error = Err(Error::RpcError(response.error));
203                                        if let Some(tx) = pending_requests.remove(&response.base.id) {
204                                            let _ = tx.send(error);
205                                        }
206                                    }
207                                    Err(e) => {
208                                        panic!("Received invalid json message: {e}");
209                                    }
210                                }
211                            }
212                            Some(Ok(msg)) => {
213                                panic!("Received non-text message: {msg:?}");
214                            }
215                            Some(Err(e)) => {
216                                panic!("WebSocket error: {e:?}");
217                            }
218                            None => {
219                                panic!("WebSocket connection closed");
220                            }
221                        }
222                    }
223                    Some((request, tx)) = request_rx.recv() => {
224                        pending_requests.insert(request.id, tx);
225                        ws_stream
226                            .send(Message::Text(
227                                serde_json::to_string(&request).unwrap().into(),
228                            ))
229                            .await
230                            .unwrap();
231                    }
232                    Some((channel, oneshot_tx)) = subscription_rx.recv() => {
233                        if let Some(broadcast_tx) = subscribers.get(&channel) {
234                            let _ = oneshot_tx.send(broadcast_tx.subscribe());
235                        } else {
236                            let (broadcast_tx, broadcast_rx) = broadcast::channel(100);
237                            subscribers.insert(channel, broadcast_tx);
238                            let _ = oneshot_tx.send(broadcast_rx);
239                        }
240                    }
241                }
242            }
243        });
244
245        Ok(Self {
246            authenticated: AtomicBool::new(false),
247            id_counter: AtomicU64::new(0),
248            request_channel: request_tx,
249            subscription_channel: subscription_tx,
250            // typed_subscription_channel: typed_subscription_tx,
251        })
252    }
253
254    fn next_id(&self) -> u64 {
255        self.id_counter.fetch_add(1, Ordering::Relaxed)
256    }
257
258    pub async fn call_raw(&self, method: &str, params: Value) -> Result<Value> {
259        let request = RpcRequest {
260            jsonrpc: "2.0".to_string(),
261            id: self.next_id(),
262            method: method.to_string(),
263            params,
264        };
265
266        let (tx, rx) = oneshot::channel();
267
268        self.request_channel
269            .send((request, tx))
270            .await
271            .map_err(|_| WSError::ConnectionClosed)?;
272
273        let value = rx.await.map_err(|_| WSError::ConnectionClosed)??;
274
275        if method == "public/auth" {
276            self.authenticated.store(true, Ordering::Release);
277        }
278
279        Ok(value)
280    }
281
282    pub async fn call<T: ApiRequest>(&self, req: T) -> Result<T::Response> {
283        let value = self.call_raw(req.method_name(), req.to_params()).await?;
284        let typed: T::Response = serde_json::from_value(value)?;
285        Ok(typed)
286    }
287
288    pub async fn subscribe_raw(
289        &self,
290        channel: &str,
291    ) -> Result<impl Stream<Item = Result<Value>> + Send + 'static + use<>> {
292        let channels = vec![channel.to_string()];
293        let subscribed_channels = if self.authenticated.load(Ordering::Acquire) {
294            self.call(PrivateSubscribeRequest {
295                channels,
296                label: None,
297            })
298            .await?
299        } else {
300            self.call(PublicSubscribeRequest { channels }).await?
301        };
302        if let Some(channel) = subscribed_channels.first() {
303            let (tx, rx) = oneshot::channel();
304            self.subscription_channel
305                .send((channel.clone(), tx))
306                .await
307                .map_err(|_| WSError::ConnectionClosed)?;
308            let channel_rx = rx.await.map_err(|_| WSError::ConnectionClosed)?;
309            let stream = BroadcastStream::new(channel_rx).map(|msg| match msg {
310                Ok(msg) => Ok(msg),
311                Err(BroadcastStreamRecvError::Lagged(lag)) => Err(Error::SubscriptionLagged(lag)),
312            });
313            Ok(stream)
314        } else {
315            Err(Error::InvalidSubscriptionChannel(channel.to_string()))
316        }
317    }
318
319    // Typed subscription: accepts a generated Subscription and returns a typed broadcast receiver
320    pub async fn subscribe<S: Subscription + Send + 'static>(
321        &self,
322        subscription: S,
323    ) -> Result<impl Stream<Item = Result<S::Data>> + Send + 'static> {
324        let channel = subscription.channel_string();
325        let raw_stream = self.subscribe_raw(&channel).await?;
326        let typed_stream = raw_stream.map(|msg| match msg {
327            Ok(msg) => serde_json::from_value::<S::Data>(msg).map_err(Error::JsonError),
328            Err(e) => Err(e),
329        });
330        Ok(typed_stream)
331    }
332}