1use std::collections::HashMap;
13use std::sync::Arc;
14
15use anyhow::{anyhow, Context, Result};
16use serde_json::{json, Value};
17use tokio::sync::RwLock;
18
19use crate::api::{PredictApiClient, RawApiResponse};
20use crate::order::{
21 predict_limit_order_amounts, PredictCreateOrderRequest, PredictOrder, PredictOrderSigner,
22 PredictOutcome, PredictSide, PredictStrategy, SignedPredictOrder, BNB_MAINNET_CHAIN_ID,
23};
24
25#[derive(Debug, Clone)]
27pub struct MarketMeta {
28 pub market_id: i64,
29 pub yes_token_id: String,
30 pub no_token_id: String,
31 pub fee_rate_bps: u32,
32 pub is_neg_risk: bool,
33 pub is_yield_bearing: bool,
34}
35
36impl MarketMeta {
37 pub fn token_id(&self, outcome: PredictOutcome) -> &str {
39 match outcome {
40 PredictOutcome::Yes => &self.yes_token_id,
41 PredictOutcome::No => &self.no_token_id,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PredictExecConfig {
48 pub api_key: String,
49 pub private_key: String,
50 pub chain_id: u64,
51 pub live_execution: bool,
52 pub fill_or_kill: bool,
53}
54
55impl PredictExecConfig {
56 pub fn from_env() -> Result<Self> {
57 let api_key = std::env::var("PREDICT_API_KEY")
58 .context("PREDICT_API_KEY is required for Predict execution")?;
59 let private_key = std::env::var("PREDICT_PRIVATE_KEY")
60 .or_else(|_| std::env::var("PREDICT_TEST_PRIVATE_KEY"))
61 .context("PREDICT_PRIVATE_KEY (or PREDICT_TEST_PRIVATE_KEY) is required")?;
62
63 let live_execution = std::env::var("PREDICT_LIVE_EXECUTION")
64 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
65 .unwrap_or(false);
66
67 let fill_or_kill = std::env::var("PREDICT_FILL_OR_KILL")
68 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
69 .unwrap_or(true);
70
71 let chain_id = std::env::var("PREDICT_CHAIN_ID")
72 .ok()
73 .and_then(|v| v.parse::<u64>().ok())
74 .unwrap_or(BNB_MAINNET_CHAIN_ID);
75
76 Ok(Self {
77 api_key,
78 private_key,
79 chain_id,
80 live_execution,
81 fill_or_kill,
82 })
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct PredictLimitOrderRequest {
88 pub market_id: i64,
89 pub outcome: PredictOutcome,
90 pub side: PredictSide,
91 pub price_per_share: f64,
92 pub quantity: f64,
93 pub strategy: PredictStrategy,
94 pub slippage_bps: Option<u32>,
95}
96
97#[derive(Debug, Clone)]
98pub struct PredictPreparedOrder {
99 pub signed_order: SignedPredictOrder,
100 pub request: PredictCreateOrderRequest,
101 pub is_neg_risk: bool,
102 pub is_yield_bearing: bool,
103}
104
105#[derive(Debug, Clone)]
106pub struct PredictSubmitResult {
107 pub prepared: PredictPreparedOrder,
108 pub submitted: bool,
109 pub response: Option<Value>,
110 pub raw: Option<RawApiResponse>,
111}
112
113#[derive(Clone)]
118pub struct PredictExecutionClient {
119 pub api: PredictApiClient,
120 pub signer: PredictOrderSigner,
121 pub config: PredictExecConfig,
122 market_cache: Arc<RwLock<HashMap<i64, MarketMeta>>>,
123}
124
125impl PredictExecutionClient {
126 pub async fn new(config: PredictExecConfig) -> Result<Self> {
127 let signer = PredictOrderSigner::from_private_key(&config.private_key, config.chain_id)?;
128 let api = PredictApiClient::new_mainnet(&config.api_key)?;
129 let jwt = Self::authenticate_jwt(&api, &signer).await?;
130 let api = api.with_jwt(jwt);
131
132 Ok(Self {
133 api,
134 signer,
135 config,
136 market_cache: Arc::new(RwLock::new(HashMap::new())),
137 })
138 }
139
140 pub async fn from_env() -> Result<Self> {
141 let cfg = PredictExecConfig::from_env()?;
142 Self::new(cfg).await
143 }
144
145 pub async fn refresh_jwt(&mut self) -> Result<()> {
147 let jwt = Self::authenticate_jwt(&self.api, &self.signer).await?;
148 self.api.set_jwt(jwt);
149 Ok(())
150 }
151
152 pub async fn authenticate_jwt(
153 api: &PredictApiClient,
154 signer: &PredictOrderSigner,
155 ) -> Result<String> {
156 let auth_message = api.auth_message().await.context("GET /auth/message failed")?;
157 let message = auth_message
158 .get("data")
159 .and_then(|d| d.get("message"))
160 .and_then(|m| m.as_str())
161 .ok_or_else(|| anyhow!("missing data.message in auth response"))?;
162
163 let signature = signer.sign_auth_message(message)?;
164 let auth = api
165 .auth(&signer.address().to_string(), message, &signature)
166 .await
167 .context("POST /auth failed")?;
168
169 auth.get("data")
170 .and_then(|d| d.get("token"))
171 .and_then(|t| t.as_str())
172 .map(str::to_string)
173 .ok_or_else(|| anyhow!("missing data.token in auth response"))
174 }
175
176 pub async fn market_meta(&self, market_id: i64) -> Result<MarketMeta> {
180 {
182 let cache = self.market_cache.read().await;
183 if let Some(meta) = cache.get(&market_id) {
184 return Ok(meta.clone());
185 }
186 }
187
188 let meta = self.fetch_market_meta(market_id).await?;
190 {
191 let mut cache = self.market_cache.write().await;
192 cache.insert(market_id, meta.clone());
193 }
194 Ok(meta)
195 }
196
197 pub async fn refresh_market_meta(&self, market_id: i64) -> Result<MarketMeta> {
199 let meta = self.fetch_market_meta(market_id).await?;
200 let mut cache = self.market_cache.write().await;
201 cache.insert(market_id, meta.clone());
202 Ok(meta)
203 }
204
205 pub async fn preload_markets(&self, market_ids: &[i64]) -> Result<()> {
207 let mut tasks = Vec::new();
208 for &id in market_ids {
209 let client = self.clone();
210 tasks.push(tokio::spawn(async move { client.market_meta(id).await }));
211 }
212 for task in tasks {
213 task.await.map_err(|e| anyhow!("join error: {}", e))??;
214 }
215 Ok(())
216 }
217
218 pub async fn clear_cache(&self) {
220 self.market_cache.write().await.clear();
221 }
222
223 fn fetch_market_meta(
224 &self,
225 market_id: i64,
226 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<MarketMeta>> + Send + '_>> {
227 Box::pin(async move {
228 let market = self
229 .api
230 .get_market(market_id)
231 .await
232 .with_context(|| format!("GET /markets/{} failed", market_id))?;
233
234 let data = market
235 .get("data")
236 .ok_or_else(|| anyhow!("missing data in market response"))?;
237
238 let outcomes = data
239 .get("outcomes")
240 .and_then(|o| o.as_array())
241 .ok_or_else(|| anyhow!("missing outcomes in market {}", market_id))?;
242
243 let yes_token = extract_token_id_from_outcomes(outcomes, 1)
244 .ok_or_else(|| anyhow!("missing YES token (indexSet=1) in market {}", market_id))?;
245 let no_token = extract_token_id_from_outcomes(outcomes, 2)
246 .ok_or_else(|| anyhow!("missing NO token (indexSet=2) in market {}", market_id))?;
247
248 Ok(MarketMeta {
249 market_id,
250 yes_token_id: yes_token,
251 no_token_id: no_token,
252 fee_rate_bps: data
253 .get("feeRateBps")
254 .and_then(|v| v.as_u64())
255 .unwrap_or(0) as u32,
256 is_neg_risk: data
257 .get("isNegRisk")
258 .and_then(|v| v.as_bool())
259 .unwrap_or(false),
260 is_yield_bearing: data
261 .get("isYieldBearing")
262 .and_then(|v| v.as_bool())
263 .unwrap_or(true),
264 })
265 })
266 }
267
268 pub async fn prepare_limit_order(
273 &self,
274 req: &PredictLimitOrderRequest,
275 ) -> Result<PredictPreparedOrder> {
276 let meta = self.market_meta(req.market_id).await?;
277 self.prepare_limit_order_with_meta(req, &meta)
278 }
279
280 pub fn prepare_limit_order_with_meta(
282 &self,
283 req: &PredictLimitOrderRequest,
284 meta: &MarketMeta,
285 ) -> Result<PredictPreparedOrder> {
286 let token_id = meta.token_id(req.outcome);
287 let price_wei = wei_from_decimal(req.price_per_share)?;
288 let quantity_wei = wei_from_decimal(req.quantity)?;
289
290 let (maker_amount, taker_amount) =
291 predict_limit_order_amounts(req.side, price_wei, quantity_wei);
292
293 let maker = self.signer.address();
294 let order = PredictOrder::new_limit(
295 maker,
296 maker,
297 token_id,
298 req.side,
299 maker_amount,
300 taker_amount,
301 meta.fee_rate_bps,
302 );
303
304 let signed_order = self
305 .signer
306 .sign_order(&order, meta.is_neg_risk, meta.is_yield_bearing)
307 .context("failed to sign predict order")?;
308
309 let create_request = signed_order.to_create_order_request(
310 price_wei,
311 req.strategy,
312 req.slippage_bps,
313 Some(self.config.fill_or_kill),
314 );
315
316 Ok(PredictPreparedOrder {
317 signed_order,
318 request: create_request,
319 is_neg_risk: meta.is_neg_risk,
320 is_yield_bearing: meta.is_yield_bearing,
321 })
322 }
323
324 pub async fn submit_prepared_order(
329 &self,
330 prepared: PredictPreparedOrder,
331 ) -> Result<PredictSubmitResult> {
332 if !self.config.live_execution {
333 return Ok(PredictSubmitResult {
334 prepared,
335 submitted: false,
336 response: None,
337 raw: None,
338 });
339 }
340
341 let body = serde_json::to_value(&prepared.request)
342 .context("failed to serialize create-order request")?;
343
344 let raw = self
345 .api
346 .raw_post("/orders", &[], body, true)
347 .await
348 .context("POST /orders failed")?;
349
350 let response = raw.json.clone();
351
352 Ok(PredictSubmitResult {
353 prepared,
354 submitted: true,
355 response,
356 raw: Some(raw),
357 })
358 }
359
360 pub async fn place_limit_order(
362 &self,
363 req: &PredictLimitOrderRequest,
364 ) -> Result<PredictSubmitResult> {
365 let prepared = self.prepare_limit_order(req).await?;
366 self.submit_prepared_order(prepared).await
367 }
368
369 pub async fn remove_order_ids(&self, ids: &[String]) -> Result<RawApiResponse> {
371 if !self.config.live_execution {
372 return Ok(RawApiResponse {
373 status: reqwest::StatusCode::OK,
374 json: Some(json!({"success": true, "dryRun": true})),
375 });
376 }
377
378 let body = json!({ "data": { "ids": ids } });
379 self.api
380 .raw_post("/orders/remove", &[], body, true)
381 .await
382 .context("POST /orders/remove failed")
383 }
384}
385
386fn extract_token_id_from_outcomes(outcomes: &[Value], index_set: u64) -> Option<String> {
387 outcomes
388 .iter()
389 .find(|o| o.get("indexSet").and_then(|v| v.as_u64()) == Some(index_set))
390 .and_then(|o| o.get("onChainId"))
391 .and_then(|v| v.as_str())
392 .map(str::to_string)
393}
394
395fn wei_from_decimal(value: f64) -> Result<alloy_primitives::U256> {
396 if !value.is_finite() || value <= 0.0 {
397 return Err(anyhow!("invalid decimal value {}, expected > 0", value));
398 }
399
400 let scaled = (value * 1e18_f64).round();
401 if scaled <= 0.0 {
402 return Err(anyhow!("value too small after scaling: {}", value));
403 }
404
405 Ok(alloy_primitives::U256::from(scaled as u128))
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_wei_from_decimal() {
414 let v = wei_from_decimal(0.1).unwrap();
415 assert_eq!(v.to_string(), "100000000000000000");
416
417 let v = wei_from_decimal(1.0).unwrap();
418 assert_eq!(v.to_string(), "1000000000000000000");
419
420 assert!(wei_from_decimal(0.0).is_err());
421 assert!(wei_from_decimal(-1.0).is_err());
422 }
423
424 #[test]
425 fn test_extract_token_id() {
426 let outcomes = vec![
427 serde_json::json!({"indexSet": 1, "onChainId": "yes_token"}),
428 serde_json::json!({"indexSet": 2, "onChainId": "no_token"}),
429 ];
430
431 assert_eq!(
432 extract_token_id_from_outcomes(&outcomes, 1).unwrap(),
433 "yes_token"
434 );
435 assert_eq!(
436 extract_token_id_from_outcomes(&outcomes, 2).unwrap(),
437 "no_token"
438 );
439 assert!(extract_token_id_from_outcomes(&outcomes, 3).is_none());
440 }
441
442 #[test]
443 fn market_meta_token_lookup() {
444 let meta = MarketMeta {
445 market_id: 123,
446 yes_token_id: "yes_abc".into(),
447 no_token_id: "no_xyz".into(),
448 fee_rate_bps: 200,
449 is_neg_risk: false,
450 is_yield_bearing: true,
451 };
452 assert_eq!(meta.token_id(PredictOutcome::Yes), "yes_abc");
453 assert_eq!(meta.token_id(PredictOutcome::No), "no_xyz");
454 }
455}