use crate::{management::Membership, user_agent::get_user_agent};
use cts_common::claims::Role;
use cts_common::{protocol::CreateWorkspaceResponse, Region, Workspace, WorkspaceId};
use miette::Diagnostic;
use reqwest::header;
use serde::{Deserialize, Serialize};
use stack_auth::{AuthError, AuthStrategy, ServiceToken};
use std::{
borrow::Cow,
fmt::{Display, Formatter},
};
use thiserror::Error;
use url::Url;
use uuid::Uuid;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateWorkspaceRequest {
region: Region,
name: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AccessKey {
pub key_id: String,
pub workspace_id: WorkspaceId,
pub key_name: String,
pub created_at: String,
pub last_used_at: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateAccessKeyInput {
workspace_id: WorkspaceId,
key_name: String,
role: Role,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct CreateAccessKeyResponse {
access_key: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OidcProvider {
pub id: Uuid,
pub issuer: Url,
pub vendor: OidcVendor,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateOidcProviderRequest<'u> {
issuer: Cow<'u, Url>,
vendor: OidcVendor,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateMembershipRequest {
user_id: String,
workspace_id: String,
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "lowercase")]
pub enum OidcVendor {
Auth0,
Okta,
Clerk,
}
impl OidcVendor {
pub fn as_str(&self) -> &'static str {
match self {
OidcVendor::Auth0 => "auth0",
OidcVendor::Okta => "okta",
OidcVendor::Clerk => "clerk",
}
}
}
impl Display for OidcVendor {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Diagnostic, Error, Debug)]
pub enum CtsClientError {
#[error(transparent)]
Auth(#[from] AuthError),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("Request failed: {body}")]
ErrorResponse {
body: String,
#[source]
error: reqwest::Error,
},
#[error("Unauthorized")]
Unauthorized(Option<String>),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct RevokeAccessKeyInput {
workspace_id: WorkspaceId,
key_name: String,
}
#[derive(Debug, Deserialize)]
struct RevokeAccessKeyResponse {
message: String,
}
pub type CTSClient<C> = CtsClient<C>;
pub struct CtsClient<C>
where
C: Send + Sync + 'static,
for<'a> &'a C: AuthStrategy,
{
client: reqwest::Client,
auth_strategy: C,
#[cfg(any(test, feature = "test-utils"))]
base_url_override: Option<Url>,
}
impl<C> CtsClient<C>
where
C: Send + Sync + 'static,
for<'a> &'a C: AuthStrategy,
{
pub fn new(credentials: C) -> Self {
Self {
client: reqwest::Client::new(),
auth_strategy: credentials,
#[cfg(any(test, feature = "test-utils"))]
base_url_override: None,
}
}
#[cfg(any(test, feature = "test-utils"))]
pub fn with_base_url(mut self, base_url: Url) -> Self {
self.base_url_override = Some(base_url);
self
}
fn base_url(&self, token: &ServiceToken) -> Result<Url, CtsClientError> {
#[cfg(any(test, feature = "test-utils"))]
if let Some(url) = &self.base_url_override {
return Ok(url.clone());
}
token.issuer().cloned().map_err(CtsClientError::Auth)
}
async fn send_reqwest(
&self,
callback: impl FnOnce(&reqwest::Client, &Url) -> reqwest::RequestBuilder,
) -> Result<reqwest::Response, CtsClientError> {
let token = (&self.auth_strategy).get_token().await?;
let base_url = self.base_url(&token)?;
let response = callback(&self.client, &base_url)
.bearer_auth(token.as_str())
.header(header::USER_AGENT, get_user_agent())
.send()
.await?;
if let Err(error) = response.error_for_status_ref() {
let body: Option<String> = response.text().await.ok();
return if error.status() == Some(reqwest::StatusCode::UNAUTHORIZED) {
Err(CtsClientError::Unauthorized(body))
} else {
Err(CtsClientError::ErrorResponse {
body: body.unwrap_or_default(),
error,
})
};
}
Ok(response)
}
pub async fn create_workspace(
&self,
name: &str,
region: Region,
) -> Result<CreateWorkspaceResponse, CtsClientError> {
let body = CreateWorkspaceRequest {
region,
name: name.into(),
};
self.send_reqwest(|client, base_url| {
let url = base_url.join("/api/workspaces").expect("Invalid url");
client.post(url).json(&body)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn list_workspaces(&self) -> Result<Vec<Workspace>, CtsClientError> {
self.send_reqwest(|client, base_url| {
let url = base_url.join("/api/workspaces").expect("Invalid url");
client.get(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn create_access_key(
&self,
name: &str,
workspace_id: WorkspaceId,
role: Role,
) -> Result<String, CtsClientError> {
let body = CreateAccessKeyInput {
workspace_id,
key_name: name.into(),
role,
};
let response = self
.send_reqwest(|client, base_url| {
let url = base_url.join("/api/access-keys").expect("Invalid url");
client
.post(url)
.header("x-cs-workspace-id", workspace_id.as_str())
.json(&body)
})
.await?;
let CreateAccessKeyResponse { access_key } = response.json().await?;
Ok(access_key)
}
pub async fn list_access_keys(&self) -> Result<Vec<AccessKey>, CtsClientError> {
self.send_reqwest(|client, base_url| {
let url = base_url.join("/api/access-keys").expect("Invalid url");
client.get(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn revoke_access_key(
&self,
name: &str,
workspace_id: WorkspaceId,
) -> Result<String, CtsClientError> {
let body = RevokeAccessKeyInput {
workspace_id,
key_name: name.into(),
};
let response = self
.send_reqwest(|client, base_url| {
let url = base_url.join("/api/access-key").expect("Invalid url");
client.delete(url).json(&body)
})
.await?;
let revoke_ak_response: RevokeAccessKeyResponse = response.json().await?;
Ok(revoke_ak_response.message)
}
pub async fn list_oidc_providers(&self) -> Result<Vec<OidcProvider>, CtsClientError> {
self.send_reqwest(|client, base_url| {
let mut url = base_url.clone();
url.set_path("api/oidc/providers");
client.get(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn create_oidc_provider(
&self,
issuer: &Url,
vendor: OidcVendor,
) -> Result<OidcProvider, CtsClientError> {
let body = CreateOidcProviderRequest {
issuer: Cow::Borrowed(issuer),
vendor,
};
self.send_reqwest(|client, base_url| {
let url = base_url.join("/api/oidc/providers").expect("Invalid url");
client.post(url).json(&body)
})
.await?
.json()
.await
.map_err(CtsClientError::from)
}
pub async fn delete_oidc_provider(&self, provider_id: Uuid) -> Result<(), CtsClientError> {
self.send_reqwest(|client, base_url| {
let url = base_url
.join(&format!("/api/oidc/providers/{provider_id}"))
.expect("Invalid url");
client.delete(url)
})
.await?;
Ok(())
}
pub async fn create_membership(
&self,
user_id: &str,
workspace_id: &str,
) -> Result<Membership, CtsClientError> {
let body = CreateMembershipRequest {
user_id: user_id.into(),
workspace_id: workspace_id.into(),
};
let membership = self
.send_reqwest(|client, base_url| {
let url = base_url.join("/api/memberships").expect("Invalid url");
client.post(url).json(&body)
})
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(membership)
}
pub async fn list_memberships(&self) -> Result<Vec<Membership>, CtsClientError> {
let memberships = self
.send_reqwest(|client, base_url| {
let url = base_url.join("/api/memberships").expect("Invalid url");
client.get(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(memberships)
}
pub async fn list_memberships_for_workspace(
&self,
ws_id: WorkspaceId,
) -> Result<Vec<Membership>, CtsClientError> {
let memberships = self
.send_reqwest(|client, base_url| {
let url = base_url
.join(&format!("/api/workspaces/{ws_id}/memberships"))
.expect("Invalid url");
client.get(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(memberships)
}
pub async fn delete_membership(
&self,
membership_id: Uuid,
) -> Result<Membership, CtsClientError> {
let membership = self
.send_reqwest(|client, base_url| {
let url = base_url
.join(&format!("/api/memberships/{membership_id}"))
.expect("Invalid url");
client.delete(url)
})
.await?
.json()
.await
.map_err(CtsClientError::from)?;
Ok(membership)
}
}