axonflow_sdk_rust/interceptors/
openai.rs1use crate::client::AxonFlowClient;
2use crate::error::AxonFlowError;
3use crate::types::agent::{AuditRequest, TokenUsage};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Serialize, Deserialize, Clone)]
9pub struct ChatMessage {
10 pub role: String,
11 pub content: String,
12}
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ChatCompletionRequest {
16 pub model: String,
17 pub messages: Vec<ChatMessage>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub temperature: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub max_tokens: Option<usize>,
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct ChatCompletionChoice {
26 pub index: usize,
27 pub message: ChatMessage,
28 pub finish_reason: String,
29}
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
32pub struct Usage {
33 pub prompt_tokens: usize,
34 pub completion_tokens: usize,
35 pub total_tokens: usize,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct ChatCompletionResponse {
40 pub id: String,
41 pub object: String,
42 pub created: i64,
43 pub model: String,
44 pub choices: Vec<ChatCompletionChoice>,
45 pub usage: Usage,
46}
47
48#[async_trait]
49pub trait OpenAIChatCompleter: Send + Sync {
50 async fn create_chat_completion(
51 &self,
52 req: ChatCompletionRequest,
53 ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
54}
55
56pub struct WrappedOpenAIClient<C: OpenAIChatCompleter> {
57 client: Arc<C>,
58 axonflow: Arc<AxonFlowClient>,
59 user_token: String,
60}
61
62impl<C: OpenAIChatCompleter> WrappedOpenAIClient<C> {
63 pub fn new(client: C, axonflow: AxonFlowClient, user_token: impl Into<String>) -> Self {
64 Self {
65 client: Arc::new(client),
66 axonflow: Arc::new(axonflow),
67 user_token: user_token.into(),
68 }
69 }
70
71 pub async fn create_chat_completion(
72 &self,
73 req: ChatCompletionRequest,
74 ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
75 let prompt = req
77 .messages
78 .iter()
79 .map(|m| format!("{}: {}", m.role, m.content))
80 .collect::<Vec<_>>()
81 .join("\n");
82
83 let mut eval_context = std::collections::HashMap::new();
84 eval_context.insert("provider".to_string(), serde_json::json!("openai"));
85 eval_context.insert("model".to_string(), serde_json::json!(req.model));
86 if let Some(t) = req.temperature {
87 eval_context.insert("temperature".to_string(), serde_json::json!(t));
88 }
89 if let Some(m) = req.max_tokens {
90 eval_context.insert("max_tokens".to_string(), serde_json::json!(m));
91 }
92
93 let start_time = std::time::Instant::now();
95 let response = self
96 .axonflow
97 .proxy_llm_call(&self.user_token, &prompt, "llm_chat", eval_context)
98 .await
99 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
100
101 if response.blocked {
102 return Err(Box::new(AxonFlowError::ApiError {
103 status: 403,
104 message: response
105 .block_reason
106 .unwrap_or_else(|| "Blocked by policy".to_string()),
107 }));
108 }
109
110 let result = self.client.create_chat_completion(req.clone()).await?;
112
113 let axonflow = Arc::clone(&self.axonflow);
115 let result_clone = result.clone();
116 let request_id = response.request_id.clone();
117 let latency_ms = start_time.elapsed().as_millis() as i64;
118 let model = req.model.clone();
119
120 tokio::spawn(async move {
121 if let Some(context_id) = request_id {
122 let summary = result_clone
123 .choices
124 .first()
125 .map(|c| {
126 let content = &c.message.content;
127 match content.char_indices().nth(100) {
128 Some((idx, _)) => format!("{}...", &content[..idx]),
129 None => content.clone(),
130 }
131 })
132 .unwrap_or_default();
133
134 let token_usage = TokenUsage {
135 prompt_tokens: result_clone.usage.prompt_tokens,
136 completion_tokens: result_clone.usage.completion_tokens,
137 total_tokens: result_clone.usage.total_tokens,
138 };
139
140 let audit_req = AuditRequest {
141 context_id,
142 response_summary: summary,
143 provider: "openai".to_string(),
144 model,
145 token_usage,
146 latency_ms,
147 metadata: None,
148 };
149
150 let _ = axonflow.audit_llm_call(&audit_req).await;
151 }
152 });
153
154 Ok(result)
155 }
156}