1use async_std::io::prelude::BufReadExt;
2use async_std::stream::StreamExt;
3
4use bytes::Bytes;
5
6use futures::TryStreamExt;
7
8use serde::de::DeserializeOwned;
9use tracing::debug;
10
11use crate::error::{Error, Result};
12
13#[derive(Debug, Clone)]
14pub struct LichessApi<HttpClient> {
15 pub client: HttpClient,
16 bearer_auth: Option<String>,
17}
18
19impl<HttpClient> LichessApi<HttpClient> {
20 pub fn new(client: HttpClient, auth_token: Option<String>) -> Self {
21 let bearer_auth = auth_token.map(|token| format!("Bearer {}", token));
22 Self {
23 client,
24 bearer_auth,
25 }
26 }
27
28 pub(crate) async fn expect_one_model<Model, G>(&self, stream: &mut G) -> Result<Model>
29 where
30 G: StreamExt<Item = Result<Model>> + std::marker::Unpin,
31 {
32 stream
33 .next()
34 .await
35 .ok_or(Error::Response("empty response stream".to_string()))?
36 }
37
38 pub(crate) async fn expect_empty<G>(&self, stream: &mut G) -> Result<()>
39 where
40 G: StreamExt<Item = Result<()>> + std::marker::Unpin,
41 {
42 if stream.next().await.is_some() {
43 Err(Error::Response(
44 "expected empty response stream".to_string(),
45 ))
46 } else {
47 Ok(())
48 }
49 }
50}
51
52impl LichessApi<reqwest::Client> {
53 pub(crate) async fn make_request<Model: DeserializeOwned>(
54 &self,
55 http_request: http::Request<Bytes>,
56 ) -> Result<impl StreamExt<Item = Result<Model>>> {
57 let stream =
58 self.make_request_as_raw_lines(http_request)
59 .await?
60 .map(|l| -> Result<Model> {
61 serde_json::from_str(&l?).map_err(|e| crate::error::Error::Json(e))
62 });
63
64 Ok(stream)
65 }
66
67 pub(crate) async fn make_request_as_raw_lines(
68 &self,
69 mut http_request: http::Request<Bytes>,
70 ) -> Result<impl StreamExt<Item = Result<String>>> {
71 if let Some(auth) = &self.bearer_auth {
72 let mut auth_header = http::HeaderValue::from_str(&auth)
73 .map_err(|e| Error::HttpRequestBuilder(http::Error::from(e)))?;
74 auth_header.set_sensitive(true);
76 http_request
77 .headers_mut()
78 .insert(http::header::AUTHORIZATION, auth_header);
79 };
80
81 let convert_err = |e: reqwest::Error| Error::Request(e.to_string());
82 let request = reqwest::Request::try_from(http_request).map_err(convert_err)?;
83 let body_text = if let Some(body) = request.body() {
84 match body.as_bytes() {
85 Some(bytes) => String::from_utf8_lossy(bytes).to_string(),
86 None => "<streaming body>".to_string(),
87 }
88 } else {
89 "<empty body>".to_string()
90 };
91 debug!(?request, body = %body_text, "sending");
92 let response = self.client.execute(request).await;
93 debug!(?response, "received");
94 let stream = response
95 .map_err(convert_err)?
96 .bytes_stream()
97 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
98 .into_async_read()
99 .lines()
100 .filter(|l| match l {
101 Ok(line) => !line.is_empty(),
103 Err(_) => true,
104 })
105 .map(|l| -> Result<String> {
106 let line = l?;
107 debug!(line, "model line");
108 if line.starts_with("<!DOCTYPE html>") {
109 return Err(crate::error::Error::PageNotFound());
110 }
111 if let Ok(error_value) = serde_json::from_str::<serde_json::Value>(&line) {
114 if let Some(error_msg) = error_value.get("error").and_then(|e| e.as_str()) {
115 return Err(crate::error::Error::Response(error_msg.to_string()));
116 }
117 }
118 Ok(line)
119 });
120
121 Ok(stream)
122 }
123}