1use crate::token::RefreshedToken;
2use crate::{Error, Token};
3use reqwest::{header, Response};
4use serde::Serialize;
5use serde_json::Value;
6use std::collections::HashMap;
7
8#[derive(Clone, Debug)]
9pub struct Client {
11 http_client: reqwest::Client,
12 client_id: String,
13 client_secret: String,
14 pub token: Token,
16 pub base_url: String,
22}
23impl Client {
24 pub async fn send_request<F, T>(&mut self, request_fn: F) -> Result<Response, Error>
29 where
30 F: Fn(&Token) -> T,
31 T: std::future::Future<Output = Result<Response, reqwest::Error>>,
32 {
33 match request_fn(&self.token).await {
34 Ok(resp) => {
35 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
38 self.refresh_token().await?;
40
41 let retry_response = request_fn(&self.token).await?;
43 Ok(retry_response)
44 } else {
45 Ok(resp)
46 }
47 }
48 Err(e) => Err(Error::Request(e)),
49 }
50 }
51
52 pub async fn post<T: Serialize>(
54 &mut self,
55 endpoint: &str,
56 payload: &T,
57 ) -> Result<Response, Error> {
58 let url = format!("https://{}/{}", self.base_url, endpoint);
59 let resp = self
60 .clone()
61 .send_request(|token| {
62 self.http_client
63 .post(&url)
64 .header("Content-Type", "application/json")
65 .header(
66 header::AUTHORIZATION,
67 format!("Bearer {}", token.access_token),
68 )
69 .json(payload)
70 .send()
71 })
72 .await?;
73
74 Ok(resp)
75 }
76
77 pub async fn get(&mut self, endpoint: &str) -> Result<Response, Error> {
79 let url = format!("https://{}/{}", self.base_url, endpoint);
80 let resp = self
81 .clone()
82 .send_request(|token| {
83 self.http_client
84 .get(&url)
85 .header(
86 header::AUTHORIZATION,
87 format!("Bearer {}", token.access_token),
88 )
89 .send()
90 })
91 .await?;
92
93 Ok(resp)
94 }
95
96 pub async fn put<T: Serialize>(
98 &mut self,
99 endpoint: &str,
100 payload: &T,
101 ) -> Result<Response, Error> {
102 let url = format!("https://{}/{}", self.base_url, endpoint);
103 let resp = self
104 .clone()
105 .send_request(|token| {
106 self.http_client
107 .put(&url)
108 .header("Content-Type", "application/json")
109 .header(
110 header::AUTHORIZATION,
111 format!("Bearer {}", token.access_token),
112 )
113 .json(&payload)
114 .send()
115 })
116 .await?;
117
118 Ok(resp)
119 }
120
121 pub async fn delete(&mut self, endpoint: &str) -> Result<Response, Error> {
123 let url = format!("https://{}/{}", self.base_url, endpoint);
124 let resp = self
125 .clone()
126 .send_request(|token| {
127 self.http_client
128 .delete(&url)
129 .header("Content-Type", "application/json")
130 .header(
131 header::AUTHORIZATION,
132 format!("Bearer {}", token.access_token),
133 )
134 .send()
135 })
136 .await?;
137
138 Ok(resp)
139 }
140
141 pub async fn stream<F>(&mut self, endpoint: &str, mut process_chunk: F) -> Result<(), Error>
145 where
146 F: FnMut(Value) -> Result<(), Error>,
147 {
148 let url = format!("https://{}/{}", self.base_url, endpoint);
149
150 let mut resp = self
151 .clone()
152 .send_request(|token| {
153 self.http_client
154 .get(&url)
155 .header(
156 reqwest::header::AUTHORIZATION,
157 format!("Bearer {}", token.access_token),
158 )
159 .send()
160 })
161 .await?;
162
163 if !resp.status().is_success() {
164 return Err(Error::StreamIssue(format!(
165 "Request failed with status: {}",
166 resp.status()
167 )));
168 }
169
170 let mut buffer = String::new();
171 while let Some(chunk) = resp.chunk().await? {
172 let chunk_str = std::str::from_utf8(&chunk).unwrap_or("");
173 buffer.push_str(chunk_str);
174
175 while let Some(pos) = buffer.find("\n") {
176 let json_str = buffer[..pos].trim().to_string();
177 buffer = buffer[pos + 1..].to_string();
178 if json_str.is_empty() {
179 continue;
180 }
181
182 match serde_json::from_str::<Value>(&json_str) {
183 Ok(json_value) => {
184 if let Err(e) = process_chunk(json_value) {
185 if matches!(e, Error::StopStream) {
186 return Ok(());
187 } else {
188 return Err(e);
189 }
190 }
191 }
192 Err(e) => {
193 return Err(Error::Json(e));
194 }
195 }
196 }
197 }
198
199 if !buffer.trim().is_empty() {
201 match serde_json::from_str::<Value>(&buffer) {
202 Ok(json_value) => {
203 if let Err(e) = process_chunk(json_value) {
204 if matches!(e, Error::StopStream) {
205 return Ok(());
206 } else {
207 return Err(e);
208 }
209 }
210 }
211 Err(e) => {
212 return Err(Error::Json(e));
213 }
214 }
215 }
216
217 Ok(())
218 }
219
220 pub async fn refresh_token(&mut self) -> Result<(), Error> {
223 let form_data: HashMap<String, String> = HashMap::from([
224 ("grant_type".into(), "refresh_token".into()),
225 ("client_id".into(), self.client_id.clone()),
226 ("client_secret".into(), self.client_secret.clone()),
227 ("refresh_token".into(), self.token.refresh_token.clone()),
228 ("redirect_uri".into(), "http://localhost:8080/".into()),
229 ]);
230
231 let new_token = self
232 .http_client
233 .post("https://signin.tradestation.com/oauth/token")
234 .header("Content-Type", "application/x-www-form-urlencoded")
235 .form(&form_data)
236 .send()
237 .await?
238 .json::<RefreshedToken>()
239 .await?;
240
241 self.token = Token {
243 refresh_token: self.token.refresh_token.clone(),
244 access_token: new_token.access_token,
245 id_token: new_token.id_token,
246 scope: new_token.scope,
247 token_type: new_token.token_type,
248 expires_in: new_token.expires_in,
249 };
250
251 Ok(())
252 }
253}
254
255#[derive(Debug, Default)]
256pub struct ClientBuilder;
258
259#[derive(Debug, Default)]
260pub struct Step1;
262#[derive(Debug, Default)]
263pub struct Step2;
265#[derive(Debug, Default)]
266pub struct Step3;
268
269#[derive(Debug, Default)]
270pub struct ClientBuilderStep<CurrentStep> {
273 _current_step: CurrentStep,
274 http_client: Option<reqwest::Client>,
275 client_id: Option<String>,
276 client_secret: Option<String>,
277 token: Option<Token>,
278 testing_url: Option<String>,
279}
280
281impl ClientBuilder {
282 #[allow(clippy::new_ret_no_self)]
283 pub fn new() -> Result<ClientBuilderStep<Step1>, Error> {
285 Ok(ClientBuilderStep {
286 _current_step: Step1,
287 http_client: Some(reqwest::Client::new()),
288 ..Default::default()
289 })
290 }
291}
292impl ClientBuilderStep<Step1> {
293 pub fn credentials(
295 self,
296 client_id: &str,
297 client_secret: &str,
298 ) -> Result<ClientBuilderStep<Step2>, Error> {
299 Ok(ClientBuilderStep {
300 _current_step: Step2,
301 http_client: Some(self.http_client.unwrap()),
302 client_id: Some(client_id.into()),
303 client_secret: Some(client_secret.into()),
304 ..Default::default()
305 })
306 }
307
308 pub fn testing_url(self, url: &str) -> ClientBuilderStep<Step3> {
316 ClientBuilderStep {
317 _current_step: Step3,
318 http_client: self.http_client,
319 client_id: self.client_id,
320 client_secret: self.client_secret,
321 token: self.token,
322 testing_url: Some(url.into()),
323 }
324 }
325}
326impl ClientBuilderStep<Step2> {
327 pub async fn authorize(
329 self,
330 authorization_code: &str,
331 ) -> Result<ClientBuilderStep<Step3>, Error> {
332 let http_client = self.http_client.unwrap();
335 let client_id = self.client_id.as_ref().unwrap();
336 let client_secret = self.client_secret.as_ref().unwrap();
337
338 let form_data = HashMap::from([
340 ("grant_type", "authorization_code"),
341 ("client_id", client_id),
342 ("client_secret", client_secret),
343 ("code", authorization_code),
344 ("redirect_uri", "http://localhost:8080/"),
345 ]);
346 let token = http_client
347 .post("https://signin.tradestation.com/oauth/token")
348 .header("Content-Type", "application/x-www-form-urlencoded")
349 .form(&form_data)
350 .send()
351 .await?
352 .json::<Token>()
353 .await?;
354
355 Ok(ClientBuilderStep {
356 _current_step: Step3,
357 http_client: Some(http_client),
358 client_id: self.client_id,
359 client_secret: self.client_secret,
360 token: Some(token),
361 testing_url: self.testing_url,
362 })
363 }
364
365 pub fn token(self, token: Token) -> Result<ClientBuilderStep<Step3>, Error> {
367 Ok(ClientBuilderStep {
368 _current_step: Step3,
369 http_client: self.http_client,
370 client_id: self.client_id,
371 client_secret: self.client_secret,
372 token: Some(token),
373 testing_url: self.testing_url,
374 })
375 }
376}
377impl ClientBuilderStep<Step3> {
378 pub async fn build(self) -> Result<Client, Error> {
380 let http_client = self.http_client.unwrap();
381
382 if self.testing_url.is_none() {
383 let client_id = self.client_id.unwrap();
384 let client_secret = self.client_secret.unwrap();
385 let token = self.token.unwrap();
386 let base_url = "api.tradestation.com/v3".to_string();
387
388 Ok(Client {
389 http_client,
390 client_id,
391 client_secret,
392 token,
393 base_url,
394 })
395 } else {
396 let client_id = "NO_CLIENT_ID_IN_TEST_MODE".to_string();
397 let client_secret = "NO_CLIENT_SECRET_IN_TEST_MODE".to_string();
398 let token = Token {
399 access_token: String::from("NO_ACCESS_TOKEN_IN_TEST_MODE"),
400 refresh_token: String::from("NO_REFRESH_TOKEN_IN_TEST_MODE"),
401 id_token: String::from("NO_ID_TOKEN_IN_TEST_MODE"),
402 token_type: String::from("TESTING"),
403 scope: String::from("NO SCOPES IN TEST MODE"),
404 expires_in: 9999,
405 };
406 let base_url = self
407 .testing_url
408 .expect("Some `Client::testing_url` to be set due to invariant check.");
409
410 Ok(Client {
411 http_client,
412 client_id,
413 client_secret,
414 token,
415 base_url,
416 })
417 }
418 }
419}