predict-sdk 0.1.0

Rust SDK for Predict.fun prediction market - order building, EIP-712 signing, and real-time WebSocket data
Documentation
//! GraphQL client for Predict.fun
//!
//! This module provides access to Predict's GraphQL API for fetching
//! category data including strike prices (`startPrice`).

use crate::errors::{Error, Result};
use chrono::{DateTime, Utc};
use reqwest::Client as HttpClient;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use tracing::debug;

/// Default GraphQL endpoint
pub const DEFAULT_GRAPHQL_ENDPOINT: &str = "https://graphql.predict.fun/graphql";

/// GraphQL client for Predict.fun
pub struct PredictGraphQL {
    http_client: HttpClient,
    endpoint: String,
}

impl PredictGraphQL {
    /// Create a new GraphQL client with default endpoint
    pub fn new() -> Self {
        Self {
            http_client: HttpClient::new(),
            endpoint: DEFAULT_GRAPHQL_ENDPOINT.to_string(),
        }
    }

    /// Create a new GraphQL client with custom endpoint
    pub fn with_endpoint(endpoint: String) -> Self {
        Self {
            http_client: HttpClient::new(),
            endpoint,
        }
    }

    /// Fetch category data including market strike prices
    ///
    /// # Arguments
    ///
    /// * `slug` - The category slug (e.g., "btc-usd-up-down-2026-01-29-11-00-15-minutes")
    ///
    /// # Returns
    ///
    /// Category data including `marketData` with `startPrice` (strike price)
    pub async fn get_category(&self, slug: &str) -> Result<CategoryData> {
        let query = GraphQLRequest {
            query: GET_CATEGORY_QUERY.to_string(),
            variables: GetCategoryVariables {
                category_id: slug.to_string(),
            },
            operation_name: "GetCategory".to_string(),
        };

        debug!("Fetching category via GraphQL: {}", slug);

        let response = self
            .http_client
            .post(&self.endpoint)
            .json(&query)
            .send()
            .await?;

        let status = response.status();
        if !status.is_success() {
            let error_text = response
                .text()
                .await
                .unwrap_or_else(|_| "Unknown error".to_string());
            return Err(Error::ApiError(format!(
                "GraphQL request failed: status={}, error={}",
                status, error_text
            )));
        }

        let gql_response: GraphQLResponse<GetCategoryResponse> = response.json().await?;

        if let Some(errors) = gql_response.errors {
            let error_messages: Vec<String> = errors.iter().map(|e| e.message.clone()).collect();
            return Err(Error::ApiError(format!(
                "GraphQL errors: {}",
                error_messages.join(", ")
            )));
        }

        gql_response
            .data
            .and_then(|d| d.category)
            .ok_or_else(|| Error::ApiError(format!("Category not found: {}", slug)))
    }

    /// Fetch category and extract the strike price for a specific market
    ///
    /// # Arguments
    ///
    /// * `slug` - The category slug
    /// * `market_id` - Optional market ID to filter (if None, returns first market's data)
    ///
    /// # Returns
    ///
    /// The market data including strike price, or None if not found
    pub async fn get_market_strike_price(
        &self,
        slug: &str,
        market_id: Option<u64>,
    ) -> Result<Option<MarketData>> {
        let category = self.get_category(slug).await?;

        if let Some(market_data) = category.market_data {
            if let Some(id) = market_id {
                // Find specific market
                let id_str = id.to_string();
                return Ok(market_data.into_iter().find(|m| m.market_id == id_str));
            } else {
                // Return first market
                return Ok(market_data.into_iter().next());
            }
        }

        Ok(None)
    }
}

impl Default for PredictGraphQL {
    fn default() -> Self {
        Self::new()
    }
}

// GraphQL query for fetching category with market data
const GET_CATEGORY_QUERY: &str = r#"
query GetCategory($categoryId: ID!) {
  category(id: $categoryId) {
    id
    slug
    title
    startsAt
    endsAt
    status
    isNegRisk
    isYieldBearing
    ... on CryptoUpDownCategory {
      marketData {
        marketId
        priceFeedId
        startPrice
        startPricePublishTime
        endPrice
        endPricePublishTime
      }
    }
  }
}
"#;

// Request/Response types

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GraphQLRequest {
    query: String,
    variables: GetCategoryVariables,
    operation_name: String,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GetCategoryVariables {
    category_id: String,
}

#[derive(Debug, Deserialize)]
struct GraphQLResponse<T> {
    data: Option<T>,
    errors: Option<Vec<GraphQLError>>,
}

#[derive(Debug, Deserialize)]
struct GraphQLError {
    message: String,
}

#[derive(Debug, Deserialize)]
struct GetCategoryResponse {
    category: Option<CategoryData>,
}

/// Category data from GraphQL API
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CategoryData {
    /// Category ID (same as slug)
    pub id: String,
    /// Category slug
    pub slug: String,
    /// Category title
    pub title: String,
    /// Market start time
    pub starts_at: DateTime<Utc>,
    /// Market end time
    pub ends_at: DateTime<Utc>,
    /// Category status (OPEN, RESOLVED, etc.)
    pub status: String,
    /// Whether this is a neg-risk market
    pub is_neg_risk: bool,
    /// Whether yield bearing is enabled
    pub is_yield_bearing: bool,
    /// Market data with strike prices (only for CryptoUpDownCategory)
    pub market_data: Option<Vec<MarketData>>,
}

/// Market data including strike price from Pyth oracle
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MarketData {
    /// Market ID
    pub market_id: String,
    /// Price feed ID (e.g., "1" for BTC/USD)
    pub price_feed_id: String,
    /// Starting price / Strike price (Pyth)
    pub start_price: f64,
    /// When the start price was published
    pub start_price_publish_time: DateTime<Utc>,
    /// Ending price (populated after resolution)
    pub end_price: Option<f64>,
    /// When the end price was published
    pub end_price_publish_time: Option<DateTime<Utc>>,
}

impl MarketData {
    /// Get the strike price as a Decimal
    pub fn strike_price(&self) -> Decimal {
        Decimal::try_from(self.start_price).unwrap_or_default()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_graphql_client_creation() {
        let client = PredictGraphQL::new();
        assert_eq!(client.endpoint, DEFAULT_GRAPHQL_ENDPOINT);
    }

    #[test]
    fn test_graphql_client_custom_endpoint() {
        let client = PredictGraphQL::with_endpoint("https://custom.endpoint.com/graphql".to_string());
        assert_eq!(client.endpoint, "https://custom.endpoint.com/graphql");
    }
}