1use crate::config::{Config, OpenAIError};
2use axum::{
3 extract::State,
4 http::{HeaderMap, StatusCode},
5 response::{IntoResponse, Response, Sse},
6 Json,
7};
8use futures::Stream;
9use opensecret::{
10 ChatCompletionChunk, ChatCompletionRequest, EmbeddingRequest, EmbeddingResponse,
11 ModelsResponse, OpenSecretClient, Result as OpenSecretResult,
12};
13use std::{convert::Infallible, sync::Arc};
14use tracing::{debug, error};
15
16#[derive(Clone)]
17pub struct ProxyState {
18 pub config: Config,
19}
20
21impl ProxyState {
22 pub fn new(config: Config) -> Self {
23 Self { config }
24 }
25}
26
27pub async fn health_check() -> impl IntoResponse {
28 Json(serde_json::json!({
29 "status": "ok",
30 "service": "maple-proxy",
31 "version": env!("CARGO_PKG_VERSION")
32 }))
33}
34
35fn extract_api_key(
36 headers: &HeaderMap,
37 default_key: &Option<String>,
38) -> Result<String, OpenAIError> {
39 if let Some(auth_header) = headers.get("authorization") {
41 let auth_str = auth_header.to_str().map_err(|_| {
42 OpenAIError::authentication_error("Invalid Authorization header format")
43 })?;
44
45 if let Some(key) = auth_str.strip_prefix("Bearer ") {
46 return Ok(key.to_string());
47 }
48 }
49
50 default_key
52 .as_ref()
53 .cloned()
54 .ok_or_else(|| OpenAIError::authentication_error("No API key provided. Set MAPLE_API_KEY environment variable or provide Authorization header"))
55}
56
57async fn create_client_with_auth(
58 backend_url: &str,
59 api_key: &str,
60) -> Result<OpenSecretClient, OpenAIError> {
61 let client = OpenSecretClient::new_with_api_key(backend_url, api_key.to_string())
62 .map_err(|e| OpenAIError::server_error(format!("Failed to create client: {}", e)))?;
63
64 client.perform_attestation_handshake().await.map_err(|e| {
66 error!("Attestation handshake failed: {}", e);
67 OpenAIError::server_error("Failed to establish secure connection with Maple backend")
68 })?;
69
70 Ok(client)
71}
72
73pub async fn list_models(
74 State(state): State<Arc<ProxyState>>,
75 headers: HeaderMap,
76) -> Result<Json<ModelsResponse>, (StatusCode, Json<OpenAIError>)> {
77 let api_key = extract_api_key(&headers, &state.config.default_api_key)
78 .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
79
80 debug!(
81 "Listing models for API key: {}...",
82 &api_key[..8.min(api_key.len())]
83 );
84
85 let client = create_client_with_auth(&state.config.backend_url, &api_key)
86 .await
87 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
88
89 let models = client.get_models().await.map_err(|e| {
90 error!("Failed to get models: {}", e);
91 (
92 StatusCode::INTERNAL_SERVER_ERROR,
93 Json(OpenAIError::server_error(format!(
94 "Failed to retrieve models: {}",
95 e
96 ))),
97 )
98 })?;
99
100 debug!("Successfully retrieved {} models", models.data.len());
101 Ok(Json(models))
102}
103
104pub async fn create_chat_completion(
105 State(state): State<Arc<ProxyState>>,
106 headers: HeaderMap,
107 Json(mut request): Json<ChatCompletionRequest>,
108) -> Result<Response, (StatusCode, Json<OpenAIError>)> {
109 let api_key = extract_api_key(&headers, &state.config.default_api_key)
110 .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
111
112 debug!(
113 "Chat completion request for model: {}, stream: {:?}",
114 request.model,
115 request.stream.unwrap_or(false)
116 );
117
118 let client = create_client_with_auth(&state.config.backend_url, &api_key)
119 .await
120 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
121
122 if request.stream.unwrap_or(false) {
124 let stream = client
126 .create_chat_completion_stream(request)
127 .await
128 .map_err(|e| {
129 error!("Failed to create streaming chat completion: {}", e);
130 (
131 StatusCode::INTERNAL_SERVER_ERROR,
132 Json(OpenAIError::server_error(format!(
133 "Failed to create streaming completion: {}",
134 e
135 ))),
136 )
137 })?;
138
139 let sse_stream = create_sse_stream(stream);
140 Ok(Sse::new(sse_stream).into_response())
141 } else {
142 request.stream = Some(false); let response = client.create_chat_completion(request).await.map_err(|e| {
146 error!("Failed to create chat completion: {}", e);
147 (
148 StatusCode::INTERNAL_SERVER_ERROR,
149 Json(OpenAIError::server_error(format!(
150 "Failed to create completion: {}",
151 e
152 ))),
153 )
154 })?;
155
156 debug!("Successfully created chat completion: {}", response.id);
157 Ok(Json(response).into_response())
158 }
159}
160
161fn create_sse_stream(
162 mut stream: std::pin::Pin<Box<dyn Stream<Item = OpenSecretResult<ChatCompletionChunk>> + Send>>,
163) -> impl Stream<Item = Result<axum::response::sse::Event, Infallible>> {
164 async_stream::stream! {
165 use futures::StreamExt;
166
167 while let Some(chunk_result) = stream.next().await {
168 match chunk_result {
169 Ok(chunk) => {
170 match serde_json::to_string(&chunk) {
171 Ok(json) => {
172 let event = axum::response::sse::Event::default()
173 .data(json);
174 yield Ok(event);
175 }
176 Err(e) => {
177 error!("Failed to serialize chunk: {}", e);
178 let error_event = axum::response::sse::Event::default()
179 .data(format!(r#"{{"error": "Failed to serialize chunk: {}"}}"#, e));
180 yield Ok(error_event);
181 break;
182 }
183 }
184 }
185 Err(e) => {
186 error!("Stream error: {}", e);
187 let error_event = axum::response::sse::Event::default()
188 .data(format!(r#"{{"error": "Stream error: {}"}}"#, e));
189 yield Ok(error_event);
190 break;
191 }
192 }
193 }
194
195 let done_event = axum::response::sse::Event::default()
197 .data("[DONE]");
198 yield Ok(done_event);
199 }
200}
201
202pub async fn create_embeddings(
203 State(state): State<Arc<ProxyState>>,
204 headers: HeaderMap,
205 Json(request): Json<EmbeddingRequest>,
206) -> Result<Json<EmbeddingResponse>, (StatusCode, Json<OpenAIError>)> {
207 let api_key = extract_api_key(&headers, &state.config.default_api_key)
208 .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
209
210 debug!("Embeddings request for model: {}", request.model);
211
212 let client = create_client_with_auth(&state.config.backend_url, &api_key)
213 .await
214 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
215
216 let response = client.create_embeddings(request).await.map_err(|e| {
217 error!("Failed to create embeddings: {}", e);
218 (
219 StatusCode::INTERNAL_SERVER_ERROR,
220 Json(OpenAIError::server_error(format!(
221 "Failed to create embeddings: {}",
222 e
223 ))),
224 )
225 })?;
226
227 debug!(
228 "Successfully created embeddings with {} vectors",
229 response.data.len()
230 );
231 Ok(Json(response))
232}