use anyhow::Result;
use clap::{Args, Parser, Subcommand};
use std::sync::Arc;
use std::path::PathBuf;
use crate::config::AppConfig;
use tracing::{info, error};
use crate::cli::commands::CliArgs;
use qdrant_client::Qdrant;
#[cfg(feature = "server")]
use crate::server::ServerConfig;
#[derive(Debug, Clone, Args)]
pub struct ServerArgs {
#[command(subcommand)]
pub command: ServerCommands,
}
#[derive(Debug, Clone, Subcommand)]
pub enum ServerCommands {
Start(ServerStartArgs),
}
#[derive(Debug, Clone, Args)]
pub struct ServerStartArgs {
#[arg(short, long, default_value = "50051")]
pub port: u16,
#[arg(short, long, default_value = "0.0.0.0")]
pub host: String,
#[arg(long)]
pub tls: bool,
#[arg(long = "tls-cert")]
pub cert: Option<PathBuf>,
#[arg(long = "tls-key")]
pub key: Option<PathBuf>,
#[arg(long = "api-key")]
pub api_key: Option<String>,
#[arg(long = "api-key-file")]
pub api_key_file: Option<PathBuf>,
#[arg(long)]
pub require_auth: bool,
#[arg(long, default_value = "100")]
pub max_concurrent_requests: usize,
}
pub async fn handle_server_command(
args: ServerArgs,
_cli_args: &CliArgs,
config: AppConfig,
client: Arc<Qdrant>,
) -> Result<()> {
match args.command {
ServerCommands::Start(start_args) => handle_server_start(start_args, config, client).await,
}
}
#[cfg(feature = "server")]
async fn handle_server_start(
args: ServerStartArgs,
config: AppConfig,
client: Arc<Qdrant>,
) -> Result<()> {
let server_config = ServerConfig {
port: args.port,
host: args.host.clone(),
use_tls: args.tls,
cert_path: args.cert.clone(),
key_path: args.key.clone(),
api_key: args.api_key.clone(),
api_key_file: args.api_key_file.clone(),
require_auth: args.require_auth,
max_concurrent_requests: args.max_concurrent_requests,
max_batch_size: 128, };
if let Err(e) = server_config.validate_tls() {
error!("TLS configuration error: {}", e);
return Err(anyhow::anyhow!("TLS configuration error: {}", e));
}
if let Err(e) = server_config.validate_auth() {
error!("Authentication configuration error: {}", e);
return Err(anyhow::anyhow!("Authentication configuration error: {}", e));
}
let addr = match server_config.socket_addr() {
Ok(addr) => addr,
Err(e) => {
error!("Invalid server address: {}", e);
return Err(anyhow::anyhow!("Invalid server address: {}", e));
}
};
info!("Starting VectorDB server on {}...", addr);
let app_config = Arc::new(config);
crate::server::start_server(
addr,
app_config,
client,
None,
server_config.use_tls,
server_config.cert_path.map(|p| p.to_string_lossy().to_string()),
server_config.key_path.map(|p| p.to_string_lossy().to_string()),
)
.await?;
Ok(())
}
#[cfg(not(feature = "server"))]
async fn handle_server_start(
args: ServerStartArgs,
_config: AppConfig,
_client: Arc<Qdrant>,
) -> Result<()> {
error!("Server feature not enabled. Compile with --features=server to enable server mode.");
Err(anyhow::anyhow!("Server feature not enabled. Use `cargo build --features server` to compile with server support."))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_args_structure() {
let cmd = ServerCommands::Start(ServerStartArgs {
port: 8080,
host: "127.0.0.1".to_string(),
tls: false,
cert: None,
key: None,
api_key: None,
api_key_file: None,
require_auth: false,
max_concurrent_requests: 100,
});
let args = ServerArgs { command: cmd };
match args.command {
ServerCommands::Start(start_args) => {
assert_eq!(start_args.port, 8080);
assert_eq!(start_args.host, "127.0.0.1");
}
}
}
#[test]
fn test_server_start_args_defaults() {
let default_port = 50051;
let default_host = "0.0.0.0";
let default_concurrent_requests = 100;
let args = ServerStartArgs {
port: default_port,
host: default_host.to_string(),
tls: false,
cert: None,
key: None,
api_key: None,
api_key_file: None,
require_auth: false,
max_concurrent_requests: default_concurrent_requests,
};
assert_eq!(args.port, default_port, "Default port should be {}", default_port);
assert_eq!(args.host, default_host, "Default host should be {}", default_host);
assert_eq!(args.tls, false, "TLS should be disabled by default");
assert_eq!(args.require_auth, false, "Auth should not be required by default");
assert_eq!(args.max_concurrent_requests, default_concurrent_requests,
"Default max concurrent requests should be {}", default_concurrent_requests);
}
#[test]
fn test_server_start_args_custom_values() {
let args = ServerStartArgs {
port: 8080,
host: "127.0.0.1".to_string(),
tls: true,
cert: Some(PathBuf::from("/path/to/cert.pem")),
key: Some(PathBuf::from("/path/to/key.pem")),
api_key: Some("test-key".to_string()),
api_key_file: Some(PathBuf::from("/path/to/api-key.txt")),
require_auth: true,
max_concurrent_requests: 200,
};
assert_eq!(args.port, 8080);
assert_eq!(args.host, "127.0.0.1");
assert_eq!(args.tls, true);
assert_eq!(args.cert, Some(PathBuf::from("/path/to/cert.pem")));
assert_eq!(args.key, Some(PathBuf::from("/path/to/key.pem")));
assert_eq!(args.api_key, Some("test-key".to_string()));
assert_eq!(args.api_key_file, Some(PathBuf::from("/path/to/api-key.txt")));
assert_eq!(args.require_auth, true);
assert_eq!(args.max_concurrent_requests, 200);
}
}