1use std::collections::HashMap;
13use std::sync::Arc;
14
15use anyhow::{anyhow, Context, Result};
16use serde_json::{json, Value};
17use tokio::sync::RwLock;
18
19use crate::api::{PredictApiClient, RawApiResponse};
20use crate::order::{
21 predict_limit_order_amounts, PredictCreateOrderRequest, PredictOrder, PredictOrderSigner,
22 PredictOutcome, PredictSide, PredictStrategy, SignedPredictOrder, BNB_MAINNET_CHAIN_ID,
23};
24
25#[derive(Debug, Clone)]
27pub struct MarketMeta {
28 pub market_id: i64,
29 pub yes_token_id: String,
30 pub no_token_id: String,
31 pub fee_rate_bps: u32,
32 pub is_neg_risk: bool,
33 pub is_yield_bearing: bool,
34}
35
36impl MarketMeta {
37 pub fn token_id(&self, outcome: PredictOutcome) -> &str {
39 match outcome {
40 PredictOutcome::Yes => &self.yes_token_id,
41 PredictOutcome::No => &self.no_token_id,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PredictExecConfig {
48 pub api_key: String,
49 pub private_key: String,
50 pub chain_id: u64,
51 pub live_execution: bool,
52 pub fill_or_kill: bool,
53}
54
55impl PredictExecConfig {
56 pub fn from_env() -> Result<Self> {
57 let api_key = std::env::var("PREDICT_API_KEY")
58 .context("PREDICT_API_KEY is required for Predict execution")?;
59 let private_key = std::env::var("PREDICT_PRIVATE_KEY")
60 .or_else(|_| std::env::var("PREDICT_TEST_PRIVATE_KEY"))
61 .context("PREDICT_PRIVATE_KEY (or PREDICT_TEST_PRIVATE_KEY) is required")?;
62
63 let live_execution = std::env::var("PREDICT_LIVE_EXECUTION")
64 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
65 .unwrap_or(false);
66
67 let fill_or_kill = std::env::var("PREDICT_FILL_OR_KILL")
68 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
69 .unwrap_or(true);
70
71 let chain_id = std::env::var("PREDICT_CHAIN_ID")
72 .ok()
73 .and_then(|v| v.parse::<u64>().ok())
74 .unwrap_or(BNB_MAINNET_CHAIN_ID);
75
76 Ok(Self {
77 api_key,
78 private_key,
79 chain_id,
80 live_execution,
81 fill_or_kill,
82 })
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct PredictLimitOrderRequest {
88 pub market_id: i64,
89 pub outcome: PredictOutcome,
90 pub side: PredictSide,
91 pub price_per_share: f64,
92 pub quantity: f64,
93 pub strategy: PredictStrategy,
94 pub slippage_bps: Option<u32>,
95}
96
97#[derive(Debug, Clone)]
98pub struct PredictPreparedOrder {
99 pub signed_order: SignedPredictOrder,
100 pub request: PredictCreateOrderRequest,
101 pub is_neg_risk: bool,
102 pub is_yield_bearing: bool,
103}
104
105#[derive(Debug, Clone)]
106pub struct PredictSubmitResult {
107 pub prepared: PredictPreparedOrder,
108 pub submitted: bool,
109 pub response: Option<Value>,
110 pub raw: Option<RawApiResponse>,
111}
112
113#[derive(Clone)]
118pub struct PredictExecutionClient {
119 pub api: PredictApiClient,
120 pub signer: PredictOrderSigner,
121 pub config: PredictExecConfig,
122 pub market_cache: Arc<RwLock<HashMap<i64, MarketMeta>>>,
124}
125
126impl PredictExecutionClient {
127 pub async fn new(config: PredictExecConfig) -> Result<Self> {
128 let signer = PredictOrderSigner::from_private_key(&config.private_key, config.chain_id)?;
129 let api = PredictApiClient::new_mainnet(&config.api_key)?;
130 let jwt = Self::authenticate_jwt(&api, &signer).await?;
131 let api = api.with_jwt(jwt);
132
133 Ok(Self {
134 api,
135 signer,
136 config,
137 market_cache: Arc::new(RwLock::new(HashMap::new())),
138 })
139 }
140
141 pub async fn from_env() -> Result<Self> {
142 let cfg = PredictExecConfig::from_env()?;
143 Self::new(cfg).await
144 }
145
146 pub async fn refresh_jwt(&mut self) -> Result<()> {
148 let jwt = Self::authenticate_jwt(&self.api, &self.signer).await?;
149 self.api.set_jwt(jwt);
150 Ok(())
151 }
152
153 pub async fn authenticate_jwt(
154 api: &PredictApiClient,
155 signer: &PredictOrderSigner,
156 ) -> Result<String> {
157 let auth_message = api.auth_message().await.context("GET /auth/message failed")?;
158 let message = auth_message
159 .get("data")
160 .and_then(|d| d.get("message"))
161 .and_then(|m| m.as_str())
162 .ok_or_else(|| anyhow!("missing data.message in auth response"))?;
163
164 let signature = signer.sign_auth_message(message)?;
165 let auth = api
166 .auth(&signer.address().to_string(), message, &signature)
167 .await
168 .context("POST /auth failed")?;
169
170 auth.get("data")
171 .and_then(|d| d.get("token"))
172 .and_then(|t| t.as_str())
173 .map(str::to_string)
174 .ok_or_else(|| anyhow!("missing data.token in auth response"))
175 }
176
177 pub async fn market_meta(&self, market_id: i64) -> Result<MarketMeta> {
181 {
183 let cache = self.market_cache.read().await;
184 if let Some(meta) = cache.get(&market_id) {
185 return Ok(meta.clone());
186 }
187 }
188
189 let meta = self.fetch_market_meta(market_id).await?;
191 {
192 let mut cache = self.market_cache.write().await;
193 cache.insert(market_id, meta.clone());
194 }
195 Ok(meta)
196 }
197
198 pub async fn refresh_market_meta(&self, market_id: i64) -> Result<MarketMeta> {
200 let meta = self.fetch_market_meta(market_id).await?;
201 let mut cache = self.market_cache.write().await;
202 cache.insert(market_id, meta.clone());
203 Ok(meta)
204 }
205
206 pub async fn preload_markets(&self, market_ids: &[i64]) -> Result<()> {
208 let mut tasks = Vec::new();
209 for &id in market_ids {
210 let client = self.clone();
211 tasks.push(tokio::spawn(async move { client.market_meta(id).await }));
212 }
213 for task in tasks {
214 task.await.map_err(|e| anyhow!("join error: {}", e))??;
215 }
216 Ok(())
217 }
218
219 pub async fn clear_cache(&self) {
221 self.market_cache.write().await.clear();
222 }
223
224 fn fetch_market_meta(
225 &self,
226 market_id: i64,
227 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<MarketMeta>> + Send + '_>> {
228 Box::pin(async move {
229 let market = self
230 .api
231 .get_market(market_id)
232 .await
233 .with_context(|| format!("GET /markets/{} failed", market_id))?;
234
235 let data = market
236 .get("data")
237 .ok_or_else(|| anyhow!("missing data in market response"))?;
238
239 let outcomes = data
240 .get("outcomes")
241 .and_then(|o| o.as_array())
242 .ok_or_else(|| anyhow!("missing outcomes in market {}", market_id))?;
243
244 let yes_token = extract_token_id_from_outcomes(outcomes, 1)
245 .ok_or_else(|| anyhow!("missing YES token (indexSet=1) in market {}", market_id))?;
246 let no_token = extract_token_id_from_outcomes(outcomes, 2)
247 .ok_or_else(|| anyhow!("missing NO token (indexSet=2) in market {}", market_id))?;
248
249 Ok(MarketMeta {
250 market_id,
251 yes_token_id: yes_token,
252 no_token_id: no_token,
253 fee_rate_bps: data
254 .get("feeRateBps")
255 .and_then(|v| v.as_u64())
256 .unwrap_or(0) as u32,
257 is_neg_risk: data
258 .get("isNegRisk")
259 .and_then(|v| v.as_bool())
260 .unwrap_or(false),
261 is_yield_bearing: data
262 .get("isYieldBearing")
263 .and_then(|v| v.as_bool())
264 .unwrap_or(true),
265 })
266 })
267 }
268
269 pub async fn prepare_limit_order(
274 &self,
275 req: &PredictLimitOrderRequest,
276 ) -> Result<PredictPreparedOrder> {
277 let meta = self.market_meta(req.market_id).await?;
278 self.prepare_limit_order_with_meta(req, &meta)
279 }
280
281 pub fn prepare_limit_order_with_meta(
283 &self,
284 req: &PredictLimitOrderRequest,
285 meta: &MarketMeta,
286 ) -> Result<PredictPreparedOrder> {
287 let token_id = meta.token_id(req.outcome);
288 let price_wei = wei_from_decimal(req.price_per_share)?;
289 let quantity_wei = wei_from_decimal(req.quantity)?;
290
291 let (maker_amount, taker_amount) =
292 predict_limit_order_amounts(req.side, price_wei, quantity_wei);
293
294 let maker = self.signer.address();
295 let order = PredictOrder::new_limit(
296 maker,
297 maker,
298 token_id,
299 req.side,
300 maker_amount,
301 taker_amount,
302 meta.fee_rate_bps,
303 );
304
305 let signed_order = self
306 .signer
307 .sign_order(&order, meta.is_neg_risk, meta.is_yield_bearing)
308 .context("failed to sign predict order")?;
309
310 let create_request = signed_order.to_create_order_request(
311 price_wei,
312 req.strategy,
313 req.slippage_bps,
314 Some(self.config.fill_or_kill),
315 );
316
317 Ok(PredictPreparedOrder {
318 signed_order,
319 request: create_request,
320 is_neg_risk: meta.is_neg_risk,
321 is_yield_bearing: meta.is_yield_bearing,
322 })
323 }
324
325 pub async fn submit_prepared_order(
330 &self,
331 prepared: PredictPreparedOrder,
332 ) -> Result<PredictSubmitResult> {
333 if !self.config.live_execution {
334 return Ok(PredictSubmitResult {
335 prepared,
336 submitted: false,
337 response: None,
338 raw: None,
339 });
340 }
341
342 let body = serde_json::to_value(&prepared.request)
343 .context("failed to serialize create-order request")?;
344
345 let raw = self
346 .api
347 .raw_post("/orders", &[], body, true)
348 .await
349 .context("POST /orders failed")?;
350
351 let response = raw.json.clone();
352
353 Ok(PredictSubmitResult {
354 prepared,
355 submitted: true,
356 response,
357 raw: Some(raw),
358 })
359 }
360
361 pub async fn place_limit_order(
363 &self,
364 req: &PredictLimitOrderRequest,
365 ) -> Result<PredictSubmitResult> {
366 let prepared = self.prepare_limit_order(req).await?;
367 self.submit_prepared_order(prepared).await
368 }
369
370 pub async fn remove_order_ids(&self, ids: &[String]) -> Result<RawApiResponse> {
372 if !self.config.live_execution {
373 return Ok(RawApiResponse {
374 status: reqwest::StatusCode::OK,
375 json: Some(json!({"success": true, "dryRun": true})),
376 });
377 }
378
379 let body = json!({ "data": { "ids": ids } });
380 self.api
381 .raw_post("/orders/remove", &[], body, true)
382 .await
383 .context("POST /orders/remove failed")
384 }
385}
386
387fn extract_token_id_from_outcomes(outcomes: &[Value], index_set: u64) -> Option<String> {
388 outcomes
389 .iter()
390 .find(|o| o.get("indexSet").and_then(|v| v.as_u64()) == Some(index_set))
391 .and_then(|o| o.get("onChainId"))
392 .and_then(|v| v.as_str())
393 .map(str::to_string)
394}
395
396fn wei_from_decimal(value: f64) -> Result<alloy_primitives::U256> {
397 if !value.is_finite() || value <= 0.0 {
398 return Err(anyhow!("invalid decimal value {}, expected > 0", value));
399 }
400
401 let scaled = (value * 1e18_f64).round();
402 if scaled <= 0.0 {
403 return Err(anyhow!("value too small after scaling: {}", value));
404 }
405
406 Ok(alloy_primitives::U256::from(scaled as u128))
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_wei_from_decimal() {
415 let v = wei_from_decimal(0.1).unwrap();
416 assert_eq!(v.to_string(), "100000000000000000");
417
418 let v = wei_from_decimal(1.0).unwrap();
419 assert_eq!(v.to_string(), "1000000000000000000");
420
421 assert!(wei_from_decimal(0.0).is_err());
422 assert!(wei_from_decimal(-1.0).is_err());
423 }
424
425 #[test]
426 fn test_extract_token_id() {
427 let outcomes = vec![
428 serde_json::json!({"indexSet": 1, "onChainId": "yes_token"}),
429 serde_json::json!({"indexSet": 2, "onChainId": "no_token"}),
430 ];
431
432 assert_eq!(
433 extract_token_id_from_outcomes(&outcomes, 1).unwrap(),
434 "yes_token"
435 );
436 assert_eq!(
437 extract_token_id_from_outcomes(&outcomes, 2).unwrap(),
438 "no_token"
439 );
440 assert!(extract_token_id_from_outcomes(&outcomes, 3).is_none());
441 }
442
443 #[test]
444 fn market_meta_token_lookup() {
445 let meta = MarketMeta {
446 market_id: 123,
447 yes_token_id: "yes_abc".into(),
448 no_token_id: "no_xyz".into(),
449 fee_rate_bps: 200,
450 is_neg_risk: false,
451 is_yield_bearing: true,
452 };
453 assert_eq!(meta.token_id(PredictOutcome::Yes), "yes_abc");
454 assert_eq!(meta.token_id(PredictOutcome::No), "no_xyz");
455 }
456}