etrade/
session.rs

1use crate::{Credentials, Mode, Store};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4
5use bytes::Buf;
6use chrono::{NaiveDate, Utc};
7use http::{
8  header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
9  Method, Request, Response,
10};
11
12use serde::de::DeserializeOwned;
13use serde::ser::Serialize;
14use tokio::io::{self, *};
15
16use hyper::{
17  client::{connect::dns::GaiResolver, HttpConnector},
18  Client,
19};
20use hyper_tls::HttpsConnector;
21
22use secstr::SecUtf8;
23
24use hyper::service::Service;
25use std::{collections::BTreeMap, fmt::Debug, iter::FromIterator};
26
27use super::{LIVE_URL, SANDBOX_URL};
28
29const SANDBOX_NAMESPACE: &str = "etradesandbox";
30const LIVE_NAMESPACE: &str = "etrade";
31
32const API_KEY: &str = "apikey";
33const SECRET_KEY: &str = "secret";
34const ACCESS_TOKEN_KEY: &str = "access_token_key";
35const ACCESS_TOKEN_SECRET: &str = "access_token_secret";
36const REQUEST_TOKEN_KEY: &str = "request_token_key";
37const REQUEST_TOKEN_SECRET: &str = "request_token_secret";
38const REQUEST_TOKEN_CREATED: &str = "request_token_ts";
39
40const REQUEST_TOKEN_URL: &str = "https://api.etrade.com/oauth/request_token";
41const ACCESS_TOKEN_URL: &str = "https://api.etrade.com/oauth/access_token";
42const RENEW_ACCESS_TOKEN_URL: &str = "https://api.etrade.com/oauth/renew_access_token";
43
44type HttpClient = Client<HttpsConnector<HttpConnector<GaiResolver>>, hyper::Body>;
45
46#[async_trait]
47pub trait CallbackProvider: Clone {
48  async fn verifier_code(&self, url: &str) -> Result<String>;
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct OOB;
53
54#[async_trait]
55impl CallbackProvider for OOB {
56  async fn verifier_code(&self, url: &str) -> Result<String> {
57    let msg = format!("please visit and accept the license: {}\ninput pin: \n", url,);
58    io::stderr().write_all(msg.as_bytes()).await?;
59
60    let stdin = io::stdin();
61    let mut user_input = String::new();
62    io::BufReader::new(stdin).read_line(&mut user_input).await?;
63
64    let result = Ok(user_input.trim().to_owned());
65    debug!("got verificaton code: {}", result.as_ref().unwrap());
66    result
67  }
68}
69
70#[derive(Debug, Clone, Copy)]
71struct UrlConfig<'a> {
72  pub access_token_url: &'a str,
73  pub renew_access_token_url: &'a str,
74  pub request_token_url: &'a str,
75}
76
77impl<'a> UrlConfig<'a> {
78  pub fn authorize_url(&self, key: &SecUtf8, token: &SecUtf8) -> String {
79    format!(
80      "https://us.etrade.com/e/t/etws/authorize?key={}&token={}",
81      key.unsecure(),
82      token.unsecure(),
83    )
84  }
85}
86
87impl<'a> Default for UrlConfig<'a> {
88  fn default() -> Self {
89    Self {
90      access_token_url: ACCESS_TOKEN_URL,
91      renew_access_token_url: RENEW_ACCESS_TOKEN_URL,
92      request_token_url: REQUEST_TOKEN_URL,
93    }
94  }
95}
96
97pub struct Session<T: Store> {
98  store: T,
99  mode: Mode,
100  client: HttpClient,
101  urls: UrlConfig<'static>,
102}
103
104impl<T> Session<T>
105where
106  T: Store,
107{
108  pub fn new(mode: Mode, store: T) -> Self {
109    let https = HttpsConnector::new();
110
111    Self {
112      store,
113      mode,
114      client: Client::builder().build(https),
115      urls: UrlConfig::default(),
116    }
117  }
118
119  fn base_url(&self) -> &str {
120    match self.mode {
121      Mode::Sandbox => SANDBOX_URL,
122      Mode::Live => LIVE_URL,
123    }
124  }
125
126  fn namespace(&self) -> &str {
127    match self.mode {
128      Mode::Sandbox => SANDBOX_NAMESPACE,
129      Mode::Live => LIVE_NAMESPACE,
130    }
131  }
132
133  pub async fn initialize(&self, key: String, secret: String) -> Result<()> {
134    self.store.put(self.namespace(), API_KEY, key).await?;
135    self.store.put(self.namespace(), SECRET_KEY, secret).await?;
136    Ok(())
137  }
138
139  async fn consumer(&self) -> Result<Credentials> {
140    let consumer_key = self
141      .store
142      .get(self.namespace(), API_KEY)
143      .await
144      .and_then(|r| r.ok_or_else(|| anyhow!("secret {}@{} not found.", API_KEY, self.namespace())))?;
145    let consumer_secret = self
146      .store
147      .get(self.namespace(), SECRET_KEY)
148      .await
149      .and_then(|r| r.ok_or_else(|| anyhow!("secret {}@{} not found.", SECRET_KEY, self.namespace())))?;
150
151    Ok(Credentials::new(consumer_key, consumer_secret))
152  }
153
154  pub async fn invalidate(&self) -> Result<()> {
155    debug!("invalidating credentials");
156    self.store.del(self.namespace(), ACCESS_TOKEN_KEY).await?;
157    self.store.del(self.namespace(), ACCESS_TOKEN_SECRET).await?;
158
159    self.store.del(self.namespace(), REQUEST_TOKEN_SECRET).await?;
160    self.store.del(self.namespace(), REQUEST_TOKEN_KEY).await?;
161    self.store.del(self.namespace(), REQUEST_TOKEN_CREATED).await
162  }
163
164  async fn request_token(&self, consumer: &Credentials) -> Result<Credentials> {
165    debug!("getting a request token");
166    let request_token = self.store.get(self.namespace(), REQUEST_TOKEN_KEY).await?;
167    let request_secret = self.store.get(self.namespace(), REQUEST_TOKEN_SECRET).await?;
168
169    let request_token_ts = self
170      .store
171      .get(self.namespace(), REQUEST_TOKEN_CREATED)
172      .await?
173      .and_then(|v| {
174        let b = NaiveDate::parse_from_str(v.unsecure(), "%Y-%m-%d").unwrap();
175
176        let d = Utc::now().with_timezone(&chrono_tz::US::Eastern).naive_local().date();
177        if b.eq(&d) {
178          Some(d)
179        } else {
180          None
181        }
182      });
183    match (request_token_ts, request_token, request_secret) {
184      (Some(_), Some(rt), Some(rs)) => {
185        debug!("using cached request token");
186        Ok(Credentials::new(rt, rs))
187      }
188      _ => {
189        debug!("getting a new request token");
190        let uri = http::Uri::from_static(self.urls.request_token_url);
191        let authorization = oauth::Builder::<_, _>::new(consumer.clone().into(), oauth::HMAC_SHA1)
192          .callback("oob")
193          .get(&uri, &());
194
195        let body = send_request(uri, authorization, &self.client).await;
196        let creds: oauth_credentials::Credentials<Box<str>> = serde_urlencoded::from_bytes(&body)?;
197
198        debug!("created request token: {:?}", &creds);
199        let request_token: Credentials = creds.into();
200        self
201          .store
202          .put(self.namespace(), REQUEST_TOKEN_KEY, request_token.key.unsecure())
203          .await?;
204        self
205          .store
206          .put(self.namespace(), REQUEST_TOKEN_SECRET, request_token.secret.unsecure())
207          .await?;
208
209        let today = Utc::now()
210          .with_timezone(&chrono_tz::US::Eastern)
211          .date_naive()
212          .format("%Y-%m-%d")
213          .to_string();
214        self.store.put(self.namespace(), REQUEST_TOKEN_CREATED, &today).await?;
215        Ok(request_token)
216      }
217    }
218  }
219
220  async fn access_token(&self, callback: impl CallbackProvider) -> Result<Credentials> {
221    let consumer = self.consumer().await?;
222
223    let access_token = self.store.get(self.namespace(), ACCESS_TOKEN_KEY).await?;
224    let access_secret = self.store.get(self.namespace(), ACCESS_TOKEN_SECRET).await?;
225
226    match (access_token, access_secret) {
227      (Some(token), Some(secret)) => {
228        debug!("using cached access token");
229        Ok(Credentials::new(token, secret))
230      }
231      _ => {
232        let request_token = self.request_token(&consumer).await;
233        if request_token.is_err() {
234          debug!("restarting full flow because request token has an error");
235          return self.full_access_token_flow(consumer, callback).await;
236        }
237
238        match self.renew_access_token(&consumer, &request_token.unwrap()).await {
239          Ok(access_token) => {
240            debug!("using renewed access token");
241            Ok(access_token)
242          }
243          Err(_) => self.full_access_token_flow(consumer, callback).await,
244        }
245      }
246    }
247  }
248
249  async fn full_access_token_flow(
250    &self,
251    consumer: Credentials,
252    callback: impl CallbackProvider,
253  ) -> Result<Credentials> {
254    self.invalidate().await?;
255
256    let request_token = self.request_token(&consumer).await?;
257    let auth_url = self.urls.authorize_url(&consumer.key, &request_token.key);
258    let pin = callback.verifier_code(&auth_url).await?;
259
260    let access_token = self.create_access_token(&consumer, &request_token, &pin).await?;
261
262    Ok(access_token)
263  }
264
265  async fn create_access_token(
266    &self,
267    consumer: &Credentials,
268    request_token: &Credentials,
269    pin: impl AsRef<str>,
270  ) -> Result<Credentials> {
271    debug!("getting an access token");
272    let uri = http::Uri::from_static(self.urls.access_token_url);
273    let authorization = oauth::Builder::<_, _>::new(consumer.clone().into(), oauth::HMAC_SHA1)
274      .token(Some(request_token.clone().into()))
275      .verifier(pin.as_ref())
276      .get(&uri, &());
277    let body = send_request(uri, authorization, &self.client).await;
278    let creds: oauth_credentials::Credentials<Box<str>> = serde_urlencoded::from_bytes(&body)?;
279
280    debug!("created access token: {:?}", &creds);
281    let access_token: Credentials = creds.into();
282    self
283      .store
284      .put(self.namespace(), ACCESS_TOKEN_KEY, access_token.key.unsecure())
285      .await?;
286    self
287      .store
288      .put(self.namespace(), ACCESS_TOKEN_SECRET, access_token.secret.unsecure())
289      .await?;
290    Ok(access_token)
291  }
292
293  async fn renew_access_token(&self, consumer: &Credentials, request_token: &Credentials) -> Result<Credentials> {
294    debug!("renewing an access token");
295    let uri = http::Uri::from_static(self.urls.renew_access_token_url);
296    let authorization = oauth::Builder::<_, _>::new(consumer.clone().into(), oauth::HMAC_SHA1)
297      .token(Some(request_token.clone().into()))
298      .get(&uri, &());
299
300    let body = send_request(uri, authorization, &self.client).await;
301    let creds: oauth_credentials::Credentials<Box<str>> = serde_urlencoded::from_bytes(&body)?;
302    debug!("renewed access token: {:?}", &creds);
303    let access_token: Credentials = creds.into();
304    self
305      .store
306      .put(self.namespace(), ACCESS_TOKEN_KEY, access_token.key.unsecure())
307      .await?;
308    self
309      .store
310      .put(self.namespace(), ACCESS_TOKEN_SECRET, access_token.secret.unsecure())
311      .await?;
312    Ok(access_token)
313  }
314
315  async fn do_send<P, B, C>(
316    &self,
317    method: http::Method,
318    path: P,
319    input: Option<B>,
320    callback: C,
321  ) -> Result<Response<hyper::Body>>
322  where
323    P: AsRef<str> + Send + Sync,
324    B: Serialize + Clone + Send + Sync,
325    C: CallbackProvider + Clone,
326  {
327    let consumer = self.consumer().await?;
328    let access_token = self.access_token(callback.clone()).await?;
329
330    let uri = format!("{}{}", self.base_url(), path.as_ref());
331
332    let (bare_uri, full_uri, params): (&str, String, BTreeMap<String, String>) = match method {
333      Method::GET => {
334        let qs = serde_urlencoded::to_string(&input)?;
335        if qs.is_empty() {
336          (&uri, uri.clone(), BTreeMap::default())
337        } else {
338          let qss: Vec<(String, String)> = serde_urlencoded::from_str(qs.as_ref())?;
339          (
340            &uri,
341            format!("{}?{}", uri, serde_urlencoded::to_string(&input)?).parse()?,
342            BTreeMap::from_iter(qss),
343          )
344        }
345      }
346      _ => (&uri, uri.clone(), BTreeMap::default()),
347    };
348
349    let oreq = oauth::request::AssertSorted::new(&params);
350
351    let authorization = oauth::Builder::new(consumer.into(), oauth::HMAC_SHA1)
352      .token(Some(access_token.into()))
353      .authorize(method.as_str(), bare_uri, &oreq);
354
355    let body: hyper::Body = match input.clone() {
356      Some(v) => match method {
357        Method::GET => hyper::Body::empty(),
358        _ => serde_json::to_vec(&v)?.into(),
359      },
360      _ => hyper::Body::empty(),
361    };
362
363    let req = Request::builder()
364      .method(method.clone())
365      .header(ACCEPT, "application/json")
366      .header(AUTHORIZATION, authorization)
367      .uri(full_uri)
368      .body(body)
369      .unwrap();
370
371    // let req = builder
372    debug!("{:?}", req);
373    let resp = self.client.request(req).await?;
374    debug!("{:?}", resp);
375    Ok(resp)
376  }
377
378  pub async fn send<P, B, R, C>(&self, method: http::Method, path: P, input: Option<B>, callback: C) -> Result<R>
379  where
380    P: AsRef<str> + Send + Sync,
381    B: Serialize + Clone + Send + Sync,
382    R: DeserializeOwned + Send + Sync,
383    C: CallbackProvider + Clone,
384  {
385    let mut resp = self
386      .do_send(method.clone(), path.as_ref(), input.clone(), callback.clone())
387      .await?;
388
389    if resp.status().as_u16() == 401 {
390      debug!("auth error, retrying with invalidated session");
391      self.invalidate().await?;
392      resp = self.do_send(method, path, input, callback).await?;
393    }
394
395    debug!("reading status code");
396    let status_code = resp.status().as_u16();
397    debug!("reading content type code");
398    let content_type = resp
399      .headers()
400      .get(CONTENT_TYPE)
401      .map(|ct| ct.to_str().unwrap_or("application/json"))
402      .unwrap_or("application/json")
403      .to_string();
404    debug!("aggregating body");
405
406    let body = hyper::body::aggregate(resp).await?;
407    if status_code / 100 != 2 {
408      debug!("non 200 status code, reading error");
409      let edata: ErrorData = quick_xml::de::from_reader(body.reader())?;
410      return Err(anyhow!("{} (code: {})", edata.message, edata.code));
411    }
412    debug!("got a successful response");
413    match content_type.as_str() {
414      "application/xml" => Ok(quick_xml::de::from_reader(body.reader())?),
415      "application/json" => Ok(serde_json::from_reader(body.reader())?),
416      v => return Err(anyhow!("api responded with unknown content type {}", v)),
417    }
418  }
419}
420
421#[derive(Debug, Deserialize, PartialEq)]
422pub struct ErrorData {
423  pub code: isize,
424  pub message: String,
425}
426
427async fn send_request<S, B>(uri: http::Uri, authorization: String, mut http: S) -> Vec<u8>
428where
429  S: Service<http::Request<B>, Response = http::Response<B>>,
430  S::Error: Debug,
431  B: http_body::Body<Error = S::Error> + Default + From<Vec<u8>> + Debug,
432{
433  let req = http::Request::get(uri)
434    .header(AUTHORIZATION, authorization)
435    .body(B::default())
436    .unwrap();
437
438  debug!("{:?}", req);
439  let resp = http.call(req).await.unwrap();
440  debug!("{:?}", resp);
441  if resp.status().as_u16() / 100 == 2 {
442    hyper::body::to_bytes(resp.into_body()).await.unwrap().to_vec()
443  } else {
444    vec![]
445  }
446}
447
448#[cfg(test)]
449mod tests {
450  use std::net::TcpListener;
451
452  use hyper::Client;
453
454  #[test]
455  fn encodes_query_string() {
456    crate::tests::init();
457    let data = Some(&[("blah", "some things go here"), ("other", "and others go here")]);
458    let v = serde_urlencoded::to_string(data).unwrap();
459    info!("query string: '{}'", v);
460    let n = serde_urlencoded::to_string(None as Option<&[u8]>).unwrap();
461    info!("query string: '{}'", n);
462  }
463
464  #[test]
465  fn encodes_json_string() {
466    crate::tests::init();
467    let data = Some(&[("blah", "some things go here"), ("other", "and others go here")]);
468
469    let data2: Option<&[(&str, &str)]> = None;
470
471    info!(
472      "query string: '{:?}'",
473      data
474        .map(|v| hyper::body::Body::from(serde_json::to_string(&v).unwrap()))
475        .unwrap_or_else(hyper::Body::empty)
476    );
477    info!(
478      "query string: '{:?}'",
479      data2
480        .map(|v| hyper::body::Body::from(serde_json::to_string(&v).unwrap()))
481        .unwrap_or_else(hyper::Body::empty)
482    );
483  }
484
485  #[tokio::test]
486  async fn it_works() {
487    crate::tests::init();
488    info!("inside the working test");
489    let listener = TcpListener::bind("127.0.0.1:0").unwrap();
490    let client = Client::new();
491    let base_url = format!("http://127.0.0.1:{}", listener.local_addr().unwrap().port());
492
493    let _th = tokio::task::spawn(async move { server::test_server(listener).await });
494    let uri: http::Uri = base_url.parse().unwrap();
495    let resp = client.get(uri).await.unwrap();
496    let body = String::from_utf8(hyper::body::to_bytes(resp.into_body()).await.unwrap().to_vec()).unwrap();
497
498    assert_eq!("Hello, world!", &body);
499  }
500
501  mod server {
502    use anyhow::{anyhow, Result};
503    use http::Response;
504    use hyper::service::{make_service_fn, service_fn};
505    use hyper::Body;
506    use hyper::Server;
507    use std::{convert::Infallible, net::TcpListener};
508
509    pub async fn test_server(listener: TcpListener) -> Result<()> {
510      let server = Server::from_tcp(listener)?;
511      let service = service_fn(|req| async move {
512        info!("{:?}", req);
513        Ok::<_, Infallible>(Response::new(Body::from("Hello, world!")))
514      });
515      let make_service = make_service_fn(|_| async move { Ok::<_, Infallible>(service) });
516      server
517        .tcp_nodelay(true)
518        .tcp_keepalive(None)
519        .serve(make_service)
520        .await
521        .map_err(|e| anyhow!("{}", e))
522    }
523  }
524}