use crate::config::{AxonFlowConfig, Mode};
use crate::error::AxonFlowError;
use crate::heartbeat::maybe_send_heartbeat;
use crate::types::agent::{ClientRequest, ClientResponse};
use base64::engine::general_purpose::STANDARD as BASE64_STD;
use base64::Engine as _;
use moka::future::Cache;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
const LICENSE_KEY_HEADER: &str = "X-License-Key";
#[derive(Clone)]
pub struct AxonFlowClient {
config: AxonFlowConfig,
http_client: reqwest::Client,
map_http_client: reqwest::Client,
cache: Option<Arc<Cache<String, ClientResponse>>>,
}
impl AxonFlowClient {
pub fn new(mut config: AxonFlowConfig) -> Result<Self, AxonFlowError> {
if config.retry.max_attempts == 0 {
return Err(AxonFlowError::ConfigError(
"retry.max_attempts must be at least 1".to_string(),
));
}
if std::env::var("AXONFLOW_TRY").unwrap_or_default() == "1" {
config.endpoint = "https://try.getaxonflow.com".to_string();
if config.client_id.is_none() {
return Err(AxonFlowError::ConfigError(
"ClientID is required in try mode (AXONFLOW_TRY=1).".to_string(),
));
}
}
if config.client_secret.is_some() && config.client_id.is_none() {
warn!("ClientID is required when ClientSecret is set.");
}
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
"User-Agent",
HeaderValue::from_static(concat!("axonflow-sdk-rust/", env!("CARGO_PKG_VERSION"))),
);
let basic_id = config
.client_id
.clone()
.unwrap_or_else(|| "community".to_string());
let basic_secret = config.client_secret.clone().unwrap_or_default();
let basic_credentials = BASE64_STD.encode(format!("{}:{}", basic_id, basic_secret));
let basic_value = format!("Basic {}", basic_credentials);
if let Ok(val) = HeaderValue::from_str(&basic_value) {
headers.insert(AUTHORIZATION, val);
}
if let Some(license_key) = &config.license_key {
if let Ok(mut val) = HeaderValue::from_str(license_key) {
val.set_sensitive(true);
headers.insert(LICENSE_KEY_HEADER, val);
}
}
let accept_invalid = config.insecure_skip_tls_verify
|| std::env::var("AXONFLOW_INSECURE_TLS").unwrap_or_default() == "1";
if accept_invalid {
warn!("TLS certificate verification is disabled.");
}
let http_client = reqwest::Client::builder()
.timeout(config.timeout)
.default_headers(headers.clone())
.danger_accept_invalid_certs(accept_invalid)
.build()
.map_err(AxonFlowError::HttpError)?;
let map_http_client = reqwest::Client::builder()
.timeout(config.map_timeout)
.default_headers(headers)
.danger_accept_invalid_certs(accept_invalid)
.build()
.map_err(AxonFlowError::HttpError)?;
let cache = if config.cache.enabled {
Some(Arc::new(
Cache::builder().time_to_live(config.cache.ttl).build(),
))
} else {
None
};
maybe_send_heartbeat(&config.endpoint);
Ok(Self {
config,
http_client,
map_http_client,
cache,
})
}
pub async fn proxy_llm_call(
&self,
user_token: &str,
query: &str,
request_type: &str,
context: HashMap<String, serde_json::Value>,
) -> Result<ClientResponse, AxonFlowError> {
let user_token = if user_token.is_empty() {
"anonymous"
} else {
user_token
};
let is_mutation = matches!(
request_type,
"execute-plan" | "generate-plan" | "cancel-plan" | "update-plan"
);
if !is_mutation {
if let Some(cache) = &self.cache {
let cache_key = self.build_cache_key(request_type, query, user_token, &context);
if let Some(cached) = cache.get(&cache_key).await {
debug!("Cache hit for query");
return Ok(cached);
}
}
}
let req = ClientRequest {
query: query.to_string(),
user_token: user_token.to_string(),
client_id: self.config.client_id.clone(),
request_type: request_type.to_string(),
context,
media: None,
};
let resp = if self.config.retry.enabled && !is_mutation {
self.execute_with_retry(&req).await
} else {
self.execute_request(&req).await
};
match resp {
Ok(response) => {
if response.success && !is_mutation {
if let Some(cache) = &self.cache {
let cache_key =
self.build_cache_key(request_type, query, user_token, &req.context);
cache.insert(cache_key, response.clone()).await;
}
}
Ok(response)
}
Err(e) => {
if self.config.mode == Mode::Production && e.is_fail_open_eligible() {
debug!("AxonFlow unavailable, failing open: {}", e);
Ok(ClientResponse::fail_open(e))
} else {
Err(e)
}
}
}
}
pub async fn list_connectors(
&self,
) -> Result<Vec<crate::types::agent::ConnectorMetadata>, AxonFlowError> {
let url = format!("{}/api/v1/connectors", self.config.endpoint);
let resp = self.checked_get(&url).await?;
let body: serde_json::Value = resp.json().await?;
let connectors = body["connectors"]
.as_array()
.ok_or_else(|| AxonFlowError::ApiError {
status: 200,
message: "response missing 'connectors' field".to_string(),
})?;
let result = serde_json::from_value(serde_json::Value::Array(connectors.clone()))?;
Ok(result)
}
pub async fn get_connector(
&self,
connector_id: &str,
) -> Result<crate::types::agent::ConnectorMetadata, AxonFlowError> {
let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
let url = format!("{}/api/v1/connectors/{}", self.config.endpoint, encoded_id);
let resp = self.checked_get(&url).await?;
Ok(resp.json().await?)
}
pub async fn get_connector_health(
&self,
connector_id: &str,
) -> Result<crate::types::agent::ConnectorHealthStatus, AxonFlowError> {
let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
let url = format!(
"{}/api/v1/connectors/{}/health",
self.config.endpoint, encoded_id
);
let resp = self.checked_get(&url).await?;
Ok(resp.json().await?)
}
pub async fn install_connector(
&self,
req: crate::types::agent::ConnectorInstallRequest,
) -> Result<(), AxonFlowError> {
let encoded_id = utf8_percent_encode(&req.connector_id, NON_ALPHANUMERIC);
let url = format!(
"{}/api/v1/connectors/{}/install",
self.config.endpoint, encoded_id
);
let resp = self.http_client.post(&url).json(&req).send().await?;
Self::check_status(resp).await?;
Ok(())
}
pub async fn query_connector(
&self,
user_token: &str,
connector_name: &str,
query: &str,
params: HashMap<String, serde_json::Value>,
) -> Result<crate::types::agent::ConnectorResponse, AxonFlowError> {
let mut context = HashMap::new();
context.insert("connector".to_string(), serde_json::json!(connector_name));
context.insert("params".to_string(), serde_json::json!(params));
let resp = self
.proxy_llm_call(user_token, query, "mcp-query", context)
.await?;
Ok(crate::types::agent::ConnectorResponse {
success: resp.success,
data: resp.data.unwrap_or(serde_json::Value::Null),
error: resp.error,
meta: resp.metadata,
redacted: false,
redacted_fields: Vec::new(),
policy_info: None,
})
}
pub async fn generate_plan(
&self,
query: &str,
domain: &str,
user_token: Option<&str>,
) -> Result<crate::types::agent::PlanResponse, AxonFlowError> {
let mut context = HashMap::new();
context.insert("domain".to_string(), serde_json::json!(domain));
let user_token = user_token.unwrap_or("anonymous");
let resp = self
.proxy_llm_call(user_token, query, "generate-plan", context)
.await?;
if let Some(data) = resp.data {
let plan: crate::types::agent::PlanResponse = serde_json::from_value(data)?;
Ok(plan)
} else {
Err(AxonFlowError::ApiError {
status: 500,
message: "empty plan data".to_string(),
})
}
}
pub async fn execute_plan(
&self,
plan_id: &str,
user_token: Option<&str>,
) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
let mut context = HashMap::new();
context.insert("plan_id".to_string(), serde_json::json!(plan_id));
let user_token = user_token.unwrap_or("anonymous");
let resp = self
.proxy_llm_call(user_token, "", "execute-plan", context)
.await?;
if let Some(data) = resp.data {
let exec: crate::types::agent::PlanExecutionResponse = serde_json::from_value(data)?;
Ok(exec)
} else {
Err(AxonFlowError::ApiError {
status: 500,
message: "empty execution data".to_string(),
})
}
}
pub async fn get_plan_status(
&self,
plan_id: &str,
) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
let url = format!("{}/api/v1/plan/{}", self.config.endpoint, encoded_id);
let resp = self.checked_map_get(&url).await?;
Ok(resp.json().await?)
}
pub async fn cancel_plan(
&self,
plan_id: &str,
reason: Option<&str>,
) -> Result<crate::types::agent::CancelPlanResponse, AxonFlowError> {
let req_body = serde_json::json!({
"reason": reason.unwrap_or("user_cancelled"),
});
let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
let url = format!("{}/api/v1/plan/{}/cancel", self.config.endpoint, encoded_id);
let resp = self
.map_http_client
.post(&url)
.json(&req_body)
.send()
.await?;
let resp = Self::check_status(resp).await?;
Ok(resp.json().await?)
}
pub async fn audit_llm_call(
&self,
req: &crate::types::agent::AuditRequest,
) -> Result<crate::types::agent::AuditResult, AxonFlowError> {
let client_id = self.get_effective_client_id();
let mut req_body = serde_json::json!({
"context_id": req.context_id,
"client_id": client_id,
"response_summary": req.response_summary,
"provider": req.provider,
"model": req.model,
"token_usage": {
"prompt_tokens": req.token_usage.prompt_tokens,
"completion_tokens": req.token_usage.completion_tokens,
"total_tokens": req.token_usage.total_tokens,
},
"latency_ms": req.latency_ms,
});
if let Some(meta) = &req.metadata {
req_body["metadata"] = serde_json::to_value(meta)?;
} else {
req_body["metadata"] = serde_json::json!({});
}
let url = format!("{}/api/audit/llm-call", self.config.endpoint);
let resp = self.http_client.post(&url).json(&req_body).send().await?;
let status = resp.status();
let body = resp.text().await?;
if status.is_success() {
let audit_resp: crate::types::agent::AuditResult = serde_json::from_str(&body)?;
Ok(audit_resp)
} else {
Err(AxonFlowError::ApiError {
status: status.as_u16(),
message: body,
})
}
}
fn get_effective_client_id(&self) -> String {
self.config
.client_id
.clone()
.unwrap_or_else(|| "community".to_string())
}
fn build_cache_key(
&self,
request_type: &str,
query: &str,
user_token: &str,
context: &HashMap<String, serde_json::Value>,
) -> String {
let context_hash = if context.is_empty() {
String::new()
} else {
let sorted: std::collections::BTreeMap<_, _> = context.iter().collect();
format!(":{}", serde_json::to_string(&sorted).unwrap_or_default())
};
format!("{}:{}:{}{}", request_type, query, user_token, context_hash)
}
async fn checked_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
let resp = self.http_client.get(url).send().await?;
Self::check_status(resp).await
}
async fn checked_map_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
let resp = self.map_http_client.get(url).send().await?;
Self::check_status(resp).await
}
async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, AxonFlowError> {
if resp.status().is_success() {
Ok(resp)
} else {
let status = resp.status().as_u16();
let message = resp.text().await?;
Err(AxonFlowError::ApiError { status, message })
}
}
async fn execute_with_retry(
&self,
req: &ClientRequest,
) -> Result<ClientResponse, AxonFlowError> {
let mut last_err = None;
for attempt in 0..self.config.retry.max_attempts {
if attempt > 0 {
let delay =
self.config.retry.initial_delay.as_secs_f64() * 2f64.powi((attempt - 1) as i32);
tokio::time::sleep(Duration::from_secs_f64(delay)).await;
}
match self.execute_request(req).await {
Ok(resp) => return Ok(resp),
Err(e) => {
if let AxonFlowError::ApiError { status, .. } = &e {
if *status >= 400
&& *status < 500
&& *status != 429
&& *status != 402
&& *status != 403
{
return Err(e);
}
}
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| {
AxonFlowError::ConfigError("retry loop completed with no attempts".to_string())
}))
}
async fn execute_request(&self, req: &ClientRequest) -> Result<ClientResponse, AxonFlowError> {
let url = format!("{}/api/request", self.config.endpoint);
let resp = self.http_client.post(&url).json(req).send().await?;
let status = resp.status();
let body = resp.text().await?;
if status.is_success() || status.as_u16() == 402 || status.as_u16() == 403 {
let client_resp: ClientResponse = serde_json::from_str(&body)?;
Ok(client_resp)
} else {
Err(AxonFlowError::ApiError {
status: status.as_u16(),
message: body,
})
}
}
}