Skip to main content

predict_sdk/
graphql.rs

1//! GraphQL client for Predict.fun
2//!
3//! This module provides access to Predict's GraphQL API for fetching
4//! category data including strike prices (`startPrice`).
5
6use crate::errors::{Error, Result};
7use chrono::{DateTime, Utc};
8use reqwest::Client as HttpClient;
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11use tracing::debug;
12
13/// Default GraphQL endpoint
14pub const DEFAULT_GRAPHQL_ENDPOINT: &str = "https://graphql.predict.fun/graphql";
15
16/// GraphQL client for Predict.fun
17pub struct PredictGraphQL {
18    http_client: HttpClient,
19    endpoint: String,
20}
21
22impl PredictGraphQL {
23    /// Create a new GraphQL client with default endpoint
24    pub fn new() -> Self {
25        Self {
26            http_client: HttpClient::new(),
27            endpoint: DEFAULT_GRAPHQL_ENDPOINT.to_string(),
28        }
29    }
30
31    /// Create a new GraphQL client with custom endpoint
32    pub fn with_endpoint(endpoint: String) -> Self {
33        Self {
34            http_client: HttpClient::new(),
35            endpoint,
36        }
37    }
38
39    /// Fetch category data including market strike prices
40    ///
41    /// # Arguments
42    ///
43    /// * `slug` - The category slug (e.g., "btc-usd-up-down-2026-01-29-11-00-15-minutes")
44    ///
45    /// # Returns
46    ///
47    /// Category data including `marketData` with `startPrice` (strike price)
48    pub async fn get_category(&self, slug: &str) -> Result<CategoryData> {
49        let query = GraphQLRequest {
50            query: GET_CATEGORY_QUERY.to_string(),
51            variables: GetCategoryVariables {
52                category_id: slug.to_string(),
53            },
54            operation_name: "GetCategory".to_string(),
55        };
56
57        debug!("Fetching category via GraphQL: {}", slug);
58
59        let response = self
60            .http_client
61            .post(&self.endpoint)
62            .json(&query)
63            .send()
64            .await?;
65
66        let status = response.status();
67        if !status.is_success() {
68            let error_text = response
69                .text()
70                .await
71                .unwrap_or_else(|_| "Unknown error".to_string());
72            return Err(Error::ApiError(format!(
73                "GraphQL request failed: status={}, error={}",
74                status, error_text
75            )));
76        }
77
78        let gql_response: GraphQLResponse<GetCategoryResponse> = response.json().await?;
79
80        if let Some(errors) = gql_response.errors {
81            let error_messages: Vec<String> = errors.iter().map(|e| e.message.clone()).collect();
82            return Err(Error::ApiError(format!(
83                "GraphQL errors: {}",
84                error_messages.join(", ")
85            )));
86        }
87
88        gql_response
89            .data
90            .and_then(|d| d.category)
91            .ok_or_else(|| Error::ApiError(format!("Category not found: {}", slug)))
92    }
93
94    /// Fetch category and extract the strike price for a specific market
95    ///
96    /// # Arguments
97    ///
98    /// * `slug` - The category slug
99    /// * `market_id` - Optional market ID to filter (if None, returns first market's data)
100    ///
101    /// # Returns
102    ///
103    /// The market data including strike price, or None if not found
104    pub async fn get_market_strike_price(
105        &self,
106        slug: &str,
107        market_id: Option<u64>,
108    ) -> Result<Option<MarketData>> {
109        let category = self.get_category(slug).await?;
110
111        if let Some(market_data) = category.market_data {
112            if let Some(id) = market_id {
113                // Find specific market
114                let id_str = id.to_string();
115                return Ok(market_data.into_iter().find(|m| m.market_id == id_str));
116            } else {
117                // Return first market
118                return Ok(market_data.into_iter().next());
119            }
120        }
121
122        Ok(None)
123    }
124}
125
126impl Default for PredictGraphQL {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132// GraphQL query for fetching category with market data
133const GET_CATEGORY_QUERY: &str = r#"
134query GetCategory($categoryId: ID!) {
135  category(id: $categoryId) {
136    id
137    slug
138    title
139    startsAt
140    endsAt
141    status
142    isNegRisk
143    isYieldBearing
144    ... on CryptoUpDownCategory {
145      marketData {
146        marketId
147        priceFeedId
148        startPrice
149        startPricePublishTime
150        endPrice
151        endPricePublishTime
152      }
153    }
154  }
155}
156"#;
157
158// Request/Response types
159
160#[derive(Debug, Serialize)]
161#[serde(rename_all = "camelCase")]
162struct GraphQLRequest {
163    query: String,
164    variables: GetCategoryVariables,
165    operation_name: String,
166}
167
168#[derive(Debug, Serialize)]
169#[serde(rename_all = "camelCase")]
170struct GetCategoryVariables {
171    category_id: String,
172}
173
174#[derive(Debug, Deserialize)]
175struct GraphQLResponse<T> {
176    data: Option<T>,
177    errors: Option<Vec<GraphQLError>>,
178}
179
180#[derive(Debug, Deserialize)]
181struct GraphQLError {
182    message: String,
183}
184
185#[derive(Debug, Deserialize)]
186struct GetCategoryResponse {
187    category: Option<CategoryData>,
188}
189
190/// Category data from GraphQL API
191#[derive(Debug, Clone, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct CategoryData {
194    /// Category ID (same as slug)
195    pub id: String,
196    /// Category slug
197    pub slug: String,
198    /// Category title
199    pub title: String,
200    /// Market start time
201    pub starts_at: DateTime<Utc>,
202    /// Market end time
203    pub ends_at: DateTime<Utc>,
204    /// Category status (OPEN, RESOLVED, etc.)
205    pub status: String,
206    /// Whether this is a neg-risk market
207    pub is_neg_risk: bool,
208    /// Whether yield bearing is enabled
209    pub is_yield_bearing: bool,
210    /// Market data with strike prices (only for CryptoUpDownCategory)
211    pub market_data: Option<Vec<MarketData>>,
212}
213
214/// Market data including strike price from Pyth oracle
215#[derive(Debug, Clone, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct MarketData {
218    /// Market ID
219    pub market_id: String,
220    /// Price feed ID (e.g., "1" for BTC/USD)
221    pub price_feed_id: String,
222    /// Starting price / Strike price (Pyth)
223    pub start_price: f64,
224    /// When the start price was published
225    pub start_price_publish_time: DateTime<Utc>,
226    /// Ending price (populated after resolution)
227    pub end_price: Option<f64>,
228    /// When the end price was published
229    pub end_price_publish_time: Option<DateTime<Utc>>,
230}
231
232impl MarketData {
233    /// Get the strike price as a Decimal
234    pub fn strike_price(&self) -> Decimal {
235        Decimal::try_from(self.start_price).unwrap_or_default()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_graphql_client_creation() {
245        let client = PredictGraphQL::new();
246        assert_eq!(client.endpoint, DEFAULT_GRAPHQL_ENDPOINT);
247    }
248
249    #[test]
250    fn test_graphql_client_custom_endpoint() {
251        let client = PredictGraphQL::with_endpoint("https://custom.endpoint.com/graphql".to_string());
252        assert_eq!(client.endpoint, "https://custom.endpoint.com/graphql");
253    }
254}