1use anyhow::{anyhow, Context, Result};
7use reqwest::{Method, StatusCode};
8use serde_json::{json, Value};
9
10pub const PREDICT_MAINNET_BASE: &str = "https://api.predict.fun/v1";
11pub const PREDICT_TESTNET_BASE: &str = "https://api-testnet.predict.fun/v1";
12
13#[derive(Debug, Clone, Copy)]
16pub struct EndpointSpec {
17 pub method: MethodName,
18 pub path: &'static str,
19 pub requires_jwt: bool,
20}
21
22#[derive(Debug, Clone, Copy)]
23pub enum MethodName {
24 Get,
25 Post,
26}
27
28impl MethodName {
29 #[inline]
30 pub fn as_reqwest(self) -> Method {
31 match self {
32 MethodName::Get => Method::GET,
33 MethodName::Post => Method::POST,
34 }
35 }
36}
37
38pub const PREDICT_ENDPOINTS: &[EndpointSpec] = &[
39 EndpointSpec {
40 method: MethodName::Get,
41 path: "/auth/message",
42 requires_jwt: false,
43 },
44 EndpointSpec {
45 method: MethodName::Post,
46 path: "/auth",
47 requires_jwt: false,
48 },
49 EndpointSpec {
50 method: MethodName::Post,
51 path: "/orders",
52 requires_jwt: true,
53 },
54 EndpointSpec {
55 method: MethodName::Get,
56 path: "/orders",
57 requires_jwt: true,
58 },
59 EndpointSpec {
60 method: MethodName::Post,
61 path: "/orders/remove",
62 requires_jwt: true,
63 },
64 EndpointSpec {
65 method: MethodName::Get,
66 path: "/orders/{hash}",
67 requires_jwt: true,
68 },
69 EndpointSpec {
70 method: MethodName::Get,
71 path: "/orders/matches",
72 requires_jwt: false,
73 },
74 EndpointSpec {
75 method: MethodName::Get,
76 path: "/markets",
77 requires_jwt: false,
78 },
79 EndpointSpec {
80 method: MethodName::Get,
81 path: "/markets/{id}",
82 requires_jwt: false,
83 },
84 EndpointSpec {
85 method: MethodName::Get,
86 path: "/markets/{id}/stats",
87 requires_jwt: false,
88 },
89 EndpointSpec {
90 method: MethodName::Get,
91 path: "/markets/{id}/last-sale",
92 requires_jwt: false,
93 },
94 EndpointSpec {
95 method: MethodName::Get,
96 path: "/markets/{id}/orderbook",
97 requires_jwt: false,
98 },
99 EndpointSpec {
100 method: MethodName::Get,
101 path: "/markets/{id}/timeseries",
102 requires_jwt: false,
103 },
104 EndpointSpec {
105 method: MethodName::Get,
106 path: "/markets/{id}/timeseries/latest",
107 requires_jwt: false,
108 },
109 EndpointSpec {
110 method: MethodName::Get,
111 path: "/categories",
112 requires_jwt: false,
113 },
114 EndpointSpec {
115 method: MethodName::Get,
116 path: "/categories/{slug}",
117 requires_jwt: false,
118 },
119 EndpointSpec {
120 method: MethodName::Get,
121 path: "/categories/{id}/stats",
122 requires_jwt: false,
123 },
124 EndpointSpec {
125 method: MethodName::Get,
126 path: "/tags",
127 requires_jwt: false,
128 },
129 EndpointSpec {
130 method: MethodName::Get,
131 path: "/positions",
132 requires_jwt: true,
133 },
134 EndpointSpec {
135 method: MethodName::Get,
136 path: "/positions/{address}",
137 requires_jwt: false,
138 },
139 EndpointSpec {
140 method: MethodName::Get,
141 path: "/account",
142 requires_jwt: true,
143 },
144 EndpointSpec {
145 method: MethodName::Post,
146 path: "/account/referral",
147 requires_jwt: true,
148 },
149 EndpointSpec {
150 method: MethodName::Get,
151 path: "/account/activity",
152 requires_jwt: true,
153 },
154 EndpointSpec {
155 method: MethodName::Post,
156 path: "/oauth/finalize",
157 requires_jwt: false,
158 },
159 EndpointSpec {
160 method: MethodName::Post,
161 path: "/oauth/orders",
162 requires_jwt: false,
163 },
164 EndpointSpec {
165 method: MethodName::Post,
166 path: "/oauth/orders/create",
167 requires_jwt: false,
168 },
169 EndpointSpec {
170 method: MethodName::Post,
171 path: "/oauth/orders/cancel",
172 requires_jwt: false,
173 },
174 EndpointSpec {
175 method: MethodName::Post,
176 path: "/oauth/positions",
177 requires_jwt: false,
178 },
179 EndpointSpec {
180 method: MethodName::Get,
181 path: "/search",
182 requires_jwt: false,
183 },
184 EndpointSpec {
185 method: MethodName::Get,
186 path: "/yield/pending",
187 requires_jwt: true,
188 },
189];
190
191#[derive(Debug, Clone)]
192pub struct RawApiResponse {
193 pub status: StatusCode,
194 pub body: String,
195 pub json: Option<Value>,
196}
197
198#[derive(Clone)]
203pub struct PredictApiClient {
204 base: String,
205 api_key: Option<String>,
206 jwt: Option<String>,
207 http: reqwest::Client,
208}
209
210impl PredictApiClient {
211 pub fn new_mainnet(api_key: impl Into<String>) -> Result<Self> {
212 Self::new(PREDICT_MAINNET_BASE, Some(api_key.into()), None)
213 }
214
215 pub fn new_testnet() -> Result<Self> {
216 Self::new(PREDICT_TESTNET_BASE, None, None)
217 }
218
219 pub fn new(
220 base: impl Into<String>,
221 api_key: Option<String>,
222 jwt: Option<String>,
223 ) -> Result<Self> {
224 let http = reqwest::Client::builder()
225 .timeout(std::time::Duration::from_secs(15))
226 .pool_max_idle_per_host(16)
227 .pool_idle_timeout(std::time::Duration::from_secs(90))
228 .build()
229 .context("failed to build predict api client")?;
230
231 Ok(Self {
232 base: base.into().trim_end_matches('/').to_string(),
233 api_key,
234 jwt,
235 http,
236 })
237 }
238
239 pub fn with_jwt(mut self, jwt: impl Into<String>) -> Self {
240 self.jwt = Some(jwt.into());
241 self
242 }
243
244 pub fn set_jwt(&mut self, jwt: impl Into<String>) {
245 self.jwt = Some(jwt.into());
246 }
247
248 pub fn clear_jwt(&mut self) {
249 self.jwt = None;
250 }
251
252 pub fn endpoint_specs() -> &'static [EndpointSpec] {
253 PREDICT_ENDPOINTS
254 }
255
256 pub async fn auth_message(&self) -> Result<Value> {
259 self.get_ok("/auth/message", &[], false).await
260 }
261
262 pub async fn auth(&self, signer: &str, message: &str, signature: &str) -> Result<Value> {
263 let body = json!({
264 "signer": signer,
265 "message": message,
266 "signature": signature,
267 });
268 self.post_ok("/auth", &[], body, false).await
269 }
270
271 pub async fn create_order(&self, body: Value) -> Result<Value> {
274 self.post_ok("/orders", &[], body, true).await
275 }
276
277 pub async fn list_orders(&self, query: &[(&str, String)]) -> Result<Value> {
278 self.get_ok("/orders", query, true).await
279 }
280
281 pub async fn remove_orders(&self, body: Value) -> Result<Value> {
282 self.post_ok("/orders/remove", &[], body, true).await
283 }
284
285 pub async fn get_order(&self, hash: &str) -> Result<Value> {
286 self.get_ok(&format!("/orders/{}", hash), &[], true).await
287 }
288
289 pub async fn get_order_matches(&self, query: &[(&str, String)]) -> Result<Value> {
290 self.get_ok("/orders/matches", query, false).await
291 }
292
293 pub async fn list_markets(&self, query: &[(&str, String)]) -> Result<Value> {
296 self.get_ok("/markets", query, false).await
297 }
298
299 pub async fn get_market(&self, id: i64) -> Result<Value> {
300 self.get_ok(&format!("/markets/{}", id), &[], false).await
301 }
302
303 pub async fn get_market_stats(&self, id: i64) -> Result<Value> {
304 self.get_ok(&format!("/markets/{}/stats", id), &[], false)
305 .await
306 }
307
308 pub async fn get_market_last_sale(&self, id: i64) -> Result<Value> {
309 self.get_ok(&format!("/markets/{}/last-sale", id), &[], false)
310 .await
311 }
312
313 pub async fn get_market_orderbook(&self, id: i64) -> Result<Value> {
314 self.get_ok(&format!("/markets/{}/orderbook", id), &[], false)
315 .await
316 }
317
318 pub async fn get_market_timeseries(&self, id: i64, query: &[(&str, String)]) -> Result<Value> {
319 self.get_ok(&format!("/markets/{}/timeseries", id), query, false)
320 .await
321 }
322
323 pub async fn get_market_timeseries_latest(&self, id: i64) -> Result<Value> {
324 self.get_ok(&format!("/markets/{}/timeseries/latest", id), &[], false)
325 .await
326 }
327
328 pub async fn list_categories(&self, query: &[(&str, String)]) -> Result<Value> {
331 self.get_ok("/categories", query, false).await
332 }
333
334 pub async fn get_category(&self, slug: &str) -> Result<Value> {
335 self.get_ok(&format!("/categories/{}", slug), &[], false)
336 .await
337 }
338
339 pub async fn get_category_stats(&self, id: i64) -> Result<Value> {
340 self.get_ok(&format!("/categories/{}/stats", id), &[], false)
341 .await
342 }
343
344 pub async fn list_tags(&self) -> Result<Value> {
345 self.get_ok("/tags", &[], false).await
346 }
347
348 pub async fn list_positions(&self, query: &[(&str, String)]) -> Result<Value> {
351 self.get_ok("/positions", query, true).await
352 }
353
354 pub async fn list_positions_for_address(
355 &self,
356 address: &str,
357 query: &[(&str, String)],
358 ) -> Result<Value> {
359 self.get_ok(&format!("/positions/{}", address), query, false)
360 .await
361 }
362
363 pub async fn account(&self) -> Result<Value> {
366 self.get_ok("/account", &[], true).await
367 }
368
369 pub async fn set_referral(&self, code: &str) -> Result<Value> {
370 let body = json!({ "code": code });
371 self.post_ok("/account/referral", &[], body, true).await
372 }
373
374 pub async fn account_activity(&self, query: &[(&str, String)]) -> Result<Value> {
375 self.get_ok("/account/activity", query, true).await
376 }
377
378 pub async fn oauth_finalize(&self, body: Value) -> Result<Value> {
381 self.post_ok("/oauth/finalize", &[], body, false).await
382 }
383
384 pub async fn oauth_orders(&self, body: Value) -> Result<Value> {
385 self.post_ok("/oauth/orders", &[], body, false).await
386 }
387
388 pub async fn oauth_create_order(&self, body: Value) -> Result<Value> {
389 self.post_ok("/oauth/orders/create", &[], body, false).await
390 }
391
392 pub async fn oauth_cancel_order(&self, body: Value) -> Result<Value> {
393 self.post_ok("/oauth/orders/cancel", &[], body, false).await
394 }
395
396 pub async fn oauth_positions(&self, body: Value) -> Result<Value> {
397 self.post_ok("/oauth/positions", &[], body, false).await
398 }
399
400 pub async fn search(&self, query: &[(&str, String)]) -> Result<Value> {
403 self.get_ok("/search", query, false).await
404 }
405
406 pub async fn yield_pending(&self, query: &[(&str, String)]) -> Result<Value> {
407 self.get_ok("/yield/pending", query, true).await
408 }
409
410 pub async fn raw_get(
413 &self,
414 path: &str,
415 query: &[(&str, String)],
416 require_jwt: bool,
417 ) -> Result<RawApiResponse> {
418 self.raw_request(Method::GET, path, query, None, require_jwt)
419 .await
420 }
421
422 pub async fn raw_post(
423 &self,
424 path: &str,
425 query: &[(&str, String)],
426 body: Value,
427 require_jwt: bool,
428 ) -> Result<RawApiResponse> {
429 self.raw_request(Method::POST, path, query, Some(body), require_jwt)
430 .await
431 }
432
433 async fn get_ok(
434 &self,
435 path: &str,
436 query: &[(&str, String)],
437 require_jwt: bool,
438 ) -> Result<Value> {
439 self.expect_success(
440 self.raw_request(Method::GET, path, query, None, require_jwt)
441 .await?,
442 Method::GET,
443 path,
444 )
445 }
446
447 async fn post_ok(
448 &self,
449 path: &str,
450 query: &[(&str, String)],
451 body: Value,
452 require_jwt: bool,
453 ) -> Result<Value> {
454 self.expect_success(
455 self.raw_request(Method::POST, path, query, Some(body), require_jwt)
456 .await?,
457 Method::POST,
458 path,
459 )
460 }
461
462 fn expect_success(&self, resp: RawApiResponse, method: Method, path: &str) -> Result<Value> {
463 if !resp.status.is_success() {
464 return Err(anyhow!(
465 "Predict API {} {} failed: status={} body={}",
466 method,
467 path,
468 resp.status,
469 truncate_body(&resp.body)
470 ));
471 }
472
473 if let Some(json) = resp.json {
474 Ok(json)
475 } else {
476 Err(anyhow!(
477 "Predict API {} {} returned non-json body: {}",
478 method,
479 path,
480 truncate_body(&resp.body)
481 ))
482 }
483 }
484
485 async fn raw_request(
486 &self,
487 method: Method,
488 path: &str,
489 query: &[(&str, String)],
490 body: Option<Value>,
491 require_jwt: bool,
492 ) -> Result<RawApiResponse> {
493 if require_jwt && self.jwt.is_none() {
494 return Err(anyhow!(
495 "JWT is required for {} {} but client has no jwt configured",
496 method,
497 path
498 ));
499 }
500
501 let url = format!("{}{}", self.base, path);
502 let mut req = self
503 .http
504 .request(method.clone(), &url)
505 .header("Accept", "application/json")
506 .header("Content-Type", "application/json");
507
508 if !query.is_empty() {
509 req = req.query(query);
510 }
511
512 if let Some(api_key) = &self.api_key {
513 req = req.header("x-api-key", api_key);
514 }
515
516 if let Some(jwt) = &self.jwt {
517 req = req.header("Authorization", format!("Bearer {}", jwt));
518 }
519
520 if let Some(v) = body {
521 req = req.json(&v);
522 }
523
524 let resp = req
525 .send()
526 .await
527 .with_context(|| format!("predict api request failed for {} {}", method, path))?;
528
529 let status = resp.status();
530 let text = resp
531 .text()
532 .await
533 .with_context(|| format!("failed reading response body for {} {}", method, path))?;
534
535 let json = serde_json::from_str::<Value>(&text).ok();
536
537 Ok(RawApiResponse {
538 status,
539 body: text,
540 json,
541 })
542 }
543}
544
545#[inline]
546fn truncate_body(body: &str) -> String {
547 const MAX: usize = 500;
548 if body.len() <= MAX {
549 body.to_string()
550 } else {
551 format!("{}...<truncated>", &body[..MAX])
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn endpoint_count_matches_current_openapi_surface() {
561 assert_eq!(PREDICT_ENDPOINTS.len(), 30);
562 }
563
564 #[test]
565 fn has_auth_and_orders_endpoints() {
566 let has_auth = PREDICT_ENDPOINTS
567 .iter()
568 .any(|ep| ep.path == "/auth" && matches!(ep.method, MethodName::Post));
569 let has_orders = PREDICT_ENDPOINTS
570 .iter()
571 .any(|ep| ep.path == "/orders" && matches!(ep.method, MethodName::Post));
572
573 assert!(has_auth);
574 assert!(has_orders);
575 }
576}