use std::{net, time::Duration};
use bytes::Bytes;
use reqwest::ClientBuilder;
use scion_proto::address::IsdAsn;
use thiserror::Error;
use url::Url;
use super::api::{AuthServerResponse, SnapsResponse, StatusResponse};
use crate::{
api::admin::api::{EndhostApisResponse, RoutersResponse, SetLinkStateRequest},
dto::IoConfigDto,
state::snap::SnapId,
};
#[derive(Debug, Clone)]
pub struct ApiClient {
client: reqwest::Client,
api: Url,
}
impl ApiClient {
pub fn new(url: &Url) -> Result<Self, ClientError> {
let api = url.join("api/v1/")?;
let client = ClientBuilder::new()
.tls_certs_only(std::iter::empty())
.timeout(Duration::from_secs(5))
.build()?;
Ok(ApiClient { client, api })
}
pub fn new_with_client(url: &Url, client: reqwest::Client) -> Result<Self, ClientError> {
let api = url.join("api/v1/")?;
Ok(ApiClient { client, api })
}
pub async fn get_status(&self) -> Result<StatusResponse, ClientError> {
self.get("status").await
}
pub async fn get_snaps(&self) -> Result<SnapsResponse, ClientError> {
self.get("snaps").await
}
pub async fn get_routers(&self) -> Result<RoutersResponse, ClientError> {
self.get("routers").await
}
pub async fn get_endhost_apis(&self) -> Result<EndhostApisResponse, ClientError> {
self.get("endhost_apis").await
}
pub async fn get_io_config(&self) -> Result<IoConfigDto, ClientError> {
self.get("io_config").await
}
pub async fn get_auth_server(&self) -> Result<AuthServerResponse, ClientError> {
self.get("auth_server").await
}
pub async fn set_link_state(
&self,
isd_as: IsdAsn,
interface_id: u16,
up: bool,
) -> Result<(), ClientError> {
let url = self.api.join("link_state")?;
let body = SetLinkStateRequest {
isd_as,
interface_id,
up,
};
let response = self.client.post(url).json(&body).send().await?;
match response.status() {
reqwest::StatusCode::OK => Ok(()),
_ => {
Err(ClientError::InvalidResponseStatus(
response.status(),
response.bytes().await?,
))
}
}
}
pub async fn delete_snap_connection(
&self,
snap_id: SnapId,
socket_addr: net::SocketAddr,
) -> Result<(), ClientError> {
let url = self
.api
.join(&format!("snaps/{snap_id}/connections/{socket_addr}"))?;
let response = self.client.delete(url).send().await?;
match response.status() {
reqwest::StatusCode::NO_CONTENT => Ok(()),
_ => {
Err(ClientError::InvalidResponseStatus(
response.status(),
response.bytes().await?,
))
}
}
}
async fn get<T>(&self, endpoint: &str) -> Result<T, ClientError>
where
T: serde::de::DeserializeOwned,
{
let url = self.api.join(endpoint)?;
let response = self.client.get(url).send().await?;
match response.status() {
reqwest::StatusCode::OK => {
let result = response.json::<T>().await?;
Ok(result)
}
reqwest::StatusCode::UNAUTHORIZED => {
Err(ClientError::Unauthorized(response.bytes().await?))
}
_ => {
Err(ClientError::InvalidResponseStatus(
response.status(),
response.bytes().await?,
))
}
}
}
}
#[derive(Error, Debug)]
pub enum ClientError {
#[error("invalid URL: {0:?}")]
InvalidURL(#[from] url::ParseError),
#[error("reqwest error: {0:?}")]
ReqwestError(#[from] reqwest::Error),
#[error("the request could not be authorized: {0:?}")]
Unauthorized(Bytes),
#[error("invalid response status ({0}): {1:?}")]
InvalidResponseStatus(reqwest::StatusCode, Bytes),
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! test_api_client {
($name:ident, $base_url:expr, $expected_url:expr) => {
#[test]
fn $name() {
scion_sdk_utils::rustls::select_ring_crypto_provider();
let client = ApiClient::new($base_url).expect("Failed to create ApiClient");
assert_eq!(client.api, Url::parse($expected_url).unwrap());
}
};
}
test_api_client!(
should_normalize_url_with_http_schema,
&"http://localhost:9000".parse().unwrap(),
"http://localhost:9000/api/v1/"
);
test_api_client!(
should_normalize_url_with_trailing_slash,
&"http://localhost:9000/".parse().unwrap(),
"http://localhost:9000/api/v1/"
);
test_api_client!(
should_normalize_url_with_https_schema,
&"https://localhost:9000".parse().unwrap(),
"https://localhost:9000/api/v1/"
);
}