lichess_api/
client.rs

1use async_std::io::prelude::BufReadExt;
2use async_std::stream::StreamExt;
3
4use bytes::Bytes;
5
6use futures::TryStreamExt;
7
8use serde::de::DeserializeOwned;
9use tracing::debug;
10
11use crate::error::{Error, Result};
12
13#[derive(Debug, Clone)]
14pub struct LichessApi<HttpClient> {
15    pub client: HttpClient,
16    bearer_auth: Option<String>,
17}
18
19impl<HttpClient> LichessApi<HttpClient> {
20    pub fn new(client: HttpClient, auth_token: Option<String>) -> Self {
21        let bearer_auth = auth_token.map(|token| format!("Bearer {}", token));
22        Self {
23            client,
24            bearer_auth,
25        }
26    }
27
28    pub(crate) async fn expect_one_model<Model, G>(&self, stream: &mut G) -> Result<Model>
29    where
30        G: StreamExt<Item = Result<Model>> + std::marker::Unpin,
31    {
32        stream
33            .next()
34            .await
35            .ok_or(Error::Response("empty response stream".to_string()))?
36    }
37
38    pub(crate) async fn expect_empty<G>(&self, stream: &mut G) -> Result<()>
39    where
40        G: StreamExt<Item = Result<()>> + std::marker::Unpin,
41    {
42        if stream.next().await.is_some() {
43            Err(Error::Response(
44                "expected empty response stream".to_string(),
45            ))
46        } else {
47            Ok(())
48        }
49    }
50}
51
52impl LichessApi<reqwest::Client> {
53    pub(crate) async fn make_request<Model: DeserializeOwned>(
54        &self,
55        http_request: http::Request<Bytes>,
56    ) -> Result<impl StreamExt<Item = Result<Model>>> {
57        let stream =
58            self.make_request_as_raw_lines(http_request)
59                .await?
60                .map(|l| -> Result<Model> {
61                    serde_json::from_str(&l?).map_err(|e| crate::error::Error::Json(e))
62                });
63
64        Ok(stream)
65    }
66
67    pub(crate) async fn make_request_as_raw_lines(
68        &self,
69        mut http_request: http::Request<Bytes>,
70    ) -> Result<impl StreamExt<Item = Result<String>>> {
71        if let Some(auth) = &self.bearer_auth {
72            let mut auth_header = http::HeaderValue::from_str(&auth)
73                .map_err(|e| Error::HttpRequestBuilder(http::Error::from(e)))?;
74            // exclude the auth header from being logged
75            auth_header.set_sensitive(true);
76            http_request
77                .headers_mut()
78                .insert(http::header::AUTHORIZATION, auth_header);
79        };
80
81        let convert_err = |e: reqwest::Error| Error::Request(e.to_string());
82        let request = reqwest::Request::try_from(http_request).map_err(convert_err)?;
83        let body_text = if let Some(body) = request.body() {
84            match body.as_bytes() {
85                Some(bytes) => String::from_utf8_lossy(bytes).to_string(),
86                None => "<streaming body>".to_string(),
87            }
88        } else {
89            "<empty body>".to_string()
90        };
91        debug!(?request, body = %body_text, "sending");
92        let response = self.client.execute(request).await;
93        debug!(?response, "received");
94        let stream = response
95            .map_err(convert_err)?
96            .bytes_stream()
97            .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
98            .into_async_read()
99            .lines()
100            .filter(|l| match l {
101                // To avoid trying to serialize blank keep alive lines.
102                Ok(line) => !line.is_empty(),
103                Err(_) => true,
104            })
105            .map(|l| -> Result<String> {
106                let line = l?;
107                debug!(line, "model line");
108                if line.starts_with("<!DOCTYPE html>") {
109                    return Err(crate::error::Error::PageNotFound());
110                }
111                // Check for error responses returned as json before model serialization is attempted.
112                // This can happen when not authorized to access an endpoint.
113                if let Ok(error_value) = serde_json::from_str::<serde_json::Value>(&line) {
114                    if let Some(error_msg) = error_value.get("error").and_then(|e| e.as_str()) {
115                        return Err(crate::error::Error::Response(error_msg.to_string()));
116                    }
117                }
118                Ok(line)
119            });
120
121        Ok(stream)
122    }
123}