1use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, RwLock};
10use tokio::time::{interval, sleep};
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use url::Url;
13use crate::crypto::{encrypt_data, decrypt_data, unsalt_key};
14use crate::error::{Result, SdkError};
15
16
17#[derive(Debug, Clone)]
19pub struct Config {
20 pub scheduler_url: String,
22 pub worker_group: String,
24 pub max_retry: usize,
26 pub ping_interval: u64,
28}
29
30pub type MethodHandler = Box<dyn Fn(Value) -> Result<Value> + Send + Sync>;
32
33#[derive(Clone)]
35pub struct Method {
36 pub name: String,
37 pub handler: Arc<MethodHandler>,
38 pub docs: Vec<String>,
39}
40
41#[derive(Deserialize, Debug)]
43#[serde(tag = "type")]
44enum IncomingMessage {
45 #[serde(rename = "task")]
46 Task {
47 #[serde(rename = "taskId")]
48 task_id: String,
49 method: String,
50 params: Value,
51 },
52 #[serde(rename = "encrypted_task")]
53 EncryptedTask {
54 #[serde(rename = "taskId")]
55 task_id: String,
56 method: String,
57 params: Value,
58 key: String,
59 crypto: String,
60 },
61 #[serde(rename = "ping")]
62 Ping,
63}
64
65#[derive(Serialize, Debug)]
66#[serde(tag = "type")]
67enum OutgoingMessage {
68 #[serde(rename = "result")]
69 Result {
70 #[serde(rename = "taskId")]
71 task_id: String,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 result: Option<Value>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 error: Option<String>,
76 },
77 #[serde(rename = "pong")]
78 Pong,
79}
80
81#[derive(Serialize, Debug)]
82struct RegistrationMessage {
83 group: String,
84 methods: Vec<MethodInfo>,
85}
86
87#[derive(Serialize, Debug)]
88struct MethodInfo {
89 name: String,
90 docs: Vec<String>,
91}
92
93pub struct Worker {
95 config: Config,
96 methods: Arc<RwLock<HashMap<String, Method>>>,
97 running: Arc<RwLock<bool>>,
98 shutdown_tx: Option<mpsc::Sender<()>>,
99}
100
101impl Worker {
102 pub fn new(config: Config) -> Self {
123 Self {
124 config,
125 methods: Arc::new(RwLock::new(HashMap::new())),
126 running: Arc::new(RwLock::new(false)),
127 shutdown_tx: None,
128 }
129 }
130
131 pub fn register_method<F>(&mut self, name: impl Into<String>, handler: F, docs: Vec<String>)
160 where
161 F: Fn(Value) -> Result<Value> + Send + Sync + 'static,
162 {
163 let method = Method {
164 name: name.into(),
165 handler: Arc::new(Box::new(handler)),
166 docs,
167 };
168
169 let methods = self.methods.clone();
171 tokio::spawn(async move {
172 let mut methods_guard = methods.write().await;
173 methods_guard.insert(method.name.clone(), method);
174 });
175 }
176
177 pub async fn start(&mut self) -> Result<()> {
205 *self.running.write().await = true;
206
207 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
208 self.shutdown_tx = Some(shutdown_tx);
209
210 log::info!("Worker {} starting", self.config.worker_group);
211
212 loop {
213 if !*self.running.read().await {
215 break;
216 }
217
218 match self.connect_and_run(&mut shutdown_rx).await {
220 Ok(_) => {
221 log::info!("Worker connection closed normally");
222 break;
223 }
224 Err(e) => {
225 log::error!("Worker connection failed: {}", e);
226 if *self.running.read().await {
227 log::info!("Retrying connection in 5 seconds...");
228 sleep(Duration::from_secs(5)).await;
229 }
230 }
231 }
232 }
233
234 log::info!("Worker {} stopped", self.config.worker_group);
235 Ok(())
236 }
237
238 pub async fn stop(&mut self) {
263 *self.running.write().await = false;
264
265 if let Some(tx) = &self.shutdown_tx {
266 let _ = tx.send(()).await;
267 }
268 }
269
270 async fn connect_and_run(&self, shutdown_rx: &mut mpsc::Receiver<()>) -> Result<()> {
271 let url = Url::parse(&self.config.scheduler_url)?;
272 let (ws_stream, _) = connect_async(url).await?;
273 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
274
275 let methods = self.get_methods_info().await;
277 let registration = RegistrationMessage {
278 group: self.config.worker_group.clone(),
279 methods,
280 };
281
282 let registration_msg = serde_json::to_string(®istration)?;
283 ws_sender.send(Message::Text(registration_msg)).await?;
284
285 log::info!("Worker {} connected and registered", self.config.worker_group);
286
287 let mut ping_interval = interval(Duration::from_secs(self.config.ping_interval));
289
290 loop {
291 tokio::select! {
292 _ = shutdown_rx.recv() => {
294 log::info!("Received shutdown signal");
295 let _ = ws_sender.close().await;
296 break;
297 }
298
299 _ = ping_interval.tick() => {
301 let ping_msg = serde_json::to_string(&OutgoingMessage::Pong)?;
302 if let Err(e) = ws_sender.send(Message::Text(ping_msg)).await {
303 log::error!("Failed to send ping: {}", e);
304 break;
305 }
306 }
307
308 msg = ws_receiver.next() => {
310 match msg {
311 Some(Ok(Message::Text(text))) => {
312 if let Err(e) = self.handle_message(&text, &mut ws_sender).await {
313 log::error!("Error handling message: {}", e);
314 }
315 }
316 Some(Ok(Message::Close(_))) => {
317 log::info!("WebSocket connection closed by server");
318 break;
319 }
320 Some(Err(e)) => {
321 log::error!("WebSocket error: {}", e);
322 break;
323 }
324 None => {
325 log::info!("WebSocket stream ended");
326 break;
327 }
328 _ => {}
329 }
330 }
331 }
332 }
333
334 Ok(())
335 }
336
337 async fn handle_message(
338 &self,
339 text: &str,
340 ws_sender: &mut futures_util::stream::SplitSink<
341 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
342 Message,
343 >,
344 ) -> Result<()> {
345 let message: IncomingMessage = serde_json::from_str(text)?;
346
347 match message {
348 IncomingMessage::Task { task_id, method, params } => {
349 self.handle_task(task_id, method, params, ws_sender).await?
350 }
351 IncomingMessage::EncryptedTask { task_id, method, params, key, crypto } => {
352 self.handle_encrypted_task(task_id, method, params, key, crypto, ws_sender).await?
353 }
354 IncomingMessage::Ping => {
355 let pong_msg = serde_json::to_string(&OutgoingMessage::Pong)?;
356 ws_sender.send(Message::Text(pong_msg)).await?;
357 }
358 }
359
360 Ok(())
361 }
362
363 async fn handle_task(
364 &self,
365 task_id: String,
366 method: String,
367 params: Value,
368 ws_sender: &mut futures_util::stream::SplitSink<
369 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
370 Message,
371 >,
372 ) -> Result<()> {
373 let methods = self.methods.read().await;
374 let method_handler = methods.get(&method).cloned();
375 drop(methods);
376
377 let response = match method_handler {
378 Some(handler) => {
379 match (handler.handler)(params) {
380 Ok(result) => OutgoingMessage::Result {
381 task_id,
382 result: Some(result),
383 error: None,
384 },
385 Err(e) => OutgoingMessage::Result {
386 task_id,
387 result: None,
388 error: Some(e.to_string()),
389 },
390 }
391 }
392 None => OutgoingMessage::Result {
393 task_id,
394 result: None,
395 error: Some(format!("Method '{}' not found", method)),
396 },
397 };
398
399 let response_text = serde_json::to_string(&response)?;
400 ws_sender.send(Message::Text(response_text)).await?;
401
402 Ok(())
403 }
404
405 async fn handle_encrypted_task(
406 &self,
407 task_id: String,
408 method: String,
409 encrypted_params: Value,
410 salted_key: String,
411 crypto: String,
412 ws_sender: &mut futures_util::stream::SplitSink<
413 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
414 Message,
415 >,
416 ) -> Result<()> {
417 let methods = self.methods.read().await;
418 let method_handler = methods.get(&method).cloned();
419 drop(methods);
420
421 let response = match method_handler {
422 Some(handler) => {
423 match self.decrypt_task_params(encrypted_params, &salted_key, &crypto).await {
425 Ok(params) => {
426 match (handler.handler)(params) {
428 Ok(result) => {
429 match self.encrypt_task_result(result, &salted_key, &crypto).await {
431 Ok(encrypted_result) => OutgoingMessage::Result {
432 task_id,
433 result: Some(encrypted_result),
434 error: None,
435 },
436 Err(e) => OutgoingMessage::Result {
437 task_id,
438 result: None,
439 error: Some(format!("Failed to encrypt result: {}", e)),
440 },
441 }
442 }
443 Err(e) => OutgoingMessage::Result {
444 task_id,
445 result: None,
446 error: Some(e.to_string()),
447 },
448 }
449 }
450 Err(e) => OutgoingMessage::Result {
451 task_id,
452 result: None,
453 error: Some(format!("Failed to decrypt params: {}", e)),
454 },
455 }
456 }
457 None => OutgoingMessage::Result {
458 task_id,
459 result: None,
460 error: Some(format!("Method '{}' not found", method)),
461 },
462 };
463
464 let response_text = serde_json::to_string(&response)?;
465 ws_sender.send(Message::Text(response_text)).await?;
466
467 Ok(())
468 }
469
470 async fn decrypt_task_params(
471 &self,
472 encrypted_params: Value,
473 salted_key: &str,
474 crypto: &str,
475 ) -> Result<Value> {
476 let encrypted_str = encrypted_params
478 .as_str()
479 .ok_or_else(|| SdkError::Crypto("Invalid encrypted params format".to_string()))?;
480
481 let salt: i32 = crypto.parse()
483 .map_err(|_| SdkError::Crypto("Invalid crypto salt format".to_string()))?;
484
485 let original_key = unsalt_key(salted_key, salt)?;
487
488 decrypt_data(encrypted_str, &original_key)
490 }
491
492 async fn encrypt_task_result(
493 &self,
494 result: Value,
495 salted_key: &str,
496 crypto: &str,
497 ) -> Result<Value> {
498 let salt: i32 = crypto.parse()
500 .map_err(|_| SdkError::Crypto("Invalid crypto salt format".to_string()))?;
501
502 let original_key = unsalt_key(salted_key, salt)?;
504
505 let result_str = serde_json::to_string(&result)?;
507
508 let encrypted_result = encrypt_data(&Value::String(result_str), &original_key)?;
510
511 Ok(Value::String(encrypted_result))
512 }
513
514 async fn get_methods_info(&self) -> Vec<MethodInfo> {
515 let methods = self.methods.read().await;
516 methods
517 .values()
518 .map(|method| MethodInfo {
519 name: method.name.clone(),
520 docs: method.docs.clone(),
521 })
522 .collect()
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use serde_json::json;
530
531 #[test]
532 fn test_worker_creation() {
533 let config = Config {
534 scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
535 worker_group: "test".to_string(),
536 max_retry: 3,
537 ping_interval: 5,
538 };
539
540 let worker = Worker::new(config.clone());
541 assert_eq!(worker.config.worker_group, "test");
542 assert_eq!(worker.config.max_retry, 3);
543 }
544
545 #[tokio::test]
546 async fn test_method_registration() {
547 let config = Config {
548 scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
549 worker_group: "test".to_string(),
550 max_retry: 3,
551 ping_interval: 5,
552 };
553
554 let mut worker = Worker::new(config);
555
556 worker.register_method("test_method", |params: Value| {
557 Ok(json!({"received": params}))
558 }, vec!["Test method".to_string()]);
559
560 tokio::time::sleep(Duration::from_millis(10)).await;
562
563 let methods = worker.methods.read().await;
564 assert!(methods.contains_key("test_method"));
565 }
566}