use reflex::cache::TieredCache;
use reflex::cache::{BqSearchBackend, L1CacheHandle, L2Config, L2SemanticCache, NvmeStorageLoader};
use reflex::embedding::RerankerConfig;
use reflex::embedding::sinter::{SinterConfig, SinterEmbedder};
use reflex::scoring::CrossEncoderScorer;
use reflex::vectordb::bq::{BqClient, MockBqClient};
use reflex_server::gateway::{HandlerState, create_router_with_state};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
const STARTUP_WAIT_TIMEOUT_SECS: u64 = 5;
const STARTUP_POLL_INTERVAL_MS: u64 = 50;
const TEST_COLLECTION_NAME: &str = "reflex_test_bq";
#[derive(Debug, Clone)]
pub struct TestServerConfig {
pub port: u16,
pub collection_name: Option<String>,
pub storage_path: Option<std::path::PathBuf>,
pub reranker_threshold: f32,
}
impl Default for TestServerConfig {
fn default() -> Self {
Self {
port: 0,
collection_name: None,
storage_path: None,
reranker_threshold: 0.70,
}
}
}
impl TestServerConfig {}
pub struct TestServer {
pub addr: SocketAddr,
_server_handle: JoinHandle<()>,
shutdown_tx: Option<oneshot::Sender<()>>,
_temp_dir: Option<TempDir>,
}
impl TestServer {
pub fn url(&self) -> String {
format!("http://{}", self.addr)
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
impl Drop for TestServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
pub async fn find_available_port() -> std::io::Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
Ok(addr.port())
}
pub async fn wait_for_server_ready(
addr: SocketAddr,
timeout: Duration,
interval: Duration,
) -> Result<(), ServerStartupError> {
let start = std::time::Instant::now();
loop {
if start.elapsed() > timeout {
return Err(ServerStartupError::Timeout);
}
match tokio::net::TcpStream::connect(addr).await {
Ok(_) => return Ok(()),
Err(_) => {
tokio::time::sleep(interval).await;
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ServerStartupError {
#[error("Server failed to start within timeout")]
Timeout,
#[error("Failed to bind to address: {0}")]
BindError(#[from] std::io::Error),
#[error("Server startup failed: {0}")]
StartupFailed(String),
}
pub async fn spawn_test_server(config: TestServerConfig) -> Result<TestServer, ServerStartupError> {
let port = if config.port == 0 {
find_available_port().await?
} else {
config.port
};
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
let (storage_path, _temp_dir) = if let Some(path) = config.storage_path {
(path, None)
} else {
let temp_dir =
TempDir::new().map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
(temp_dir.path().to_path_buf(), Some(temp_dir))
};
let collection_name = config
.collection_name
.unwrap_or_else(|| TEST_COLLECTION_NAME.to_string());
let bq_client = MockBqClient::new();
bq_client
.ensure_bq_collection(&collection_name, reflex::constants::DEFAULT_VECTOR_SIZE_U64)
.await
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
let loader = NvmeStorageLoader::new(storage_path.clone());
let embedder = SinterEmbedder::load(SinterConfig::stub())
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
let l2_config = L2Config::default().collection_name(&collection_name);
let l2_cache = L2SemanticCache::new(embedder, bq_client.clone(), loader, l2_config)
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
let l1_cache = L1CacheHandle::new();
let tiered_cache = TieredCache::new(l1_cache, l2_cache);
let tiered_cache = Arc::new(tiered_cache);
let scorer =
CrossEncoderScorer::new(RerankerConfig::stub().with_threshold(config.reranker_threshold))
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
let scorer = Arc::new(scorer);
let state = HandlerState::new_with_mock_provider(
tiered_cache,
scorer,
storage_path,
bq_client,
collection_name,
true,
);
let app = create_router_with_state(state);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let server_handle = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await
.unwrap();
});
wait_for_server_ready(
local_addr,
Duration::from_secs(STARTUP_WAIT_TIMEOUT_SECS),
Duration::from_millis(STARTUP_POLL_INTERVAL_MS),
)
.await?;
Ok(TestServer {
addr: local_addr,
_server_handle: server_handle,
shutdown_tx: Some(shutdown_tx),
_temp_dir,
})
}
pub async fn spawn_real_server(config: TestServerConfig) -> Result<TestServer, ServerStartupError> {
let port = if config.port == 0 {
find_available_port().await?
} else {
config.port
};
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
let (storage_path, _temp_dir) = if let Some(path) = config.storage_path {
(path, None)
} else {
let temp_dir =
TempDir::new().map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
(temp_dir.path().to_path_buf(), Some(temp_dir))
};
let collection_name = config
.collection_name
.unwrap_or_else(|| format!("{}_{}", TEST_COLLECTION_NAME, uuid::Uuid::new_v4().simple()));
let qdrant_url =
std::env::var("REFLEX_QDRANT_URL").unwrap_or_else(|_| "http://localhost:6334".to_string());
let bq_client = BqClient::new(&qdrant_url).await.map_err(|e| {
ServerStartupError::StartupFailed(format!("Failed to connect to Qdrant: {}", e))
})?;
bq_client
.ensure_collection(&collection_name, reflex::constants::DEFAULT_VECTOR_SIZE_U64)
.await
.map_err(|e| {
ServerStartupError::StartupFailed(format!("Failed to ensure collection: {}", e))
})?;
let loader = NvmeStorageLoader::new(storage_path.clone());
let embedder = if let Ok(path) = std::env::var("REFLEX_MODEL_PATH") {
println!("Using Real Embedder: {}", path);
SinterEmbedder::load(SinterConfig::new(path))
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?
} else {
println!("Using Stub Embedder");
SinterEmbedder::load(SinterConfig::stub())
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?
};
let l2_config = L2Config::default().collection_name(&collection_name);
let l2_cache = L2SemanticCache::new(embedder, bq_client.clone(), loader, l2_config)
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?;
let l1_cache = L1CacheHandle::new();
let tiered_cache = TieredCache::new(l1_cache, l2_cache);
let tiered_cache = Arc::new(tiered_cache);
let scorer = if let Ok(path) = std::env::var("REFLEX_RERANKER_PATH") {
println!("Using Real Reranker: {}", path);
let config = RerankerConfig::new(path).with_threshold(config.reranker_threshold);
CrossEncoderScorer::new(config)
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?
} else {
println!("Using Stub Reranker");
CrossEncoderScorer::new(RerankerConfig::stub().with_threshold(config.reranker_threshold))
.map_err(|e| ServerStartupError::StartupFailed(e.to_string()))?
};
let scorer = Arc::new(scorer);
let state = HandlerState::new_with_mock_provider(
tiered_cache,
scorer,
storage_path,
bq_client,
collection_name,
true, );
let app = create_router_with_state(state);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let server_handle = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await
.unwrap();
});
wait_for_server_ready(
local_addr,
Duration::from_secs(STARTUP_WAIT_TIMEOUT_SECS),
Duration::from_millis(STARTUP_POLL_INTERVAL_MS),
)
.await?;
Ok(TestServer {
addr: local_addr,
_server_handle: server_handle,
shutdown_tx: Some(shutdown_tx),
_temp_dir,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_find_available_port() {
let port = find_available_port()
.await
.expect("Should find available port");
assert!(port > 0);
}
#[tokio::test]
async fn test_server_config_defaults() {
let config = TestServerConfig::default();
assert_eq!(config.port, 0);
}
#[tokio::test]
async fn test_server_helpers_are_callable() {
let (shutdown_tx, _shutdown_rx) = oneshot::channel();
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let server = TestServer {
addr,
_server_handle: tokio::spawn(async {}),
shutdown_tx: Some(shutdown_tx),
_temp_dir: None,
};
let _ = server.url();
server.shutdown().await;
}
#[test]
fn test_spawners_are_referenced() {
std::mem::drop(spawn_test_server(TestServerConfig::default()));
std::mem::drop(spawn_real_server(TestServerConfig::default()));
}
#[test]
fn test_server_url_formatting() {
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let url = format!("http://{}", addr);
assert_eq!(url, "http://127.0.0.1:8080");
}
}