Skip to main content

axonflow_sdk_rust/
client.rs

1use crate::config::{AxonFlowConfig, Mode};
2use crate::error::AxonFlowError;
3use crate::heartbeat::maybe_send_heartbeat;
4use crate::types::agent::{ClientRequest, ClientResponse};
5use base64::engine::general_purpose::STANDARD as BASE64_STD;
6use base64::Engine as _;
7use moka::future::Cache;
8use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, warn};
14
15const LICENSE_KEY_HEADER: &str = "X-License-Key";
16
17// Path-segment encode set: mirrors Go's `url.PathEscape` semantics so
18// percent-encoding parity holds across SDKs. Keeps RFC 3986 unreserved
19// characters (alphanum, `-`, `.`, `_`, `~`) unencoded; escapes path-
20// significant chars (`/`, `?`, `#`, `%`) plus controls and characters
21// that web infra commonly rejects (` "<>``\\{}`).
22//
23// Replaces the previous `NON_ALPHANUMERIC` usage which over-escaped
24// underscores and dashes — observable as `dec_wf1_step2` becoming
25// `dec%5Fwf1%5Fstep2` in the explain path, and `amadeus-travel`
26// becoming `amadeus%2Dtravel` for connector lookups. Gorilla mux
27// percent-decodes path segments so the platform happened to tolerate
28// the over-escaped form, but the wire was wrong and any stricter
29// router would 404. Found while wiring `decisions::explain_decision`.
30pub(crate) const PATH_SEGMENT: &AsciiSet = &CONTROLS
31    .add(b' ')
32    .add(b'"')
33    .add(b'<')
34    .add(b'>')
35    .add(b'`')
36    .add(b'\\')
37    .add(b'{')
38    .add(b'}')
39    .add(b'#')
40    .add(b'?')
41    .add(b'/')
42    .add(b'%');
43
44#[derive(Clone)]
45pub struct AxonFlowClient {
46    config: AxonFlowConfig,
47    http_client: reqwest::Client,
48    map_http_client: reqwest::Client,
49    cache: Option<Arc<Cache<String, ClientResponse>>>,
50}
51
52impl AxonFlowClient {
53    pub fn new(mut config: AxonFlowConfig) -> Result<Self, AxonFlowError> {
54        if config.retry.max_attempts == 0 {
55            return Err(AxonFlowError::ConfigError(
56                "retry.max_attempts must be at least 1".to_string(),
57            ));
58        }
59
60        if std::env::var("AXONFLOW_TRY").unwrap_or_default() == "1" {
61            config.endpoint = "https://try.getaxonflow.com".to_string();
62            if config.client_id.is_none() {
63                return Err(AxonFlowError::ConfigError(
64                    "ClientID is required in try mode (AXONFLOW_TRY=1).".to_string(),
65                ));
66            }
67        }
68
69        if config.client_secret.is_some() && config.client_id.is_none() {
70            warn!("ClientID is required when ClientSecret is set.");
71        }
72
73        let mut headers = HeaderMap::new();
74        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
75        headers.insert(
76            "User-Agent",
77            HeaderValue::from_static(concat!("axonflow-sdk-rust/", env!("CARGO_PKG_VERSION"))),
78        );
79
80        // HTTP Basic auth: "Basic base64(client_id:client_secret)".
81        // When neither is configured, default to the community tenant —
82        // matches the cross-SDK contract (see axonflow-sdk-go selfhosted_auth_headers_test.go).
83        let basic_id = config
84            .client_id
85            .clone()
86            .unwrap_or_else(|| "community".to_string());
87        let basic_secret = config.client_secret.clone().unwrap_or_default();
88        let basic_credentials = BASE64_STD.encode(format!("{}:{}", basic_id, basic_secret));
89        let basic_value = format!("Basic {}", basic_credentials);
90        if let Ok(val) = HeaderValue::from_str(&basic_value) {
91            headers.insert(AUTHORIZATION, val);
92        }
93
94        // Enterprise license key — sent only when configured.
95        if let Some(license_key) = &config.license_key {
96            if let Ok(mut val) = HeaderValue::from_str(license_key) {
97                val.set_sensitive(true);
98                headers.insert(LICENSE_KEY_HEADER, val);
99            }
100        }
101
102        let accept_invalid = config.insecure_skip_tls_verify
103            || std::env::var("AXONFLOW_INSECURE_TLS").unwrap_or_default() == "1";
104
105        if accept_invalid {
106            warn!("TLS certificate verification is disabled.");
107        }
108
109        let http_client = reqwest::Client::builder()
110            .timeout(config.timeout)
111            .default_headers(headers.clone())
112            .danger_accept_invalid_certs(accept_invalid)
113            .build()
114            .map_err(AxonFlowError::HttpError)?;
115
116        let map_http_client = reqwest::Client::builder()
117            .timeout(config.map_timeout)
118            .default_headers(headers)
119            .danger_accept_invalid_certs(accept_invalid)
120            .build()
121            .map_err(AxonFlowError::HttpError)?;
122
123        let cache = if config.cache.enabled {
124            Some(Arc::new(
125                Cache::builder().time_to_live(config.cache.ttl).build(),
126            ))
127        } else {
128            None
129        };
130
131        maybe_send_heartbeat(&config.endpoint, &config.mode);
132
133        Ok(Self {
134            config,
135            http_client,
136            map_http_client,
137            cache,
138        })
139    }
140
141    pub async fn proxy_llm_call(
142        &self,
143        user_token: &str,
144        query: &str,
145        request_type: &str,
146        context: HashMap<String, serde_json::Value>,
147    ) -> Result<ClientResponse, AxonFlowError> {
148        let user_token = if user_token.is_empty() {
149            "anonymous"
150        } else {
151            user_token
152        };
153
154        let is_mutation = matches!(
155            request_type,
156            "execute-plan" | "generate-plan" | "cancel-plan" | "update-plan"
157        );
158
159        if !is_mutation {
160            if let Some(cache) = &self.cache {
161                let cache_key = self.build_cache_key(request_type, query, user_token, &context);
162                if let Some(cached) = cache.get(&cache_key).await {
163                    debug!("Cache hit for query");
164                    return Ok(cached);
165                }
166            }
167        }
168
169        let req = ClientRequest {
170            query: query.to_string(),
171            user_token: user_token.to_string(),
172            client_id: self.config.client_id.clone(),
173            request_type: request_type.to_string(),
174            context,
175            media: None,
176        };
177
178        let resp = if self.config.retry.enabled && !is_mutation {
179            self.execute_with_retry(&req).await
180        } else {
181            self.execute_request(&req).await
182        };
183
184        match resp {
185            Ok(response) => {
186                if response.success && !is_mutation {
187                    if let Some(cache) = &self.cache {
188                        let cache_key =
189                            self.build_cache_key(request_type, query, user_token, &req.context);
190                        cache.insert(cache_key, response.clone()).await;
191                    }
192                }
193                Ok(response)
194            }
195            Err(e) => {
196                if self.config.mode == Mode::Production && e.is_fail_open_eligible() {
197                    debug!("AxonFlow unavailable, failing open: {}", e);
198                    Ok(ClientResponse::fail_open(e))
199                } else {
200                    Err(e)
201                }
202            }
203        }
204    }
205
206    // ============================================================================
207    // MCP Connector Management
208    // ============================================================================
209
210    pub async fn list_connectors(
211        &self,
212    ) -> Result<Vec<crate::types::agent::ConnectorMetadata>, AxonFlowError> {
213        let url = format!("{}/api/v1/connectors", self.config.endpoint);
214        let resp = self.checked_get(&url).await?;
215
216        let body: serde_json::Value = resp.json().await?;
217        let connectors = body["connectors"]
218            .as_array()
219            .ok_or_else(|| AxonFlowError::ApiError {
220                status: 200,
221                message: "response missing 'connectors' field".to_string(),
222            })?;
223
224        let result = serde_json::from_value(serde_json::Value::Array(connectors.clone()))?;
225        Ok(result)
226    }
227
228    pub async fn get_connector(
229        &self,
230        connector_id: &str,
231    ) -> Result<crate::types::agent::ConnectorMetadata, AxonFlowError> {
232        let encoded_id = utf8_percent_encode(connector_id, PATH_SEGMENT);
233        let url = format!("{}/api/v1/connectors/{}", self.config.endpoint, encoded_id);
234        let resp = self.checked_get(&url).await?;
235        Ok(resp.json().await?)
236    }
237
238    pub async fn get_connector_health(
239        &self,
240        connector_id: &str,
241    ) -> Result<crate::types::agent::ConnectorHealthStatus, AxonFlowError> {
242        let encoded_id = utf8_percent_encode(connector_id, PATH_SEGMENT);
243        let url = format!(
244            "{}/api/v1/connectors/{}/health",
245            self.config.endpoint, encoded_id
246        );
247        let resp = self.checked_get(&url).await?;
248        Ok(resp.json().await?)
249    }
250
251    pub async fn install_connector(
252        &self,
253        req: crate::types::agent::ConnectorInstallRequest,
254    ) -> Result<(), AxonFlowError> {
255        let encoded_id = utf8_percent_encode(&req.connector_id, PATH_SEGMENT);
256        let url = format!(
257            "{}/api/v1/connectors/{}/install",
258            self.config.endpoint, encoded_id
259        );
260        let resp = self.http_client.post(&url).json(&req).send().await?;
261        Self::check_status(resp).await?;
262        Ok(())
263    }
264
265    pub async fn query_connector(
266        &self,
267        user_token: &str,
268        connector_name: &str,
269        query: &str,
270        params: HashMap<String, serde_json::Value>,
271    ) -> Result<crate::types::agent::ConnectorResponse, AxonFlowError> {
272        // Connector queries are dispatched through the agent's proxy endpoint
273        // with request_type=mcp-query — there is no standalone /api/v1/query.
274        // Mirror the Go SDK's QueryConnector contract.
275        let mut context = HashMap::new();
276        context.insert("connector".to_string(), serde_json::json!(connector_name));
277        context.insert("params".to_string(), serde_json::json!(params));
278
279        let resp = self
280            .proxy_llm_call(user_token, query, "mcp-query", context)
281            .await?;
282
283        Ok(crate::types::agent::ConnectorResponse {
284            success: resp.success,
285            data: resp.data.unwrap_or(serde_json::Value::Null),
286            error: resp.error,
287            meta: resp.metadata,
288            redacted: false,
289            redacted_fields: Vec::new(),
290            policy_info: None,
291        })
292    }
293
294    // ============================================================================
295    // Multi-Agent Planning (MAP)
296    // ============================================================================
297
298    pub async fn generate_plan(
299        &self,
300        query: &str,
301        domain: &str,
302        user_token: Option<&str>,
303    ) -> Result<crate::types::agent::PlanResponse, AxonFlowError> {
304        let mut context = HashMap::new();
305        context.insert("domain".to_string(), serde_json::json!(domain));
306        let user_token = user_token.unwrap_or("anonymous");
307
308        let resp = self
309            .proxy_llm_call(user_token, query, "generate-plan", context)
310            .await?;
311
312        if let Some(data) = resp.data {
313            let plan: crate::types::agent::PlanResponse = serde_json::from_value(data)?;
314            Ok(plan)
315        } else {
316            Err(AxonFlowError::ApiError {
317                status: 500,
318                message: "empty plan data".to_string(),
319            })
320        }
321    }
322
323    pub async fn execute_plan(
324        &self,
325        plan_id: &str,
326        user_token: Option<&str>,
327    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
328        let mut context = HashMap::new();
329        context.insert("plan_id".to_string(), serde_json::json!(plan_id));
330        let user_token = user_token.unwrap_or("anonymous");
331
332        let resp = self
333            .proxy_llm_call(user_token, "", "execute-plan", context)
334            .await?;
335
336        if let Some(data) = resp.data {
337            let exec: crate::types::agent::PlanExecutionResponse = serde_json::from_value(data)?;
338            Ok(exec)
339        } else {
340            Err(AxonFlowError::ApiError {
341                status: 500,
342                message: "empty execution data".to_string(),
343            })
344        }
345    }
346
347    pub async fn get_plan_status(
348        &self,
349        plan_id: &str,
350    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
351        let encoded_id = utf8_percent_encode(plan_id, PATH_SEGMENT);
352        let url = format!("{}/api/v1/plan/{}", self.config.endpoint, encoded_id);
353        let resp = self.checked_map_get(&url).await?;
354        Ok(resp.json().await?)
355    }
356
357    pub async fn cancel_plan(
358        &self,
359        plan_id: &str,
360        reason: Option<&str>,
361    ) -> Result<crate::types::agent::CancelPlanResponse, AxonFlowError> {
362        let req_body = serde_json::json!({
363            "reason": reason.unwrap_or("user_cancelled"),
364        });
365
366        let encoded_id = utf8_percent_encode(plan_id, PATH_SEGMENT);
367        let url = format!("{}/api/v1/plan/{}/cancel", self.config.endpoint, encoded_id);
368        let resp = self
369            .map_http_client
370            .post(&url)
371            .json(&req_body)
372            .send()
373            .await?;
374        let resp = Self::check_status(resp).await?;
375        Ok(resp.json().await?)
376    }
377
378    pub async fn audit_llm_call(
379        &self,
380        req: &crate::types::agent::AuditRequest,
381    ) -> Result<crate::types::agent::AuditResult, AxonFlowError> {
382        let client_id = self.get_effective_client_id();
383
384        let mut req_body = serde_json::json!({
385            "context_id": req.context_id,
386            "client_id": client_id,
387            "response_summary": req.response_summary,
388            "provider": req.provider,
389            "model": req.model,
390            "token_usage": {
391                "prompt_tokens": req.token_usage.prompt_tokens,
392                "completion_tokens": req.token_usage.completion_tokens,
393                "total_tokens": req.token_usage.total_tokens,
394            },
395            "latency_ms": req.latency_ms,
396        });
397
398        if let Some(meta) = &req.metadata {
399            req_body["metadata"] = serde_json::to_value(meta)?;
400        } else {
401            req_body["metadata"] = serde_json::json!({});
402        }
403
404        let url = format!("{}/api/audit/llm-call", self.config.endpoint);
405        let resp = self.http_client.post(&url).json(&req_body).send().await?;
406
407        let status = resp.status();
408        let body = resp.text().await?;
409
410        if status.is_success() {
411            let audit_resp: crate::types::agent::AuditResult = serde_json::from_str(&body)?;
412            Ok(audit_resp)
413        } else {
414            Err(AxonFlowError::ApiError {
415                status: status.as_u16(),
416                message: body,
417            })
418        }
419    }
420
421    // ============================================================================
422    // Private helpers
423    // ============================================================================
424
425    fn get_effective_client_id(&self) -> String {
426        self.config
427            .client_id
428            .clone()
429            .unwrap_or_else(|| "community".to_string())
430    }
431
432    fn build_cache_key(
433        &self,
434        request_type: &str,
435        query: &str,
436        user_token: &str,
437        context: &HashMap<String, serde_json::Value>,
438    ) -> String {
439        let context_hash = if context.is_empty() {
440            String::new()
441        } else {
442            let sorted: std::collections::BTreeMap<_, _> = context.iter().collect();
443            format!(":{}", serde_json::to_string(&sorted).unwrap_or_default())
444        };
445        format!("{}:{}:{}{}", request_type, query, user_token, context_hash)
446    }
447
448    /// Endpoint URL the client is configured against.
449    /// Crate-internal accessor for sibling modules (e.g. `decisions.rs`)
450    /// that need to build absolute URLs without exposing `config`.
451    pub(crate) fn endpoint(&self) -> &str {
452        &self.config.endpoint
453    }
454
455    pub(crate) async fn checked_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
456        let resp = self.http_client.get(url).send().await?;
457        Self::check_status(resp).await
458    }
459
460    /// Crate-internal GET that returns the raw response without translating
461    /// non-2xx into [`AxonFlowError::ApiError`]. Lets sibling modules branch
462    /// on specific status codes (e.g. parse a 429 V1 upgrade envelope into
463    /// [`AxonFlowError::RateLimited`]) before falling back to the generic
464    /// error path.
465    pub(crate) async fn raw_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
466        Ok(self.http_client.get(url).send().await?)
467    }
468
469    async fn checked_map_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
470        let resp = self.map_http_client.get(url).send().await?;
471        Self::check_status(resp).await
472    }
473
474    async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, AxonFlowError> {
475        if resp.status().is_success() {
476            Ok(resp)
477        } else {
478            let status = resp.status().as_u16();
479            let message = resp.text().await?;
480            Err(AxonFlowError::ApiError { status, message })
481        }
482    }
483
484    async fn execute_with_retry(
485        &self,
486        req: &ClientRequest,
487    ) -> Result<ClientResponse, AxonFlowError> {
488        let mut last_err = None;
489
490        for attempt in 0..self.config.retry.max_attempts {
491            if attempt > 0 {
492                let delay =
493                    self.config.retry.initial_delay.as_secs_f64() * 2f64.powi((attempt - 1) as i32);
494                tokio::time::sleep(Duration::from_secs_f64(delay)).await;
495            }
496
497            match self.execute_request(req).await {
498                Ok(resp) => return Ok(resp),
499                Err(e) => {
500                    if let AxonFlowError::ApiError { status, .. } = &e {
501                        if *status >= 400
502                            && *status < 500
503                            && *status != 429
504                            && *status != 402
505                            && *status != 403
506                        {
507                            return Err(e);
508                        }
509                    }
510                    last_err = Some(e);
511                }
512            }
513        }
514
515        Err(last_err.unwrap_or_else(|| {
516            AxonFlowError::ConfigError("retry loop completed with no attempts".to_string())
517        }))
518    }
519
520    async fn execute_request(&self, req: &ClientRequest) -> Result<ClientResponse, AxonFlowError> {
521        let url = format!("{}/api/request", self.config.endpoint);
522        let resp = self.http_client.post(&url).json(req).send().await?;
523
524        let status = resp.status();
525        let body = resp.text().await?;
526
527        if status.is_success() || status.as_u16() == 402 || status.as_u16() == 403 {
528            let client_resp: ClientResponse = serde_json::from_str(&body)?;
529            Ok(client_resp)
530        } else {
531            Err(AxonFlowError::ApiError {
532                status: status.as_u16(),
533                message: body,
534            })
535        }
536    }
537}