1use crate::errors::{Error, Result};
7use chrono::{DateTime, Utc};
8use reqwest::Client as HttpClient;
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11use tracing::debug;
12
13pub const DEFAULT_GRAPHQL_ENDPOINT: &str = "https://graphql.predict.fun/graphql";
15
16pub struct PredictGraphQL {
18 http_client: HttpClient,
19 endpoint: String,
20}
21
22impl PredictGraphQL {
23 pub fn new() -> Self {
25 Self {
26 http_client: HttpClient::new(),
27 endpoint: DEFAULT_GRAPHQL_ENDPOINT.to_string(),
28 }
29 }
30
31 pub fn with_endpoint(endpoint: String) -> Self {
33 Self {
34 http_client: HttpClient::new(),
35 endpoint,
36 }
37 }
38
39 pub async fn get_category(&self, slug: &str) -> Result<CategoryData> {
49 let query = GraphQLRequest {
50 query: GET_CATEGORY_QUERY.to_string(),
51 variables: GetCategoryVariables {
52 category_id: slug.to_string(),
53 },
54 operation_name: "GetCategory".to_string(),
55 };
56
57 debug!("Fetching category via GraphQL: {}", slug);
58
59 let response = self
60 .http_client
61 .post(&self.endpoint)
62 .json(&query)
63 .send()
64 .await?;
65
66 let status = response.status();
67 if !status.is_success() {
68 let error_text = response
69 .text()
70 .await
71 .unwrap_or_else(|_| "Unknown error".to_string());
72 return Err(Error::ApiError(format!(
73 "GraphQL request failed: status={}, error={}",
74 status, error_text
75 )));
76 }
77
78 let gql_response: GraphQLResponse<GetCategoryResponse> = response.json().await?;
79
80 if let Some(errors) = gql_response.errors {
81 let error_messages: Vec<String> = errors.iter().map(|e| e.message.clone()).collect();
82 return Err(Error::ApiError(format!(
83 "GraphQL errors: {}",
84 error_messages.join(", ")
85 )));
86 }
87
88 gql_response
89 .data
90 .and_then(|d| d.category)
91 .ok_or_else(|| Error::ApiError(format!("Category not found: {}", slug)))
92 }
93
94 pub async fn get_market_strike_price(
105 &self,
106 slug: &str,
107 market_id: Option<u64>,
108 ) -> Result<Option<MarketData>> {
109 let category = self.get_category(slug).await?;
110
111 if let Some(market_data) = category.market_data {
112 if let Some(id) = market_id {
113 let id_str = id.to_string();
115 return Ok(market_data.into_iter().find(|m| m.market_id == id_str));
116 } else {
117 return Ok(market_data.into_iter().next());
119 }
120 }
121
122 Ok(None)
123 }
124}
125
126impl Default for PredictGraphQL {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132const GET_CATEGORY_QUERY: &str = r#"
134query GetCategory($categoryId: ID!) {
135 category(id: $categoryId) {
136 id
137 slug
138 title
139 startsAt
140 endsAt
141 status
142 isNegRisk
143 isYieldBearing
144 ... on CryptoUpDownCategory {
145 marketData {
146 marketId
147 priceFeedId
148 startPrice
149 startPricePublishTime
150 endPrice
151 endPricePublishTime
152 }
153 }
154 }
155}
156"#;
157
158#[derive(Debug, Serialize)]
161#[serde(rename_all = "camelCase")]
162struct GraphQLRequest {
163 query: String,
164 variables: GetCategoryVariables,
165 operation_name: String,
166}
167
168#[derive(Debug, Serialize)]
169#[serde(rename_all = "camelCase")]
170struct GetCategoryVariables {
171 category_id: String,
172}
173
174#[derive(Debug, Deserialize)]
175struct GraphQLResponse<T> {
176 data: Option<T>,
177 errors: Option<Vec<GraphQLError>>,
178}
179
180#[derive(Debug, Deserialize)]
181struct GraphQLError {
182 message: String,
183}
184
185#[derive(Debug, Deserialize)]
186struct GetCategoryResponse {
187 category: Option<CategoryData>,
188}
189
190#[derive(Debug, Clone, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct CategoryData {
194 pub id: String,
196 pub slug: String,
198 pub title: String,
200 pub starts_at: DateTime<Utc>,
202 pub ends_at: DateTime<Utc>,
204 pub status: String,
206 pub is_neg_risk: bool,
208 pub is_yield_bearing: bool,
210 pub market_data: Option<Vec<MarketData>>,
212}
213
214#[derive(Debug, Clone, Deserialize)]
216#[serde(rename_all = "camelCase")]
217pub struct MarketData {
218 pub market_id: String,
220 pub price_feed_id: String,
222 pub start_price: f64,
224 pub start_price_publish_time: DateTime<Utc>,
226 pub end_price: Option<f64>,
228 pub end_price_publish_time: Option<DateTime<Utc>>,
230}
231
232impl MarketData {
233 pub fn strike_price(&self) -> Decimal {
235 Decimal::try_from(self.start_price).unwrap_or_default()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_graphql_client_creation() {
245 let client = PredictGraphQL::new();
246 assert_eq!(client.endpoint, DEFAULT_GRAPHQL_ENDPOINT);
247 }
248
249 #[test]
250 fn test_graphql_client_custom_endpoint() {
251 let client = PredictGraphQL::with_endpoint("https://custom.endpoint.com/graphql".to_string());
252 assert_eq!(client.endpoint, "https://custom.endpoint.com/graphql");
253 }
254}