use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use serde::Serialize;
use tokio::sync::RwLock;
use crate::error::{AkribesError, Result};
#[derive(Debug)]
pub(crate) struct Inner {
pub(crate) base_url: String,
pub(crate) project_id: Option<i64>,
pub(crate) name: String,
pub(crate) id: String,
pub(crate) http: reqwest::Client,
pub(crate) token: Arc<RwLock<Option<String>>>,
pub(crate) on_behalf_of: Arc<RwLock<Option<String>>>,
pub(crate) heartbeat_handle: Mutex<Option<tokio::task::AbortHandle>>,
pub(crate) shutdown: Arc<AtomicBool>,
pub(crate) schema_cache: Mutex<HashMap<String, Vec<(String, String)>>>,
pub(crate) broken_scripts: Mutex<HashSet<String>>,
pub(crate) ingest_poll_timeout: Duration,
}
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub(crate) const DEFAULT_INGEST_POLL_TIMEOUT_SECS: u64 = 300;
pub(crate) fn ingest_poll_timeout_from_env() -> Option<Duration> {
let raw = std::env::var("AKRIBES_SDK_INGEST_TIMEOUT_SECS").ok()?;
match raw.trim().parse::<u64>() {
Ok(0) => None,
Ok(n) => Some(Duration::from_secs(n)),
Err(_) => None,
}
}
fn default_http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(DEFAULT_REQUEST_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.build()
.unwrap_or_else(|_| reqwest::Client::new())
}
const MAX_ERROR_BODY_BYTES: usize = 64 * 1024;
const DECODE_ERROR_SNIPPET_BYTES: usize = 512;
async fn read_body_capped(res: reqwest::Response) -> String {
use futures::StreamExt;
let mut buf: Vec<u8> = Vec::new();
let mut truncated = false;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let Ok(chunk) = chunk else { break };
let remaining = MAX_ERROR_BODY_BYTES.saturating_sub(buf.len());
if remaining == 0 {
truncated = true;
break;
}
let take = chunk.len().min(remaining);
buf.extend_from_slice(&chunk[..take]);
if take < chunk.len() {
truncated = true;
break;
}
}
let mut s = String::from_utf8_lossy(&buf).into_owned();
if truncated {
s.push_str("… (truncated)");
}
s
}
fn body_snippet(bytes: &[u8]) -> String {
let total = bytes.len();
let cut = total.min(DECODE_ERROR_SNIPPET_BYTES);
let s: String = String::from_utf8_lossy(&bytes[..cut]).into_owned();
if total > cut {
format!("{s}… (truncated, {total} bytes total)")
} else {
s
}
}
pub(crate) async fn decode_json<T: serde::de::DeserializeOwned>(
res: reqwest::Response,
) -> Result<T> {
let status = res.status();
let url = res.url().to_string();
if status == reqwest::StatusCode::NOT_FOUND {
let body = read_body_capped(res).await;
let message = if body.trim().is_empty() {
format!("HTTP 404 Not Found (url: {url})")
} else {
format!("HTTP 404 (url: {url}): {body}")
};
return Err(AkribesError::HttpStatus {
status: 404,
message,
});
}
let bytes = res.bytes().await?;
serde_json::from_slice::<T>(&bytes).map_err(|e| {
AkribesError::Other(format!(
"failed to decode response from {url} as {ty}: {e}; body: {snippet}",
ty = std::any::type_name::<T>(),
snippet = body_snippet(&bytes),
))
})
}
#[derive(Clone, Debug)]
pub struct AkribesClient {
pub(crate) inner: Arc<Inner>,
}
impl AkribesClient {
#[deprecated(since = "0.4.0", note = "use AkribesClient::builder(base_url) instead")]
pub fn new(base_url: &str, project_id: i64, name: &str, id: &str) -> Self {
Self {
inner: Arc::new(Inner {
base_url: base_url.trim_end_matches('/').to_string(),
project_id: Some(project_id),
name: name.to_string(),
id: id.to_string(),
http: default_http_client(),
token: Arc::new(RwLock::new(None)),
on_behalf_of: Arc::new(RwLock::new(None)),
heartbeat_handle: Mutex::new(None),
shutdown: Arc::new(AtomicBool::new(false)),
schema_cache: Mutex::new(HashMap::new()),
broken_scripts: Mutex::new(HashSet::new()),
ingest_poll_timeout: ingest_poll_timeout_from_env()
.unwrap_or(Duration::from_secs(DEFAULT_INGEST_POLL_TIMEOUT_SECS)),
}),
}
}
pub fn builder(base_url: impl Into<String>) -> AkribesClientBuilder {
AkribesClientBuilder {
base_url: base_url.into().trim_end_matches('/').to_string(),
project_id: None,
name: None,
id: None,
token: None,
on_behalf_of: None,
http_client: None,
ingest_poll_timeout: None,
}
}
pub async fn set_token(&self, token: Option<String>) {
*self.inner.token.write().await = token;
}
pub async fn set_on_behalf_of(&self, email: Option<String>) {
*self.inner.on_behalf_of.write().await = email;
}
pub fn project_id(&self) -> Option<i64> {
self.inner.project_id
}
pub fn base_url(&self) -> &str {
&self.inner.base_url
}
pub async fn get_json_value(&self, url: &str) -> Result<serde_json::Value> {
let res = self.send(self.inner.http.get(url)).await?;
if res.status() == reqwest::StatusCode::NOT_FOUND {
return Err(AkribesError::HttpStatus {
status: 404,
message: format!("GET {url} returned 404"),
});
}
decode_json(res).await
}
pub async fn get_json_value_opt(&self, url: &str) -> Result<Option<serde_json::Value>> {
let res = self.send(self.inner.http.get(url)).await?;
if res.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
Ok(Some(decode_json(res).await?))
}
pub async fn post_json_value(
&self,
url: &str,
body: &serde_json::Value,
) -> Result<serde_json::Value> {
let res = self.send(self.inner.http.post(url).json(body)).await?;
if res.status().as_u16() == 204 || res.content_length() == Some(0) {
return Ok(serde_json::Value::Null);
}
decode_json(res).await
}
pub fn ingest_poll_timeout(&self) -> Duration {
self.inner.ingest_poll_timeout
}
pub(crate) async fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let token_guard = self.inner.token.read().await;
let mut builder = match token_guard.as_deref() {
Some(t) => builder.bearer_auth(t),
None => builder,
};
drop(token_guard);
let obo_guard = self.inner.on_behalf_of.read().await;
if let Some(email) = obo_guard.as_deref() {
builder = builder.header("X-Akribes-User", email);
}
builder
}
pub(crate) async fn send(&self, req: reqwest::RequestBuilder) -> Result<reqwest::Response> {
let req = self.authed(req).await;
let res = req.send().await?;
let status = res.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
let body = read_body_capped(res).await;
let message = if body.trim().is_empty() {
format!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
)
} else {
format!("HTTP {}: {}", status.as_u16(), body)
};
return Err(AkribesError::Fatal {
message,
execution_id: None,
});
}
if status.as_u16() == 429
|| status == reqwest::StatusCode::INTERNAL_SERVER_ERROR
|| status == reqwest::StatusCode::BAD_GATEWAY
|| status == reqwest::StatusCode::SERVICE_UNAVAILABLE
|| status == reqwest::StatusCode::GATEWAY_TIMEOUT
{
let retry_after = res
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(std::time::Duration::from_secs);
let body = read_body_capped(res).await;
let message = if body.trim().is_empty() {
format!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
)
} else {
format!("HTTP {}: {}", status.as_u16(), body)
};
return Err(AkribesError::Transient {
message,
execution_id: None,
retry_after,
status: Some(status.as_u16()),
});
}
if !status.is_success() && status != reqwest::StatusCode::NOT_FOUND {
let msg = read_body_capped(res).await;
if status.as_u16() == 409 {
if let Ok(body) = serde_json::from_str::<serde_json::Value>(&msg) {
if body.get("error_type").and_then(|v| v.as_str())
== Some("suite_already_exists")
{
if let Some(id) = body.get("existing_suite_id").and_then(|v| v.as_i64()) {
return Err(AkribesError::AlreadyExists {
message: body
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("already exists")
.to_string(),
existing_id: id,
});
}
}
}
}
return Err(AkribesError::HttpStatus {
status: status.as_u16(),
message: msg,
});
}
Ok(res)
}
pub(crate) async fn get_opt<T: serde::de::DeserializeOwned>(
&self,
url: &str,
) -> Result<Option<T>> {
let res = self.send(self.inner.http.get(url)).await?;
if res.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
Ok(Some(decode_json(res).await?))
}
pub(crate) async fn get_list<T: serde::de::DeserializeOwned>(
&self,
url: &str,
) -> Result<Vec<T>> {
Ok(self.get_opt::<Vec<T>>(url).await?.unwrap_or_default())
}
pub(crate) fn url_with_query<Q: Serialize>(base: &str, q: &Q) -> String {
match serde_urlencoded::to_string(q) {
Ok(qs) if !qs.is_empty() => format!("{base}?{qs}"),
_ => base.to_string(),
}
}
pub(crate) async fn post<B: Serialize, T: serde::de::DeserializeOwned>(
&self,
url: &str,
body: &B,
) -> Result<T> {
let res = self.send(self.inner.http.post(url).json(body)).await?;
decode_json(res).await
}
pub(crate) async fn patch<B: Serialize, T: serde::de::DeserializeOwned>(
&self,
url: &str,
body: &B,
) -> Result<T> {
let res = self.send(self.inner.http.patch(url).json(body)).await?;
decode_json(res).await
}
pub(crate) async fn patch_empty<B: Serialize>(&self, url: &str, body: &B) -> Result<()> {
self.send(self.inner.http.patch(url).json(body)).await?;
Ok(())
}
pub(crate) async fn put_empty<B: Serialize>(&self, url: &str, body: &B) -> Result<()> {
self.send(self.inner.http.put(url).json(body)).await?;
Ok(())
}
pub(crate) async fn put_json<B: Serialize, T: serde::de::DeserializeOwned>(
&self,
url: &str,
body: &B,
) -> Result<T> {
let res = self.send(self.inner.http.put(url).json(body)).await?;
decode_json(res).await
}
pub(crate) async fn delete(&self, url: &str) -> Result<bool> {
let res = self.send(self.inner.http.delete(url)).await?;
Ok(res.status() != reqwest::StatusCode::NOT_FOUND)
}
pub(crate) async fn delete_json<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T> {
let res = self.send(self.inner.http.delete(url)).await?;
decode_json(res).await
}
pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
&self,
url: &str,
form: reqwest::multipart::Form,
) -> Result<T> {
let res = self.send(self.inner.http.post(url).multipart(form)).await?;
decode_json(res).await
}
}
impl AkribesClient {
pub fn projects(&self) -> crate::sub::projects::ProjectsClient {
crate::sub::projects::ProjectsClient::new(Arc::clone(&self.inner))
}
pub fn executions(&self) -> crate::sub::executions::ExecutionsClient {
crate::sub::executions::ExecutionsClient::new(Arc::clone(&self.inner))
}
pub fn tokens(&self) -> crate::sub::tokens::TokensClient {
crate::sub::tokens::TokensClient::new(Arc::clone(&self.inner))
}
pub fn convert(&self) -> crate::sub::convert::ConvertClient {
crate::sub::convert::ConvertClient::new(Arc::clone(&self.inner))
}
pub fn bench_runs(&self) -> crate::sub::bench::BenchRunsClient {
crate::sub::bench::BenchRunsClient::new(Arc::clone(&self.inner))
}
pub fn project(&self, project_id: i64) -> ProjectScope {
ProjectScope {
inner: Arc::clone(&self.inner),
project_id,
}
}
pub fn scoped(&self) -> Result<ProjectScope> {
let pid = self
.inner
.project_id
.ok_or(AkribesError::MissingProjectId)?;
Ok(ProjectScope {
inner: Arc::clone(&self.inner),
project_id: pid,
})
}
pub async fn get_state(&self) -> crate::error::Result<serde_json::Value> {
let url = format!("{}/state", self.inner.base_url);
Ok(self
.get_opt::<serde_json::Value>(&url)
.await?
.unwrap_or(serde_json::json!({})))
}
pub async fn get_sandbox_project_id(&self) -> crate::error::Result<i64> {
let url = format!("{}/me/sandbox", self.inner.base_url);
let body: crate::models::SandboxProjectIdResponse =
self.send(self.inner.http.get(&url)).await?.json().await?;
Ok(body.project_id)
}
pub async fn run_adhoc(
&self,
source: &str,
inputs: Option<std::collections::HashMap<String, serde_json::Value>>,
breakpoint_lines: Option<Vec<usize>>,
) -> crate::error::Result<crate::models::AdhocRunResult> {
self.run_adhoc_with(source, inputs, breakpoint_lines, None, None)
.await
}
pub async fn run_adhoc_with(
&self,
source: &str,
inputs: Option<std::collections::HashMap<String, serde_json::Value>>,
breakpoint_lines: Option<Vec<usize>>,
channel: Option<&str>,
triggered_by: Option<&str>,
) -> crate::error::Result<crate::models::AdhocRunResult> {
let url = format!("{}/execute", self.inner.base_url);
self.post(
&url,
&crate::models::AdhocRunRequest {
source,
inputs,
breakpoint_lines,
channel,
triggered_by,
},
)
.await
}
pub async fn adhoc_event_stream(
&self,
project_id: i64,
) -> crate::error::Result<(
tokio::sync::mpsc::UnboundedReceiver<crate::models::EngineEvent>,
crate::sub::events::EventSubscription,
)> {
self.adhoc_event_stream_with_ready(project_id, None).await
}
pub async fn adhoc_event_stream_with_ready(
&self,
project_id: i64,
ready: Option<Arc<tokio::sync::Notify>>,
) -> crate::error::Result<(
tokio::sync::mpsc::UnboundedReceiver<crate::models::EngineEvent>,
crate::sub::events::EventSubscription,
)> {
use crate::models::HubEvent;
use tokio::sync::oneshot;
let (hub_tx, mut hub_rx) = tokio::sync::mpsc::unbounded_channel();
let (engine_tx, engine_rx) = tokio::sync::mpsc::unbounded_channel();
let (ready_tx, ready_rx) = oneshot::channel::<crate::error::Result<()>>();
if let Some(notify) = ready.clone() {
tokio::spawn(async move {
if let Ok(Ok(())) = ready_rx.await {
notify.notify_one();
}
});
} else {
tokio::spawn(async move {
let _ = ready_rx.await;
});
}
let http = self.inner.http.clone();
let token = self.inner.token.clone();
let base_url = self.inner.base_url.clone();
let sse_handle = tokio::spawn(async move {
let _ = crate::sub::events::stream_sse_with_retry(
http,
token,
base_url,
project_id,
Some("adhoc".to_string()),
hub_tx,
Some(ready_tx),
)
.await;
});
let filter_handle = tokio::spawn(async move {
while let Some(evt) = hub_rx.recv().await {
if let HubEvent::Execution { event, .. } = evt {
if engine_tx.send(event).is_err() {
break;
}
}
}
sse_handle.abort();
});
Ok((
engine_rx,
crate::sub::events::EventSubscription::from_handle(filter_handle),
))
}
}
#[derive(Clone, Debug)]
pub struct ProjectScope {
pub(crate) inner: Arc<Inner>,
pub(crate) project_id: i64,
}
impl ProjectScope {
pub fn project_id(&self) -> i64 {
self.project_id
}
pub fn client(&self) -> AkribesClient {
AkribesClient {
inner: Arc::clone(&self.inner),
}
}
pub fn scripts(&self) -> crate::sub::scripts::ScriptsClient {
crate::sub::scripts::ScriptsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn drafts(&self) -> crate::sub::drafts::DraftsClient {
crate::sub::drafts::DraftsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn versions(&self) -> crate::sub::versions::VersionsClient {
crate::sub::versions::VersionsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn channels(&self) -> crate::sub::channels::ChannelsClient {
crate::sub::channels::ChannelsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn bench(&self) -> crate::sub::bench::BenchClient {
crate::sub::bench::BenchClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn events(&self) -> crate::sub::events::EventsClient {
crate::sub::events::EventsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn registered_clients(&self) -> crate::sub::clients::RegisteredClientsClient {
crate::sub::clients::RegisteredClientsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn executions(&self) -> crate::sub::executions::ScopedExecutionsClient {
crate::sub::executions::ScopedExecutionsClient::new(
Arc::clone(&self.inner),
self.project_id,
)
}
pub fn mcp(&self) -> crate::sub::mcp::McpClient {
crate::sub::mcp::McpClient::new(Arc::clone(&self.inner), self.project_id)
}
pub fn documents(&self) -> crate::sub::documents::DocumentsClient {
crate::sub::documents::DocumentsClient::new(Arc::clone(&self.inner), self.project_id)
}
pub async fn convert_file(
&self,
filename: &str,
data: Vec<u8>,
) -> Result<crate::models::ConvertResult> {
crate::sub::convert::ConvertClient::new(Arc::clone(&self.inner))
.convert_file_for_project(self.project_id, filename, data)
.await
}
}
impl Drop for AkribesClient {
fn drop(&mut self) {
if Arc::strong_count(&self.inner) == 1 {
self.inner.shutdown.store(true, Ordering::Release);
if let Ok(mut h) = self.inner.heartbeat_handle.lock() {
if let Some(handle) = h.take() {
handle.abort();
}
}
}
}
}
#[must_use = "a builder does nothing until .build() is called"]
pub struct AkribesClientBuilder {
base_url: String,
project_id: Option<i64>,
name: Option<String>,
id: Option<String>,
token: Option<String>,
on_behalf_of: Option<String>,
http_client: Option<reqwest::Client>,
ingest_poll_timeout: Option<Duration>,
}
impl AkribesClientBuilder {
pub fn project_id(mut self, project_id: i64) -> Self {
self.project_id = Some(project_id);
self
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
pub fn on_behalf_of(mut self, email: impl Into<String>) -> Self {
self.on_behalf_of = Some(email.into());
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = Some(client);
self
}
pub fn ingest_poll_timeout(mut self, timeout: Duration) -> Self {
self.ingest_poll_timeout = Some(timeout);
self
}
pub fn build(self) -> AkribesClient {
let ingest_poll_timeout = self
.ingest_poll_timeout
.or_else(ingest_poll_timeout_from_env)
.unwrap_or(Duration::from_secs(DEFAULT_INGEST_POLL_TIMEOUT_SECS));
AkribesClient {
inner: Arc::new(Inner {
base_url: self.base_url,
project_id: self.project_id,
name: self.name.unwrap_or_else(|| "rust-sdk".to_string()),
id: self.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
http: self.http_client.unwrap_or_else(default_http_client),
token: Arc::new(RwLock::new(self.token)),
on_behalf_of: Arc::new(RwLock::new(self.on_behalf_of)),
heartbeat_handle: Mutex::new(None),
shutdown: Arc::new(AtomicBool::new(false)),
schema_cache: Mutex::new(HashMap::new()),
broken_scripts: Mutex::new(HashSet::new()),
ingest_poll_timeout,
}),
}
}
}