Skip to main content

maple_proxy/
proxy.rs

1use crate::config::{Config, OpenAIError};
2use axum::{
3    extract::State,
4    http::{HeaderMap, StatusCode},
5    response::{IntoResponse, Response, Sse},
6    Json,
7};
8use futures::Stream;
9use opensecret::{
10    ChatCompletionChunk, ChatCompletionRequest, EmbeddingRequest, EmbeddingResponse,
11    ModelsResponse, OpenSecretClient, Result as OpenSecretResult,
12};
13use std::{convert::Infallible, sync::Arc};
14use tracing::{debug, error};
15
16#[derive(Clone)]
17pub struct ProxyState {
18    pub config: Config,
19}
20
21impl ProxyState {
22    pub fn new(config: Config) -> Self {
23        Self { config }
24    }
25}
26
27pub async fn health_check() -> impl IntoResponse {
28    Json(serde_json::json!({
29        "status": "ok",
30        "service": "maple-proxy",
31        "version": env!("CARGO_PKG_VERSION")
32    }))
33}
34
35fn extract_api_key(
36    headers: &HeaderMap,
37    default_key: &Option<String>,
38) -> Result<String, OpenAIError> {
39    // Try to get API key from Authorization header first
40    if let Some(auth_header) = headers.get("authorization") {
41        let auth_str = auth_header.to_str().map_err(|_| {
42            OpenAIError::authentication_error("Invalid Authorization header format")
43        })?;
44
45        if let Some(key) = auth_str.strip_prefix("Bearer ") {
46            return Ok(key.to_string());
47        }
48    }
49
50    // Fall back to default API key from config
51    default_key
52        .as_ref()
53        .cloned()
54        .ok_or_else(|| OpenAIError::authentication_error("No API key provided. Set MAPLE_API_KEY environment variable or provide Authorization header"))
55}
56
57async fn create_client_with_auth(
58    backend_url: &str,
59    api_key: &str,
60) -> Result<OpenSecretClient, OpenAIError> {
61    let client = OpenSecretClient::new_with_api_key(backend_url, api_key.to_string())
62        .map_err(|e| OpenAIError::server_error(format!("Failed to create client: {}", e)))?;
63
64    // Perform attestation handshake
65    client.perform_attestation_handshake().await.map_err(|e| {
66        error!("Attestation handshake failed: {}", e);
67        OpenAIError::server_error("Failed to establish secure connection with Maple backend")
68    })?;
69
70    Ok(client)
71}
72
73pub async fn list_models(
74    State(state): State<Arc<ProxyState>>,
75    headers: HeaderMap,
76) -> Result<Json<ModelsResponse>, (StatusCode, Json<OpenAIError>)> {
77    let api_key = extract_api_key(&headers, &state.config.default_api_key)
78        .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
79
80    debug!(
81        "Listing models for API key: {}...",
82        &api_key[..8.min(api_key.len())]
83    );
84
85    let client = create_client_with_auth(&state.config.backend_url, &api_key)
86        .await
87        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
88
89    let models = client.get_models().await.map_err(|e| {
90        error!("Failed to get models: {}", e);
91        (
92            StatusCode::INTERNAL_SERVER_ERROR,
93            Json(OpenAIError::server_error(format!(
94                "Failed to retrieve models: {}",
95                e
96            ))),
97        )
98    })?;
99
100    debug!("Successfully retrieved {} models", models.data.len());
101    Ok(Json(models))
102}
103
104pub async fn create_chat_completion(
105    State(state): State<Arc<ProxyState>>,
106    headers: HeaderMap,
107    Json(mut request): Json<ChatCompletionRequest>,
108) -> Result<Response, (StatusCode, Json<OpenAIError>)> {
109    let api_key = extract_api_key(&headers, &state.config.default_api_key)
110        .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
111
112    debug!(
113        "Chat completion request for model: {}, stream: {:?}",
114        request.model,
115        request.stream.unwrap_or(false)
116    );
117
118    let client = create_client_with_auth(&state.config.backend_url, &api_key)
119        .await
120        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
121
122    // Check if streaming is requested
123    if request.stream.unwrap_or(false) {
124        // Handle streaming response
125        let stream = client
126            .create_chat_completion_stream(request)
127            .await
128            .map_err(|e| {
129                error!("Failed to create streaming chat completion: {}", e);
130                (
131                    StatusCode::INTERNAL_SERVER_ERROR,
132                    Json(OpenAIError::server_error(format!(
133                        "Failed to create streaming completion: {}",
134                        e
135                    ))),
136                )
137            })?;
138
139        let sse_stream = create_sse_stream(stream);
140        Ok(Sse::new(sse_stream).into_response())
141    } else {
142        // Handle non-streaming response
143        request.stream = Some(false); // Ensure it's explicitly false
144
145        let response = client.create_chat_completion(request).await.map_err(|e| {
146            error!("Failed to create chat completion: {}", e);
147            (
148                StatusCode::INTERNAL_SERVER_ERROR,
149                Json(OpenAIError::server_error(format!(
150                    "Failed to create completion: {}",
151                    e
152                ))),
153            )
154        })?;
155
156        debug!("Successfully created chat completion: {}", response.id);
157        Ok(Json(response).into_response())
158    }
159}
160
161fn create_sse_stream(
162    mut stream: std::pin::Pin<Box<dyn Stream<Item = OpenSecretResult<ChatCompletionChunk>> + Send>>,
163) -> impl Stream<Item = Result<axum::response::sse::Event, Infallible>> {
164    async_stream::stream! {
165        use futures::StreamExt;
166
167        while let Some(chunk_result) = stream.next().await {
168            match chunk_result {
169                Ok(chunk) => {
170                    match serde_json::to_string(&chunk) {
171                        Ok(json) => {
172                            let event = axum::response::sse::Event::default()
173                                .data(json);
174                            yield Ok(event);
175                        }
176                        Err(e) => {
177                            error!("Failed to serialize chunk: {}", e);
178                            let error_event = axum::response::sse::Event::default()
179                                .data(format!(r#"{{"error": "Failed to serialize chunk: {}"}}"#, e));
180                            yield Ok(error_event);
181                            break;
182                        }
183                    }
184                }
185                Err(e) => {
186                    error!("Stream error: {}", e);
187                    let error_event = axum::response::sse::Event::default()
188                        .data(format!(r#"{{"error": "Stream error: {}"}}"#, e));
189                    yield Ok(error_event);
190                    break;
191                }
192            }
193        }
194
195        // Send [DONE] event to indicate end of stream
196        let done_event = axum::response::sse::Event::default()
197            .data("[DONE]");
198        yield Ok(done_event);
199    }
200}
201
202pub async fn create_embeddings(
203    State(state): State<Arc<ProxyState>>,
204    headers: HeaderMap,
205    Json(request): Json<EmbeddingRequest>,
206) -> Result<Json<EmbeddingResponse>, (StatusCode, Json<OpenAIError>)> {
207    let api_key = extract_api_key(&headers, &state.config.default_api_key)
208        .map_err(|e| (StatusCode::UNAUTHORIZED, Json(e)))?;
209
210    debug!("Embeddings request for model: {}", request.model);
211
212    let client = create_client_with_auth(&state.config.backend_url, &api_key)
213        .await
214        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(e)))?;
215
216    let response = client.create_embeddings(request).await.map_err(|e| {
217        error!("Failed to create embeddings: {}", e);
218        (
219            StatusCode::INTERNAL_SERVER_ERROR,
220            Json(OpenAIError::server_error(format!(
221                "Failed to create embeddings: {}",
222                e
223            ))),
224        )
225    })?;
226
227    debug!(
228        "Successfully created embeddings with {} vectors",
229        response.data.len()
230    );
231    Ok(Json(response))
232}