use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::channels::{Channel, ChannelError, InboundMessage, OutboundError, OutboundMessage};
use crate::message::Message;
use crate::namespace::Namespace;
#[derive(Debug, Clone)]
pub struct GatewayConfig {
pub host: String,
pub port: u16,
pub api_key: Option<String>,
pub max_body_size: usize,
pub cors_origins: Vec<String>,
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".into(),
port: 8080,
api_key: None,
max_body_size: 1024 * 1024,
cors_origins: vec!["*".into()],
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ChatRequest {
pub message: String,
pub namespace: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ChatResponse {
pub message: String,
pub namespace: String,
pub usage: ChatUsage,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ErrorResponse {
pub error: String,
pub code: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type")]
pub enum WsMessage {
#[serde(rename = "chat")]
Chat {
message: String,
namespace: Option<String>,
model: Option<String>,
agent: Option<String>,
},
#[serde(rename = "text_delta")]
TextDelta { content: String },
#[serde(rename = "response")]
Response {
message: String,
namespace: String,
usage: ChatUsage,
#[serde(skip_serializing_if = "Option::is_none")]
agent: Option<String>,
},
#[serde(rename = "error")]
Error { error: String },
#[serde(rename = "tool_approval_request")]
ToolApprovalRequest {
call_id: String,
tool_name: String,
arguments: serde_json::Value,
},
#[serde(rename = "tool_approval_response")]
ToolApprovalResponse {
call_id: String,
approved: bool,
},
#[serde(rename = "ping")]
Ping,
#[serde(rename = "pong")]
Pong,
}
pub struct GatewayChannel {
config: GatewayConfig,
inbound_tx: tokio::sync::mpsc::Sender<InboundMessage>,
inbound_rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<InboundMessage>>,
response_map: Arc<tokio::sync::RwLock<HashMap<String, ResponseSender>>>,
}
type ResponseSender = tokio::sync::oneshot::Sender<Result<OutboundMessage, OutboundError>>;
impl GatewayChannel {
pub fn new(config: GatewayConfig) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(256);
Self {
config,
inbound_tx: tx,
inbound_rx: tokio::sync::Mutex::new(rx),
response_map: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub fn config(&self) -> &GatewayConfig {
&self.config
}
pub async fn submit_and_wait(
&self,
request: ChatRequest,
) -> Result<ChatResponse, GatewayError> {
let ns_key = request
.namespace
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let request_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
self.response_map
.write()
.await
.insert(request_id.clone(), tx);
let mut metadata = request.metadata;
metadata.insert("_gateway_request_id".into(), serde_json::json!(request_id));
let inbound = InboundMessage {
namespace: Namespace::parse(&ns_key),
message: Message::user(request.message),
metadata,
};
self.inbound_tx
.send(inbound)
.await
.map_err(|_| GatewayError::ChannelClosed)?;
let result = rx
.await
.map_err(|_| GatewayError::ResponseDropped)?;
match result {
Ok(outbound) => Ok(ChatResponse {
message: outbound.message.content.clone(),
namespace: ns_key,
usage: ChatUsage {
input_tokens: outbound.run_result.total_usage.input_tokens,
output_tokens: outbound.run_result.total_usage.output_tokens,
total_tokens: outbound.run_result.total_usage.total_tokens(),
},
}),
Err(err) => Err(GatewayError::Runtime(err.error)),
}
}
pub fn authenticate(&self, bearer_token: Option<&str>) -> bool {
match &self.config.api_key {
None => true,
Some(expected) => bearer_token == Some(expected.as_str()),
}
}
async fn route_response(
&self,
metadata: &HashMap<String, serde_json::Value>,
result: Result<OutboundMessage, OutboundError>,
) -> Result<(), GatewayError> {
let request_id = metadata
.get("_gateway_request_id")
.and_then(|v| v.as_str())
.ok_or_else(|| GatewayError::MissingRequestId)?;
let sender = self
.response_map
.write()
.await
.remove(request_id)
.ok_or_else(|| GatewayError::NoWaitingRequest(request_id.to_string()))?;
let _ = sender.send(result);
Ok(())
}
}
#[async_trait]
impl Channel for GatewayChannel {
async fn receive(&self) -> Option<InboundMessage> {
self.inbound_rx.lock().await.recv().await
}
async fn send(&self, response: OutboundMessage) -> Result<(), ChannelError> {
self.route_response(&response.metadata, Ok(response.clone()))
.await
.map_err(|e| ChannelError::Send(e.to_string()))
}
async fn send_error(&self, error: OutboundError) -> Result<(), ChannelError> {
self.route_response(&error.metadata, Err(error.clone()))
.await
.map_err(|e| ChannelError::Send(e.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
pub enum GatewayError {
#[error("gateway channel closed")]
ChannelClosed,
#[error("response sender was dropped")]
ResponseDropped,
#[error("runtime error: {0}")]
Runtime(String),
#[error("missing gateway request ID in metadata")]
MissingRequestId,
#[error("no waiting request for ID: {0}")]
NoWaitingRequest(String),
#[error("authentication failed")]
Unauthorized,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_gateway_config() {
let config = GatewayConfig::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert!(config.api_key.is_none());
assert_eq!(config.max_body_size, 1024 * 1024);
assert_eq!(config.cors_origins, vec!["*"]);
}
#[test]
fn chat_request_deserialization() {
let json = r#"{"message": "Hello!", "namespace": "user-123"}"#;
let req: ChatRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.message, "Hello!");
assert_eq!(req.namespace.as_deref(), Some("user-123"));
assert!(req.metadata.is_empty());
}
#[test]
fn chat_request_minimal() {
let json = r#"{"message": "Hi"}"#;
let req: ChatRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.message, "Hi");
assert!(req.namespace.is_none());
}
#[test]
fn chat_response_serialization() {
let resp = ChatResponse {
message: "Hello!".into(),
namespace: "ns-1".into(),
usage: ChatUsage {
input_tokens: 10,
output_tokens: 5,
total_tokens: 15,
},
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("Hello!"));
assert!(json.contains("ns-1"));
}
#[test]
fn ws_message_chat_serialization() {
let msg = WsMessage::Chat {
message: "Hello".into(),
namespace: Some("test".into()),
model: None,
agent: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"chat\""));
assert!(json.contains("Hello"));
}
#[test]
fn ws_message_text_delta_serialization() {
let msg = WsMessage::TextDelta {
content: "chunk".into(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"text_delta\""));
}
#[test]
fn authenticate_no_key_required() {
let channel = GatewayChannel::new(GatewayConfig::default());
assert!(channel.authenticate(None));
assert!(channel.authenticate(Some("anything")));
}
#[test]
fn authenticate_with_key() {
let config = GatewayConfig {
api_key: Some("secret-key".into()),
..Default::default()
};
let channel = GatewayChannel::new(config);
assert!(channel.authenticate(Some("secret-key")));
assert!(!channel.authenticate(Some("wrong-key")));
assert!(!channel.authenticate(None));
}
#[tokio::test]
async fn submit_and_route_roundtrip() {
let channel = GatewayChannel::new(GatewayConfig::default());
let channel_ref = &channel;
let submit_handle = tokio::spawn({
let inbound_tx = channel.inbound_tx.clone();
let response_map = channel.response_map.clone();
async move {
let request = ChatRequest {
message: "test message".into(),
namespace: Some("test-ns".into()),
metadata: HashMap::new(),
};
assert_eq!(request.message, "test message");
}
});
submit_handle.await.unwrap();
}
#[test]
fn error_response_serialization() {
let err = ErrorResponse {
error: "something went wrong".into(),
code: "internal_error".into(),
};
let json = serde_json::to_string(&err).unwrap();
assert!(json.contains("internal_error"));
}
#[test]
fn ws_message_roundtrip() {
let msg = WsMessage::Response {
message: "Hi there".into(),
namespace: "ns".into(),
usage: ChatUsage {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
},
agent: Some("Atlas".into()),
};
let json = serde_json::to_string(&msg).unwrap();
let parsed: WsMessage = serde_json::from_str(&json).unwrap();
match parsed {
WsMessage::Response {
message, namespace, ..
} => {
assert_eq!(message, "Hi there");
assert_eq!(namespace, "ns");
}
_ => panic!("expected Response variant"),
}
}
}