use crate::protocol::message::{
AllSchemasMsg, ClientMessage, ConfigMsg, DownloadEvent, DownloadInfo, DownloadList, ErrorMsg,
HubAction, Pong, SendMsg, SendResponse, ServerMessage, ServiceInfoMsg, ServiceListMsg,
ServiceQueryResultMsg, ServiceSchemaMsg, SessionInfo, SessionList, StreamEvent, StreamMsg,
TaskEvent, TaskInfo, TaskList, client_message, server_message,
};
use anyhow::Result;
use futures_core::Stream;
use futures_util::StreamExt;
fn server_error(code: u32, message: String) -> ServerMessage {
ServerMessage {
msg: Some(server_message::Msg::Error(ErrorMsg { code, message })),
}
}
fn server_pong() -> ServerMessage {
ServerMessage {
msg: Some(server_message::Msg::Pong(Pong {})),
}
}
fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
match result {
Ok(resp) => resp.into(),
Err(e) => server_error(500, e.to_string()),
}
}
pub trait Server: Sync {
fn send(&self, req: SendMsg) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
fn stream(&self, req: StreamMsg) -> impl Stream<Item = Result<StreamEvent>> + Send;
fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
fn hub(
&self,
package: String,
action: HubAction,
filters: Vec<String>,
) -> impl Stream<Item = Result<DownloadEvent>> + Send;
fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
fn approve_task(
&self,
task_id: u64,
response: String,
) -> impl std::future::Future<Output = Result<bool>> + Send;
fn subscribe_tasks(&self) -> impl Stream<Item = Result<TaskEvent>> + Send;
fn list_downloads(&self)
-> impl std::future::Future<Output = Result<Vec<DownloadInfo>>> + Send;
fn subscribe_downloads(&self) -> impl Stream<Item = Result<DownloadEvent>> + Send;
fn get_config(&self) -> impl std::future::Future<Output = Result<String>> + Send;
fn set_config(&self, config: String) -> impl std::future::Future<Output = Result<()>> + Send;
fn service_query(
&self,
service: String,
query: String,
) -> impl std::future::Future<Output = Result<String>> + Send;
fn get_service_schema(
&self,
service: String,
) -> impl std::future::Future<Output = Result<String>> + Send;
fn get_all_schemas(
&self,
) -> impl std::future::Future<Output = Result<std::collections::HashMap<String, String>>> + Send;
fn list_services(
&self,
) -> impl std::future::Future<Output = Result<Vec<ServiceInfoMsg>>> + Send;
fn set_service_config(
&self,
service: String,
config: String,
) -> impl std::future::Future<Output = Result<()>> + Send;
fn reload(&self) -> impl std::future::Future<Output = Result<()>> + Send;
fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
async_stream::stream! {
let Some(inner) = msg.msg else {
yield server_error(400, "empty client message".to_string());
return;
};
match inner {
client_message::Msg::Send(send_msg) => {
yield result_to_msg(self.send(send_msg).await);
}
client_message::Msg::Stream(stream_msg) => {
let s = self.stream(stream_msg);
tokio::pin!(s);
while let Some(result) = s.next().await {
yield result_to_msg(result);
}
}
client_message::Msg::Ping(_) => {
yield match self.ping().await {
Ok(()) => server_pong(),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::Hub(hub_msg) => {
let action = hub_msg.action();
let s = self.hub(hub_msg.package, action, hub_msg.filters);
tokio::pin!(s);
while let Some(result) = s.next().await {
yield result_to_msg(result);
}
}
client_message::Msg::Sessions(_) => {
yield match self.list_sessions().await {
Ok(sessions) => ServerMessage {
msg: Some(server_message::Msg::Sessions(SessionList { sessions })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::Kill(kill_msg) => {
yield match self.kill_session(kill_msg.session).await {
Ok(true) => server_pong(),
Ok(false) => server_error(
404,
format!("session {} not found", kill_msg.session),
),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::Tasks(_) => {
yield match self.list_tasks().await {
Ok(tasks) => ServerMessage {
msg: Some(server_message::Msg::Tasks(TaskList { tasks })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::KillTask(kill_task_msg) => {
yield match self.kill_task(kill_task_msg.task_id).await {
Ok(true) => server_pong(),
Ok(false) => server_error(
404,
format!("task {} not found", kill_task_msg.task_id),
),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::Approve(approve_msg) => {
yield match self.approve_task(approve_msg.task_id, approve_msg.response).await {
Ok(true) => server_pong(),
Ok(false) => server_error(
404,
format!("task {} not found or not blocked", approve_msg.task_id),
),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::SubscribeTasks(_) => {
let s = self.subscribe_tasks();
tokio::pin!(s);
while let Some(result) = s.next().await {
yield result_to_msg(result);
}
}
client_message::Msg::Downloads(_) => {
yield match self.list_downloads().await {
Ok(downloads) => ServerMessage {
msg: Some(server_message::Msg::Downloads(DownloadList { downloads })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::SubscribeDownloads(_) => {
let s = self.subscribe_downloads();
tokio::pin!(s);
while let Some(result) = s.next().await {
yield result_to_msg(result);
}
}
client_message::Msg::GetConfig(_) => {
yield match self.get_config().await {
Ok(config) => ServerMessage {
msg: Some(server_message::Msg::Config(ConfigMsg { config })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::SetConfig(set_config_msg) => {
yield match self.set_config(set_config_msg.config).await {
Ok(()) => server_pong(),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::ServiceQuery(sq) => {
yield match self.service_query(sq.service, sq.query).await {
Ok(result) => ServerMessage {
msg: Some(server_message::Msg::ServiceQueryResult(
ServiceQueryResultMsg { result },
)),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::GetServiceSchema(req) => {
let service = req.service;
yield match self.get_service_schema(service.clone()).await {
Ok(schema) => ServerMessage {
msg: Some(server_message::Msg::ServiceSchema(ServiceSchemaMsg {
service,
schema,
})),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::GetAllSchemas(_) => {
yield match self.get_all_schemas().await {
Ok(schemas) => ServerMessage {
msg: Some(server_message::Msg::AllSchemas(AllSchemasMsg { schemas })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::GetServices(_) => {
yield match self.list_services().await {
Ok(services) => ServerMessage {
msg: Some(server_message::Msg::ServiceList(ServiceListMsg { services })),
},
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::SetServiceConfig(req) => {
yield match self.set_service_config(req.service, req.config).await {
Ok(()) => server_pong(),
Err(e) => server_error(500, e.to_string()),
};
}
client_message::Msg::Reload(_) => {
yield match self.reload().await {
Ok(()) => server_pong(),
Err(e) => server_error(500, e.to_string()),
};
}
}
}
}
}