Skip to main content

axonflow_sdk_rust/
client.rs

1use crate::config::{AxonFlowConfig, Mode};
2use crate::error::AxonFlowError;
3use crate::heartbeat::maybe_send_heartbeat;
4use crate::types::agent::{ClientRequest, ClientResponse};
5use base64::engine::general_purpose::STANDARD as BASE64_STD;
6use base64::Engine as _;
7use moka::future::Cache;
8use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, warn};
14
15const LICENSE_KEY_HEADER: &str = "X-License-Key";
16
17#[derive(Clone)]
18pub struct AxonFlowClient {
19    config: AxonFlowConfig,
20    http_client: reqwest::Client,
21    map_http_client: reqwest::Client,
22    cache: Option<Arc<Cache<String, ClientResponse>>>,
23}
24
25impl AxonFlowClient {
26    pub fn new(mut config: AxonFlowConfig) -> Result<Self, AxonFlowError> {
27        if config.retry.max_attempts == 0 {
28            return Err(AxonFlowError::ConfigError(
29                "retry.max_attempts must be at least 1".to_string(),
30            ));
31        }
32
33        if std::env::var("AXONFLOW_TRY").unwrap_or_default() == "1" {
34            config.endpoint = "https://try.getaxonflow.com".to_string();
35            if config.client_id.is_none() {
36                return Err(AxonFlowError::ConfigError(
37                    "ClientID is required in try mode (AXONFLOW_TRY=1).".to_string(),
38                ));
39            }
40        }
41
42        if config.client_secret.is_some() && config.client_id.is_none() {
43            warn!("ClientID is required when ClientSecret is set.");
44        }
45
46        let mut headers = HeaderMap::new();
47        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
48        headers.insert(
49            "User-Agent",
50            HeaderValue::from_static(concat!("axonflow-sdk-rust/", env!("CARGO_PKG_VERSION"))),
51        );
52
53        // HTTP Basic auth: "Basic base64(client_id:client_secret)".
54        // When neither is configured, default to the community tenant —
55        // matches the cross-SDK contract (see axonflow-sdk-go selfhosted_auth_headers_test.go).
56        let basic_id = config
57            .client_id
58            .clone()
59            .unwrap_or_else(|| "community".to_string());
60        let basic_secret = config.client_secret.clone().unwrap_or_default();
61        let basic_credentials = BASE64_STD.encode(format!("{}:{}", basic_id, basic_secret));
62        let basic_value = format!("Basic {}", basic_credentials);
63        if let Ok(val) = HeaderValue::from_str(&basic_value) {
64            headers.insert(AUTHORIZATION, val);
65        }
66
67        // Enterprise license key — sent only when configured.
68        if let Some(license_key) = &config.license_key {
69            if let Ok(mut val) = HeaderValue::from_str(license_key) {
70                val.set_sensitive(true);
71                headers.insert(LICENSE_KEY_HEADER, val);
72            }
73        }
74
75        let accept_invalid = config.insecure_skip_tls_verify
76            || std::env::var("AXONFLOW_INSECURE_TLS").unwrap_or_default() == "1";
77
78        if accept_invalid {
79            warn!("TLS certificate verification is disabled.");
80        }
81
82        let http_client = reqwest::Client::builder()
83            .timeout(config.timeout)
84            .default_headers(headers.clone())
85            .danger_accept_invalid_certs(accept_invalid)
86            .build()
87            .map_err(AxonFlowError::HttpError)?;
88
89        let map_http_client = reqwest::Client::builder()
90            .timeout(config.map_timeout)
91            .default_headers(headers)
92            .danger_accept_invalid_certs(accept_invalid)
93            .build()
94            .map_err(AxonFlowError::HttpError)?;
95
96        let cache = if config.cache.enabled {
97            Some(Arc::new(
98                Cache::builder().time_to_live(config.cache.ttl).build(),
99            ))
100        } else {
101            None
102        };
103
104        maybe_send_heartbeat(&config.endpoint);
105
106        Ok(Self {
107            config,
108            http_client,
109            map_http_client,
110            cache,
111        })
112    }
113
114    pub async fn proxy_llm_call(
115        &self,
116        user_token: &str,
117        query: &str,
118        request_type: &str,
119        context: HashMap<String, serde_json::Value>,
120    ) -> Result<ClientResponse, AxonFlowError> {
121        let user_token = if user_token.is_empty() {
122            "anonymous"
123        } else {
124            user_token
125        };
126
127        let is_mutation = matches!(
128            request_type,
129            "execute-plan" | "generate-plan" | "cancel-plan" | "update-plan"
130        );
131
132        if !is_mutation {
133            if let Some(cache) = &self.cache {
134                let cache_key = self.build_cache_key(request_type, query, user_token, &context);
135                if let Some(cached) = cache.get(&cache_key).await {
136                    debug!("Cache hit for query");
137                    return Ok(cached);
138                }
139            }
140        }
141
142        let req = ClientRequest {
143            query: query.to_string(),
144            user_token: user_token.to_string(),
145            client_id: self.config.client_id.clone(),
146            request_type: request_type.to_string(),
147            context,
148            media: None,
149        };
150
151        let resp = if self.config.retry.enabled && !is_mutation {
152            self.execute_with_retry(&req).await
153        } else {
154            self.execute_request(&req).await
155        };
156
157        match resp {
158            Ok(response) => {
159                if response.success && !is_mutation {
160                    if let Some(cache) = &self.cache {
161                        let cache_key =
162                            self.build_cache_key(request_type, query, user_token, &req.context);
163                        cache.insert(cache_key, response.clone()).await;
164                    }
165                }
166                Ok(response)
167            }
168            Err(e) => {
169                if self.config.mode == Mode::Production && e.is_fail_open_eligible() {
170                    debug!("AxonFlow unavailable, failing open: {}", e);
171                    Ok(ClientResponse::fail_open(e))
172                } else {
173                    Err(e)
174                }
175            }
176        }
177    }
178
179    // ============================================================================
180    // MCP Connector Management
181    // ============================================================================
182
183    pub async fn list_connectors(
184        &self,
185    ) -> Result<Vec<crate::types::agent::ConnectorMetadata>, AxonFlowError> {
186        let url = format!("{}/api/v1/connectors", self.config.endpoint);
187        let resp = self.checked_get(&url).await?;
188
189        let body: serde_json::Value = resp.json().await?;
190        let connectors = body["connectors"]
191            .as_array()
192            .ok_or_else(|| AxonFlowError::ApiError {
193                status: 200,
194                message: "response missing 'connectors' field".to_string(),
195            })?;
196
197        let result = serde_json::from_value(serde_json::Value::Array(connectors.clone()))?;
198        Ok(result)
199    }
200
201    pub async fn get_connector(
202        &self,
203        connector_id: &str,
204    ) -> Result<crate::types::agent::ConnectorMetadata, AxonFlowError> {
205        let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
206        let url = format!("{}/api/v1/connectors/{}", self.config.endpoint, encoded_id);
207        let resp = self.checked_get(&url).await?;
208        Ok(resp.json().await?)
209    }
210
211    pub async fn get_connector_health(
212        &self,
213        connector_id: &str,
214    ) -> Result<crate::types::agent::ConnectorHealthStatus, AxonFlowError> {
215        let encoded_id = utf8_percent_encode(connector_id, NON_ALPHANUMERIC);
216        let url = format!(
217            "{}/api/v1/connectors/{}/health",
218            self.config.endpoint, encoded_id
219        );
220        let resp = self.checked_get(&url).await?;
221        Ok(resp.json().await?)
222    }
223
224    pub async fn install_connector(
225        &self,
226        req: crate::types::agent::ConnectorInstallRequest,
227    ) -> Result<(), AxonFlowError> {
228        let encoded_id = utf8_percent_encode(&req.connector_id, NON_ALPHANUMERIC);
229        let url = format!(
230            "{}/api/v1/connectors/{}/install",
231            self.config.endpoint, encoded_id
232        );
233        let resp = self.http_client.post(&url).json(&req).send().await?;
234        Self::check_status(resp).await?;
235        Ok(())
236    }
237
238    pub async fn query_connector(
239        &self,
240        user_token: &str,
241        connector_name: &str,
242        query: &str,
243        params: HashMap<String, serde_json::Value>,
244    ) -> Result<crate::types::agent::ConnectorResponse, AxonFlowError> {
245        // Connector queries are dispatched through the agent's proxy endpoint
246        // with request_type=mcp-query — there is no standalone /api/v1/query.
247        // Mirror the Go SDK's QueryConnector contract.
248        let mut context = HashMap::new();
249        context.insert("connector".to_string(), serde_json::json!(connector_name));
250        context.insert("params".to_string(), serde_json::json!(params));
251
252        let resp = self
253            .proxy_llm_call(user_token, query, "mcp-query", context)
254            .await?;
255
256        Ok(crate::types::agent::ConnectorResponse {
257            success: resp.success,
258            data: resp.data.unwrap_or(serde_json::Value::Null),
259            error: resp.error,
260            meta: resp.metadata,
261            redacted: false,
262            redacted_fields: Vec::new(),
263            policy_info: None,
264        })
265    }
266
267    // ============================================================================
268    // Multi-Agent Planning (MAP)
269    // ============================================================================
270
271    pub async fn generate_plan(
272        &self,
273        query: &str,
274        domain: &str,
275        user_token: Option<&str>,
276    ) -> Result<crate::types::agent::PlanResponse, AxonFlowError> {
277        let mut context = HashMap::new();
278        context.insert("domain".to_string(), serde_json::json!(domain));
279        let user_token = user_token.unwrap_or("anonymous");
280
281        let resp = self
282            .proxy_llm_call(user_token, query, "generate-plan", context)
283            .await?;
284
285        if let Some(data) = resp.data {
286            let plan: crate::types::agent::PlanResponse = serde_json::from_value(data)?;
287            Ok(plan)
288        } else {
289            Err(AxonFlowError::ApiError {
290                status: 500,
291                message: "empty plan data".to_string(),
292            })
293        }
294    }
295
296    pub async fn execute_plan(
297        &self,
298        plan_id: &str,
299        user_token: Option<&str>,
300    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
301        let mut context = HashMap::new();
302        context.insert("plan_id".to_string(), serde_json::json!(plan_id));
303        let user_token = user_token.unwrap_or("anonymous");
304
305        let resp = self
306            .proxy_llm_call(user_token, "", "execute-plan", context)
307            .await?;
308
309        if let Some(data) = resp.data {
310            let exec: crate::types::agent::PlanExecutionResponse = serde_json::from_value(data)?;
311            Ok(exec)
312        } else {
313            Err(AxonFlowError::ApiError {
314                status: 500,
315                message: "empty execution data".to_string(),
316            })
317        }
318    }
319
320    pub async fn get_plan_status(
321        &self,
322        plan_id: &str,
323    ) -> Result<crate::types::agent::PlanExecutionResponse, AxonFlowError> {
324        let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
325        let url = format!("{}/api/v1/plan/{}", self.config.endpoint, encoded_id);
326        let resp = self.checked_map_get(&url).await?;
327        Ok(resp.json().await?)
328    }
329
330    pub async fn cancel_plan(
331        &self,
332        plan_id: &str,
333        reason: Option<&str>,
334    ) -> Result<crate::types::agent::CancelPlanResponse, AxonFlowError> {
335        let req_body = serde_json::json!({
336            "reason": reason.unwrap_or("user_cancelled"),
337        });
338
339        let encoded_id = utf8_percent_encode(plan_id, NON_ALPHANUMERIC);
340        let url = format!("{}/api/v1/plan/{}/cancel", self.config.endpoint, encoded_id);
341        let resp = self
342            .map_http_client
343            .post(&url)
344            .json(&req_body)
345            .send()
346            .await?;
347        let resp = Self::check_status(resp).await?;
348        Ok(resp.json().await?)
349    }
350
351    pub async fn audit_llm_call(
352        &self,
353        req: &crate::types::agent::AuditRequest,
354    ) -> Result<crate::types::agent::AuditResult, AxonFlowError> {
355        let client_id = self.get_effective_client_id();
356
357        let mut req_body = serde_json::json!({
358            "context_id": req.context_id,
359            "client_id": client_id,
360            "response_summary": req.response_summary,
361            "provider": req.provider,
362            "model": req.model,
363            "token_usage": {
364                "prompt_tokens": req.token_usage.prompt_tokens,
365                "completion_tokens": req.token_usage.completion_tokens,
366                "total_tokens": req.token_usage.total_tokens,
367            },
368            "latency_ms": req.latency_ms,
369        });
370
371        if let Some(meta) = &req.metadata {
372            req_body["metadata"] = serde_json::to_value(meta)?;
373        } else {
374            req_body["metadata"] = serde_json::json!({});
375        }
376
377        let url = format!("{}/api/audit/llm-call", self.config.endpoint);
378        let resp = self.http_client.post(&url).json(&req_body).send().await?;
379
380        let status = resp.status();
381        let body = resp.text().await?;
382
383        if status.is_success() {
384            let audit_resp: crate::types::agent::AuditResult = serde_json::from_str(&body)?;
385            Ok(audit_resp)
386        } else {
387            Err(AxonFlowError::ApiError {
388                status: status.as_u16(),
389                message: body,
390            })
391        }
392    }
393
394    // ============================================================================
395    // Private helpers
396    // ============================================================================
397
398    fn get_effective_client_id(&self) -> String {
399        self.config
400            .client_id
401            .clone()
402            .unwrap_or_else(|| "community".to_string())
403    }
404
405    fn build_cache_key(
406        &self,
407        request_type: &str,
408        query: &str,
409        user_token: &str,
410        context: &HashMap<String, serde_json::Value>,
411    ) -> String {
412        let context_hash = if context.is_empty() {
413            String::new()
414        } else {
415            let sorted: std::collections::BTreeMap<_, _> = context.iter().collect();
416            format!(":{}", serde_json::to_string(&sorted).unwrap_or_default())
417        };
418        format!("{}:{}:{}{}", request_type, query, user_token, context_hash)
419    }
420
421    async fn checked_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
422        let resp = self.http_client.get(url).send().await?;
423        Self::check_status(resp).await
424    }
425
426    async fn checked_map_get(&self, url: &str) -> Result<reqwest::Response, AxonFlowError> {
427        let resp = self.map_http_client.get(url).send().await?;
428        Self::check_status(resp).await
429    }
430
431    async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, AxonFlowError> {
432        if resp.status().is_success() {
433            Ok(resp)
434        } else {
435            let status = resp.status().as_u16();
436            let message = resp.text().await?;
437            Err(AxonFlowError::ApiError { status, message })
438        }
439    }
440
441    async fn execute_with_retry(
442        &self,
443        req: &ClientRequest,
444    ) -> Result<ClientResponse, AxonFlowError> {
445        let mut last_err = None;
446
447        for attempt in 0..self.config.retry.max_attempts {
448            if attempt > 0 {
449                let delay =
450                    self.config.retry.initial_delay.as_secs_f64() * 2f64.powi((attempt - 1) as i32);
451                tokio::time::sleep(Duration::from_secs_f64(delay)).await;
452            }
453
454            match self.execute_request(req).await {
455                Ok(resp) => return Ok(resp),
456                Err(e) => {
457                    if let AxonFlowError::ApiError { status, .. } = &e {
458                        if *status >= 400
459                            && *status < 500
460                            && *status != 429
461                            && *status != 402
462                            && *status != 403
463                        {
464                            return Err(e);
465                        }
466                    }
467                    last_err = Some(e);
468                }
469            }
470        }
471
472        Err(last_err.unwrap_or_else(|| {
473            AxonFlowError::ConfigError("retry loop completed with no attempts".to_string())
474        }))
475    }
476
477    async fn execute_request(&self, req: &ClientRequest) -> Result<ClientResponse, AxonFlowError> {
478        let url = format!("{}/api/request", self.config.endpoint);
479        let resp = self.http_client.post(&url).json(req).send().await?;
480
481        let status = resp.status();
482        let body = resp.text().await?;
483
484        if status.is_success() || status.as_u16() == 402 || status.as_u16() == 403 {
485            let client_resp: ClientResponse = serde_json::from_str(&body)?;
486            Ok(client_resp)
487        } else {
488            Err(AxonFlowError::ApiError {
489                status: status.as_u16(),
490                message: body,
491            })
492        }
493    }
494}