1use futures_util::{SinkExt, Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tokio_stream::wrappers::BroadcastStream;
9use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
10use tokio_tungstenite::connect_async;
11use tokio_tungstenite::tungstenite::Error as WSError;
12use tokio_tungstenite::tungstenite::Message;
13
14pub mod prod {
16 use serde::{Deserialize, Serialize};
17 use serde_json::Value;
18 include!(concat!(env!("OUT_DIR"), "/deribit_client_prod.rs"));
19}
20
21#[cfg(feature = "testnet")]
22pub mod testnet {
23 use serde::{Deserialize, Serialize};
24 use serde_json::Value;
25 include!(concat!(env!("OUT_DIR"), "/deribit_client_testnet.rs"));
26}
27
28pub use prod::*;
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct RpcError {
32 pub code: i32,
33 pub message: String,
34 pub data: Option<Value>,
35}
36
37impl std::fmt::Display for RpcError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "RPC Error {}: {}", self.code, self.message)
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct RpcRequest {
45 jsonrpc: String,
46 id: u64,
47 method: String,
48 params: Value,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct RpcResponseBase {
53 jsonrpc: String,
54 id: u64,
55 testnet: bool,
56 #[serde(rename = "usIn")]
57 us_in: u64,
58 #[serde(rename = "usOut")]
59 us_out: u64,
60 #[serde(rename = "usDiff")]
61 us_diff: u64,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65struct RpcOkResponse {
66 #[serde(flatten)]
67 base: RpcResponseBase,
68 result: Value,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct RpcErrorResponse {
73 #[serde(flatten)]
74 base: RpcResponseBase,
75 error: RpcError,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct SubscriptionParams {
80 channel: String,
81 data: Value,
82 label: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86struct SubscriptionNotification {
87 jsonrpc: String,
88 method: String,
89 params: SubscriptionParams,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(untagged)]
94enum JsonRPCMessage {
95 Notification(SubscriptionNotification),
96 OkResponse(RpcOkResponse),
97 ErrorResponse(RpcErrorResponse),
98}
99
100#[derive(Debug, thiserror::Error)]
101pub enum Error {
102 #[error("RPC error: {0}")]
103 RpcError(RpcError),
104 #[error("WebSocket error: {0}")]
105 WebSocketError(#[from] WSError),
106 #[error("JSON decode error: {0}")]
107 JsonError(#[from] serde_json::Error),
108 #[error("Invalid subscription channel: {0}")]
109 InvalidSubscriptionChannel(String),
110 #[error("Subscription messages lagged: {0}")]
111 SubscriptionLagged(u64),
112}
113
114pub type Result<T> = std::result::Result<T, Error>;
115
116pub trait ApiRequest: serde::Serialize {
118 type Response: DeserializeOwned + Serialize;
119 fn method_name(&self) -> &'static str;
120
121 fn is_private(&self) -> bool {
122 self.method_name().starts_with("private/")
123 }
124
125 fn to_params(&self) -> Value {
126 serde_json::to_value(self).unwrap_or_default()
127 }
128}
129
130pub trait Subscription {
132 type Data: DeserializeOwned + Serialize + Send + 'static;
133 fn channel_string(&self) -> String;
134}
135
136pub(crate) fn sub_param_to_string<T: Serialize>(value: &T) -> String {
138 let json = serde_json::to_value(value).unwrap_or(Value::Null);
141 match json {
142 Value::String(s) => s,
143 Value::Number(n) => n.to_string(),
144 Value::Bool(b) => b.to_string(),
145 _ => json.to_string(),
147 }
148}
149
150#[derive(Debug)]
151pub enum Env {
152 Production,
153 Testnet,
154}
155
156#[derive(Debug)]
157pub struct DeribitClient {
158 authenticated: AtomicBool,
159 id_counter: AtomicU64,
160 request_channel: mpsc::Sender<(RpcRequest, oneshot::Sender<Result<Value>>)>,
161 subscription_channel: mpsc::Sender<(String, oneshot::Sender<broadcast::Receiver<Value>>)>,
162 }
164
165impl DeribitClient {
166 pub async fn connect(env: Env) -> Result<Self> {
167 let ws_url = match env {
168 Env::Production => "wss://www.deribit.com/ws/api/v2",
169 Env::Testnet => "wss://test.deribit.com/ws/api/v2",
170 };
171
172 let (mut ws_stream, _) = connect_async(ws_url).await?;
173 let (request_tx, mut request_rx) =
174 mpsc::channel::<(RpcRequest, oneshot::Sender<Result<Value>>)>(100);
175 let (subscription_tx, mut subscription_rx) =
176 mpsc::channel::<(String, oneshot::Sender<broadcast::Receiver<Value>>)>(100);
177
178 tokio::spawn(async move {
179 let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value>>> = HashMap::new();
180 let mut subscribers: HashMap<String, broadcast::Sender<Value>> = HashMap::new();
181
182 loop {
183 tokio::select! {
184 msg = ws_stream.next() => {
185 match msg {
186 Some(Ok(Message::Text(text))) => {
187 match serde_json::from_str::<JsonRPCMessage>(&text) {
188 Ok(JsonRPCMessage::Notification(notification)) => {
189 if let Some(tx) = subscribers.get(¬ification.params.channel)
190 && tx.send(notification.params.data.clone()).is_err()
191 {
192 subscribers.remove(¬ification.params.channel);
193 }
194 }
195 Ok(JsonRPCMessage::OkResponse(response)) => {
196 let result = Ok(response.result);
197 if let Some(tx) = pending_requests.remove(&response.base.id) {
198 let _ = tx.send(result);
199 }
200 }
201 Ok(JsonRPCMessage::ErrorResponse(response)) => {
202 let error = Err(Error::RpcError(response.error));
203 if let Some(tx) = pending_requests.remove(&response.base.id) {
204 let _ = tx.send(error);
205 }
206 }
207 Err(e) => {
208 panic!("Received invalid json message: {e}");
209 }
210 }
211 }
212 Some(Ok(msg)) => {
213 panic!("Received non-text message: {msg:?}");
214 }
215 Some(Err(e)) => {
216 panic!("WebSocket error: {e:?}");
217 }
218 None => {
219 panic!("WebSocket connection closed");
220 }
221 }
222 }
223 Some((request, tx)) = request_rx.recv() => {
224 pending_requests.insert(request.id, tx);
225 ws_stream
226 .send(Message::Text(
227 serde_json::to_string(&request).unwrap().into(),
228 ))
229 .await
230 .unwrap();
231 }
232 Some((channel, oneshot_tx)) = subscription_rx.recv() => {
233 if let Some(broadcast_tx) = subscribers.get(&channel) {
234 let _ = oneshot_tx.send(broadcast_tx.subscribe());
235 } else {
236 let (broadcast_tx, broadcast_rx) = broadcast::channel(100);
237 subscribers.insert(channel, broadcast_tx);
238 let _ = oneshot_tx.send(broadcast_rx);
239 }
240 }
241 }
242 }
243 });
244
245 Ok(Self {
246 authenticated: AtomicBool::new(false),
247 id_counter: AtomicU64::new(0),
248 request_channel: request_tx,
249 subscription_channel: subscription_tx,
250 })
252 }
253
254 fn next_id(&self) -> u64 {
255 self.id_counter.fetch_add(1, Ordering::Relaxed)
256 }
257
258 pub async fn call_raw(&self, method: &str, params: Value) -> Result<Value> {
259 let request = RpcRequest {
260 jsonrpc: "2.0".to_string(),
261 id: self.next_id(),
262 method: method.to_string(),
263 params,
264 };
265
266 let (tx, rx) = oneshot::channel();
267
268 self.request_channel
269 .send((request, tx))
270 .await
271 .map_err(|_| WSError::ConnectionClosed)?;
272
273 let value = rx.await.map_err(|_| WSError::ConnectionClosed)??;
274
275 if method == "public/auth" {
276 self.authenticated.store(true, Ordering::Release);
277 }
278
279 Ok(value)
280 }
281
282 pub async fn call<T: ApiRequest>(&self, req: T) -> Result<T::Response> {
283 let value = self.call_raw(req.method_name(), req.to_params()).await?;
284 let typed: T::Response = serde_json::from_value(value)?;
285 Ok(typed)
286 }
287
288 pub async fn subscribe_raw(
289 &self,
290 channel: &str,
291 ) -> Result<impl Stream<Item = Result<Value>> + Send + 'static + use<>> {
292 let channels = vec![channel.to_string()];
293 let subscribed_channels = if self.authenticated.load(Ordering::Acquire) {
294 self.call(PrivateSubscribeRequest {
295 channels,
296 label: None,
297 })
298 .await?
299 } else {
300 self.call(PublicSubscribeRequest { channels }).await?
301 };
302 if let Some(channel) = subscribed_channels.first() {
303 let (tx, rx) = oneshot::channel();
304 self.subscription_channel
305 .send((channel.clone(), tx))
306 .await
307 .map_err(|_| WSError::ConnectionClosed)?;
308 let channel_rx = rx.await.map_err(|_| WSError::ConnectionClosed)?;
309 let stream = BroadcastStream::new(channel_rx).map(|msg| match msg {
310 Ok(msg) => Ok(msg),
311 Err(BroadcastStreamRecvError::Lagged(lag)) => Err(Error::SubscriptionLagged(lag)),
312 });
313 Ok(stream)
314 } else {
315 Err(Error::InvalidSubscriptionChannel(channel.to_string()))
316 }
317 }
318
319 pub async fn subscribe<S: Subscription + Send + 'static>(
321 &self,
322 subscription: S,
323 ) -> Result<impl Stream<Item = Result<S::Data>> + Send + 'static> {
324 let channel = subscription.channel_string();
325 let raw_stream = self.subscribe_raw(&channel).await?;
326 let typed_stream = raw_stream.map(|msg| match msg {
327 Ok(msg) => serde_json::from_value::<S::Data>(msg).map_err(Error::JsonError),
328 Err(e) => Err(e),
329 });
330 Ok(typed_stream)
331 }
332}