1use crate::config::{AxonFlowConfig, Mode};
2use crate::error::AxonFlowError;
3use crate::heartbeat::maybe_send_heartbeat;
4use crate::types::agent::{ClientRequest, ClientResponse};
5use base64::engine::general_purpose::STANDARD as BASE64_STD;
6use base64::Engine as _;
7use moka::future::Cache;
8use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, warn};
14
15const LICENSE_KEY_HEADER: &str = "X-License-Key";
16
17#[derive(Clone)]
18pub struct AxonFlowClient {
19 config: AxonFlowConfig,
20 http_client: reqwest::Client,
21 map_http_client: reqwest::Client,
22 cache: Option<Arc<Cache<String, ClientResponse>>>,
23}
24
25impl AxonFlowClient {
26 pub fn new(mut config: AxonFlowConfig) -> Result<Self, AxonFlowError> {
27 if config.retry.max_attempts == 0 {
28 return Err(AxonFlowError::ConfigError(
29 "retry.max_attempts must be at least 1".to_string(),
30 ));
31 }
32
33 if std::env::var("AXONFLOW_TRY").unwrap_or_default() == "1" {
34 config.endpoint = "https://try.getaxonflow.com".to_string();
35 if config.client_id.is_none() {
36 return Err(AxonFlowError::ConfigError(
37 "ClientID is required in try mode (AXONFLOW_TRY=1).".to_string(),
38 ));
39 }
40 }
41
42 if config.client_secret.is_some() && config.client_id.is_none() {
43 warn!("ClientID is required when ClientSecret is set.");
44 }
45
46 let mut headers = HeaderMap::new();
47 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
48 headers.insert(
49 "User-Agent",
50 HeaderValue::from_static(concat!("axonflow-sdk-rust/", env!("CARGO_PKG_VERSION"))),
51 );
52
53 let basic_id = config
57 .client_id
58 .clone()
59 .unwrap_or_else(|| "community".to_string());
60 let basic_secret = config.client_secret.clone().unwrap_or_default();
61 let basic_credentials = BASE64_STD.encode(format!("{}:{}", basic_id, basic_secret));
62 let basic_value = format!("Basic {}", basic_credentials);
63 if let Ok(val) = HeaderValue::from_str(&basic_value) {
64 headers.insert(AUTHORIZATION, val);
65 }
66
67 if let Some(license_key) = &config.license_key {
69 if let Ok(mut val) = HeaderValue::from_str(license_key) {
70 val.set_sensitive(true);
71 headers.insert(LICENSE_KEY_HEADER, val);
72 }
73 }
74
75 let accept_invalid = config.insecure_skip_tls_verify
76 || std::env::var("AXONFLOW_INSECURE_TLS").unwrap_or_default() == "1";
77
78 if accept_invalid {
79 warn!("TLS certificate verification is disabled.");
80 }
81
82 let http_client = reqwest::Client::builder()
83 .timeout(config.timeout)
84 .default_headers(headers.clone())
85 .danger_accept_invalid_certs(accept_invalid)
86 .build()
87 .map_err(AxonFlowError::HttpError)?;
88
89 let map_http_client = reqwest::Client::builder()
90 .timeout(config.map_timeout)
91 .default_headers(headers)
92 .danger_accept_invalid_certs(accept_invalid)
93 .build()
94 .map_err(AxonFlowError::HttpError)?;
95
96 let cache = if config.cache.enabled {
97 Some(Arc::new(
98 Cache::builder().time_to_live(config.cache.ttl).build(),
99 ))
100 } else {
101 None
102 };
103
104 maybe_send_heartbeat(&config.endpoint);
105
106 Ok(Self {
107 config,
108 http_client,
109 map_http_client,
110 cache,
111 })
112 }
113
114 pub async fn proxy_llm_call(
115 &self,
116 user_token: &str,
117 query: &str,
118 request_type: &str,
119 context: HashMap<String, serde_json::Value>,
120 ) -> Result<ClientResponse, AxonFlowError> {
121 let user_token = if user_token.is_empty() {
122 "anonymous"
123 } else {
124 user_token
125 };
126
127 let is_mutation = matches!(
128 request_type,
129 "execute-plan" | "generate-plan" | "cancel-plan" | "update-plan"
130 );
131
132 if !is_mutation {
133 if let Some(cache) = &self.cache {
134 let cache_key = self.build_cache_key(request_type, query, user_token, &context);
135 if let Some(cached) = cache.get(&cache_key).await {
136 debug!("Cache hit for query");
137 return Ok(cached);
138 }
139 }
140 }
141
142 let req = ClientRequest {
143 query: query.to_string(),
144 user_token: user_token.to_string(),
145 client_id: self.config.client_id.clone(),
146 request_type: request_type.to_string(),
147 context,
148 media: None,
149 };
150
151 let resp = if self.config.retry.enabled && !is_mutation {
152 self.execute_with_retry(&req).await
153 } else {
154 self.execute_request(&req).await
155 };
156
157 match resp {
158 Ok(response) => {
159 if response.success && !is_mutation {
160 if let Some(cache) = &self.cache {
161 let cache_key =
162 self.build_cache_key(request_type, query, user_token, &req.context);
163 cache.insert(cache_key, response.clone()).await;
164 }
165 }
166 Ok(response)
167 }
168 Err(e) => {
169 if self.config.mode == Mode::Production && e.is_fail_open_eligible() {
170 debug!("AxonFlow unavailable, failing open: {}", e);
171 Ok(ClientResponse::fail_open(e))
172 } else {
173 Err(e)
174 }
175 }
176 }
177 }
178
179 pub async fn list_connectors(
184 &self,
185 ) -> Result<Vec<crate::types::agent::ConnectorMetadata>, AxonFlowError> {
186 let url = format!("{}/api/v1/connectors", self.config.endpoint);
187 let resp = self.checked_get(&url).await?;
188
189 let body: serde_json::Value = resp.json().await?;
190 let connectors = body["connectors"]
191 .as_array()
192 .ok_or_else(|| AxonFlowError::ApiError {
193 status: 200,
194 message: "response missing 'connectors' field".to_string(),
195 })?;
196
197 let result = serde_json::from_value(serde_json::Value::Array(connectors.clone()))?;
198 Ok(result)
199 }
200
201 pub async fn get_connector(
202 &self,
203 connector_id: &str,
204 ) -> Result<crate::types::agent::ConnectorMetadata, AxonFlowError> {
205 let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
206 let url = format!("{}/api/v1/connectors/{}", self.config.endpoint, encoded_id);
207 let resp = self.checked_get(&url).await?;
208 Ok(resp.json().await?)
209 }
210
211 pub async fn get_connector_health(
212 &self,
213 connector_id: &str,
214 ) -> Result<crate::types::agent::ConnectorHealthStatus, AxonFlowError> {
215 let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
216 let url = format!(
217 "{}/api/v1/connectors/{}/health",
218 self.config.endpoint, encoded_id
219 );
220 let resp = self.checked_get(&url).await?;
221 Ok(resp.json().await?)
222 }
223
224 pub async fn install_connector(
225 &self,
226 req: crate::types::agent::ConnectorInstallRequest,
227 ) -> Result<(), AxonFlowError> {
228 let encoded_id = utf8_percent_encode(&req.connector_id, NON_ALPHANUMERIC);
229 let url = format!(
230 "{}/api/v1/connectors/{}/install",
231 self.config.endpoint, encoded_id
232 );
233 let resp = self.http_client.post(&url).json(&req).send().await?;
234 Self::check_status(resp).await?;
235 Ok(())
236 }
237
238 pub async fn query_connector(
239 &self,
240 user_token: &str,
241 connector_name: &str,
242 query: &str,
243 params: HashMap<String, serde_json::Value>,
244 ) -> Result<crate::types::agent::ConnectorResponse, AxonFlowError> {
245 let mut context = HashMap::new();
249 context.insert("connector".to_string(), serde_json::json!(connector_name));
250 context.insert("params".to_string(), serde_json::json!(params));
251
252 let resp = self
253 .proxy_llm_call(user_token, query, "mcp-query", context)
254 .await?;
255
256 Ok(crate::types::agent::ConnectorResponse {
257 success: resp.success,
258 data: resp.data.unwrap_or(serde_json::Value::Null),
259 error: resp.error,
260 meta: resp.metadata,
261 redacted: false,
262 redacted_fields: Vec::new(),
263 policy_info: None,
264 })
265 }
266
267 pub async fn generate_plan(
272 &self,
273 query: &str,
274 domain: &str,
275 user_token: Option<&str>,
276 ) -> Result<crate::types::agent::PlanResponse, AxonFlowError> {
277 let mut context = HashMap::new();
278 context.insert("domain".to_string(), serde_json::json!(domain));
279 let user_token = user_token.unwrap_or("anonymous");
280
281 let resp = self
282 .proxy_llm_call(user_token, query, "generate-plan", context)
283 .await?;
284
285 if let Some(data) = resp.data {
286 let plan: crate::types::agent::PlanResponse = serde_json::from_value(data)?;
287 Ok(plan)
288 } else {
289 Err(AxonFlowError::ApiError {
290 status: 500,
291 message: "empty plan data".to_string(),
292 })
293 }
294 }
295
296 pub async fn execute_plan(
297 &self,
298 plan_id: &str,
299 user_token: Option<&str>,
300 ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
301 let mut context = HashMap::new();
302 context.insert("plan_id".to_string(), serde_json::json!(plan_id));
303 let user_token = user_token.unwrap_or("anonymous");
304
305 let resp = self
306 .proxy_llm_call(user_token, "", "execute-plan", context)
307 .await?;
308
309 if let Some(data) = resp.data {
310 let exec: crate::types::agent::PlanExecutionResponse = serde_json::from_value(data)?;
311 Ok(exec)
312 } else {
313 Err(AxonFlowError::ApiError {
314 status: 500,
315 message: "empty execution data".to_string(),
316 })
317 }
318 }
319
320 pub async fn get_plan_status(
321 &self,
322 plan_id: &str,
323 ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
324 let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
325 let url = format!("{}/api/v1/plan/{}", self.config.endpoint, encoded_id);
326 let resp = self.checked_map_get(&url).await?;
327 Ok(resp.json().await?)
328 }
329
330 pub async fn cancel_plan(
331 &self,
332 plan_id: &str,
333 reason: Option<&str>,
334 ) -> Result<crate::types::agent::CancelPlanResponse, AxonFlowError> {
335 let req_body = serde_json::json!({
336 "reason": reason.unwrap_or("user_cancelled"),
337 });
338
339 let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
340 let url = format!("{}/api/v1/plan/{}/cancel", self.config.endpoint, encoded_id);
341 let resp = self
342 .map_http_client
343 .post(&url)
344 .json(&req_body)
345 .send()
346 .await?;
347 let resp = Self::check_status(resp).await?;
348 Ok(resp.json().await?)
349 }
350
351 pub async fn audit_llm_call(
352 &self,
353 req: &crate::types::agent::AuditRequest,
354 ) -> Result<crate::types::agent::AuditResult, AxonFlowError> {
355 let client_id = self.get_effective_client_id();
356
357 let mut req_body = serde_json::json!({
358 "context_id": req.context_id,
359 "client_id": client_id,
360 "response_summary": req.response_summary,
361 "provider": req.provider,
362 "model": req.model,
363 "token_usage": {
364 "prompt_tokens": req.token_usage.prompt_tokens,
365 "completion_tokens": req.token_usage.completion_tokens,
366 "total_tokens": req.token_usage.total_tokens,
367 },
368 "latency_ms": req.latency_ms,
369 });
370
371 if let Some(meta) = &req.metadata {
372 req_body["metadata"] = serde_json::to_value(meta)?;
373 } else {
374 req_body["metadata"] = serde_json::json!({});
375 }
376
377 let url = format!("{}/api/audit/llm-call", self.config.endpoint);
378 let resp = self.http_client.post(&url).json(&req_body).send().await?;
379
380 let status = resp.status();
381 let body = resp.text().await?;
382
383 if status.is_success() {
384 let audit_resp: crate::types::agent::AuditResult = serde_json::from_str(&body)?;
385 Ok(audit_resp)
386 } else {
387 Err(AxonFlowError::ApiError {
388 status: status.as_u16(),
389 message: body,
390 })
391 }
392 }
393
394 fn get_effective_client_id(&self) -> String {
399 self.config
400 .client_id
401 .clone()
402 .unwrap_or_else(|| "community".to_string())
403 }
404
405 fn build_cache_key(
406 &self,
407 request_type: &str,
408 query: &str,
409 user_token: &str,
410 context: &HashMap<String, serde_json::Value>,
411 ) -> String {
412 let context_hash = if context.is_empty() {
413 String::new()
414 } else {
415 let sorted: std::collections::BTreeMap<_, _> = context.iter().collect();
416 format!(":{}", serde_json::to_string(&sorted).unwrap_or_default())
417 };
418 format!("{}:{}:{}{}", request_type, query, user_token, context_hash)
419 }
420
421 async fn checked_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
422 let resp = self.http_client.get(url).send().await?;
423 Self::check_status(resp).await
424 }
425
426 async fn checked_map_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
427 let resp = self.map_http_client.get(url).send().await?;
428 Self::check_status(resp).await
429 }
430
431 async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, AxonFlowError> {
432 if resp.status().is_success() {
433 Ok(resp)
434 } else {
435 let status = resp.status().as_u16();
436 let message = resp.text().await?;
437 Err(AxonFlowError::ApiError { status, message })
438 }
439 }
440
441 async fn execute_with_retry(
442 &self,
443 req: &ClientRequest,
444 ) -> Result<ClientResponse, AxonFlowError> {
445 let mut last_err = None;
446
447 for attempt in 0..self.config.retry.max_attempts {
448 if attempt > 0 {
449 let delay =
450 self.config.retry.initial_delay.as_secs_f64() * 2f64.powi((attempt - 1) as i32);
451 tokio::time::sleep(Duration::from_secs_f64(delay)).await;
452 }
453
454 match self.execute_request(req).await {
455 Ok(resp) => return Ok(resp),
456 Err(e) => {
457 if let AxonFlowError::ApiError { status, .. } = &e {
458 if *status >= 400
459 && *status < 500
460 && *status != 429
461 && *status != 402
462 && *status != 403
463 {
464 return Err(e);
465 }
466 }
467 last_err = Some(e);
468 }
469 }
470 }
471
472 Err(last_err.unwrap_or_else(|| {
473 AxonFlowError::ConfigError("retry loop completed with no attempts".to_string())
474 }))
475 }
476
477 async fn execute_request(&self, req: &ClientRequest) -> Result<ClientResponse, AxonFlowError> {
478 let url = format!("{}/api/request", self.config.endpoint);
479 let resp = self.http_client.post(&url).json(req).send().await?;
480
481 let status = resp.status();
482 let body = resp.text().await?;
483
484 if status.is_success() || status.as_u16() == 402 || status.as_u16() == 403 {
485 let client_resp: ClientResponse = serde_json::from_str(&body)?;
486 Ok(client_resp)
487 } else {
488 Err(AxonFlowError::ApiError {
489 status: status.as_u16(),
490 message: body,
491 })
492 }
493 }
494}