use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::channels::IncomingMessage;
use crate::channels::web::auth::MultiAuthState;
use crate::channels::web::server::{GatewayState, PerUserRateLimiter, RateLimiter, start_server};
use crate::channels::web::sse::SseManager;
use crate::channels::web::ws::WsConnectionTracker;
pub struct TestGatewayBuilder {
msg_tx: Option<mpsc::Sender<IncomingMessage>>,
llm_provider: Option<Arc<dyn crate::llm::LlmProvider>>,
user_id: String,
}
impl Default for TestGatewayBuilder {
fn default() -> Self {
Self {
msg_tx: None,
llm_provider: None,
user_id: "test-user".to_string(),
}
}
}
impl TestGatewayBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn msg_tx(mut self, tx: mpsc::Sender<IncomingMessage>) -> Self {
self.msg_tx = Some(tx);
self
}
pub fn llm_provider(mut self, provider: Arc<dyn crate::llm::LlmProvider>) -> Self {
self.llm_provider = Some(provider);
self
}
pub fn user_id(mut self, id: impl Into<String>) -> Self {
self.user_id = id.into();
self
}
pub fn build(self) -> Arc<GatewayState> {
Arc::new(GatewayState {
msg_tx: tokio::sync::RwLock::new(self.msg_tx),
sse: Arc::new(SseManager::new()),
workspace: None,
workspace_pool: None,
session_manager: None,
log_broadcaster: None,
log_level_handle: None,
extension_manager: None,
tool_registry: None,
store: None,
job_manager: None,
prompt_queue: None,
owner_id: self.user_id.clone(),
default_sender_id: self.user_id,
shutdown_tx: tokio::sync::RwLock::new(None),
ws_tracker: Some(Arc::new(WsConnectionTracker::new())),
llm_provider: self.llm_provider,
skill_registry: None,
skill_catalog: None,
scheduler: None,
chat_rate_limiter: PerUserRateLimiter::new(30, 60),
oauth_rate_limiter: RateLimiter::new(10, 60),
webhook_rate_limiter: RateLimiter::new(10, 60),
registry_entries: Vec::new(),
cost_guard: None,
routine_engine: Arc::new(tokio::sync::RwLock::new(None)),
startup_time: std::time::Instant::now(),
active_config: crate::channels::web::server::ActiveConfigSnapshot::default(),
})
}
pub async fn start(
self,
auth_token: &str,
) -> Result<(SocketAddr, Arc<GatewayState>), crate::error::ChannelError> {
let auth = MultiAuthState::single(auth_token.to_string(), "test-user".to_string());
let state = self.build();
let addr: SocketAddr = "127.0.0.1:0"
.parse()
.expect("hard-coded address must parse"); let bound = start_server(addr, state.clone(), auth).await?;
Ok((bound, state))
}
pub async fn start_multi(
self,
auth: MultiAuthState,
) -> Result<(SocketAddr, Arc<GatewayState>), crate::error::ChannelError> {
let state = self.build();
let addr: SocketAddr = "127.0.0.1:0"
.parse()
.expect("hard-coded address must parse"); let bound = start_server(addr, state.clone(), auth).await?;
Ok((bound, state))
}
}