Skip to main content

predict_fun_sdk/
execution.rs

1//! Predict.fun execution pipeline: auth → sign → submit with dry-run safety guard.
2//!
3//! [`PredictExecutionClient`] handles the full lifecycle: JWT authentication,
4//! market lookup, order preparation, EIP-712 signing, and submission.
5
6use anyhow::{anyhow, Context, Result};
7use serde_json::{json, Value};
8
9use crate::api::{PredictApiClient, RawApiResponse};
10use crate::order::{
11    predict_limit_order_amounts, PredictCreateOrderRequest, PredictOrder, PredictOrderSigner,
12    PredictOutcome, PredictSide, PredictStrategy, SignedPredictOrder, BNB_MAINNET_CHAIN_ID,
13};
14
15#[derive(Debug, Clone)]
16pub struct PredictExecConfig {
17    pub api_key: String,
18    pub private_key: String,
19    pub chain_id: u64,
20    pub live_execution: bool,
21    pub fill_or_kill: bool,
22}
23
24impl PredictExecConfig {
25    pub fn from_env() -> Result<Self> {
26        let api_key = std::env::var("PREDICT_API_KEY")
27            .context("PREDICT_API_KEY is required for Predict execution")?;
28        let private_key = std::env::var("PREDICT_PRIVATE_KEY")
29            .or_else(|_| std::env::var("PREDICT_TEST_PRIVATE_KEY"))
30            .context("PREDICT_PRIVATE_KEY (or PREDICT_TEST_PRIVATE_KEY) is required")?;
31
32        let live_execution = std::env::var("PREDICT_LIVE_EXECUTION")
33            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
34            .unwrap_or(false);
35
36        let fill_or_kill = std::env::var("PREDICT_FILL_OR_KILL")
37            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
38            .unwrap_or(true);
39
40        let chain_id = std::env::var("PREDICT_CHAIN_ID")
41            .ok()
42            .and_then(|v| v.parse::<u64>().ok())
43            .unwrap_or(BNB_MAINNET_CHAIN_ID);
44
45        Ok(Self {
46            api_key,
47            private_key,
48            chain_id,
49            live_execution,
50            fill_or_kill,
51        })
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct PredictLimitOrderRequest {
57    pub market_id: i64,
58    pub outcome: PredictOutcome,
59    pub side: PredictSide,
60    pub price_per_share: f64,
61    pub quantity: f64,
62    pub strategy: PredictStrategy,
63    pub slippage_bps: Option<u32>,
64}
65
66#[derive(Debug, Clone)]
67pub struct PredictPreparedOrder {
68    pub signed_order: SignedPredictOrder,
69    pub request: PredictCreateOrderRequest,
70    pub is_neg_risk: bool,
71    pub is_yield_bearing: bool,
72}
73
74#[derive(Debug, Clone)]
75pub struct PredictSubmitResult {
76    pub prepared: PredictPreparedOrder,
77    pub submitted: bool,
78    pub response: Option<Value>,
79    pub raw: Option<RawApiResponse>,
80}
81
82#[derive(Clone)]
83pub struct PredictExecutionClient {
84    pub api: PredictApiClient,
85    pub signer: PredictOrderSigner,
86    pub config: PredictExecConfig,
87}
88
89impl PredictExecutionClient {
90    pub async fn new(config: PredictExecConfig) -> Result<Self> {
91        let signer = PredictOrderSigner::from_private_key(&config.private_key, config.chain_id)?;
92        let api = PredictApiClient::new_mainnet(&config.api_key)?;
93        let jwt = Self::authenticate_jwt(&api, &signer).await?;
94        let api = api.with_jwt(jwt);
95
96        Ok(Self {
97            api,
98            signer,
99            config,
100        })
101    }
102
103    pub async fn from_env() -> Result<Self> {
104        let cfg = PredictExecConfig::from_env()?;
105        Self::new(cfg).await
106    }
107
108    pub async fn authenticate_jwt(api: &PredictApiClient, signer: &PredictOrderSigner) -> Result<String> {
109        let auth_message = api.auth_message().await.context("GET /auth/message failed")?;
110        let message = auth_message
111            .get("data")
112            .and_then(|d| d.get("message"))
113            .and_then(|m| m.as_str())
114            .ok_or_else(|| anyhow!("missing data.message in auth response"))?;
115
116        let signature = signer.sign_auth_message(message)?;
117        let auth = api
118            .auth(&signer.address().to_string(), message, &signature)
119            .await
120            .context("POST /auth failed")?;
121
122        auth.get("data")
123            .and_then(|d| d.get("token"))
124            .and_then(|t| t.as_str())
125            .map(str::to_string)
126            .ok_or_else(|| anyhow!("missing data.token in auth response"))
127    }
128
129    pub async fn prepare_limit_order(&self, req: &PredictLimitOrderRequest) -> Result<PredictPreparedOrder> {
130        let market = self
131            .api
132            .get_market(req.market_id)
133            .await
134            .with_context(|| format!("GET /markets/{} failed", req.market_id))?;
135
136        let market_data = market
137            .get("data")
138            .ok_or_else(|| anyhow!("missing data in market response"))?;
139
140        let token_id = extract_token_id(market_data, req.outcome)?;
141        let fee_rate_bps = market_data
142            .get("feeRateBps")
143            .and_then(|v| v.as_u64())
144            .unwrap_or(0) as u32;
145        let is_neg_risk = market_data
146            .get("isNegRisk")
147            .and_then(|v| v.as_bool())
148            .unwrap_or(false);
149        let is_yield_bearing = market_data
150            .get("isYieldBearing")
151            .and_then(|v| v.as_bool())
152            .unwrap_or(true);
153
154        let price_per_share_wei = wei_from_decimal(req.price_per_share)?;
155        let quantity_wei = wei_from_decimal(req.quantity)?;
156
157        let (maker_amount, taker_amount) =
158            predict_limit_order_amounts(req.side, price_per_share_wei, quantity_wei);
159
160        let maker = self.signer.address();
161        let order = PredictOrder::new_limit(
162            maker,
163            maker,
164            token_id,
165            req.side,
166            maker_amount,
167            taker_amount,
168            fee_rate_bps,
169        );
170
171        let signed_order = self
172            .signer
173            .sign_order(&order, is_neg_risk, is_yield_bearing)
174            .context("failed to sign predict order")?;
175
176        let create_request = signed_order.to_create_order_request(
177            price_per_share_wei,
178            req.strategy,
179            req.slippage_bps,
180            Some(self.config.fill_or_kill),
181        );
182
183        Ok(PredictPreparedOrder {
184            signed_order,
185            request: create_request,
186            is_neg_risk,
187            is_yield_bearing,
188        })
189    }
190
191    /// Submit a signed order request.
192    ///
193    /// - If `live_execution=false`, this is a dry-run and does not POST.
194    /// - If `live_execution=true`, this sends POST /orders and returns response.
195    pub async fn submit_prepared_order(&self, prepared: PredictPreparedOrder) -> Result<PredictSubmitResult> {
196        if !self.config.live_execution {
197            return Ok(PredictSubmitResult {
198                prepared,
199                submitted: false,
200                response: None,
201                raw: None,
202            });
203        }
204
205        let body = serde_json::to_value(&prepared.request)
206            .context("failed to serialize create-order request")?;
207
208        let raw = self
209            .api
210            .raw_post("/orders", &[], body, true)
211            .await
212            .context("POST /orders failed")?;
213
214        let response = raw.json.clone();
215
216        Ok(PredictSubmitResult {
217            prepared,
218            submitted: true,
219            response,
220            raw: Some(raw),
221        })
222    }
223
224    pub async fn place_limit_order(
225        &self,
226        req: &PredictLimitOrderRequest,
227    ) -> Result<PredictSubmitResult> {
228        let prepared = self.prepare_limit_order(req).await?;
229        self.submit_prepared_order(prepared).await
230    }
231
232    /// Remove orders from orderbook via POST /orders/remove.
233    /// Input must be order IDs (not hashes).
234    pub async fn remove_order_ids(&self, ids: &[String]) -> Result<RawApiResponse> {
235        if !self.config.live_execution {
236            return Ok(RawApiResponse {
237                status: reqwest::StatusCode::OK,
238                body: "{\"success\":true,\"dryRun\":true}".to_string(),
239                json: Some(json!({"success": true, "dryRun": true})),
240            });
241        }
242
243        let body = json!({
244            "data": {
245                "ids": ids,
246            }
247        });
248
249        self.api
250            .raw_post("/orders/remove", &[], body, true)
251            .await
252            .context("POST /orders/remove failed")
253    }
254}
255
256fn extract_token_id(market_data: &Value, outcome: PredictOutcome) -> Result<String> {
257    let target_index = match outcome {
258        PredictOutcome::Yes => 1_u64,
259        PredictOutcome::No => 2_u64,
260    };
261
262    let token = market_data
263        .get("outcomes")
264        .and_then(|o| o.as_array())
265        .and_then(|arr| {
266            arr.iter().find(|item| {
267                item.get("indexSet")
268                    .and_then(|v| v.as_u64())
269                    .map(|v| v == target_index)
270                    .unwrap_or(false)
271            })
272        })
273        .and_then(|o| o.get("onChainId"))
274        .and_then(|v| v.as_str())
275        .ok_or_else(|| anyhow!("missing outcome token id for indexSet={}", target_index))?;
276
277    Ok(token.to_string())
278}
279
280fn wei_from_decimal(value: f64) -> Result<alloy_primitives::U256> {
281    if !value.is_finite() || value <= 0.0 {
282        return Err(anyhow!("invalid decimal value {}, expected > 0", value));
283    }
284
285    let scaled = (value * 1e18_f64).round();
286    if scaled <= 0.0 {
287        return Err(anyhow!("value too small after scaling: {}", value));
288    }
289
290    let scaled_u128: u128 = scaled as u128;
291    Ok(alloy_primitives::U256::from(scaled_u128))
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_wei_from_decimal() {
300        let v = wei_from_decimal(0.1).unwrap();
301        assert_eq!(v.to_string(), "100000000000000000");
302
303        let v = wei_from_decimal(1.0).unwrap();
304        assert_eq!(v.to_string(), "1000000000000000000");
305
306        assert!(wei_from_decimal(0.0).is_err());
307        assert!(wei_from_decimal(-1.0).is_err());
308    }
309
310    #[test]
311    fn test_extract_token_id_by_indexset() {
312        let m = json!({
313            "outcomes": [
314                {"indexSet": 1, "onChainId": "yes_token"},
315                {"indexSet": 2, "onChainId": "no_token"}
316            ]
317        });
318
319        assert_eq!(extract_token_id(&m, PredictOutcome::Yes).unwrap(), "yes_token");
320        assert_eq!(extract_token_id(&m, PredictOutcome::No).unwrap(), "no_token");
321    }
322}