use std::time::Duration;
use axum::{
body::Bytes,
http::{HeaderMap, HeaderValue, Method, StatusCode},
response::{IntoResponse, Response},
};
use indexmap::IndexMap;
use objectiveai_sdk::client_objectiveai_mcp::server_request;
use crate::streaming_ws::{ReverseChannelRegistry, send_server_request};
const REVERSE_CHANNEL_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn handle_request(
session_id: String,
method: Method,
registry: ReverseChannelRegistry,
headers: HeaderMap,
body: Bytes,
) -> Response {
let rc = match registry.get(&session_id) {
Some(rc) => rc.clone(),
None => {
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("no reverse channel for session_id {session_id:?}"),
)
.into_response();
}
};
let forward_headers: IndexMap<String, String> = headers
.iter()
.filter_map(|(k, v)| Some((k.as_str().to_string(), v.to_str().ok()?.to_string())))
.collect();
let body_value: Option<serde_json::Value> = if body.is_empty() {
None
} else {
match serde_json::from_slice(&body) {
Ok(v) => Some(v),
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("body is not valid JSON: {e}"),
)
.into_response();
}
}
};
let request_id = uuid::Uuid::new_v4().to_string();
let request = server_request::Request {
id: request_id.clone(),
method: method.as_str().to_string(),
headers: forward_headers,
body: body_value,
};
let rx = match send_server_request(&rc.sink, &rc.pending, request).await {
Ok(rx) => {
rx
}
Err(()) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
"reverse channel closed before request could be sent",
)
.into_response();
}
};
let server_resp = match tokio::time::timeout(REVERSE_CHANNEL_TIMEOUT, rx).await {
Ok(Ok(r)) => {
r
}
Ok(Err(_)) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
"reverse channel dropped before response arrived",
)
.into_response();
}
Err(_) => {
return (
StatusCode::GATEWAY_TIMEOUT,
"reverse channel timed out waiting for response",
)
.into_response();
}
};
let status = StatusCode::from_u16(server_resp.status)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut builder = axum::response::Response::builder().status(status);
let mut has_content_type = false;
for (k, v) in &server_resp.headers {
if let Ok(value) = HeaderValue::from_str(v) {
if k.eq_ignore_ascii_case("content-type") {
has_content_type = true;
}
builder = builder.header(k, value);
}
}
if !has_content_type && server_resp.body.is_some() {
builder = builder.header("Content-Type", "application/json");
}
let body_bytes: Vec<u8> = match server_resp.body {
Some(v) => serde_json::to_vec(&v).unwrap_or_default(),
None => Vec::new(),
};
let resp = builder
.body(axum::body::Body::from(body_bytes))
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response());
resp
}