Skip to main content

axonflow_sdk_rust/
client.rs

1use crate::config::{AxonFlowConfig, Mode};
2use crate::error::AxonFlowError;
3use crate::heartbeat::maybe_send_heartbeat;
4use crate::types::agent::{ClientRequest, ClientResponse};
5use base64::engine::general_purpose::STANDARD as BASE64_STD;
6use base64::Engine as _;
7use moka::future::Cache;
8use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, warn};
14
15const LICENSE_KEY_HEADER: &str = "X-License-Key";
16
17// Path-segment encode set: mirrors Go's `url.PathEscape` semantics so
18// percent-encoding parity holds across SDKs. Keeps RFC 3986 unreserved
19// characters (alphanum, `-`, `.`, `_`, `~`) unencoded; escapes path-
20// significant chars (`/`, `?`, `#`, `%`) plus controls and characters
21// that web infra commonly rejects (` "<>``\\{}`).
22//
23// Replaces the previous `NON_ALPHANUMERIC` usage which over-escaped
24// underscores and dashes — observable as `dec_wf1_step2` becoming
25// `dec%5Fwf1%5Fstep2` in the explain path, and `amadeus-travel`
26// becoming `amadeus%2Dtravel` for connector lookups. Gorilla mux
27// percent-decodes path segments so the platform happened to tolerate
28// the over-escaped form, but the wire was wrong and any stricter
29// router would 404. Found while wiring `decisions::explain_decision`.
30pub(crate) const PATH_SEGMENT: &AsciiSet = &CONTROLS
31    .add(b' ')
32    .add(b'"')
33    .add(b'<')
34    .add(b'>')
35    .add(b'`')
36    .add(b'\\')
37    .add(b'{')
38    .add(b'}')
39    .add(b'#')
40    .add(b'?')
41    .add(b'/')
42    .add(b'%');
43
44#[derive(Clone)]
45pub struct AxonFlowClient {
46    config: AxonFlowConfig,
47    http_client: reqwest::Client,
48    map_http_client: reqwest::Client,
49    cache: Option<Arc<Cache<String, ClientResponse>>>,
50}
51
52impl AxonFlowClient {
53    pub fn new(mut config: AxonFlowConfig) -> Result<Self, AxonFlowError> {
54        if config.retry.max_attempts == 0 {
55            return Err(AxonFlowError::ConfigError(
56                "retry.max_attempts must be at least 1".to_string(),
57            ));
58        }
59
60        if std::env::var("AXONFLOW_TRY").unwrap_or_default() == "1" {
61            config.endpoint = "https://try.getaxonflow.com".to_string();
62            if config.client_id.is_none() {
63                return Err(AxonFlowError::ConfigError(
64                    "ClientID is required in try mode (AXONFLOW_TRY=1).".to_string(),
65                ));
66            }
67        }
68
69        if config.client_secret.is_some() && config.client_id.is_none() {
70            warn!("ClientID is required when ClientSecret is set.");
71        }
72
73        let mut headers = HeaderMap::new();
74        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
75        headers.insert(
76            "User-Agent",
77            HeaderValue::from_static(concat!("axonflow-sdk-rust/", env!("CARGO_PKG_VERSION"))),
78        );
79        // ADR-050 §4: every governed request to the agent carries
80        // X-Axonflow-Client so the agent can derive request scope (sdk)
81        // and validate against the token's aud.scope via HasScope().
82        // Sourced from CARGO_PKG_VERSION; no env override (the consumer
83        // doesn't get to spoof its own client identity to the agent).
84        headers.insert(
85            "X-Axonflow-Client",
86            HeaderValue::from_static(concat!("sdk-rust/", env!("CARGO_PKG_VERSION"))),
87        );
88
89        // HTTP Basic auth: "Basic base64(client_id:client_secret)".
90        // When neither is configured, default to the community tenant —
91        // matches the cross-SDK contract (see axonflow-sdk-go selfhosted_auth_headers_test.go).
92        let basic_id = config
93            .client_id
94            .clone()
95            .unwrap_or_else(|| "community".to_string());
96        let basic_secret = config.client_secret.clone().unwrap_or_default();
97        let basic_credentials = BASE64_STD.encode(format!("{}:{}", basic_id, basic_secret));
98        let basic_value = format!("Basic {}", basic_credentials);
99        if let Ok(val) = HeaderValue::from_str(&basic_value) {
100            headers.insert(AUTHORIZATION, val);
101        }
102
103        // X-Client-ID (v9): server-side identity decisions don't have to
104        // re-decode Basic auth. The agent's apiAuthMiddleware overwrites
105        // the header with its auth-derived value, so caller-supplied
106        // values are harmless (no spoofing surface).
107        if let Ok(val) = HeaderValue::from_str(&basic_id) {
108            headers.insert("X-Client-ID", val);
109        }
110
111        // Enterprise license key — sent only when configured.
112        if let Some(license_key) = &config.license_key {
113            if let Ok(mut val) = HeaderValue::from_str(license_key) {
114                val.set_sensitive(true);
115                headers.insert(LICENSE_KEY_HEADER, val);
116            }
117        }
118
119        let accept_invalid = config.insecure_skip_tls_verify
120            || std::env::var("AXONFLOW_INSECURE_TLS").unwrap_or_default() == "1";
121
122        if accept_invalid {
123            warn!("TLS certificate verification is disabled.");
124        }
125
126        let http_client = reqwest::Client::builder()
127            .timeout(config.timeout)
128            .default_headers(headers.clone())
129            .danger_accept_invalid_certs(accept_invalid)
130            .build()
131            .map_err(AxonFlowError::HttpError)?;
132
133        let map_http_client = reqwest::Client::builder()
134            .timeout(config.map_timeout)
135            .default_headers(headers)
136            .danger_accept_invalid_certs(accept_invalid)
137            .build()
138            .map_err(AxonFlowError::HttpError)?;
139
140        let cache = if config.cache.enabled {
141            Some(Arc::new(
142                Cache::builder().time_to_live(config.cache.ttl).build(),
143            ))
144        } else {
145            None
146        };
147
148        maybe_send_heartbeat(&config.endpoint, &config.mode);
149
150        Ok(Self {
151            config,
152            http_client,
153            map_http_client,
154            cache,
155        })
156    }
157
158    pub async fn proxy_llm_call(
159        &self,
160        user_token: &str,
161        query: &str,
162        request_type: &str,
163        context: HashMap<String, serde_json::Value>,
164    ) -> Result<ClientResponse, AxonFlowError> {
165        let user_token = if user_token.is_empty() {
166            "anonymous"
167        } else {
168            user_token
169        };
170
171        let is_mutation = matches!(
172            request_type,
173            "execute-plan" | "generate-plan" | "cancel-plan" | "update-plan"
174        );
175
176        if !is_mutation {
177            if let Some(cache) = &self.cache {
178                let cache_key = self.build_cache_key(request_type, query, user_token, &context);
179                if let Some(cached) = cache.get(&cache_key).await {
180                    debug!("Cache hit for query");
181                    return Ok(cached);
182                }
183            }
184        }
185
186        let req = ClientRequest {
187            query: query.to_string(),
188            user_token: user_token.to_string(),
189            client_id: self.config.client_id.clone(),
190            request_type: request_type.to_string(),
191            context,
192            media: None,
193        };
194
195        let resp = if self.config.retry.enabled && !is_mutation {
196            self.execute_with_retry(&req).await
197        } else {
198            self.execute_request(&req).await
199        };
200
201        match resp {
202            Ok(response) => {
203                if response.success && !is_mutation {
204                    if let Some(cache) = &self.cache {
205                        let cache_key =
206                            self.build_cache_key(request_type, query, user_token, &req.context);
207                        cache.insert(cache_key, response.clone()).await;
208                    }
209                }
210                Ok(response)
211            }
212            Err(e) => {
213                if self.config.mode == Mode::Production && e.is_fail_open_eligible() {
214                    debug!("AxonFlow unavailable, failing open: {}", e);
215                    Ok(ClientResponse::fail_open(e))
216                } else {
217                    Err(e)
218                }
219            }
220        }
221    }
222
223    // ============================================================================
224    // MCP Connector Management
225    // ============================================================================
226
227    pub async fn list_connectors(
228        &self,
229    ) -> Result<Vec<crate::types::agent::ConnectorMetadata>, AxonFlowError> {
230        let url = format!("{}/api/v1/connectors", self.config.endpoint);
231        let resp = self.checked_get(&url).await?;
232
233        let body: serde_json::Value = resp.json().await?;
234        let connectors = body["connectors"]
235            .as_array()
236            .ok_or_else(|| AxonFlowError::ApiError {
237                status: 200,
238                message: "response missing 'connectors' field".to_string(),
239            })?;
240
241        let result = serde_json::from_value(serde_json::Value::Array(connectors.clone()))?;
242        Ok(result)
243    }
244
245    pub async fn get_connector(
246        &self,
247        connector_id: &str,
248    ) -> Result<crate::types::agent::ConnectorMetadata, AxonFlowError> {
249        let encoded_id = utf8_percent_encode(connector_id, PATH_SEGMENT);
250        let url = format!("{}/api/v1/connectors/{}", self.config.endpoint, encoded_id);
251        let resp = self.checked_get(&url).await?;
252        Ok(resp.json().await?)
253    }
254
255    pub async fn get_connector_health(
256        &self,
257        connector_id: &str,
258    ) -> Result<crate::types::agent::ConnectorHealthStatus, AxonFlowError> {
259        let encoded_id = utf8_percent_encode(connector_id, PATH_SEGMENT);
260        let url = format!(
261            "{}/api/v1/connectors/{}/health",
262            self.config.endpoint, encoded_id
263        );
264        let resp = self.checked_get(&url).await?;
265        Ok(resp.json().await?)
266    }
267
268    pub async fn install_connector(
269        &self,
270        req: crate::types::agent::ConnectorInstallRequest,
271    ) -> Result<(), AxonFlowError> {
272        let encoded_id = utf8_percent_encode(&req.connector_id, PATH_SEGMENT);
273        let url = format!(
274            "{}/api/v1/connectors/{}/install",
275            self.config.endpoint, encoded_id
276        );
277        let resp = self.http_client.post(&url).json(&req).send().await?;
278        Self::check_status(resp).await?;
279        Ok(())
280    }
281
282    pub async fn query_connector(
283        &self,
284        user_token: &str,
285        connector_name: &str,
286        query: &str,
287        params: HashMap<String, serde_json::Value>,
288    ) -> Result<crate::types::agent::ConnectorResponse, AxonFlowError> {
289        // Connector queries are dispatched through the agent's proxy endpoint
290        // with request_type=mcp-query — there is no standalone /api/v1/query.
291        // Mirror the Go SDK's QueryConnector contract.
292        let mut context = HashMap::new();
293        context.insert("connector".to_string(), serde_json::json!(connector_name));
294        context.insert("params".to_string(), serde_json::json!(params));
295
296        let resp = self
297            .proxy_llm_call(user_token, query, "mcp-query", context)
298            .await?;
299
300        Ok(crate::types::agent::ConnectorResponse {
301            success: resp.success,
302            data: resp.data.unwrap_or(serde_json::Value::Null),
303            error: resp.error,
304            meta: resp.metadata,
305            redacted: false,
306            redacted_fields: Vec::new(),
307            policy_info: None,
308        })
309    }
310
311    // ============================================================================
312    // Multi-Agent Planning (MAP)
313    // ============================================================================
314
315    pub async fn generate_plan(
316        &self,
317        query: &str,
318        domain: &str,
319        user_token: Option<&str>,
320    ) -> Result<crate::types::agent::PlanResponse, AxonFlowError> {
321        let mut context = HashMap::new();
322        context.insert("domain".to_string(), serde_json::json!(domain));
323        let user_token = user_token.unwrap_or("anonymous");
324
325        let resp = self
326            .proxy_llm_call(user_token, query, "generate-plan", context)
327            .await?;
328
329        if let Some(data) = resp.data {
330            let plan: crate::types::agent::PlanResponse = serde_json::from_value(data)?;
331            Ok(plan)
332        } else {
333            Err(AxonFlowError::ApiError {
334                status: 500,
335                message: "empty plan data".to_string(),
336            })
337        }
338    }
339
340    pub async fn execute_plan(
341        &self,
342        plan_id: &str,
343        user_token: Option<&str>,
344    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
345        let mut context = HashMap::new();
346        context.insert("plan_id".to_string(), serde_json::json!(plan_id));
347        let user_token = user_token.unwrap_or("anonymous");
348
349        let resp = self
350            .proxy_llm_call(user_token, "", "execute-plan", context)
351            .await?;
352
353        if let Some(data) = resp.data {
354            let exec: crate::types::agent::PlanExecutionResponse = serde_json::from_value(data)?;
355            Ok(exec)
356        } else {
357            Err(AxonFlowError::ApiError {
358                status: 500,
359                message: "empty execution data".to_string(),
360            })
361        }
362    }
363
364    pub async fn get_plan_status(
365        &self,
366        plan_id: &str,
367    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
368        let encoded_id = utf8_percent_encode(plan_id, PATH_SEGMENT);
369        let url = format!("{}/api/v1/plan/{}", self.config.endpoint, encoded_id);
370        let resp = self.checked_map_get(&url).await?;
371        Ok(resp.json().await?)
372    }
373
374    pub async fn cancel_plan(
375        &self,
376        plan_id: &str,
377        reason: Option<&str>,
378    ) -> Result<crate::types::agent::CancelPlanResponse, AxonFlowError> {
379        let req_body = serde_json::json!({
380            "reason": reason.unwrap_or("user_cancelled"),
381        });
382
383        let encoded_id = utf8_percent_encode(plan_id, PATH_SEGMENT);
384        let url = format!("{}/api/v1/plan/{}/cancel", self.config.endpoint, encoded_id);
385        let resp = self
386            .map_http_client
387            .post(&url)
388            .json(&req_body)
389            .send()
390            .await?;
391        let resp = Self::check_status(resp).await?;
392        Ok(resp.json().await?)
393    }
394
395    pub async fn audit_llm_call(
396        &self,
397        req: &crate::types::agent::AuditRequest,
398    ) -> Result<crate::types::agent::AuditResult, AxonFlowError> {
399        let client_id = self.get_effective_client_id();
400
401        let mut req_body = serde_json::json!({
402            "context_id": req.context_id,
403            "client_id": client_id,
404            "response_summary": req.response_summary,
405            "provider": req.provider,
406            "model": req.model,
407            "token_usage": {
408                "prompt_tokens": req.token_usage.prompt_tokens,
409                "completion_tokens": req.token_usage.completion_tokens,
410                "total_tokens": req.token_usage.total_tokens,
411            },
412            "latency_ms": req.latency_ms,
413        });
414
415        if let Some(meta) = &req.metadata {
416            req_body["metadata"] = serde_json::to_value(meta)?;
417        } else {
418            req_body["metadata"] = serde_json::json!({});
419        }
420
421        let url = format!("{}/api/audit/llm-call", self.config.endpoint);
422        let resp = self.http_client.post(&url).json(&req_body).send().await?;
423
424        let status = resp.status();
425        let body = resp.text().await?;
426
427        if status.is_success() {
428            let audit_resp: crate::types::agent::AuditResult = serde_json::from_str(&body)?;
429            Ok(audit_resp)
430        } else {
431            Err(AxonFlowError::ApiError {
432                status: status.as_u16(),
433                message: body,
434            })
435        }
436    }
437
438    // ============================================================================
439    // Private helpers
440    // ============================================================================
441
442    fn get_effective_client_id(&self) -> String {
443        self.config
444            .client_id
445            .clone()
446            .unwrap_or_else(|| "community".to_string())
447    }
448
449    fn build_cache_key(
450        &self,
451        request_type: &str,
452        query: &str,
453        user_token: &str,
454        context: &HashMap<String, serde_json::Value>,
455    ) -> String {
456        let context_hash = if context.is_empty() {
457            String::new()
458        } else {
459            let sorted: std::collections::BTreeMap<_, _> = context.iter().collect();
460            format!(":{}", serde_json::to_string(&sorted).unwrap_or_default())
461        };
462        format!("{}:{}:{}{}", request_type, query, user_token, context_hash)
463    }
464
465    /// Endpoint URL the client is configured against.
466    /// Crate-internal accessor for sibling modules (e.g. `decisions.rs`)
467    /// that need to build absolute URLs without exposing `config`.
468    pub(crate) fn endpoint(&self) -> &str {
469        &self.config.endpoint
470    }
471
472    pub(crate) async fn checked_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
473        let resp = self.http_client.get(url).send().await?;
474        Self::check_status(resp).await
475    }
476
477    /// Crate-internal POST that serializes `body` as JSON and translates
478    /// non-2xx into [`AxonFlowError::ApiError`] — the symmetric helper to
479    /// [`checked_get`](Self::checked_get). Used by sibling modules
480    /// (e.g. `hitl`) that POST a typed payload and don't need to branch
481    /// on specific status codes before falling back to the generic error
482    /// path.
483    pub(crate) async fn checked_post_json<T: serde::Serialize + ?Sized>(
484        &self,
485        url: &str,
486        body: &T,
487    ) -> Result<reqwest::Response, AxonFlowError> {
488        let resp = self.http_client.post(url).json(body).send().await?;
489        Self::check_status(resp).await
490    }
491
492    /// Crate-internal GET that returns the raw response without translating
493    /// non-2xx into [`AxonFlowError::ApiError`]. Lets sibling modules branch
494    /// on specific status codes (e.g. parse a 429 V1 upgrade envelope into
495    /// [`AxonFlowError::RateLimited`]) before falling back to the generic
496    /// error path.
497    pub(crate) async fn raw_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
498        Ok(self.http_client.get(url).send().await?)
499    }
500
501    async fn checked_map_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
502        let resp = self.map_http_client.get(url).send().await?;
503        Self::check_status(resp).await
504    }
505
506    async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, AxonFlowError> {
507        if resp.status().is_success() {
508            Ok(resp)
509        } else {
510            let status = resp.status().as_u16();
511            let message = resp.text().await?;
512            Err(AxonFlowError::ApiError { status, message })
513        }
514    }
515
516    /// Retry the request with exponential backoff, honoring the
517    /// SDK-wide retry contract.
518    ///
519    /// **Retried status codes:**
520    /// - 5xx — server-side failures (treated as transient).
521    /// - 429 — rate-limit responses (transient by definition).
522    /// - Transport-level errors (connection refused, DNS, TLS) —
523    ///   surfaced as non-`ApiError` variants of [`AxonFlowError`];
524    ///   the `if let AxonFlowError::ApiError { .. }` guard doesn't
525    ///   match them, so they fall through to `last_err = Some(e)` and
526    ///   retry on the next iteration.
527    ///
528    /// **Terminal status codes (early `return Err(e)`):**
529    /// - 401 — auth failure. Retrying with the same invalid
530    ///   credential just compounds the storm on the agent. See
531    ///   issue [#2275](https://github.com/getaxonflow/axonflow-enterprise/issues/2275)
532    ///   for the customer-observed retry loop that motivated the
533    ///   regression-locking test `test_401_not_retried_issue_2275`.
534    /// - 400, 404, 405, 406, 408, 409, 410, 411, 412, 413, 414, 415,
535    ///   416, 417, 418, 421, 422, 423, 424, 425, 426, 428, 431, 451 —
536    ///   every other 4xx that isn't in the `{429, 402, 403}` allowlist.
537    ///
538    /// **Caveat on 402/403:** `execute_request` returns 402 + 403 as
539    /// `Ok(client_resp)` because those are SUCCESS responses carrying
540    /// policy/quota envelope data — not errors. They never reach this
541    /// function as `Err`, so the `*status != 402` and `*status != 403`
542    /// clauses below are functionally dead in current code. They're
543    /// kept as intent-preserving belt-and-suspenders for any future
544    /// refactor that converts 402/403 back to `Err`.
545    ///
546    /// See `CHANGELOG.md` for the contract's history.
547    async fn execute_with_retry(
548        &self,
549        req: &ClientRequest,
550    ) -> Result<ClientResponse, AxonFlowError> {
551        let mut last_err = None;
552
553        for attempt in 0..self.config.retry.max_attempts {
554            if attempt > 0 {
555                let delay =
556                    self.config.retry.initial_delay.as_secs_f64() * 2f64.powi((attempt - 1) as i32);
557                tokio::time::sleep(Duration::from_secs_f64(delay)).await;
558            }
559
560            match self.execute_request(req).await {
561                Ok(resp) => return Ok(resp),
562                Err(e) => {
563                    if let AxonFlowError::ApiError { status, .. } = &e {
564                        // Retry allowlist: any 4xx NOT in {429, 402, 403} is
565                        // terminal. 5xx always retries (falls through to the
566                        // `last_err = Some(e)` path below).
567                        //
568                        // 402/403 NEVER reach this branch as `Err`: see
569                        // `execute_request` at line 586 — those statuses
570                        // return as `Ok(client_resp)` because they carry
571                        // policy/quota envelope data. The `*status != 402`
572                        // and `*status != 403` clauses are intentional
573                        // belt-and-suspenders for a hypothetical future
574                        // refactor that errors on those statuses.
575                        if *status >= 400
576                            && *status < 500
577                            && *status != 429
578                            && *status != 402
579                            && *status != 403
580                        {
581                            return Err(e);
582                        }
583                    }
584                    last_err = Some(e);
585                }
586            }
587        }
588
589        Err(last_err.unwrap_or_else(|| {
590            AxonFlowError::ConfigError("retry loop completed with no attempts".to_string())
591        }))
592    }
593
594    async fn execute_request(&self, req: &ClientRequest) -> Result<ClientResponse, AxonFlowError> {
595        let url = format!("{}/api/request", self.config.endpoint);
596        let resp = self.http_client.post(&url).json(req).send().await?;
597
598        let status = resp.status();
599        let body = resp.text().await?;
600
601        if status.is_success() || status.as_u16() == 402 || status.as_u16() == 403 {
602            let client_resp: ClientResponse = serde_json::from_str(&body)?;
603            Ok(client_resp)
604        } else {
605            Err(AxonFlowError::ApiError {
606                status: status.as_u16(),
607                message: body,
608            })
609        }
610    }
611}