use crate::errors::{Error, Result};
use chrono::{DateTime, Utc};
use reqwest::Client as HttpClient;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use tracing::debug;
pub const DEFAULT_GRAPHQL_ENDPOINT: &str = "https://graphql.predict.fun/graphql";
pub struct PredictGraphQL {
http_client: HttpClient,
endpoint: String,
}
impl PredictGraphQL {
pub fn new() -> Self {
Self {
http_client: HttpClient::new(),
endpoint: DEFAULT_GRAPHQL_ENDPOINT.to_string(),
}
}
pub fn with_endpoint(endpoint: String) -> Self {
Self {
http_client: HttpClient::new(),
endpoint,
}
}
pub async fn get_category(&self, slug: &str) -> Result<CategoryData> {
let query = GraphQLRequest {
query: GET_CATEGORY_QUERY.to_string(),
variables: GetCategoryVariables {
category_id: slug.to_string(),
},
operation_name: "GetCategory".to_string(),
};
debug!("Fetching category via GraphQL: {}", slug);
let response = self
.http_client
.post(&self.endpoint)
.json(&query)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(Error::ApiError(format!(
"GraphQL request failed: status={}, error={}",
status, error_text
)));
}
let gql_response: GraphQLResponse<GetCategoryResponse> = response.json().await?;
if let Some(errors) = gql_response.errors {
let error_messages: Vec<String> = errors.iter().map(|e| e.message.clone()).collect();
return Err(Error::ApiError(format!(
"GraphQL errors: {}",
error_messages.join(", ")
)));
}
gql_response
.data
.and_then(|d| d.category)
.ok_or_else(|| Error::ApiError(format!("Category not found: {}", slug)))
}
pub async fn get_market_strike_price(
&self,
slug: &str,
market_id: Option<u64>,
) -> Result<Option<MarketData>> {
let category = self.get_category(slug).await?;
if let Some(market_data) = category.market_data {
if let Some(id) = market_id {
let id_str = id.to_string();
return Ok(market_data.into_iter().find(|m| m.market_id == id_str));
} else {
return Ok(market_data.into_iter().next());
}
}
Ok(None)
}
}
impl Default for PredictGraphQL {
fn default() -> Self {
Self::new()
}
}
const GET_CATEGORY_QUERY: &str = r#"
query GetCategory($categoryId: ID!) {
category(id: $categoryId) {
id
slug
title
startsAt
endsAt
status
isNegRisk
isYieldBearing
... on CryptoUpDownCategory {
marketData {
marketId
priceFeedId
startPrice
startPricePublishTime
endPrice
endPricePublishTime
}
}
}
}
"#;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GraphQLRequest {
query: String,
variables: GetCategoryVariables,
operation_name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GetCategoryVariables {
category_id: String,
}
#[derive(Debug, Deserialize)]
struct GraphQLResponse<T> {
data: Option<T>,
errors: Option<Vec<GraphQLError>>,
}
#[derive(Debug, Deserialize)]
struct GraphQLError {
message: String,
}
#[derive(Debug, Deserialize)]
struct GetCategoryResponse {
category: Option<CategoryData>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CategoryData {
pub id: String,
pub slug: String,
pub title: String,
pub starts_at: DateTime<Utc>,
pub ends_at: DateTime<Utc>,
pub status: String,
pub is_neg_risk: bool,
pub is_yield_bearing: bool,
pub market_data: Option<Vec<MarketData>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MarketData {
pub market_id: String,
pub price_feed_id: String,
pub start_price: f64,
pub start_price_publish_time: DateTime<Utc>,
pub end_price: Option<f64>,
pub end_price_publish_time: Option<DateTime<Utc>>,
}
impl MarketData {
pub fn strike_price(&self) -> Decimal {
Decimal::try_from(self.start_price).unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graphql_client_creation() {
let client = PredictGraphQL::new();
assert_eq!(client.endpoint, DEFAULT_GRAPHQL_ENDPOINT);
}
#[test]
fn test_graphql_client_custom_endpoint() {
let client = PredictGraphQL::with_endpoint("https://custom.endpoint.com/graphql".to_string());
assert_eq!(client.endpoint, "https://custom.endpoint.com/graphql");
}
}