Skip to main content

axonflow_sdk_rust/interceptors/
openai.rs

1use crate::client::AxonFlowClient;
2use crate::error::AxonFlowError;
3use crate::types::agent::{AuditRequest, TokenUsage};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Serialize, Deserialize, Clone)]
9pub struct ChatMessage {
10    pub role: String,
11    pub content: String,
12}
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ChatCompletionRequest {
16    pub model: String,
17    pub messages: Vec<ChatMessage>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub temperature: Option<f32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub max_tokens: Option<usize>,
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct ChatCompletionChoice {
26    pub index: usize,
27    pub message: ChatMessage,
28    pub finish_reason: String,
29}
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
32pub struct Usage {
33    pub prompt_tokens: usize,
34    pub completion_tokens: usize,
35    pub total_tokens: usize,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct ChatCompletionResponse {
40    pub id: String,
41    pub object: String,
42    pub created: i64,
43    pub model: String,
44    pub choices: Vec<ChatCompletionChoice>,
45    pub usage: Usage,
46}
47
48#[async_trait]
49pub trait OpenAIChatCompleter: Send + Sync {
50    async fn create_chat_completion(
51        &self,
52        req: ChatCompletionRequest,
53    ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
54}
55
56pub struct WrappedOpenAIClient<C: OpenAIChatCompleter> {
57    client: Arc<C>,
58    axonflow: Arc<AxonFlowClient>,
59    user_token: String,
60}
61
62impl<C: OpenAIChatCompleter> WrappedOpenAIClient<C> {
63    pub fn new(client: C, axonflow: AxonFlowClient, user_token: impl Into<String>) -> Self {
64        Self {
65            client: Arc::new(client),
66            axonflow: Arc::new(axonflow),
67            user_token: user_token.into(),
68        }
69    }
70
71    pub async fn create_chat_completion(
72        &self,
73        req: ChatCompletionRequest,
74    ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
75        // Extract prompt
76        let prompt = req
77            .messages
78            .iter()
79            .map(|m| format!("{}: {}", m.role, m.content))
80            .collect::<Vec<_>>()
81            .join("\n");
82
83        let mut eval_context = std::collections::HashMap::new();
84        eval_context.insert("provider".to_string(), serde_json::json!("openai"));
85        eval_context.insert("model".to_string(), serde_json::json!(req.model));
86        if let Some(t) = req.temperature {
87            eval_context.insert("temperature".to_string(), serde_json::json!(t));
88        }
89        if let Some(m) = req.max_tokens {
90            eval_context.insert("max_tokens".to_string(), serde_json::json!(m));
91        }
92
93        // Check with AxonFlow
94        let start_time = std::time::Instant::now();
95        let response = self
96            .axonflow
97            .proxy_llm_call(&self.user_token, &prompt, "llm_chat", eval_context)
98            .await
99            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
100
101        if response.blocked {
102            return Err(Box::new(AxonFlowError::ApiError {
103                status: 403,
104                message: response
105                    .block_reason
106                    .unwrap_or_else(|| "Blocked by policy".to_string()),
107            }));
108        }
109
110        // Make actual call
111        let result = self.client.create_chat_completion(req.clone()).await?;
112
113        // Audit (async, fire-and-forget)
114        let axonflow = Arc::clone(&self.axonflow);
115        let result_clone = result.clone();
116        let request_id = response.request_id.clone();
117        let latency_ms = start_time.elapsed().as_millis() as i64;
118        let model = req.model.clone();
119
120        tokio::spawn(async move {
121            if let Some(context_id) = request_id {
122                let summary = result_clone
123                    .choices
124                    .first()
125                    .map(|c| {
126                        let content = &c.message.content;
127                        match content.char_indices().nth(100) {
128                            Some((idx, _)) => format!("{}...", &content[..idx]),
129                            None => content.clone(),
130                        }
131                    })
132                    .unwrap_or_default();
133
134                let token_usage = TokenUsage {
135                    prompt_tokens: result_clone.usage.prompt_tokens,
136                    completion_tokens: result_clone.usage.completion_tokens,
137                    total_tokens: result_clone.usage.total_tokens,
138                };
139
140                let audit_req = AuditRequest {
141                    context_id,
142                    response_summary: summary,
143                    provider: "openai".to_string(),
144                    model,
145                    token_usage,
146                    latency_ms,
147                    metadata: None,
148                };
149
150                let _ = axonflow.audit_llm_call(&audit_req).await;
151            }
152        });
153
154        Ok(result)
155    }
156}