pub mod auth;
pub(crate) mod handlers;
pub mod log_layer;
pub mod openai_compat;
pub mod responses_api;
pub mod server;
pub mod sse;
pub mod types;
pub(crate) mod util;
pub mod ws;
pub mod test_helpers;
#[cfg(test)]
mod tests;
use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::agent::SessionManager;
use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate};
use crate::config::GatewayConfig;
use crate::db::Database;
use crate::error::ChannelError;
use crate::extensions::ExtensionManager;
use crate::orchestrator::job_manager::ContainerJobManager;
use crate::skills::catalog::SkillCatalog;
use crate::skills::registry::SkillRegistry;
use crate::tools::ToolRegistry;
use crate::workspace::Workspace;
use self::log_layer::{LogBroadcaster, LogLevelHandle};
use self::auth::{CombinedAuthState, DbAuthenticator, MultiAuthState};
use self::server::GatewayState;
use self::sse::SseManager;
use self::types::AppEvent;
pub struct GatewayChannel {
config: GatewayConfig,
state: Arc<GatewayState>,
auth: CombinedAuthState,
}
impl GatewayChannel {
pub fn new(config: GatewayConfig, owner_id: String) -> Self {
let auth_token = config.auth_token.clone().unwrap_or_else(|| {
use rand::RngCore;
use rand::rngs::OsRng;
let mut bytes = [0u8; 32];
OsRng.fill_bytes(&mut bytes);
bytes.iter().map(|b| format!("{b:02x}")).collect()
});
let oidc_state = config.oidc.as_ref().and_then(|oidc_config| {
match auth::OidcState::from_config(oidc_config) {
Ok(state) => {
tracing::info!(
header = %oidc_config.header,
jwks_url = %oidc_config.jwks_url,
"OIDC JWT authentication enabled"
);
Some(state)
}
Err(e) => {
tracing::error!(error = %e, "Failed to initialize OIDC auth — falling back to token-only auth");
None
}
}
});
let auth = CombinedAuthState {
env_auth: MultiAuthState::single(auth_token, owner_id.clone()),
db_auth: None,
oidc: oidc_state,
};
let state = Arc::new(GatewayState {
msg_tx: tokio::sync::RwLock::new(None),
sse: Arc::new(SseManager::new()),
workspace: None,
workspace_pool: None,
session_manager: None,
log_broadcaster: None,
log_level_handle: None,
extension_manager: None,
tool_registry: None,
store: None,
job_manager: None,
prompt_queue: None,
scheduler: None,
owner_id,
shutdown_tx: tokio::sync::RwLock::new(None),
ws_tracker: Some(Arc::new(ws::WsConnectionTracker::new())),
llm_provider: None,
skill_registry: None,
skill_catalog: None,
chat_rate_limiter: server::PerUserRateLimiter::new(30, 60),
oauth_rate_limiter: server::RateLimiter::new(10, 60),
webhook_rate_limiter: server::RateLimiter::new(10, 60),
registry_entries: Vec::new(),
cost_guard: None,
routine_engine: Arc::new(tokio::sync::RwLock::new(None)),
startup_time: std::time::Instant::now(),
active_config: server::ActiveConfigSnapshot::default(),
secrets_store: None,
db_auth: None,
});
Self {
config,
state,
auth,
}
}
fn rebuild_state(&mut self, mutate: impl FnOnce(&mut GatewayState)) {
let mut new_state = GatewayState {
msg_tx: tokio::sync::RwLock::new(None),
sse: Arc::new(SseManager::from_sender(self.state.sse.sender())),
workspace: self.state.workspace.clone(),
workspace_pool: self.state.workspace_pool.clone(),
session_manager: self.state.session_manager.clone(),
log_broadcaster: self.state.log_broadcaster.clone(),
log_level_handle: self.state.log_level_handle.clone(),
extension_manager: self.state.extension_manager.clone(),
tool_registry: self.state.tool_registry.clone(),
store: self.state.store.clone(),
job_manager: self.state.job_manager.clone(),
prompt_queue: self.state.prompt_queue.clone(),
scheduler: self.state.scheduler.clone(),
owner_id: self.state.owner_id.clone(),
shutdown_tx: tokio::sync::RwLock::new(None),
ws_tracker: self.state.ws_tracker.clone(),
llm_provider: self.state.llm_provider.clone(),
skill_registry: self.state.skill_registry.clone(),
skill_catalog: self.state.skill_catalog.clone(),
chat_rate_limiter: server::PerUserRateLimiter::new(30, 60),
oauth_rate_limiter: server::RateLimiter::new(10, 60),
webhook_rate_limiter: server::RateLimiter::new(10, 60),
registry_entries: self.state.registry_entries.clone(),
cost_guard: self.state.cost_guard.clone(),
routine_engine: Arc::clone(&self.state.routine_engine),
startup_time: self.state.startup_time,
active_config: self.state.active_config.clone(),
secrets_store: self.state.secrets_store.clone(),
db_auth: self.state.db_auth.clone(),
};
mutate(&mut new_state);
self.state = Arc::new(new_state);
}
pub fn with_workspace(mut self, workspace: Arc<Workspace>) -> Self {
self.rebuild_state(|s| s.workspace = Some(workspace));
self
}
pub fn with_session_manager(mut self, sm: Arc<SessionManager>) -> Self {
self.rebuild_state(|s| s.session_manager = Some(sm));
self
}
pub fn with_log_broadcaster(mut self, lb: Arc<LogBroadcaster>) -> Self {
self.rebuild_state(|s| s.log_broadcaster = Some(lb));
self
}
pub fn with_log_level_handle(mut self, h: Arc<LogLevelHandle>) -> Self {
self.rebuild_state(|s| s.log_level_handle = Some(h));
self
}
pub fn with_extension_manager(mut self, em: Arc<ExtensionManager>) -> Self {
self.rebuild_state(|s| s.extension_manager = Some(em));
self
}
pub fn with_tool_registry(mut self, tr: Arc<ToolRegistry>) -> Self {
self.rebuild_state(|s| s.tool_registry = Some(tr));
self
}
pub fn with_store(mut self, store: Arc<dyn Database>) -> Self {
self.rebuild_state(|s| s.store = Some(store));
self
}
pub fn with_db_auth(mut self, store: Arc<dyn Database>) -> Self {
let authenticator = DbAuthenticator::new(store);
self.rebuild_state(|s| s.db_auth = Some(Arc::new(authenticator.clone())));
self.auth.db_auth = Some(authenticator);
self
}
pub fn with_job_manager(mut self, jm: Arc<ContainerJobManager>) -> Self {
self.rebuild_state(|s| s.job_manager = Some(jm));
self
}
pub fn with_prompt_queue(
mut self,
pq: Arc<
tokio::sync::Mutex<
std::collections::HashMap<
uuid::Uuid,
std::collections::VecDeque<crate::orchestrator::api::PendingPrompt>,
>,
>,
>,
) -> Self {
self.rebuild_state(|s| s.prompt_queue = Some(pq));
self
}
pub fn with_scheduler(mut self, slot: crate::tools::builtin::SchedulerSlot) -> Self {
self.rebuild_state(|s| s.scheduler = Some(slot));
self
}
pub fn with_skill_registry(mut self, sr: Arc<std::sync::RwLock<SkillRegistry>>) -> Self {
self.rebuild_state(|s| s.skill_registry = Some(sr));
self
}
pub fn with_skill_catalog(mut self, sc: Arc<SkillCatalog>) -> Self {
self.rebuild_state(|s| s.skill_catalog = Some(sc));
self
}
pub fn with_llm_provider(mut self, llm: Arc<dyn crate::llm::LlmProvider>) -> Self {
self.rebuild_state(|s| s.llm_provider = Some(llm));
self
}
pub fn with_registry_entries(mut self, entries: Vec<crate::extensions::RegistryEntry>) -> Self {
self.rebuild_state(|s| s.registry_entries = entries);
self
}
pub fn with_cost_guard(mut self, cg: Arc<crate::agent::cost_guard::CostGuard>) -> Self {
self.rebuild_state(|s| s.cost_guard = Some(cg));
self
}
pub fn with_routine_engine_slot(mut self, slot: server::RoutineEngineSlot) -> Self {
self.rebuild_state(|s| s.routine_engine = slot);
self
}
pub fn with_active_config(mut self, config: server::ActiveConfigSnapshot) -> Self {
self.rebuild_state(|s| s.active_config = config);
self
}
pub fn with_secrets_store(
mut self,
store: Arc<dyn crate::secrets::SecretsStore + Send + Sync>,
) -> Self {
self.rebuild_state(|s| s.secrets_store = Some(store));
self
}
pub fn with_workspace_pool(mut self, pool: Arc<server::WorkspacePool>) -> Self {
self.rebuild_state(|s| s.workspace_pool = Some(pool));
self
}
pub fn auth_token(&self) -> &str {
self.auth.env_auth.first_token().unwrap_or("")
}
pub fn state(&self) -> &Arc<GatewayState> {
&self.state
}
}
#[async_trait]
impl Channel for GatewayChannel {
fn name(&self) -> &str {
"gateway"
}
async fn start(&self) -> Result<MessageStream, ChannelError> {
let (tx, rx) = mpsc::channel(256);
*self.state.msg_tx.write().await = Some(tx);
let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
.parse()
.map_err(|e| ChannelError::StartupFailed {
name: "gateway".to_string(),
reason: format!(
"Invalid address '{}:{}': {}",
self.config.host, self.config.port, e
),
})?;
server::start_server(addr, self.state.clone(), self.auth.clone()).await?;
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let thread_id = match &msg.thread_id {
Some(tid) => tid.clone(),
None => {
return Err(ChannelError::MissingRoutingTarget {
name: "gateway".to_string(),
reason: "respond() requires a thread_id on the incoming message".to_string(),
});
}
};
self.state.sse.broadcast_for_user(
&msg.user_id,
AppEvent::Response {
content: response.content,
thread_id,
},
);
Ok(())
}
async fn send_status(
&self,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
let thread_id = metadata
.get("thread_id")
.and_then(|v| v.as_str())
.map(String::from);
let event = match status {
StatusUpdate::Thinking(msg) => AppEvent::Thinking {
message: msg,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolStarted { name } => AppEvent::ToolStarted {
name,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolCompleted {
name,
success,
error,
parameters,
} => AppEvent::ToolCompleted {
name,
success,
error,
parameters,
thread_id: thread_id.clone(),
},
StatusUpdate::ToolResult { name, preview } => AppEvent::ToolResult {
name,
preview,
thread_id: thread_id.clone(),
},
StatusUpdate::StreamChunk(content) => AppEvent::StreamChunk {
content,
thread_id: thread_id.clone(),
},
StatusUpdate::Status(msg) => AppEvent::Status {
message: msg,
thread_id: thread_id.clone(),
},
StatusUpdate::JobStarted {
job_id,
title,
browse_url,
} => AppEvent::JobStarted {
job_id,
title,
browse_url,
},
StatusUpdate::ApprovalNeeded {
request_id,
tool_name,
description,
parameters,
allow_always,
} => AppEvent::ApprovalNeeded {
request_id,
tool_name,
description,
parameters: serde_json::to_string_pretty(¶meters)
.unwrap_or_else(|_| parameters.to_string()),
thread_id,
allow_always,
},
StatusUpdate::AuthRequired {
extension_name,
instructions,
auth_url,
setup_url,
} => AppEvent::AuthRequired {
extension_name,
instructions,
auth_url,
setup_url,
},
StatusUpdate::AuthCompleted {
extension_name,
success,
message,
} => AppEvent::AuthCompleted {
extension_name,
success,
message,
},
StatusUpdate::ImageGenerated { data_url, path } => AppEvent::ImageGenerated {
data_url,
path,
thread_id: thread_id.clone(),
},
StatusUpdate::Suggestions { suggestions } => AppEvent::Suggestions {
suggestions,
thread_id: thread_id.clone(),
},
StatusUpdate::ReasoningUpdate {
narrative,
decisions,
} => AppEvent::ReasoningUpdate {
narrative,
decisions: decisions
.into_iter()
.map(|d| crate::channels::web::types::ToolDecisionDto {
tool_name: d.tool_name,
rationale: d.rationale,
})
.collect(),
thread_id,
},
StatusUpdate::TurnCost {
input_tokens,
output_tokens,
cost_usd,
} => AppEvent::TurnCost {
input_tokens,
output_tokens,
cost_usd,
thread_id,
},
};
if let Some(uid) = metadata.get("user_id").and_then(|v| v.as_str()) {
self.state.sse.broadcast_for_user(uid, event);
} else {
tracing::debug!("Status event missing user_id in metadata; broadcasting globally");
self.state.sse.broadcast(event);
}
Ok(())
}
async fn broadcast(
&self,
user_id: &str,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let thread_id = match response.thread_id {
Some(tid) => tid,
None => {
return Err(ChannelError::MissingRoutingTarget {
name: "gateway".to_string(),
reason: "broadcast() requires a thread_id on the response".to_string(),
});
}
};
self.state.sse.broadcast_for_user(
user_id,
AppEvent::Response {
content: response.content,
thread_id,
},
);
Ok(())
}
async fn health_check(&self) -> Result<(), ChannelError> {
if self.state.msg_tx.read().await.is_some() {
Ok(())
} else {
Err(ChannelError::HealthCheckFailed {
name: "gateway".to_string(),
})
}
}
async fn shutdown(&self) -> Result<(), ChannelError> {
if let Some(tx) = self.state.shutdown_tx.write().await.take() {
let _ = tx.send(());
}
*self.state.msg_tx.write().await = None;
Ok(())
}
}