use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use neumann_server::proto::health_client::HealthClient;
use neumann_server::proto::query_service_client::QueryServiceClient;
use neumann_server::proto::{BatchQueryRequest, HealthCheckRequest, QueryRequest, ServingStatus};
use neumann_server::{AuthConfig, NeumannServer, RateLimitConfig, ServerConfig, ShutdownConfig};
use parking_lot::RwLock;
use query_router::QueryRouter;
use tokio::sync::oneshot;
use tonic::metadata::MetadataValue;
use tonic::transport::Channel;
async fn start_grpc_test_server(
config_override: Option<ServerConfig>,
) -> (
SocketAddr,
oneshot::Sender<()>,
QueryServiceClient<Channel>,
HealthClient<Channel>,
) {
let router = Arc::new(RwLock::new(QueryRouter::new()));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let config = config_override.unwrap_or_else(|| ServerConfig {
bind_addr: addr,
enable_grpc_web: false, ..Default::default()
});
let config = config.with_bind_addr(addr);
let server = NeumannServer::new(router, config);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let _ = server
.serve_with_shutdown(async {
let _ = shutdown_rx.await;
})
.await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
let channel = Channel::from_shared(format!("http://{addr}"))
.unwrap()
.connect()
.await
.unwrap();
let query_client = QueryServiceClient::new(channel.clone());
let health_client = HealthClient::new(channel);
(addr, shutdown_tx, query_client, health_client)
}
#[tokio::test]
async fn test_grpc_health_check_serving() {
let (_addr, shutdown, _query, mut health) = start_grpc_test_server(None).await;
let response = health
.check(HealthCheckRequest { service: None })
.await
.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_health_check_query_service() {
let (_addr, shutdown, _query, mut health) = start_grpc_test_server(None).await;
let response = health
.check(HealthCheckRequest {
service: Some("neumann.v1.QueryService".to_string()),
})
.await
.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_execute_create_table() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let response = query
.execute(QueryRequest {
query: "CREATE TABLE grpc_test (name text, age int)".to_string(),
identity: None,
})
.await
.unwrap();
let inner = response.into_inner();
assert!(inner.error.is_none(), "Expected no error");
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_execute_insert_and_select() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let _ = query
.execute(QueryRequest {
query: "CREATE TABLE users (name text, age int)".to_string(),
identity: None,
})
.await
.unwrap();
let _ = query
.execute(QueryRequest {
query: "INSERT INTO users (name, age) VALUES ('Alice', 30)".to_string(),
identity: None,
})
.await
.unwrap();
let response = query
.execute(QueryRequest {
query: "SELECT * FROM users".to_string(),
identity: None,
})
.await
.unwrap();
let inner = response.into_inner();
assert!(inner.error.is_none(), "Expected no error");
if let Some(neumann_server::proto::query_response::Result::Rows(rows)) = inner.result {
assert_eq!(rows.rows.len(), 1);
} else {
panic!("Expected Rows result");
}
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_execute_invalid_query() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let result = query
.execute(QueryRequest {
query: "INVALID QUERY!!!".to_string(),
identity: None,
})
.await;
assert!(result.is_err(), "Expected error for invalid query");
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_execute_batch() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let response = query
.execute_batch(BatchQueryRequest {
queries: vec![
QueryRequest {
query: "CREATE TABLE batch_tbl (x int)".to_string(),
identity: None,
},
QueryRequest {
query: "INSERT INTO batch_tbl (x) VALUES (1)".to_string(),
identity: None,
},
QueryRequest {
query: "INSERT INTO batch_tbl (x) VALUES (2)".to_string(),
identity: None,
},
QueryRequest {
query: "SELECT * FROM batch_tbl".to_string(),
identity: None,
},
],
})
.await
.unwrap();
let inner = response.into_inner();
assert_eq!(inner.results.len(), 4);
for (i, r) in inner.results.iter().enumerate() {
assert!(r.error.is_none(), "Query {i} failed: {:?}", r.error);
}
let last = &inner.results[3];
if let Some(neumann_server::proto::query_response::Result::Rows(rows)) = &last.result {
assert_eq!(rows.rows.len(), 2);
} else {
panic!("Expected Rows result for SELECT, got: {:?}", last.result);
}
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_execute_stream_rows() {
use tokio_stream::StreamExt;
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let _ = query
.execute(QueryRequest {
query: "CREATE TABLE stream (name text)".to_string(),
identity: None,
})
.await
.unwrap();
let _ = query
.execute(QueryRequest {
query: "INSERT INTO stream (name) VALUES ('Alice')".to_string(),
identity: None,
})
.await
.unwrap();
let _ = query
.execute(QueryRequest {
query: "INSERT INTO stream (name) VALUES ('Bob')".to_string(),
identity: None,
})
.await
.unwrap();
let response = query
.execute_stream(QueryRequest {
query: "SELECT * FROM stream".to_string(),
identity: None,
})
.await
.unwrap();
let mut stream = response.into_inner();
let mut chunks = vec![];
while let Some(chunk) = stream.next().await {
chunks.push(chunk.unwrap());
}
assert!(chunks.len() >= 2, "Expected at least 2 chunks");
assert!(
chunks.last().unwrap().is_final,
"Last chunk should be final"
);
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_auth_rejection_without_key() {
let config = ServerConfig::default().with_auth(
AuthConfig::new()
.with_api_key(neumann_server::config::ApiKey::new(
"test-api-key-12345678".to_string(),
"user:test".to_string(),
))
.with_anonymous(false),
);
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(Some(config)).await;
let result = query
.execute(QueryRequest {
query: "CREATE TABLE auth_test (x int)".to_string(),
identity: None,
})
.await;
assert!(result.is_err());
let status = result.unwrap_err();
assert_eq!(status.code(), tonic::Code::Unauthenticated);
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_auth_success_with_key() {
let config = ServerConfig::default().with_auth(
AuthConfig::new()
.with_api_key(neumann_server::config::ApiKey::new(
"test-api-key-12345678".to_string(),
"user:test".to_string(),
))
.with_anonymous(false),
);
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(Some(config)).await;
let mut request = tonic::Request::new(QueryRequest {
query: "CREATE TABLE auth_success (x int)".to_string(),
identity: None,
});
request.metadata_mut().insert(
"x-api-key",
MetadataValue::from_static("test-api-key-12345678"),
);
let result = query.execute(request).await;
assert!(result.is_ok(), "Expected success with valid API key");
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_rate_limiting() {
let config = ServerConfig::default()
.with_auth(AuthConfig::new().with_anonymous(true))
.with_rate_limit(RateLimitConfig::new().with_max_queries(2));
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(Some(config)).await;
let result = query
.execute(QueryRequest {
query: "CREATE TABLE rate_test (x int)".to_string(),
identity: None,
})
.await;
assert!(result.is_ok());
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_graceful_shutdown() {
let config = ServerConfig::default().with_shutdown(
ShutdownConfig::new()
.with_drain_timeout(Duration::from_millis(500))
.with_grace_period(Duration::from_millis(100)),
);
let (addr, shutdown, mut query, mut health) = start_grpc_test_server(Some(config)).await;
let _ = query
.execute(QueryRequest {
query: "CREATE TABLE shutdown_test (x int)".to_string(),
identity: None,
})
.await
.unwrap();
let response = health
.check(HealthCheckRequest { service: None })
.await
.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
let _ = shutdown.send(());
tokio::time::sleep(Duration::from_millis(50)).await;
let channel_result = Channel::from_shared(format!("http://{addr}"))
.unwrap()
.connect()
.await;
if let Ok(channel) = channel_result {
let mut health = HealthClient::new(channel);
let response = health.check(HealthCheckRequest { service: None }).await;
if let Ok(resp) = response {
let status = resp.into_inner().status;
assert!(
status == i32::from(ServingStatus::NotServing)
|| status == i32::from(ServingStatus::Serving),
"Expected Serving or NotServing"
);
}
}
tokio::time::sleep(Duration::from_millis(700)).await;
}
#[tokio::test]
async fn test_grpc_serve_cancellation() {
let router = Arc::new(RwLock::new(QueryRouter::new()));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let config = ServerConfig {
bind_addr: addr,
enable_grpc_web: false,
..Default::default()
};
let server = NeumannServer::new(router, config);
let handle = tokio::spawn(async move {
let _ = server.serve().await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
let channel = Channel::from_shared(format!("http://{addr}"))
.unwrap()
.connect()
.await
.unwrap();
let mut health = HealthClient::new(channel);
let response = health
.check(HealthCheckRequest { service: None })
.await
.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
handle.abort();
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_grpc_multiple_concurrent_requests() {
let (_addr, shutdown, query, _health) = start_grpc_test_server(None).await;
let mut q = query.clone();
let _ = q
.execute(QueryRequest {
query: "CREATE TABLE concurrent (id int, value text)".to_string(),
identity: None,
})
.await
.unwrap();
let mut handles = vec![];
for i in 0..10 {
let mut client = query.clone();
handles.push(tokio::spawn(async move {
client
.execute(QueryRequest {
query: format!("INSERT INTO concurrent (id, value) VALUES ({i}, 'test{i}')"),
identity: None,
})
.await
}));
}
let results: Vec<_> = futures::future::join_all(handles).await;
for result in results {
assert!(result.is_ok(), "Task should not panic");
assert!(result.unwrap().is_ok(), "Request should succeed");
}
let mut q = query.clone();
let response = q
.execute(QueryRequest {
query: "SELECT * FROM concurrent".to_string(),
identity: None,
})
.await
.unwrap();
let inner = response.into_inner();
if let Some(neumann_server::proto::query_response::Result::Rows(rows)) = inner.result {
assert_eq!(rows.rows.len(), 10);
} else {
panic!("Expected Rows result");
}
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_connection_with_http2_settings() {
let config = ServerConfig::default()
.with_max_concurrent_connections(100)
.with_max_concurrent_streams_per_connection(50)
.with_initial_window_size(65535)
.with_request_timeout(Duration::from_secs(30));
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(Some(config)).await;
let response = query
.execute(QueryRequest {
query: "CREATE TABLE http2_test (x int)".to_string(),
identity: None,
})
.await
.unwrap();
assert!(response.into_inner().error.is_none());
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_graph_operations() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let response = query
.execute(QueryRequest {
query: "NEIGHBORS 1".to_string(),
identity: None,
})
.await;
assert!(response.is_ok() || response.is_err());
drop(shutdown);
}
#[tokio::test]
async fn test_grpc_vector_operations() {
let (_addr, shutdown, mut query, _health) = start_grpc_test_server(None).await;
let response = query
.execute(QueryRequest {
query: "EMBED doc1 [1.0, 2.0, 3.0]".to_string(),
identity: None,
})
.await
.unwrap();
assert!(response.into_inner().error.is_none());
let response = query
.execute(QueryRequest {
query: "EMBED doc2 [1.1, 2.1, 3.1]".to_string(),
identity: None,
})
.await
.unwrap();
assert!(response.into_inner().error.is_none());
let response = query
.execute(QueryRequest {
query: "SIMILAR [1.0, 2.0, 3.0] TOP 5".to_string(),
identity: None,
})
.await
.unwrap();
let inner = response.into_inner();
assert!(inner.error.is_none(), "Vector search should succeed");
drop(shutdown);
}