use std::sync::Arc;
use subtle::ConstantTimeEq;
use tonic::{Request, Response, Status};
use crate::server::{Server, ServerRegistry};
pub mod proto {
#![allow(missing_docs)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::default_trait_access)]
#![allow(clippy::missing_const_for_fn)]
#![allow(clippy::too_many_lines)]
#![allow(clippy::derive_partial_eq_without_eq)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::missing_errors_doc)]
tonic::include_proto!("deepslate");
#[cfg(feature = "grpc-reflection")]
pub const FILE_DESCRIPTOR_SET: &[u8] =
tonic::include_file_descriptor_set!("deepslate_descriptor");
}
use proto::deepslate_server::Deepslate;
use proto::{
DeregisterServerRequest, DeregisterServerResponse, ForcedHostEntry, ListServersRequest,
ListServersResponse, RegisterServerRequest, RegisterServerResponse, SetForcedHostsRequest,
SetForcedHostsResponse, SetTryOrderRequest, SetTryOrderResponse,
};
pub struct DeepslateService {
registry: Arc<ServerRegistry>,
}
impl DeepslateService {
#[must_use]
pub const fn new(registry: Arc<ServerRegistry>) -> Self {
Self { registry }
}
}
#[tonic::async_trait]
impl Deepslate for DeepslateService {
async fn register_server(
&self,
request: Request<RegisterServerRequest>,
) -> Result<Response<RegisterServerResponse>, Status> {
let req = request.into_inner();
let server = Server::new(req.id.clone(), req.address.clone());
if self.registry.register(&server) {
tracing::info!(id = %req.id, addr = %req.address, "server registered via gRPC");
Ok(Response::new(RegisterServerResponse {
success: true,
error: String::new(),
}))
} else {
Ok(Response::new(RegisterServerResponse {
success: false,
error: format!("server with ID '{}' already exists", req.id),
}))
}
}
async fn deregister_server(
&self,
request: Request<DeregisterServerRequest>,
) -> Result<Response<DeregisterServerResponse>, Status> {
let req = request.into_inner();
if self.registry.deregister(&req.id).is_some() {
tracing::info!(id = %req.id, "server deregistered via gRPC");
Ok(Response::new(DeregisterServerResponse {
success: true,
error: String::new(),
}))
} else {
Ok(Response::new(DeregisterServerResponse {
success: false,
error: format!("server with ID '{}' not found", req.id),
}))
}
}
async fn list_servers(
&self,
_request: Request<ListServersRequest>,
) -> Result<Response<ListServersResponse>, Status> {
let servers = self
.registry
.list()
.into_iter()
.map(|s| proto::Server {
id: s.id,
address: s.addr,
})
.collect();
let try_order = self.registry.try_order();
let forced_hosts = self
.registry
.forced_hosts()
.into_iter()
.map(|(host, ids)| (host, ForcedHostEntry { server_ids: ids }))
.collect();
Ok(Response::new(ListServersResponse {
servers,
try_order,
forced_hosts,
}))
}
async fn set_try_order(
&self,
request: Request<SetTryOrderRequest>,
) -> Result<Response<SetTryOrderResponse>, Status> {
let req = request.into_inner();
tracing::info!(order = ?req.ids, "try order updated via gRPC");
self.registry.set_try_order(req.ids);
Ok(Response::new(SetTryOrderResponse { success: true }))
}
async fn set_forced_hosts(
&self,
request: Request<SetForcedHostsRequest>,
) -> Result<Response<SetForcedHostsResponse>, Status> {
let req = request.into_inner();
let map: std::collections::HashMap<String, Vec<String>> = req
.forced_hosts
.into_iter()
.map(|(host, entry)| (host.to_lowercase(), entry.server_ids))
.collect();
tracing::info!(hosts = ?map.keys().collect::<Vec<_>>(), "forced hosts updated via gRPC");
self.registry.set_forced_hosts(map);
Ok(Response::new(SetForcedHostsResponse { success: true }))
}
}
pub fn bearer_auth_interceptor(
expected: Arc<String>,
) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone + Send + Sync + 'static {
move |req: Request<()>| {
let provided = req
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match provided {
Some(token) if token.as_bytes().ct_eq(expected.as_bytes()).into() => Ok(req),
_ => Err(Status::unauthenticated("invalid or missing auth token")),
}
}
}