Skip to main content

predict_sdk/
client.rs

1/// HTTP client for the predict.fun REST API
2///
3/// This module provides a client for interacting with the predict.fun API,
4/// including fetching markets, orderbooks, placing/cancelling orders, and JWT authentication.
5
6use crate::api_types::*;
7use crate::errors::{Error, Result};
8use crate::onchain::{OnchainClient, SplitOptions};
9use crate::order_builder::OrderBuilder;
10use crate::types::{BuildOrderInput, ChainId, LimitOrderData, OrderStrategy, Side};
11use alloy::signers::local::PrivateKeySigner;
12use alloy::signers::Signer;
13use reqwest::Client as HttpClient;
14use rust_decimal::Decimal;
15use std::sync::Arc;
16use tracing::{debug, info};
17
18/// Client for interacting with predict.fun
19pub struct PredictClient {
20    order_builder: Arc<OrderBuilder>,
21    http_client: HttpClient,
22    api_base_url: String,
23    chain_id: ChainId,
24    api_key: Option<String>,
25    jwt_token: std::sync::RwLock<Option<String>>,
26}
27
28impl PredictClient {
29    /// Create a new PredictClient with full trading capability
30    pub fn new(
31        chain_id: u64,
32        private_key: &str,
33        api_base_url: String,
34        api_key: Option<String>,
35    ) -> Result<Self> {
36        let chain_id = ChainId::try_from(chain_id)?;
37        let signer = Self::parse_private_key(private_key)?;
38        let order_builder =
39            OrderBuilder::new(chain_id, Some(signer), None).map_err(|e| Error::Other(e.to_string()))?;
40
41        Ok(Self {
42            order_builder: Arc::new(order_builder),
43            http_client: HttpClient::new(),
44            api_base_url,
45            chain_id,
46            api_key,
47            jwt_token: std::sync::RwLock::new(None),
48        })
49    }
50
51    /// Create a new PredictClient with Predict Smart Wallet (Kernel) signing
52    pub fn new_with_predict_account(
53        chain_id: u64,
54        privy_private_key: &str,
55        predict_account: &str,
56        api_base_url: String,
57        api_key: Option<String>,
58    ) -> Result<Self> {
59        let chain_id = ChainId::try_from(chain_id)?;
60        let signer = Self::parse_private_key(privy_private_key)?;
61        let order_builder = OrderBuilder::with_predict_account(
62            chain_id,
63            signer,
64            predict_account,
65            None,
66        ).map_err(|e| Error::Other(e.to_string()))?;
67
68        Ok(Self {
69            order_builder: Arc::new(order_builder),
70            http_client: HttpClient::new(),
71            api_base_url,
72            chain_id,
73            api_key,
74            jwt_token: std::sync::RwLock::new(None),
75        })
76    }
77
78    /// Create a read-only PredictClient for market data operations
79    pub fn new_readonly(
80        chain_id: u64,
81        api_base_url: String,
82        api_key: Option<String>,
83    ) -> Result<Self> {
84        let chain_id = ChainId::try_from(chain_id)?;
85        let order_builder =
86            OrderBuilder::new(chain_id, None, None).map_err(|e| Error::Other(e.to_string()))?;
87
88        Ok(Self {
89            order_builder: Arc::new(order_builder),
90            http_client: HttpClient::new(),
91            api_base_url,
92            chain_id,
93            api_key,
94            jwt_token: std::sync::RwLock::new(None),
95        })
96    }
97
98    /// Check if this client has signing capability
99    pub fn can_sign(&self) -> bool {
100        self.order_builder.signer_address().is_ok()
101    }
102
103    /// Check if this client uses Predict Account (Kernel) signing
104    pub fn uses_predict_account(&self) -> bool {
105        self.order_builder.uses_predict_account()
106    }
107
108    /// Get the Predict Account address if configured
109    pub fn predict_account(&self) -> Option<String> {
110        self.order_builder.predict_account().map(|addr| format!("{}", addr))
111    }
112
113    /// Parse private key from hex string
114    fn parse_private_key(private_key: &str) -> Result<PrivateKeySigner> {
115        let key = private_key.trim().trim_start_matches("0x");
116        let bytes = hex::decode(key)
117            .map_err(|e| Error::ConfigError(format!("Invalid private key format: {}", e)))?;
118
119        if bytes.len() != 32 {
120            return Err(Error::ConfigError("Private key must be 32 bytes".into()));
121        }
122
123        let mut key_bytes = [0u8; 32];
124        key_bytes.copy_from_slice(&bytes);
125
126        PrivateKeySigner::from_bytes(&key_bytes.into())
127            .map_err(|e| Error::ConfigError(format!("Failed to create signer: {}", e)))
128    }
129
130    /// Add authentication headers to a request
131    ///
132    /// Adds `x-api-key` if configured, and `Authorization: Bearer {jwt}` if authenticated.
133    fn add_auth_headers(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
134        let mut request = request;
135        if let Some(ref api_key) = self.api_key {
136            request = request.header("x-api-key", api_key);
137        }
138        if let Ok(guard) = self.jwt_token.read() {
139            if let Some(ref jwt) = *guard {
140                request = request.header("Authorization", format!("Bearer {}", jwt));
141            }
142        }
143        request
144    }
145
146    // ========================================================================
147    // Authentication
148    // ========================================================================
149
150    /// Authenticate with Predict API and obtain a JWT token
151    ///
152    /// Flow:
153    /// 1. GET /v1/auth/message → dynamic message to sign
154    /// 2. Sign message with wallet private key (EIP-191 personal sign)
155    /// 3. POST /v1/auth → submit signature → receive JWT
156    ///
157    /// The JWT is required for authenticated WebSocket subscriptions
158    /// (e.g., `predictWalletEvents/{jwt}` for order fill notifications).
159    pub async fn authenticate(&self) -> Result<String> {
160        let signer = self.order_builder.signer()
161            .ok_or_else(|| Error::Other("No signer configured - cannot authenticate".into()))?;
162
163        // Step 1: Get the dynamic auth message
164        let url = format!("{}/v1/auth/message", self.api_base_url);
165        let request = self.add_auth_headers(self.http_client.get(&url));
166        let response = request.send().await?;
167
168        if !response.status().is_success() {
169            let error_text = response.text().await.unwrap_or_default();
170            return Err(Error::ApiError(format!(
171                "Failed to get auth message: {}", error_text
172            )));
173        }
174
175        let auth_msg: AuthMessageResponse = response.json().await?;
176        if !auth_msg.success {
177            return Err(Error::ApiError("Auth message request returned success=false".into()));
178        }
179
180        let message = auth_msg.data.message;
181        debug!("Got auth message to sign: {}", &message[..message.len().min(50)]);
182
183        // Step 2: Sign the message
184        // For Predict Accounts: use Kernel-wrapped signing, signer = predict_account
185        // For EOA: use plain EIP-191 personal sign, signer = EOA address
186        let (signature_hex, signer_address) = if let Some(predict_account) = self.order_builder.predict_account() {
187            let ecdsa_validator = self.order_builder.addresses().ecdsa_validator
188                .parse::<alloy::primitives::Address>()
189                .map_err(|e| Error::Other(format!("Invalid ECDSA validator address: {}", e)))?;
190
191            let sig = crate::internal::signing::sign_message_for_predict_account(
192                message.as_bytes(),
193                self.chain_id,
194                predict_account,
195                ecdsa_validator,
196                &signer,
197            ).await?;
198
199            (sig, format!("{}", predict_account))
200        } else {
201            let signature = signer
202                .sign_message(message.as_bytes())
203                .await
204                .map_err(|e| Error::SigningError(format!("Failed to sign auth message: {}", e)))?;
205
206            let mut sig_bytes = signature.as_bytes().to_vec();
207            if sig_bytes[64] < 27 {
208                sig_bytes[64] += 27;
209            }
210
211            (format!("0x{}", hex::encode(sig_bytes)), format!("{}", signer.address()))
212        };
213
214        // Step 3: Submit signature to get JWT
215        let url = format!("{}/v1/auth", self.api_base_url);
216        let auth_request = AuthRequest {
217            signer: signer_address,
218            signature: signature_hex,
219            message,
220        };
221
222        let request = self.add_auth_headers(self.http_client.post(&url))
223            .json(&auth_request);
224        let response = request.send().await?;
225
226        if !response.status().is_success() {
227            let error_text = response.text().await.unwrap_or_default();
228            return Err(Error::ApiError(format!(
229                "Failed to authenticate: {}", error_text
230            )));
231        }
232
233        let auth_response: AuthResponse = response.json().await?;
234        if !auth_response.success {
235            return Err(Error::ApiError("Authentication returned success=false".into()));
236        }
237
238        info!("Successfully authenticated with Predict API");
239        Ok(auth_response.data.token)
240    }
241
242    /// Authenticate and store the JWT token for subsequent REST API requests
243    ///
244    /// This calls `authenticate()` and stores the resulting JWT so that
245    /// `add_auth_headers()` will include `Authorization: Bearer {jwt}` on all requests.
246    pub async fn authenticate_and_store(&self) -> Result<String> {
247        let jwt = self.authenticate().await?;
248        if let Ok(mut guard) = self.jwt_token.write() {
249            *guard = Some(jwt.clone());
250        }
251        Ok(jwt)
252    }
253
254    /// Get the stored JWT token (if authenticated)
255    pub fn jwt_token(&self) -> Option<String> {
256        self.jwt_token.read().ok().and_then(|guard| guard.clone())
257    }
258
259    // ========================================================================
260    // Market Data
261    // ========================================================================
262
263    /// Fetch all markets from Predict
264    pub async fn get_markets(&self) -> Result<Vec<PredictMarket>> {
265        let url = format!("{}/markets", self.api_base_url);
266        debug!("Fetching markets from: {}", url);
267
268        let response = self.http_client.get(&url).send().await?;
269
270        if !response.status().is_success() {
271            return Err(Error::ApiError(format!(
272                "Failed to fetch markets: status={}",
273                response.status()
274            )));
275        }
276
277        let markets: Vec<PredictMarket> = response.json().await?;
278        info!("Fetched {} markets from Predict", markets.len());
279        Ok(markets)
280    }
281
282    /// Fetch orderbook for a specific market
283    pub async fn get_orderbook(&self, market_id: &str) -> Result<PredictOrderBook> {
284        let url = format!("{}/markets/{}/orderbook", self.api_base_url, market_id);
285        debug!("Fetching orderbook from: {}", url);
286
287        let response = self.http_client.get(&url).send().await?;
288
289        if !response.status().is_success() {
290            return Err(Error::ApiError(format!(
291                "Failed to fetch orderbook for market {}: status={}",
292                market_id,
293                response.status()
294            )));
295        }
296
297        let orderbook: PredictOrderBook = response.json().await?;
298        Ok(orderbook)
299    }
300
301    // ========================================================================
302    // Order Management
303    // ========================================================================
304
305    /// Place a limit order on Predict
306    ///
307    /// Builds, signs, and submits a limit order to the Predict API.
308    /// Returns the order ID and hash on success.
309    pub async fn place_limit_order(
310        &self,
311        token_id: &str,
312        side: Side,
313        price: Decimal,
314        quantity: Decimal,
315        is_neg_risk: bool,
316        is_yield_bearing: bool,
317        fee_rate_bps: u64,
318    ) -> Result<PlaceOrderResponse> {
319        info!(
320            "Placing limit order: token_id={}, side={:?}, price={}, quantity={}",
321            token_id, side, price, quantity
322        );
323
324        // Calculate order amounts
325        let amounts = self
326            .order_builder
327            .get_limit_order_amounts(LimitOrderData {
328                side,
329                price_per_share_wei: price,
330                quantity_wei: quantity,
331            })
332            .map_err(|e| Error::Other(format!("Failed to calculate order amounts: {}", e)))?;
333
334        // Build order - let build_order handle maker/signer based on config:
335        // - With predict_account: maker=predict_account, signer=predict_account (matching official SDKs)
336        // - Without: maker=EOA, signer=EOA
337        let order = self
338            .order_builder
339            .build_order(
340                OrderStrategy::Limit,
341                BuildOrderInput {
342                    side,
343                    token_id: token_id.to_string(),
344                    maker_amount: amounts.maker_amount.trunc().to_string(),
345                    taker_amount: amounts.taker_amount.trunc().to_string(),
346                    fee_rate_bps,
347                    signer: None,
348                    nonce: None,
349                    salt: None,
350                    maker: None,
351                    taker: None,
352                    signature_type: None,
353                    expires_at: None,
354                },
355            )
356            .map_err(|e| Error::Other(format!("Failed to build order: {}", e)))?;
357
358        let verifying_contract = self.order_builder.get_verifying_contract(is_neg_risk, is_yield_bearing);
359        info!(
360            "Signing order: chain_id={:?}, is_neg_risk={}, is_yield_bearing={}, verifying_contract={}, maker={}, signer={}, uses_predict_account={}",
361            self.chain_id, is_neg_risk, is_yield_bearing, verifying_contract,
362            order.maker, order.signer, self.order_builder.uses_predict_account(),
363        );
364
365        // Sign order — automatically uses Kernel-wrapped signing for predict accounts,
366        // or plain EOA EIP-712 for direct wallets.
367        let signed_order = self
368            .order_builder
369            .sign_typed_data_order(order, is_neg_risk, is_yield_bearing)
370            .await
371            .map_err(|e| Error::Other(format!("Failed to sign order: {}", e)))?;
372
373        // Build the API request body per Predict API spec:
374        // POST /v1/orders with body: {"data": {"order": {...}, "pricePerShare": "...", "strategy": "LIMIT"}}
375        let order_json = serde_json::to_value(&signed_order)?;
376        let price_per_share = amounts.price_per_share.to_string();
377
378        let request_body = CreateOrderRequest {
379            data: CreateOrderData {
380                order: order_json,
381                price_per_share,
382                strategy: "LIMIT".to_string(),
383            },
384        };
385
386        info!("Order request body: {}", serde_json::to_string(&request_body).unwrap_or_default());
387
388        // Submit to API
389        let url = format!("{}/v1/orders", self.api_base_url);
390        let request = self.add_auth_headers(self.http_client.post(&url))
391            .json(&request_body);
392        let response = request.send().await?;
393
394        let status = response.status();
395        if !status.is_success() {
396            let error_text = response
397                .text()
398                .await
399                .unwrap_or_else(|_| "Unknown error".to_string());
400            return Err(Error::ApiError(format!(
401                "Failed to place order: status={}, error={}",
402                status, error_text
403            )));
404        }
405
406        let place_response: PlaceOrderResponse = response.json().await?;
407
408        if !place_response.success {
409            return Err(Error::ApiError("Order placement returned success=false".into()));
410        }
411
412        if let Some(ref data) = place_response.data {
413            info!(
414                "Order placed successfully: order_id={}, hash={}",
415                data.order_id, data.order_hash
416            );
417        }
418
419        Ok(place_response)
420    }
421
422    /// Cancel orders by their IDs
423    ///
424    /// Removes orders from the Predict orderbook.
425    /// Note: This removes orders from the orderbook but does not cancel on-chain.
426    ///
427    /// # Arguments
428    /// * `order_ids` - Order IDs to cancel (max 100)
429    pub async fn cancel_orders(&self, order_ids: &[String]) -> Result<RemoveOrdersResponse> {
430        if order_ids.is_empty() {
431            return Ok(RemoveOrdersResponse {
432                success: true,
433                removed: vec![],
434                noop: vec![],
435            });
436        }
437
438        if order_ids.len() > 100 {
439            return Err(Error::Other("Cannot cancel more than 100 orders at once".into()));
440        }
441
442        info!("Cancelling {} orders on Predict", order_ids.len());
443
444        let request_body = RemoveOrdersRequest {
445            data: RemoveOrdersData {
446                ids: order_ids.to_vec(),
447            },
448        };
449
450        let url = format!("{}/v1/orders/remove", self.api_base_url);
451        let request = self.add_auth_headers(self.http_client.post(&url))
452            .json(&request_body);
453        let response = request.send().await?;
454
455        let status = response.status();
456        if !status.is_success() {
457            let error_text = response.text().await.unwrap_or_default();
458            return Err(Error::ApiError(format!(
459                "Failed to cancel orders: status={}, error={}",
460                status, error_text
461            )));
462        }
463
464        let cancel_response: RemoveOrdersResponse = response.json().await?;
465
466        info!(
467            "Cancel result: removed={}, noop={}",
468            cancel_response.removed.len(),
469            cancel_response.noop.len()
470        );
471
472        Ok(cancel_response)
473    }
474
475    /// Fetch open orders for this account
476    ///
477    /// Returns all orders with status OPEN for the authenticated user.
478    pub async fn get_open_orders(&self) -> Result<Vec<PredictOrder>> {
479        let url = format!("{}/v1/orders?status=OPEN", self.api_base_url);
480        debug!("Fetching open orders from: {}", url);
481
482        let request = self.add_auth_headers(self.http_client.get(&url));
483        let response = request.send().await?;
484
485        let status = response.status();
486        if !status.is_success() {
487            let error_text = response.text().await.unwrap_or_default();
488            return Err(Error::ApiError(format!(
489                "Failed to fetch open orders: status={}, error={}",
490                status, error_text
491            )));
492        }
493
494        let body = response.text().await?;
495        debug!("get_open_orders raw response: {}", &body[..500.min(body.len())]);
496
497        let orders_response: GetOrdersResponse = serde_json::from_str(&body)
498            .map_err(|e| Error::ApiError(format!(
499                "Failed to parse open orders: {} | body: {}",
500                e, &body[..500.min(body.len())]
501            )))?;
502
503        if !orders_response.success {
504            return Err(Error::ApiError("Get orders returned success=false".into()));
505        }
506
507        debug!("Fetched {} open orders", orders_response.data.len());
508        Ok(orders_response.data)
509    }
510
511    /// Fetch positions (token balances) for this account
512    pub async fn get_positions(&self) -> Result<Vec<PredictPosition>> {
513        let url = format!("{}/v1/positions", self.api_base_url);
514        debug!("Fetching positions from: {}", url);
515
516        let request = self.add_auth_headers(self.http_client.get(&url));
517        let response = request.send().await?;
518
519        let status = response.status();
520        if !status.is_success() {
521            let error_text = response.text().await.unwrap_or_default();
522            return Err(Error::ApiError(format!(
523                "Failed to fetch positions: status={}, error={}",
524                status, error_text
525            )));
526        }
527
528        let positions_response: GetPositionsResponse = response.json().await?;
529
530        if !positions_response.success {
531            return Err(Error::ApiError("Get positions returned success=false".into()));
532        }
533
534        debug!("Fetched {} positions", positions_response.data.len());
535        Ok(positions_response.data)
536    }
537
538    // ========================================================================
539    // Accessors
540    // ========================================================================
541
542    /// Get the signer address
543    pub fn signer_address(&self) -> Result<String> {
544        self.order_builder
545            .signer_address()
546            .map(|addr| format!("{}", addr))
547            .map_err(|e| Error::Other(format!("Failed to get signer address: {}", e)))
548    }
549
550    /// Get the chain ID
551    pub fn chain_id(&self) -> ChainId {
552        self.chain_id
553    }
554
555    /// Get the API key (if set)
556    pub fn api_key(&self) -> Option<&str> {
557        self.api_key.as_deref()
558    }
559
560    /// Get the order builder
561    pub fn order_builder(&self) -> &OrderBuilder {
562        &self.order_builder
563    }
564
565    /// Get the API base URL
566    pub fn api_base_url(&self) -> &str {
567        &self.api_base_url
568    }
569
570    // ========================================================================
571    // Category API for Market Matching
572    // ========================================================================
573
574    /// Fetch a category by slug from Predict
575    pub async fn get_category(&self, slug: &str) -> Result<PredictCategory> {
576        let url = format!("{}/v1/categories/{}", self.api_base_url, slug);
577        debug!("Fetching category from: {}", url);
578
579        let request = self.add_auth_headers(self.http_client.get(&url));
580        let response = request.send().await?;
581        let status = response.status();
582
583        if status == reqwest::StatusCode::NOT_FOUND {
584            return Err(Error::ApiError(format!("Category not found: slug={}", slug)));
585        }
586
587        if !status.is_success() {
588            let error_text = response
589                .text()
590                .await
591                .unwrap_or_else(|_| "Unknown error".to_string());
592            return Err(Error::ApiError(format!(
593                "Failed to fetch category {}: status={}, error={}",
594                slug, status, error_text
595            )));
596        }
597
598        let wrapper: CategoryResponse = response.json().await?;
599        debug!(
600            "Fetched category '{}' with {} markets",
601            wrapper.data.slug,
602            wrapper.data.markets.len()
603        );
604
605        Ok(wrapper.data)
606    }
607
608    /// Fetch a category by slug, returning None if not found
609    pub async fn get_category_optional(&self, slug: &str) -> Result<Option<PredictCategory>> {
610        match self.get_category(slug).await {
611            Ok(category) => Ok(Some(category)),
612            Err(Error::ApiError(msg)) if msg.contains("not found") => Ok(None),
613            Err(e) => Err(e),
614        }
615    }
616
617    // ========================================================================
618    // On-chain Operations (Split/Merge/Approvals)
619    // ========================================================================
620
621    /// Set all necessary on-chain approvals for trading.
622    ///
623    /// This must be called once before placing orders. It sets:
624    /// - ERC-1155 approval on Conditional Tokens for the CTF Exchange
625    /// - ERC-20 approval on USDT for the CTF Exchange
626    /// - Neg Risk Adapter approval (if is_neg_risk)
627    pub async fn set_approvals(
628        &self,
629        is_neg_risk: bool,
630        is_yield_bearing: bool,
631    ) -> Result<()> {
632        let signer = self
633            .order_builder
634            .signer()
635            .ok_or_else(|| Error::Other("No signer configured - cannot set approvals".into()))?;
636
637        // Approvals must be set for the trading address (predict_account or EOA).
638        // For predict_account: the smart wallet holds tokens, so it needs approvals.
639        // set_approvals checks `trading_address()` which returns predict_account if set.
640        let onchain_client = if let Some(predict_account) = self.order_builder.predict_account() {
641            OnchainClient::with_predict_account(
642                self.chain_id,
643                signer,
644                &format!("{}", predict_account),
645            )?
646        } else {
647            OnchainClient::new(self.chain_id, signer)
648        };
649        onchain_client.set_approvals(is_neg_risk, is_yield_bearing).await
650    }
651
652    /// Split USDT into UP/DOWN outcome tokens for a market.
653    ///
654    /// When a predict_account is configured, splits via Kernel so tokens land
655    /// on the predict_account (which is the order maker).
656    /// When using direct EOA, splits directly.
657    pub async fn split_positions(
658        &self,
659        condition_id: &str,
660        amount: f64,
661        is_neg_risk: bool,
662        is_yield_bearing: bool,
663    ) -> Result<String> {
664        let signer = self
665            .order_builder
666            .signer()
667            .ok_or_else(|| Error::Other("No signer configured - cannot perform on-chain operations".into()))?;
668
669        let onchain_client = if let Some(predict_account) = self.order_builder.predict_account() {
670            OnchainClient::with_predict_account(
671                self.chain_id,
672                signer,
673                &format!("{}", predict_account),
674            )?
675        } else {
676            OnchainClient::new(self.chain_id, signer)
677        };
678
679        let options = SplitOptions {
680            condition_id: condition_id.to_string(),
681            amount,
682            is_neg_risk,
683            is_yield_bearing,
684        };
685
686        onchain_client.split_positions(options).await
687    }
688
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_parse_private_key() {
697        // Test with 0x prefix
698        let key_with_prefix =
699            "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
700        let result = PredictClient::parse_private_key(key_with_prefix);
701        assert!(result.is_ok());
702
703        // Test without 0x prefix
704        let key_without_prefix =
705            "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
706        let result = PredictClient::parse_private_key(key_without_prefix);
707        assert!(result.is_ok());
708
709        // Test invalid key
710        let invalid_key = "invalid";
711        let result = PredictClient::parse_private_key(invalid_key);
712        assert!(result.is_err());
713    }
714}