use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, watch, Mutex, RwLock, Semaphore};
use tokio::task::JoinHandle;
use tokio::time::Duration;
use crate::channel::Channel;
use crate::error_classify::classify_error;
use crate::message::{IncomingMessage, OutgoingMessage, ResponseMeta};
use crate::meta::meta;
use crate::GatewayInbox;
const GATEWAY_BUFFER: usize = 1024;
const MAX_CONCURRENT_ROUTES: usize = 32;
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
struct ChannelEntry {
channel: Arc<dyn Channel>,
shutdown_tx: watch::Sender<bool>,
task: JoinHandle<()>,
}
pub struct Gateway {
channels: Arc<RwLock<HashMap<String, ChannelEntry>>>,
rx: Mutex<mpsc::Receiver<GatewayInbox>>,
tx: mpsc::Sender<GatewayInbox>,
orchestrator: Arc<oxios_kernel::Orchestrator>,
shutdown: watch::Sender<bool>,
concurrency: Arc<Semaphore>,
}
impl Gateway {
pub fn new(orchestrator: Arc<oxios_kernel::Orchestrator>) -> Self {
let (tx, rx) = mpsc::channel(GATEWAY_BUFFER);
let (shutdown, _) = watch::channel(false);
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
rx: Mutex::new(rx),
tx,
orchestrator,
shutdown,
concurrency: Arc::new(Semaphore::new(MAX_CONCURRENT_ROUTES)),
}
}
pub fn signal_shutdown(&self) {
let _ = self.shutdown.send(true);
tracing::info!("Gateway shutdown signal sent");
}
pub fn is_shutdown(&self) -> bool {
*self.shutdown.borrow()
}
pub async fn register(&self, channel: Box<dyn Channel>) -> Result<()> {
let name = channel.name().to_owned();
let (ch_shutdown, ch_shutdown_rx) = watch::channel(false);
let ch_arc: Arc<dyn Channel> = Arc::from(channel);
let task = ch_arc.start(self.tx.clone(), ch_shutdown_rx).await?;
self.channels.write().await.insert(
name.clone(),
ChannelEntry {
channel: ch_arc,
shutdown_tx: ch_shutdown,
task,
},
);
tracing::info!(channel = %name, "Channel registered and started");
Ok(())
}
pub async fn unregister(&self, name: &str) -> Result<()> {
let entry = self.channels.write().await.remove(name);
if let Some(entry) = entry {
let _ = entry.shutdown_tx.send(true);
let _ = tokio::time::timeout(SHUTDOWN_TIMEOUT, entry.task).await;
tracing::info!(channel = %name, "Channel unregistered");
}
Ok(())
}
pub async fn channel_names(&self) -> Vec<String> {
self.channels.read().await.keys().cloned().collect()
}
pub async fn run(&self) -> Result<()> {
tracing::info!("Gateway event loop started");
let mut rx = self.rx.lock().await;
let mut shutdown = self.shutdown.subscribe();
loop {
tokio::select! {
inbox = rx.recv() => {
match inbox {
Some((channel_name, msg)) => {
self.dispatch(channel_name, msg);
}
None => {
tracing::info!("All channels disconnected, exiting");
break;
}
}
}
_ = shutdown.changed() => {
tracing::info!("Gateway shutting down");
let channels = self.channels.read().await;
for (name, entry) in channels.iter() {
let _ = entry.shutdown_tx.send(true);
tracing::info!(channel = %name, "Shutdown signal sent");
}
break;
}
}
}
Ok(())
}
fn dispatch(&self, channel_name: String, msg: IncomingMessage) {
let orchestrator = self.orchestrator.clone();
let channels = self.channels.clone();
let semaphore = self.concurrency.clone();
tokio::spawn(async move {
let _permit = match semaphore.acquire().await {
Ok(p) => p,
Err(_) => {
tracing::warn!("Semaphore closed, dropping message");
return;
}
};
tracing::info!(
channel = %msg.channel,
user = %msg.user_id,
content_len = msg.content.len(),
request_id = %msg.id,
"Routing incoming message"
);
let start = std::time::Instant::now();
let session_id = msg.metadata.get(meta::SESSION_ID).cloned();
let project_ids = msg.metadata.get(meta::PROJECT_IDS).cloned();
let request_id = msg.id.to_string();
let result = orchestrator
.handle_message(
&msg.user_id,
&msg.content,
session_id.as_deref(),
project_ids.as_deref(),
&request_id,
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
let guard = channels.read().await;
let entry = guard.get(&channel_name);
match (result, entry) {
(Ok(orchestration), Some(entry)) => {
tracing::info!(
phase = %orchestration.phase_reached,
seed_id = ?orchestration.seed_id,
duration_ms = duration_ms,
"Orchestration complete"
);
let mut channel_meta = HashMap::new();
if let Some(ref sid) = orchestration.session_id {
channel_meta.insert(meta::SESSION_ID.to_owned(), sid.clone());
}
if let Some(ref pid) = orchestration.primary_project_id {
channel_meta.insert(meta::PROJECT_IDS.to_owned(), pid.to_string());
}
let response_meta = ResponseMeta {
session_id: orchestration.session_id,
project_id: orchestration.primary_project_id.map(|u| u.to_string()),
project_tag: orchestration.project_tag,
seed_id: orchestration.seed_id.map(|u| u.to_string()),
phase: orchestration.phase_reached.to_string(),
evaluation_passed: orchestration.evaluation_passed,
duration_ms: Some(duration_ms),
error: None,
};
let outgoing = OutgoingMessage::success(
msg.id,
&msg.channel,
&msg.user_id,
&orchestration.response,
channel_meta,
response_meta,
);
if let Err(e) = entry.channel.send(outgoing).await {
tracing::error!(error = %e, "Failed to send response");
}
}
(Err(e), Some(entry)) => {
tracing::error!(error = %e, "Orchestration failed");
let user_err = classify_error(&e);
let mut outgoing =
OutgoingMessage::error(msg.id, &msg.channel, &msg.user_id, user_err);
if let Some(sid) = msg.metadata.get(meta::SESSION_ID).cloned() {
outgoing.metadata.insert(meta::SESSION_ID.to_string(), sid);
}
if let Err(e) = entry.channel.send(outgoing).await {
tracing::error!(error = %e, "Failed to send error response");
}
}
(_, None) => {
tracing::warn!(channel = %channel_name, "Channel no longer registered");
}
}
});
}
pub async fn send_to(&self, channel_name: &str, msg: OutgoingMessage) -> Result<()> {
let channels = self.channels.read().await;
if let Some(entry) = channels.get(channel_name) {
entry.channel.send(msg).await?;
} else {
tracing::warn!(channel = %channel_name, "No such channel registered");
}
Ok(())
}
}
impl std::fmt::Debug for Gateway {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Gateway").finish()
}
}