use std::{
borrow::Cow,
fmt::{Display, Formatter},
};
use crate::{
credentials::{user_credentials::UserToken, Credentials, GetTokenError},
management::Membership,
user_agent::get_user_agent,
};
use cts_common::{
protocol::CreateWorkspaceResponse, CtsServiceDiscovery, Region, RegionError, ServiceDiscovery,
Workspace, WorkspaceId,
};
use miette::Diagnostic;
use reqwest::{header, Method};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use url::Url;
use uuid::Uuid;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateWorkspaceRequest {
region: Region,
name: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AccessKey {
pub key_id: String,
pub workspace_id: WorkspaceId,
pub key_name: String,
pub created_at: String,
pub last_used_at: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateAccessKeyInput {
workspace_id: WorkspaceId,
key_name: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OidcProvider {
pub id: Uuid,
pub issuer: Url,
pub vendor: OidcVendor,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateOidcProviderRequest<'u> {
issuer: Cow<'u, Url>,
vendor: OidcVendor,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateMembershipRequest {
user_id: String,
workspace_id: String,
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "lowercase")]
pub enum OidcVendor {
Auth0,
Okta,
Clerk,
}
impl OidcVendor {
pub fn as_str(&self) -> &'static str {
match self {
OidcVendor::Auth0 => "auth0",
OidcVendor::Okta => "okta",
OidcVendor::Clerk => "clerk",
}
}
}
impl Display for OidcVendor {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Diagnostic, Error, Debug)]
pub enum CtsClientError {
#[error(transparent)]
GetToken(#[from] GetTokenError),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("Request failed: {body}")]
ErrorResponse {
body: String,
#[source]
error: reqwest::Error,
},
#[error("Unauthorized")]
Unauthorized(Option<String>),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct RevokeAccessKeyInput {
workspace_id: WorkspaceId,
key_name: String,
}
#[derive(Debug, Deserialize)]
struct RevokeAccessKeyResponse {
message: String,
}
pub type CTSClient<C> = CtsClient<C>;
pub struct CtsClient<C: Credentials<Token = UserToken>> {
client: reqwest::Client,
base_url: Url,
credentials: C,
}
impl<C> CtsClient<C>
where
C: Credentials<Token = UserToken>,
{
pub fn new(base_url: Url, credentials: C) -> Self {
Self {
client: reqwest::Client::new(),
base_url,
credentials,
}
}
pub fn for_region(region: Region, credentials: C) -> Result<Self, RegionError> {
CtsServiceDiscovery::endpoint(region).map(|base_url| Self::new(base_url, credentials))
}
async fn build_request(
&self,
method: Method,
path: &str,
) -> Result<reqwest::RequestBuilder, CtsClientError> {
let url = {
let mut url = self.base_url.clone();
url.set_path(path);
url
};
let token = self.credentials.get_token().await?;
let builder = self
.client
.request(method, url)
.bearer_auth(token.access_token())
.header(header::USER_AGENT, get_user_agent());
Ok(builder)
}
async fn send_reqwest(
&self,
callback: impl FnOnce(&reqwest::Client) -> reqwest::RequestBuilder,
) -> Result<reqwest::Response, CtsClientError> {
let token = self.credentials.get_token().await?;
let response = callback(&self.client)
.header("authorization", token.as_header())
.header("user-agent", get_user_agent())
.send()
.await?;
if let Err(error) = response.error_for_status_ref() {
let body: Option<String> = response.text().await.ok();
return if error.status() == Some(reqwest::StatusCode::UNAUTHORIZED) {
Err(CtsClientError::Unauthorized(body))
} else {
Err(CtsClientError::ErrorResponse {
body: body.unwrap_or_default(),
error,
})
};
}
Ok(response)
}
pub async fn create_workspace(
&self,
name: &str,
region: Region,
) -> Result<CreateWorkspaceResponse, CtsClientError> {
let url = self.base_url.join("/api/workspaces").expect("Invalid url");
let body = CreateWorkspaceRequest {
region,
name: name.into(),
};
self.send_reqwest(|client| client.post(url).json(&body))
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn list_workspaces(&self) -> Result<Vec<Workspace>, CtsClientError> {
self.build_request(Method::GET, "/api/workspaces")
.await?
.send()
.await?
.json()
.await
.map_err(CtsClientError::from)
}
#[deprecated(
since = "0.22.2",
note = "Use `list_workspaces` instead, which returns all workspaces the user has access to."
)]
pub async fn list_all_workspaces(&self) -> Result<Vec<Workspace>, CtsClientError> {
self.build_request(Method::GET, "/api/meta/admin/workspaces")
.await?
.send()
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn create_access_key(
&self,
name: &str,
workspace_id: WorkspaceId,
) -> Result<String, CtsClientError> {
let url = self.base_url.join("/api/access-key").expect("Invalid url");
let body = CreateAccessKeyInput {
workspace_id,
key_name: name.into(),
};
let response = self
.send_reqwest(|client| client.post(url).json(&body))
.await?;
let access_key: String = response.text().await?;
Ok(access_key)
}
pub async fn list_access_keys(
&self,
workspace_id: WorkspaceId,
) -> Result<Vec<AccessKey>, CtsClientError> {
let endpoint = format!("/api/access-keys/{workspace_id}");
let url = self.base_url.join(&endpoint).expect("Invalid url");
let response = self.send_reqwest(|client| client.get(url)).await?;
let access_keys: Vec<AccessKey> = response.json().await?;
Ok(access_keys)
}
pub async fn revoke_access_key(
&self,
name: &str,
workspace_id: WorkspaceId,
) -> Result<String, CtsClientError> {
let url = self.base_url.join("/api/access-key").expect("Invalid url");
let body = RevokeAccessKeyInput {
workspace_id,
key_name: name.into(),
};
let response = self
.send_reqwest(|client| client.delete(url).json(&body))
.await?;
let revoke_ak_response: RevokeAccessKeyResponse = response.json().await?;
Ok(revoke_ak_response.message)
}
pub async fn list_oidc_providers(
&self,
workspace_id: WorkspaceId,
) -> Result<Vec<OidcProvider>, CtsClientError> {
let mut url = self.base_url.clone();
url.set_path("api/oidc/providers");
self.send_reqwest(|client| {
client
.get(url)
.header("x-cs-workspace-id", workspace_id.as_str())
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn create_oidc_provider(
&self,
workspace_id: WorkspaceId,
issuer: &Url,
vendor: OidcVendor,
) -> Result<OidcProvider, CtsClientError> {
let url = self
.base_url
.join("/api/oidc/providers")
.expect("Invalid url");
let body = CreateOidcProviderRequest {
issuer: Cow::Borrowed(issuer),
vendor,
};
self.send_reqwest(|client| {
client
.post(url)
.header("x-cs-workspace-id", workspace_id.as_str())
.json(&body)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn delete_oidc_provider(
&self,
workspace_id: WorkspaceId,
provider_id: Uuid,
) -> Result<(), CtsClientError> {
let url = self
.base_url
.join(&format!("/api/oidc/providers/{provider_id}"))
.expect("Invalid url");
self.send_reqwest(|client| {
client
.delete(url)
.header("x-cs-workspace-id", workspace_id.as_str())
})
.await?;
Ok(())
}
pub async fn create_membership(
&self,
user_id: &str,
workspace_id: &str,
) -> Result<Membership, CtsClientError> {
let url = self
.base_url
.join("/api/meta/admin/memberships")
.expect("Invalid url");
let body = CreateMembershipRequest {
user_id: user_id.into(),
workspace_id: workspace_id.into(),
};
let membership = self
.send_reqwest(|client| client.post(url).json(&body))
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(membership)
}
pub async fn list_memberships(&self) -> Result<Vec<Membership>, CtsClientError> {
let url = self
.base_url
.join("/api/meta/admin/memberships")
.expect("Invalid url");
let memberships = self
.send_reqwest(|client| client.get(url))
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(memberships)
}
pub async fn list_memberships_for_workspace(
&self,
ws_id: WorkspaceId,
) -> Result<Vec<Membership>, CtsClientError> {
let url = self
.base_url
.join(&format!("/api/meta/admin/workspaces/{ws_id}/memberships"))
.expect("Invalid url");
let memberships = self
.send_reqwest(|client| client.get(url))
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(memberships)
}
pub async fn delete_membership(
&self,
membership_id: Uuid,
) -> Result<Membership, CtsClientError> {
let url = self
.base_url
.join(&format!("/api/meta/admin/memberships/{membership_id}"))
.expect("Invalid url");
let membership = self
.send_reqwest(|client| client.delete(url))
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(membership)
}
}