use crate::{Error, Result};
use async_stream::try_stream;
use futures::{Stream, StreamExt, pin_mut};
use http::header::{HeaderName, USER_AGENT};
use reqwest::{ClientBuilder, IntoUrl, Method, StatusCode, header::HeaderMap, header::HeaderValue};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{Map, Value};
use stac::api::{
GetItems, Item, ItemCollection, Items, ItemsClient, Search, StreamItemsClient, UrlBuilder,
};
use stac::{Collection, Link, Links, SelfHref};
use std::pin::Pin;
use tokio::{
runtime::{Builder, Runtime},
sync::mpsc::{self, error::SendError},
task::JoinHandle,
};
const DEFAULT_CHANNEL_BUFFER: usize = 4;
pub async fn search(
href: &str,
search: Search,
max_items: Option<usize>,
) -> Result<ItemCollection> {
search_with_headers(href, search, max_items, &[]).await
}
pub async fn search_with_headers(
href: &str,
mut search: Search,
max_items: Option<usize>,
headers: &[(String, String)],
) -> Result<ItemCollection> {
let mut builder = ApiClientBuilder::new(href)?;
if !headers.is_empty() {
builder = builder.with_headers(headers)?;
}
let client = builder.build()?;
if search.limit.is_none()
&& let Some(max_items) = max_items
{
search.limit = Some(max_items.try_into()?);
}
let stream = StreamItemsClient::search_stream(&client, search).await?;
let mut items = if let Some(max_items) = max_items {
if max_items == 0 {
return Ok(ItemCollection::default());
}
Vec::with_capacity(max_items)
} else {
Vec::new()
};
pin_mut!(stream);
while let Some(item) = stream.next().await {
let item = item?;
items.push(item);
if let Some(max_items) = max_items
&& items.len() >= max_items
{
break;
}
}
let item_collection = ItemCollection::new(items)?;
Ok(item_collection)
}
#[derive(Clone, Debug)]
pub struct Client {
client: reqwest::Client,
channel_buffer: usize,
url_builder: UrlBuilder,
}
#[derive(Debug)]
pub struct BlockingClient(Client);
#[allow(missing_debug_implementations)]
pub struct BlockingIterator {
runtime: Runtime,
stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
}
pub struct ApiClientBuilder {
url: String,
headers: HeaderMap,
}
impl ApiClientBuilder {
pub fn new(url: &str) -> Result<Self> {
let mut headers = HeaderMap::new();
let _ = headers.insert(
USER_AGENT,
format!("rustac/{}", env!("CARGO_PKG_VERSION")).parse()?,
);
Ok(Self {
url: url.to_string(),
headers,
})
}
pub fn with_headers(mut self, headers: &[(String, String)]) -> Result<Self> {
for (key, val) in headers.iter() {
let header_name = key.parse::<HeaderName>()?;
let header_value = HeaderValue::from_str(val)?;
self.headers.insert(header_name, header_value);
}
Ok(self)
}
pub fn build(self) -> Result<Client> {
let client = ClientBuilder::new().default_headers(self.headers).build()?;
Client::with_client(client, &self.url)
}
}
impl Client {
pub fn new(url: &str) -> Result<Client> {
ApiClientBuilder::new(url)?.build()
}
pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
Ok(Client {
client,
channel_buffer: DEFAULT_CHANNEL_BUFFER,
url_builder: UrlBuilder::new(url)?,
})
}
pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
let url = self.url_builder.collection(id)?;
not_found_to_none(self.get(url).await)
}
pub async fn items(
&self,
id: &str,
items: impl Into<Option<Items>>,
) -> Result<impl Stream<Item = Result<Item>>> {
let url = self.url_builder.items(id)?; let items = match items.into() {
Some(items) => Some(GetItems::try_from(items)?),
_ => None,
};
let page = self
.request(Method::GET, url.clone(), items.as_ref(), None)
.await?;
Ok(stream_items(self.clone(), page, self.channel_buffer))
}
async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
where
V: DeserializeOwned + SelfHref,
{
let url = url.into_url()?;
let mut value = self
.request::<(), V>(Method::GET, url.clone(), None, None)
.await?;
value.set_self_href(url);
Ok(value)
}
async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
where
S: Serialize + 'static,
R: DeserializeOwned,
{
self.request(Method::POST, url, Some(data), None).await
}
async fn request<S, R>(
&self,
method: Method,
url: impl IntoUrl,
params: impl Into<Option<&S>>,
headers: impl Into<Option<HeaderMap>>,
) -> Result<R>
where
S: Serialize + 'static,
R: DeserializeOwned,
{
let url = url.into_url()?;
let mut request = match method {
Method::GET => {
let mut request = self.client.get(url);
if let Some(query) = params.into() {
request = request.query(query);
}
request
}
Method::POST => {
let mut request = self.client.post(url);
if let Some(data) = params.into() {
request = request.json(&data);
}
request
}
_ => unimplemented!(),
};
if let Some(headers) = headers.into() {
request = request.headers(headers);
}
let response = request.send().await?.error_for_status()?;
response.json().await.map_err(Error::from)
}
async fn request_from_link<R>(&self, link: Link) -> Result<R>
where
R: DeserializeOwned,
{
let method = if let Some(method) = link.method {
method.parse()?
} else {
Method::GET
};
let headers = if let Some(headers) = link.headers {
let mut header_map = HeaderMap::new();
for (key, value) in headers.into_iter() {
let header_name: HeaderName = key.parse()?;
let _ = header_map.insert(header_name, value.to_string().parse()?);
}
Some(header_map)
} else {
None
};
self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
.await
}
}
impl ItemsClient for Client {
type Error = Error;
async fn search(&self, search: Search) -> std::result::Result<ItemCollection, Error> {
let url = self.url_builder.search().clone();
tracing::debug!("searching {url}");
self.post(url, &search).await
}
}
impl StreamItemsClient for Client {
type Error = Error;
async fn search_stream(
&self,
search: Search,
) -> std::result::Result<impl Stream<Item = std::result::Result<Item, Error>> + Send, Error>
{
let page = ItemsClient::search(self, search).await?;
Ok(stream_items(self.clone(), page, self.channel_buffer))
}
async fn items_stream(
&self,
collection_id: &str,
items: Items,
) -> std::result::Result<impl Stream<Item = std::result::Result<Item, Error>> + Send, Error>
{
self.items(collection_id, Some(items)).await
}
}
impl BlockingClient {
pub fn new(url: &str) -> Result<BlockingClient> {
Client::new(url).map(Self)
}
pub fn search(&self, search: Search) -> Result<BlockingIterator> {
let runtime = Builder::new_current_thread().enable_all().build()?;
let client = self.0.clone();
let stream = runtime.block_on(async move {
let page = ItemsClient::search(&client, search).await?;
Ok::<_, Error>(stream_items(client, page, self.0.channel_buffer))
})?;
Ok(BlockingIterator {
runtime,
stream: Box::pin(stream),
})
}
}
impl Iterator for BlockingIterator {
type Item = Result<Item>;
fn next(&mut self) -> Option<Self::Item> {
self.runtime.block_on(self.stream.next())
}
}
fn stream_items(
client: Client,
page: ItemCollection,
channel_buffer: usize,
) -> impl Stream<Item = Result<Item>> {
let (tx, mut rx) = mpsc::channel(channel_buffer);
let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
let pages = stream_pages(client, page);
pin_mut!(pages);
while let Some(result) = pages.next().await {
match result {
Ok(page) => tx.send(Ok(page)).await?,
Err(err) => {
tx.send(Err(err)).await?;
return Ok(());
}
}
}
Ok(())
});
try_stream! {
while let Some(result) = rx.recv().await {
let page = result?;
for item in page.items {
yield item;
}
}
let _ = handle.await?;
}
}
fn stream_pages(
client: Client,
mut page: ItemCollection,
) -> impl Stream<Item = Result<ItemCollection>> {
try_stream! {
loop {
if page.items.is_empty() {
break;
}
let next_link = page.link("next").cloned();
yield page;
if let Some(next_link) = next_link {
if let Some(next_page) = client.request_from_link(next_link).await? {
page = next_page;
} else {
break;
}
} else {
break;
}
}
}
}
fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
let mut result = result.map(Some);
if let Err(Error::Reqwest(ref err)) = result
&& err
.status()
.map(|s| s == StatusCode::NOT_FOUND)
.unwrap_or_default()
{
result = Ok(None);
}
result
}
#[cfg(test)]
mod tests {
use crate::api::ApiClientBuilder;
use super::Client;
use futures::StreamExt;
use mockito::{Matcher, Server};
use serde_json::json;
use stac::Links;
use stac::api::{ItemCollection, Items, ItemsClient, Search, StreamItemsClient};
use url::Url;
#[tokio::test]
async fn collection_not_found() {
let mut server = Server::new_async().await;
let collection = server
.mock("GET", "/collections/not-a-collection")
.with_body(include_str!("../mocks/not-a-collection.json"))
.with_header("content-type", "application/json")
.with_status(404)
.create_async()
.await;
let client = Client::new(&server.url()).unwrap();
assert!(
client
.collection("not-a-collection")
.await
.unwrap()
.is_none()
);
collection.assert_async().await;
}
#[tokio::test]
async fn search_with_paging() {
let mut server = Server::new_async().await;
let mut page_1_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
let mut next_link = page_1_body.link("next").unwrap().clone();
next_link.href = format!("{}/search", server.url());
page_1_body.set_link(next_link);
let page_1 = server
.mock("POST", "/search")
.match_body(Matcher::Json(json!({
"collections": ["sentinel-2-l2a"],
"limit": 1
})))
.with_body(serde_json::to_string(&page_1_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let page_2 = server
.mock("POST", "/search")
.match_body(Matcher::Json(json!({
"collections": ["sentinel-2-l2a"],
"limit": 1,
"token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
})))
.with_body(include_str!("../mocks/search-page-2.json"))
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = Client::new(&server.url()).unwrap();
let mut search = Search {
collections: vec!["sentinel-2-l2a".to_string()],
..Default::default()
};
search.items.limit = Some(1);
let items: Vec<_> = StreamItemsClient::search_stream(&client, search)
.await
.unwrap()
.map(|result| result.unwrap())
.take(2)
.collect()
.await;
page_1.assert_async().await;
page_2.assert_async().await;
assert_eq!(items.len(), 2);
assert!(items[0]["id"] != items[1]["id"]);
}
#[tokio::test]
async fn items_with_paging() {
let mut server = Server::new_async().await;
let mut page_1_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
let mut next_link = page_1_body.link("next").unwrap().clone();
let url: Url = next_link.href.as_str().parse().unwrap();
let query = url.query().unwrap();
next_link.href = format!(
"{}/collections/sentinel-2-l2a/items?{}",
server.url(),
query
);
page_1_body.set_link(next_link);
let page_1 = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
.with_body(serde_json::to_string(&page_1_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let page_2 = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
.with_body(include_str!("../mocks/items-page-2.json"))
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = Client::new(&server.url()).unwrap();
let items = Items {
limit: Some(1),
..Default::default()
};
let items: Vec<_> = client
.items("sentinel-2-l2a", Some(items))
.await
.unwrap()
.map(|result| result.unwrap())
.take(2)
.collect()
.await;
page_1.assert_async().await;
page_2.assert_async().await;
assert_eq!(items.len(), 2);
assert!(items[0]["id"] != items[1]["id"]);
}
#[tokio::test]
async fn stop_on_empty_page() {
let mut server = Server::new_async().await;
let mut page_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
let mut next_link = page_body.link("next").unwrap().clone();
let url: Url = next_link.href.as_str().parse().unwrap();
let query = url.query().unwrap();
next_link.href = format!(
"{}/collections/sentinel-2-l2a/items?{}",
server.url(),
query
);
page_body.set_link(next_link);
page_body.items = vec![];
let page = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
.with_body(serde_json::to_string(&page_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = Client::new(&server.url()).unwrap();
let items = Items {
limit: Some(1),
..Default::default()
};
let items: Vec<_> = client
.items("sentinel-2-l2a", Some(items))
.await
.unwrap()
.map(|result| result.unwrap())
.collect()
.await;
page.assert_async().await;
assert!(items.is_empty());
}
#[tokio::test]
async fn user_agent() {
let mut server = Server::new_async().await;
let _ = server
.mock("POST", "/search")
.with_body_from_file("mocks/items-page-1.json")
.match_header(
"user-agent",
format!("rustac/{}", env!("CARGO_PKG_VERSION")).as_str(),
)
.create_async()
.await;
let client = Client::new(&server.url()).unwrap();
let _ = client.search(Default::default()).await.unwrap();
}
#[tokio::test]
async fn custom_header() {
let mut server = Server::new_async().await;
let _ = server
.mock("POST", "/search")
.with_body_from_file("mocks/items-page-1.json")
.match_header("x-my-header", "value")
.match_header("x-my-other-header", "othervalue")
.create_async()
.await;
let headers = vec![
("x-my-header".to_string(), "value".to_string()),
("x-my-other-header".to_string(), "othervalue".to_string()),
];
let builder = ApiClientBuilder::new(&server.url())
.unwrap()
.with_headers(&headers)
.unwrap();
let client = builder.build().unwrap();
let _ = client.search(Default::default()).await.unwrap();
}
}