pub mod auth;
pub mod log_layer;
pub mod openai_compat;
pub mod server;
pub mod sse;
pub mod types;
pub mod ws;
use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::agent::SessionManager;
use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate};
use crate::config::GatewayConfig;
use crate::db::Database;
use crate::error::ChannelError;
use crate::extensions::ExtensionManager;
use crate::orchestrator::job_manager::ContainerJobManager;
use crate::tools::ToolRegistry;
use crate::workspace::Workspace;
use self::log_layer::LogBroadcaster;
use self::server::GatewayState;
use self::sse::SseManager;
use self::types::SseEvent;
pub struct GatewayChannel {
config: GatewayConfig,
state: Arc<GatewayState>,
auth_token: String,
}
impl GatewayChannel {
pub fn new(config: GatewayConfig) -> Self {
let auth_token = config.auth_token.clone().unwrap_or_else(|| {
use rand::Rng;
let token: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(32)
.map(char::from)
.collect();
token
});
let state = Arc::new(GatewayState {
msg_tx: tokio::sync::RwLock::new(None),
sse: SseManager::new(),
workspace: None,
session_manager: None,
log_broadcaster: None,
extension_manager: None,
tool_registry: None,
store: None,
job_manager: None,
prompt_queue: None,
user_id: config.user_id.clone(),
shutdown_tx: tokio::sync::RwLock::new(None),
ws_tracker: Some(Arc::new(ws::WsConnectionTracker::new())),
llm_provider: None,
chat_rate_limiter: server::RateLimiter::new(30, 60),
});
Self {
config,
state,
auth_token,
}
}
fn rebuild_state(&mut self, mutate: impl FnOnce(&mut GatewayState)) {
let mut new_state = GatewayState {
msg_tx: tokio::sync::RwLock::new(None),
sse: SseManager::new(),
workspace: self.state.workspace.clone(),
session_manager: self.state.session_manager.clone(),
log_broadcaster: self.state.log_broadcaster.clone(),
extension_manager: self.state.extension_manager.clone(),
tool_registry: self.state.tool_registry.clone(),
store: self.state.store.clone(),
job_manager: self.state.job_manager.clone(),
prompt_queue: self.state.prompt_queue.clone(),
user_id: self.state.user_id.clone(),
shutdown_tx: tokio::sync::RwLock::new(None),
ws_tracker: self.state.ws_tracker.clone(),
llm_provider: self.state.llm_provider.clone(),
chat_rate_limiter: server::RateLimiter::new(30, 60),
};
mutate(&mut new_state);
self.state = Arc::new(new_state);
}
pub fn with_workspace(mut self, workspace: Arc<Workspace>) -> Self {
self.rebuild_state(|s| s.workspace = Some(workspace));
self
}
pub fn with_session_manager(mut self, sm: Arc<SessionManager>) -> Self {
self.rebuild_state(|s| s.session_manager = Some(sm));
self
}
pub fn with_log_broadcaster(mut self, lb: Arc<LogBroadcaster>) -> Self {
self.rebuild_state(|s| s.log_broadcaster = Some(lb));
self
}
pub fn with_extension_manager(mut self, em: Arc<ExtensionManager>) -> Self {
self.rebuild_state(|s| s.extension_manager = Some(em));
self
}
pub fn with_tool_registry(mut self, tr: Arc<ToolRegistry>) -> Self {
self.rebuild_state(|s| s.tool_registry = Some(tr));
self
}
pub fn with_store(mut self, store: Arc<dyn Database>) -> Self {
self.rebuild_state(|s| s.store = Some(store));
self
}
pub fn with_job_manager(mut self, jm: Arc<ContainerJobManager>) -> Self {
self.rebuild_state(|s| s.job_manager = Some(jm));
self
}
pub fn with_prompt_queue(
mut self,
pq: Arc<
tokio::sync::Mutex<
std::collections::HashMap<
uuid::Uuid,
std::collections::VecDeque<crate::orchestrator::api::PendingPrompt>,
>,
>,
>,
) -> Self {
self.rebuild_state(|s| s.prompt_queue = Some(pq));
self
}
pub fn with_llm_provider(mut self, llm: Arc<dyn crate::llm::LlmProvider>) -> Self {
self.rebuild_state(|s| s.llm_provider = Some(llm));
self
}
pub fn auth_token(&self) -> &str {
&self.auth_token
}
pub fn state(&self) -> &Arc<GatewayState> {
&self.state
}
}
#[async_trait]
impl Channel for GatewayChannel {
fn name(&self) -> &str {
"gateway"
}
async fn start(&self) -> Result<MessageStream, ChannelError> {
let (tx, rx) = mpsc::channel(256);
*self.state.msg_tx.write().await = Some(tx);
let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
.parse()
.map_err(|e| ChannelError::StartupFailed {
name: "gateway".to_string(),
reason: format!(
"Invalid address '{}:{}': {}",
self.config.host, self.config.port, e
),
})?;
server::start_server(addr, self.state.clone(), self.auth_token.clone()).await?;
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let thread_id = msg.thread_id.clone().unwrap_or_default();
self.state.sse.broadcast(SseEvent::Response {
content: response.content,
thread_id,
});
Ok(())
}
async fn send_status(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
let thread_id = metadata
.get("thread_id")
.and_then(|v| v.as_str())
.map(String::from);
let event = match status {
StatusUpdate::Thinking(msg) => SseEvent::Thinking {
message: msg,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolStarted { name } => SseEvent::ToolStarted {
name,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolCompleted { name, success } => SseEvent::ToolCompleted {
name,
success,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolResult { name, preview } => SseEvent::ToolResult {
name,
preview,
thread_id: thread_id.clone(),
},
StatusUpdate::StreamChunk(content) => SseEvent::StreamChunk {
content,
thread_id: thread_id.clone(),
},
StatusUpdate::Status(msg) => SseEvent::Status {
message: msg,
thread_id: thread_id.clone(),
},
StatusUpdate::JobStarted {
job_id,
title,
browse_url,
} => SseEvent::JobStarted {
job_id,
title,
browse_url,
},
StatusUpdate::ApprovalNeeded {
request_id,
tool_name,
description,
parameters,
} => SseEvent::ApprovalNeeded {
request_id,
tool_name,
description,
parameters: serde_json::to_string_pretty(¶meters)
.unwrap_or_else(|_| parameters.to_string()),
},
StatusUpdate::AuthRequired {
extension_name,
instructions,
auth_url,
setup_url,
} => SseEvent::AuthRequired {
extension_name,
instructions,
auth_url,
setup_url,
},
StatusUpdate::AuthCompleted {
extension_name,
success,
message,
} => SseEvent::AuthCompleted {
extension_name,
success,
message,
},
};
self.state.sse.broadcast(event);
Ok(())
}
async fn broadcast(
&self,
_user_id: &str,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
self.state.sse.broadcast(SseEvent::Response {
content: response.content,
thread_id: String::new(),
});
Ok(())
}
async fn health_check(&self) -> Result<(), ChannelError> {
if self.state.msg_tx.read().await.is_some() {
Ok(())
} else {
Err(ChannelError::HealthCheckFailed {
name: "gateway".to_string(),
})
}
}
async fn shutdown(&self) -> Result<(), ChannelError> {
if let Some(tx) = self.state.shutdown_tx.write().await.take() {
let _ = tx.send(());
}
*self.state.msg_tx.write().await = None;
Ok(())
}
}