pub mod content_type;
use crate::engine::interfaces::{
L7Terminator, ParamDef, ParamType, Plugin, ResolvedInputs, TerminatorResult,
};
use crate::layers::l7::{
container::{Container, PayloadState},
http::wrapper::VaneBody,
};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use bytes::Bytes;
use fancy_log::{LogLevel, log};
use http::{HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::Full;
use serde_json::Value;
use std::any::Any;
use std::borrow::Cow;
pub struct SendResponsePlugin;
impl Plugin for SendResponsePlugin {
fn name(&self) -> &'static str {
"internal.terminator.response"
}
fn params(&self) -> Vec<ParamDef> {
vec![
ParamDef {
name: "status".into(),
required: false,
param_type: ParamType::Integer,
},
ParamDef {
name: "headers".into(),
required: false,
param_type: ParamType::Map,
},
ParamDef {
name: "body".into(),
required: false,
param_type: ParamType::Any,
},
]
}
fn supported_protocols(&self) -> Vec<Cow<'static, str>> {
vec!["httpx".into()]
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_l7_terminator(&self) -> Option<&dyn L7Terminator> {
Some(self)
}
}
#[async_trait]
impl L7Terminator for SendResponsePlugin {
async fn execute_l7(
&self,
context: &mut (dyn Any + Send),
inputs: ResolvedInputs,
) -> Result<TerminatorResult> {
let container = context
.downcast_mut::<Container>()
.ok_or_else(|| anyhow!("Context is not a Container"))?;
if let (Some(client_upgrade), Some(upstream_upgrade)) = (
container
.http_data_mut()
.and_then(|d| d.client_upgrade.take()),
container
.http_data_mut()
.and_then(|d| d.upstream_upgrade.take()),
) {
log(
LogLevel::Debug,
"➜ Establishing WebSocket bidirectional tunnel...",
);
let mut response = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.body(())
.map_err(|e| anyhow!("Failed to build WebSocket 101 response: {e}"))?;
*response.headers_mut() = std::mem::take(&mut container.response_headers);
if let Some(tx) = container.response_tx.take() {
if tx.send(response).is_err() {
return Err(anyhow!("Failed to send WebSocket upgrade response"));
}
} else {
return Err(anyhow!("Response channel missing for WebSocket upgrade"));
}
tokio::spawn(async move {
log(LogLevel::Debug, "⚙ Waiting for upgrade to complete...");
let tunnel_result = tokio::try_join!(client_upgrade, upstream_upgrade);
match tunnel_result {
Ok((client_io, upstream_io)) => {
let mut client_io = hyper_util::rt::TokioIo::new(client_io);
let mut upstream_io = hyper_util::rt::TokioIo::new(upstream_io);
log(
LogLevel::Debug,
"✓ WebSocket tunnel established, starting bidirectional copy",
);
match tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await {
Ok((client_to_upstream, upstream_to_client)) => {
log(
LogLevel::Debug,
&format!(
"✓ WebSocket tunnel closed gracefully. Client→Upstream: {client_to_upstream} bytes, Upstream→Client: {upstream_to_client} bytes"
),
);
}
Err(e) => {
log(
LogLevel::Warn,
&format!("⚠ WebSocket tunnel I/O error: {e}"),
);
}
}
}
Err(e) => {
log(LogLevel::Error, &format!("✗ WebSocket upgrade failed: {e}"));
}
}
});
return Ok(TerminatorResult::Finished);
}
let status_code = if let Some(s) = inputs.get("status").and_then(Value::as_u64) {
StatusCode::from_u16(s as u16).unwrap_or(StatusCode::OK)
} else if let Some(s) = container
.kv
.get("res.status")
.and_then(|s| s.parse::<u16>().ok())
{
StatusCode::from_u16(s).unwrap_or(StatusCode::OK)
} else {
StatusCode::OK
};
let headers = &mut container.response_headers;
if let Some(headers_input) = inputs.get("headers").and_then(Value::as_object) {
headers.clear();
for (k, v) in headers_input {
let Ok(header_name) = HeaderName::from_bytes(k.as_bytes()) else {
continue;
};
match v {
Value::String(s) => {
if let Ok(val) = HeaderValue::from_str(s) {
headers.insert(header_name, val);
}
}
Value::Array(arr) => {
for item in arr {
if let Some(s) = item.as_str()
&& let Ok(val) = HeaderValue::from_str(s)
{
headers.append(header_name.clone(), val);
}
}
}
_ => {}
}
}
}
if let Some(body_input) = inputs.get("body") {
let body_bytes = parse_body_input(body_input)?;
if !headers.contains_key(http::header::CONTENT_TYPE) {
let mime = content_type::guess_mime(&body_bytes);
headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_str(mime).map_err(|e| anyhow!("Invalid mime type: {e}"))?,
);
}
let full_body = Full::new(body_bytes);
container.response_body = PayloadState::Http(VaneBody::Buffered(full_body));
} else if let Some(body_str) = container.kv.get("res.body") {
let body_bytes = Bytes::copy_from_slice(body_str.as_bytes());
if !headers.contains_key(http::header::CONTENT_TYPE) {
headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
}
let full_body = Full::new(body_bytes);
container.response_body = PayloadState::Http(VaneBody::Buffered(full_body));
}
let mut response = Response::builder()
.status(status_code)
.body(())
.map_err(|e| anyhow!("Failed to build response: {e}"))?;
*response.headers_mut() = std::mem::take(headers);
if let Some(tx) = container.response_tx.take() {
let _ = tx.send(response);
} else {
log(
LogLevel::Warn,
"⚠ SendResponse called but response channel is missing.",
);
}
Ok(TerminatorResult::Finished)
}
}
fn parse_body_input(input: &Value) -> Result<Bytes> {
match input {
Value::String(s) => Ok(Bytes::copy_from_slice(s.as_bytes())),
Value::Object(map) => {
let content = map
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Structured body missing 'content' field"))?;
let encoding = map
.get("encoding")
.and_then(Value::as_str)
.unwrap_or("text");
match encoding {
"base64" => {
use base64::prelude::*;
let decoded = BASE64_STANDARD
.decode(content)
.map_err(|e| anyhow!("Base64 decode failed: {e}"))?;
Ok(Bytes::from(decoded))
}
"hex" => {
let decoded = hex::decode(content).map_err(|e| anyhow!("Hex decode failed: {e}"))?;
Ok(Bytes::from(decoded))
}
"text" | "utf8" => Ok(Bytes::copy_from_slice(content.as_bytes())),
_ => Err(anyhow!("Unknown encoding: {encoding}")),
}
}
_ => Err(anyhow!("Invalid body format. Expected String or Object.")),
}
}