use std::sync::Arc;
use async_trait::async_trait;
use endpoint_libs::libs::error_code::ErrorCode;
use endpoint_libs::libs::handler::RequestHandler;
#[cfg(feature = "error_aggregation")]
use endpoint_libs::libs::log::error_aggregation::ErrorAggregationConfig;
use endpoint_libs::libs::log::{LogLevel, LoggingConfig, OtelConfig, setup_logging};
use endpoint_libs::libs::toolbox::{ArcToolbox, RequestContext};
use endpoint_libs::libs::ws::toolbox::CustomError;
use endpoint_libs::libs::ws::{
AuthController, WebsocketServer, WsConnection, WsRequest, WsResponse, WsServerConfig,
};
use eyre::Result;
use futures::FutureExt;
use futures::future::LocalBoxFuture;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EchoRequest {
pub message: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EchoResponse {
pub message: String,
}
impl WsRequest for EchoRequest {
type Response = EchoResponse;
const METHOD_ID: u32 = 1;
const ROLES: &'static [u32] = &[1];
const SCHEMA: &'static str = r#"{
"name": "Echo",
"code": 1,
"parameters": [{"name": "message", "ty": "String"}],
"returns": [{"name": "message", "ty": "String"}],
"roles": []
}"#;
}
impl WsResponse for EchoResponse {
type Request = EchoRequest;
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct HoneyReceiveUserInfoRequest {
pub user_pub_id: Uuid,
pub username: String,
#[serde(default)]
pub app_pub_id: Option<Uuid>,
#[serde(default)]
pub token: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HoneyReceiveUserInfoResponse {}
impl WsRequest for HoneyReceiveUserInfoRequest {
type Response = HoneyReceiveUserInfoResponse;
const METHOD_ID: u32 = 211;
const ROLES: &'static [u32] = &[1];
const SCHEMA: &'static str = r#"{
"name": "ReceiveUserInfo",
"code": 211,
"parameters": [
{"name": "userPubId", "ty": "UUID"},
{"name": "username", "ty": "String"},
{"name": "appPubId", "ty": {"Optional": "UUID"}},
{"name": "token", "ty": {"Optional": "String"}}
],
"returns": [],
"stream_response": null,
"description": "Test endpoint mirroring ReceiveUserInfo (code 211)",
"json_schema": null,
"roles": []
}"#;
}
impl WsResponse for HoneyReceiveUserInfoResponse {
type Request = HoneyReceiveUserInfoRequest;
}
pub struct MethodEcho;
#[async_trait(?Send)]
impl RequestHandler for MethodEcho {
type Request = EchoRequest;
async fn handle(&self, ctx: RequestContext, req: EchoRequest) -> Result<EchoResponse> {
tracing::info!(
conn_id = %ctx.connection_id,
message = %req.message,
"Echo request received"
);
let response = EchoResponse {
message: format!("echo: {}", req.message),
};
tracing::info!(
conn_id = %ctx.connection_id,
response_message = %response.message,
"Echo response sent"
);
Ok(response)
}
}
pub struct MethodReceiveUserInfo;
#[async_trait(?Send)]
impl RequestHandler for MethodReceiveUserInfo {
type Request = HoneyReceiveUserInfoRequest;
async fn handle(
&self,
ctx: RequestContext,
req: HoneyReceiveUserInfoRequest,
) -> Result<HoneyReceiveUserInfoResponse> {
tracing::info!(
conn_id = %ctx.connection_id,
user_pub_id = %req.user_pub_id,
username = %req.username,
app_pub_id = ?req.app_pub_id,
has_token = req.token.is_some(),
"ReceiveUserInfo request received (test server — will reject)"
);
let msg = format!(
"Test passed: received ReceiveUserInfo for user '{}' (id: {}){} — \
this is a test server and will not process the request",
req.username,
req.user_pub_id,
req.app_pub_id
.map(|id| format!(", app: {id}"))
.unwrap_or_default(),
);
tracing::info!(
conn_id = %ctx.connection_id,
user_pub_id = %req.user_pub_id,
"Rejecting ReceiveUserInfo with BAD_REQUEST"
);
Err(CustomError::new(ErrorCode::BAD_REQUEST, msg).into())
}
}
struct AllowAllAuthController;
impl AuthController for AllowAllAuthController {
fn auth(
self: Arc<Self>,
_toolbox: &ArcToolbox,
header: String,
conn: Arc<WsConnection>,
) -> LocalBoxFuture<'static, Result<()>> {
async move {
let conn_id = conn.connection_id;
tracing::info!(
conn_id = %conn_id,
ip = %conn.address,
header_len = header.len(),
"New connection — granting role 1 (allow-all auth)"
);
conn.set_roles(Arc::new(vec![1]));
tracing::info!(conn_id = %conn_id, "Roles set successfully");
Ok(())
}
.boxed_local()
}
}
#[tokio::main]
async fn main() -> Result<()> {
let _log = setup_logging(LoggingConfig {
level: LogLevel::Debug,
otel_config: OtelConfig::default(),
file_config: None,
#[cfg(feature = "error_aggregation")]
error_aggregation: ErrorAggregationConfig {
limit: 100,
normalize: true,
},
#[cfg(feature = "log_throttling")]
throttling_config: None,
})?;
tracing::info!("Logging initialised at DEBUG level");
let cert_dir = tempfile::tempdir()?;
let cert_path = cert_dir.path().join("cert.pem");
let key_path = cert_dir.path().join("key.pem");
let key_pair = rcgen::KeyPair::generate()?;
let params = rcgen::CertificateParams::new(vec!["localhost".into(), "127.0.0.1".into()])?;
let cert = params.self_signed(&key_pair)?;
std::fs::write(&cert_path, cert.pem())?;
std::fs::write(&key_path, key_pair.serialize_pem())?;
tracing::info!("Self-signed TLS certificate generated");
let config = WsServerConfig {
name: "ws-echo".to_string(),
address: "0.0.0.0:8443".to_string(),
insecure: false,
pub_certs: Some(vec![cert_path]),
priv_key: Some(key_path),
..Default::default()
};
tracing::info!(
name = %config.name,
address = %config.address,
"Starting WebSocket server (TLS)"
);
let mut server = WebsocketServer::new(config);
server.set_auth_controller(AllowAllAuthController);
server.add_handler(MethodEcho);
server.add_handler(MethodReceiveUserInfo);
tracing::info!("Registered handlers: Echo (1), ReceiveUserInfo (211)");
server.listen().await
}