use std::sync::Arc;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use serde_json::json;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::mpsc;
use crate::api::WeComApiClient;
use crate::crypto_utils::decrypt_file;
use crate::logger::DefaultLogger;
use crate::message_handler::{MessageEvent, MessageHandler};
use crate::types::{
Logger, SdkError, WsCmd, WsFrame, WSClientOptions,
};
use crate::utils::generate_req_id;
use crate::ws::WsConnectionManager;
#[derive(Clone)]
pub struct WSClient {
#[allow(dead_code)]
options: WSClientOptions,
logger: Arc<dyn Logger>,
started: Arc<RwLock<bool>>,
api_client: Arc<WeComApiClient>,
ws_manager: Arc<WsConnectionManager>,
message_handler: Arc<MessageHandler>,
event_handlers: Arc<Mutex<EventHandlers>>,
}
#[derive(Clone)]
struct EventHandlers {
connected: Vec<Arc<dyn Fn() + Send + Sync>>,
authenticated: Vec<Arc<dyn Fn() + Send + Sync>>,
disconnected: Vec<Arc<dyn Fn(&str) + Send + Sync>>,
reconnecting: Vec<Arc<dyn Fn(usize) + Send + Sync>>,
error: Vec<Arc<dyn Fn(&SdkError) + Send + Sync>>,
message: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
message_text: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
message_image: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
message_mixed: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
message_voice: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
message_file: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
event: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
event_enter_chat: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
event_template_card: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
event_feedback: Vec<Arc<dyn Fn(&WsFrame) + Send + Sync>>,
}
impl Default for EventHandlers {
fn default() -> Self {
Self {
connected: Vec::new(),
authenticated: Vec::new(),
disconnected: Vec::new(),
reconnecting: Vec::new(),
error: Vec::new(),
message: Vec::new(),
message_text: Vec::new(),
message_image: Vec::new(),
message_mixed: Vec::new(),
message_voice: Vec::new(),
message_file: Vec::new(),
event: Vec::new(),
event_enter_chat: Vec::new(),
event_template_card: Vec::new(),
event_feedback: Vec::new(),
}
}
}
impl WSClient {
pub fn new(options: WSClientOptions) -> Self {
let logger: Arc<dyn Logger> = options
.logger
.clone()
.unwrap_or_else(|| Arc::new(DefaultLogger::default()));
let request_timeout = options.request_timeout;
let logger_clone = logger.clone();
let api_client = Arc::new(WeComApiClient::new(logger_clone, request_timeout));
let ws_manager = Arc::new(WsConnectionManager::new(options.clone(), logger.clone()));
let message_handler = Arc::new(MessageHandler::new(logger.clone()));
Self {
options,
logger,
started: Arc::new(RwLock::new(false)),
api_client,
ws_manager,
message_handler,
event_handlers: Arc::new(Mutex::new(EventHandlers::default())),
}
}
pub async fn connect(&self) -> Result<(), SdkError> {
let mut started = self.started.write().await;
if *started {
self.logger.warn("Client already connected");
return Ok(());
}
self.logger.info("Establishing WebSocket connection...");
*started = true;
self.ws_manager.connect().await?;
self._start_receive_loop();
Ok(())
}
fn _start_receive_loop(&self) {
let ws_manager = self.ws_manager.clone();
let event_handlers = self.event_handlers.clone();
let message_handler = self.message_handler.clone();
let logger = self.logger.clone();
tokio::spawn(async move {
let event_rx = ws_manager.get_event_receiver().await;
Self::_receive_loop(event_rx, event_handlers, message_handler, logger).await;
});
}
async fn _receive_loop(
mut event_rx: mpsc::UnboundedReceiver<WsFrame>,
event_handlers: Arc<Mutex<EventHandlers>>,
message_handler: Arc<MessageHandler>,
_logger: Arc<dyn Logger>,
) {
while let Some(frame) = event_rx.recv().await {
let events = message_handler.handle_frame(&frame);
for event in events {
let handlers = event_handlers.lock().await;
match event {
MessageEvent::Message(f) => {
for handler in &handlers.message {
handler(&f);
}
}
MessageEvent::Text(f) => {
for handler in &handlers.message_text {
handler(&f);
}
}
MessageEvent::Image(f) => {
for handler in &handlers.message_image {
handler(&f);
}
}
MessageEvent::Mixed(f) => {
for handler in &handlers.message_mixed {
handler(&f);
}
}
MessageEvent::Voice(f) => {
for handler in &handlers.message_voice {
handler(&f);
}
}
MessageEvent::File(f) => {
for handler in &handlers.message_file {
handler(&f);
}
}
MessageEvent::Event(f) => {
for handler in &handlers.event {
handler(&f);
}
}
MessageEvent::EnterChat(f) => {
for handler in &handlers.event_enter_chat {
handler(&f);
}
}
MessageEvent::TemplateCardEvent(f) => {
for handler in &handlers.event_template_card {
handler(&f);
}
}
MessageEvent::FeedbackEvent(f) => {
for handler in &handlers.event_feedback {
handler(&f);
}
}
}
}
}
}
pub fn disconnect(&self) {
self.ws_manager.disconnect();
self.logger.info("Disconnected");
}
pub async fn reply(
&self,
frame: &WsFrame,
body: serde_json::Value,
cmd: Option<&str>,
) -> Result<WsFrame, SdkError> {
let req_id = &frame.headers.req_id;
let cmd = cmd.unwrap_or(WsCmd::RESPONSE);
self.ws_manager.send_reply(req_id, body, cmd).await
}
pub async fn reply_stream(
&self,
frame: &WsFrame,
stream_id: &str,
content: &str,
finish: bool,
msg_item: Option<Vec<serde_json::Value>>,
feedback: Option<serde_json::Value>,
) -> Result<WsFrame, SdkError> {
let mut stream_data = json!({
"id": stream_id,
"finish": finish,
"content": content,
});
if finish {
if let Some(items) = msg_item {
if !items.is_empty() {
stream_data["msg_item"] = json!(items);
}
}
}
if let Some(fb) = feedback {
stream_data["feedback"] = fb;
}
self.reply(
frame,
json!({
"msgtype": "stream",
"stream": stream_data,
}),
None,
)
.await
}
pub async fn reply_welcome(
&self,
frame: &WsFrame,
body: serde_json::Value,
) -> Result<WsFrame, SdkError> {
self.reply(frame, body, Some(WsCmd::RESPONSE_WELCOME)).await
}
pub async fn reply_template_card(
&self,
frame: &WsFrame,
template_card: serde_json::Value,
feedback: Option<serde_json::Value>,
) -> Result<WsFrame, SdkError> {
let card = if let Some(fb) = feedback {
let mut card_obj = template_card.as_object().cloned().unwrap_or_default();
card_obj.insert("feedback".to_string(), fb);
json!(card_obj)
} else {
template_card
};
let body = json!({
"msgtype": "template_card",
"template_card": card,
});
self.reply(frame, body, None).await
}
pub async fn reply_stream_with_card(
&self,
frame: &WsFrame,
stream_id: &str,
content: &str,
finish: bool,
msg_item: Option<Vec<serde_json::Value>>,
stream_feedback: Option<serde_json::Value>,
template_card: Option<serde_json::Value>,
card_feedback: Option<serde_json::Value>,
) -> Result<WsFrame, SdkError> {
let mut stream_data = json!({
"id": stream_id,
"finish": finish,
"content": content,
});
if finish {
if let Some(items) = msg_item {
if !items.is_empty() {
stream_data["msg_item"] = json!(items);
}
}
}
if let Some(fb) = stream_feedback {
stream_data["feedback"] = fb;
}
let mut body = json!({
"msgtype": "stream_with_template_card",
"stream": stream_data,
});
if let Some(card) = template_card {
let card_obj = if let Some(fb) = card_feedback {
let mut obj = card.as_object().cloned().unwrap_or_default();
obj.insert("feedback".to_string(), fb);
obj
} else {
card.as_object().cloned().unwrap_or_default()
};
body["template_card"] = json!(card_obj);
}
self.reply(frame, body, None).await
}
pub async fn update_template_card(
&self,
frame: &WsFrame,
template_card: serde_json::Value,
userids: Option<Vec<String>>,
) -> Result<WsFrame, SdkError> {
let mut body = json!({
"response_type": "update_template_card",
"template_card": template_card,
});
if let Some(ids) = userids {
if !ids.is_empty() {
body["userids"] = json!(ids);
}
}
self.reply(frame, body, Some(WsCmd::RESPONSE_UPDATE)).await
}
pub async fn send_message(
&self,
chatid: &str,
body: serde_json::Value,
) -> Result<WsFrame, SdkError> {
let req_id = generate_req_id(WsCmd::SEND_MSG);
let mut full_body = body.as_object().cloned().unwrap_or_default();
full_body.insert("chatid".to_string(), json!(chatid));
self.ws_manager
.send_reply(&req_id, json!(full_body), WsCmd::SEND_MSG)
.await
}
pub async fn download_file(
&self,
url: &str,
aes_key: Option<&str>,
) -> Result<(Vec<u8>, Option<String>), SdkError> {
self.logger.info("Downloading and decrypting file...");
let (encrypted_data, filename) = self.api_client.download_file_raw(url).await?;
if aes_key.is_none() {
self.logger.warn("No aes_key provided, returning raw file data");
return Ok((encrypted_data, filename));
}
let decrypted_data = decrypt_file(&encrypted_data, aes_key.unwrap())?;
self.logger.info("File downloaded and decrypted successfully");
Ok((decrypted_data, filename))
}
pub fn is_connected(&self) -> bool {
self.ws_manager.is_connected()
}
pub fn api(&self) -> Arc<WeComApiClient> {
self.api_client.clone()
}
pub async fn upload_media(
&self,
media_type: &str,
file_data: &[u8],
filename: &str,
) -> Result<serde_json::Value, SdkError> {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
self.logger.info(&format!(
"Uploading media file: {}, type: {}, size: {} bytes",
filename, media_type, file_data.len()
));
let mut hasher = DefaultHasher::new();
file_data.hash(&mut hasher);
let md5 = format!("{:x}", hasher.finish());
const CHUNK_SIZE: usize = 6000;
let total_size = file_data.len();
let total_chunks = (total_size + CHUNK_SIZE - 1) / CHUNK_SIZE;
let upload_id = self
.ws_manager
.upload_media_init(media_type, filename, total_size, total_chunks, &md5)
.await?;
self.logger.debug(&format!("Upload initialized, upload_id: {}", upload_id));
for (index, chunk) in file_data.chunks(CHUNK_SIZE).enumerate() {
let base64_data = BASE64.encode(chunk);
self.ws_manager
.upload_media_chunk(&upload_id, index + 1, base64_data)
.await?;
}
self.logger.debug("All chunks uploaded");
let result = self.ws_manager.upload_media_finish(&upload_id).await?;
self.logger.info(&format!(
"Media upload completed, result: {:?}",
result
));
Ok(result)
}
pub async fn on_connected<F>(&self, handler: F)
where
F: Fn() + Send + Sync + 'static,
{
self.event_handlers.lock().await.connected.push(Arc::new(handler));
}
pub async fn on_authenticated<F>(&self, handler: F)
where
F: Fn() + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.authenticated
.push(Arc::new(handler));
}
pub async fn on_disconnected<F>(&self, handler: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.disconnected
.push(Arc::new(handler));
}
pub async fn on_reconnecting<F>(&self, handler: F)
where
F: Fn(usize) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.reconnecting
.push(Arc::new(handler));
}
pub async fn on_error<F>(&self, handler: F)
where
F: Fn(&SdkError) + Send + Sync + 'static,
{
self.event_handlers.lock().await.error.push(Arc::new(handler));
}
pub async fn on_message<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers.lock().await.message.push(Arc::new(handler));
}
pub async fn on_message_text<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.message_text
.push(Arc::new(handler));
}
pub async fn on_message_image<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.message_image
.push(Arc::new(handler));
}
pub async fn on_message_mixed<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.message_mixed
.push(Arc::new(handler));
}
pub async fn on_message_voice<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.message_voice
.push(Arc::new(handler));
}
pub async fn on_message_file<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.message_file
.push(Arc::new(handler));
}
pub async fn on_event<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers.lock().await.event.push(Arc::new(handler));
}
pub async fn on_event_enter_chat<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.event_enter_chat
.push(Arc::new(handler));
}
pub async fn on_event_template_card<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.event_template_card
.push(Arc::new(handler));
}
pub async fn on_event_feedback<F>(&self, handler: F)
where
F: Fn(&WsFrame) + Send + Sync + 'static,
{
self.event_handlers
.lock()
.await
.event_feedback
.push(Arc::new(handler));
}
}