use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use cert_provider::CertProvider;
#[cfg(feature = "s3-sync")]
use cert_provider::S3CertProvider;
use cert_provider::provider::dns01::{BunnyDns, DnsAcmeProvider};
#[cfg(feature = "s3-sync")]
use cert_provider::s3_sync::{S3CertSync, S3Config};
use endpoint_libs::libs::error_code::ErrorCode;
use endpoint_libs::libs::handler::RequestHandler;
#[cfg(feature = "error_aggregation")]
use endpoint_libs::libs::log::error_aggregation::ErrorAggregationConfig;
use endpoint_libs::libs::log::{LogLevel, LoggingConfig, OtelConfig, setup_logging};
use endpoint_libs::libs::toolbox::{ArcToolbox, RequestContext};
use endpoint_libs::libs::ws::toolbox::CustomError;
use endpoint_libs::libs::ws::{
AuthController, WebsocketServer, WsConnection, WsRequest, WsResponse, WsServerConfig,
};
use eyre::Result;
use futures::FutureExt;
use futures::future::LocalBoxFuture;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EchoRequest {
pub message: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EchoResponse {
pub message: String,
}
impl WsRequest for EchoRequest {
type Response = EchoResponse;
const METHOD_ID: u32 = 1;
const ROLES: &'static [u32] = &[1];
const SCHEMA: &'static str = r#"{
"name": "Echo",
"code": 1,
"parameters": [{"name": "message", "ty": "String"}],
"returns": [{"name": "message", "ty": "String"}],
"roles": []
}"#;
}
impl WsResponse for EchoResponse {
type Request = EchoRequest;
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct HoneyReceiveUserInfoRequest {
pub user_pub_id: Uuid,
pub username: String,
#[serde(default)]
pub app_pub_id: Option<Uuid>,
#[serde(default)]
pub token: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HoneyReceiveUserInfoResponse {}
impl WsRequest for HoneyReceiveUserInfoRequest {
type Response = HoneyReceiveUserInfoResponse;
const METHOD_ID: u32 = 211;
const ROLES: &'static [u32] = &[1];
const SCHEMA: &'static str = r#"{
"name": "ReceiveUserInfo",
"code": 211,
"parameters": [
{"name": "userPubId", "ty": "UUID"},
{"name": "username", "ty": "String"},
{"name": "appPubId", "ty": {"Optional": "UUID"}},
{"name": "token", "ty": {"Optional": "String"}}
],
"returns": [],
"stream_response": null,
"description": "Test endpoint mirroring ReceiveUserInfo (code 211)",
"json_schema": null,
"roles": []
}"#;
}
impl WsResponse for HoneyReceiveUserInfoResponse {
type Request = HoneyReceiveUserInfoRequest;
}
pub struct MethodEcho;
#[async_trait(?Send)]
impl RequestHandler for MethodEcho {
type Request = EchoRequest;
async fn handle(&self, ctx: RequestContext, req: EchoRequest) -> Result<EchoResponse> {
tracing::info!(
conn_id = %ctx.connection_id,
message = %req.message,
"Echo request received"
);
let response = EchoResponse {
message: format!("echo: {}", req.message),
};
tracing::info!(
conn_id = %ctx.connection_id,
response_message = %response.message,
"Echo response sent"
);
Ok(response)
}
}
pub struct MethodReceiveUserInfo;
#[async_trait(?Send)]
impl RequestHandler for MethodReceiveUserInfo {
type Request = HoneyReceiveUserInfoRequest;
async fn handle(
&self,
ctx: RequestContext,
req: HoneyReceiveUserInfoRequest,
) -> Result<HoneyReceiveUserInfoResponse> {
tracing::info!(
conn_id = %ctx.connection_id,
user_pub_id = %req.user_pub_id,
username = %req.username,
app_pub_id = ?req.app_pub_id,
has_token = req.token.is_some(),
"ReceiveUserInfo request received (test server — will reject)"
);
let msg = format!(
"Test passed: received ReceiveUserInfo for user '{}' (id: {}){} — \
this is a test server and will not process the request",
req.username,
req.user_pub_id,
req.app_pub_id
.map(|id| format!(", app: {id}"))
.unwrap_or_default(),
);
tracing::info!(
conn_id = %ctx.connection_id,
user_pub_id = %req.user_pub_id,
"Rejecting ReceiveUserInfo with BAD_REQUEST"
);
Err(CustomError::new(ErrorCode::BAD_REQUEST, msg).into())
}
}
struct AllowAllAuthController;
impl AuthController for AllowAllAuthController {
fn auth(
self: Arc<Self>,
_toolbox: &ArcToolbox,
header: String,
conn: Arc<WsConnection>,
) -> LocalBoxFuture<'static, Result<()>> {
async move {
let conn_id = conn.connection_id;
tracing::info!(
conn_id = %conn_id,
ip = %conn.address,
header_len = header.len(),
"New connection — granting role 1 (allow-all auth)"
);
conn.set_roles(Arc::new(vec![1]));
tracing::info!(conn_id = %conn_id, "Roles set successfully");
Ok(())
}
.boxed_local()
}
}
#[tokio::main]
async fn main() -> Result<()> {
let _log = setup_logging(LoggingConfig {
level: LogLevel::Debug,
otel_config: OtelConfig::default(),
file_config: None,
#[cfg(feature = "error_aggregation")]
error_aggregation: ErrorAggregationConfig {
limit: 100,
normalize: true,
},
#[cfg(feature = "log_throttling")]
throttling_config: None,
})?;
tracing::info!("Logging initialised at DEBUG level");
let acme_email =
std::env::var("ACME_EMAIL").unwrap_or_else(|_| "admin@example.com".to_string());
let cert_dir =
PathBuf::from(std::env::var("CERT_DIR").unwrap_or_else(|_| "/certs".to_string()));
let domains =
vec![std::env::var("DOMAIN").unwrap_or_else(|_| "ws-echo.example.com".to_string())];
let bunny_api_key = std::env::var("BUNNY_API_KEY").expect("BUNNY_API_KEY must be set");
let dns = BunnyDns::new(bunny_api_key);
#[cfg(not(feature = "s3-sync"))]
let mut provider: Box<dyn CertProvider> =
Box::new(DnsAcmeProvider::new(acme_email, dns).production());
#[cfg(feature = "s3-sync")]
let mut provider: Box<dyn CertProvider> = {
let sync = S3CertSync::new(S3Config {
bucket_name: std::env::var("CERT_S3_BUCKET")
.expect("CERT_S3_BUCKET must be set when s3-sync feature is enabled"),
endpoint: std::env::var("CERT_S3_ENDPOINT")
.expect("CERT_S3_ENDPOINT must be set when s3-sync feature is enabled"),
access_key: std::env::var("CERT_S3_ACCESS_KEY")
.expect("CERT_S3_ACCESS_KEY must be set when s3-sync feature is enabled"),
secret_key: std::env::var("CERT_S3_SECRET_KEY")
.expect("CERT_S3_SECRET_KEY must be set when s3-sync feature is enabled"),
region: std::env::var("CERT_S3_REGION").ok(),
prefix: Some("ws-echo".into()),
})?;
Box::new(S3CertProvider::new(
DnsAcmeProvider::new(acme_email, dns).production(),
sync,
))
};
let _cert_guard = provider.init(cert_dir.clone(), Some(domains)).await?;
let cert_path = cert_dir.join("fullchain.pem");
let key_path = cert_dir.join("privkey.pem");
let config = WsServerConfig {
name: "ws-echo".to_string(),
address: "0.0.0.0:443".to_string(),
insecure: false,
pub_certs: Some(vec![cert_path]),
priv_key: Some(key_path),
..Default::default()
};
tracing::info!(
name = %config.name,
address = %config.address,
"Starting WebSocket server"
);
let mut server = WebsocketServer::new(config);
server.set_auth_controller(AllowAllAuthController);
server.add_handler(MethodEcho);
server.add_handler(MethodReceiveUserInfo);
tracing::info!("Registered handlers: Echo (1), ReceiveUserInfo (211)");
server.listen().await
}