go_server_rust_sdk/worker/
client.rs

1//! Worker implementation for handling distributed tasks
2
3use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, RwLock};
10use tokio::time::{interval, sleep};
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use url::Url;
13use crate::crypto::{encrypt_data, decrypt_data, unsalt_key};
14use crate::error::{Result, SdkError};
15
16
17/// Worker configuration
18#[derive(Debug, Clone)]
19pub struct Config {
20    /// Scheduler WebSocket URL
21    pub scheduler_url: String,
22    /// Worker group name
23    pub worker_group: String,
24    /// Maximum connection retry attempts
25    pub max_retry: usize,
26    /// Ping interval in seconds
27    pub ping_interval: u64,
28}
29
30/// Method handler function type
31pub type MethodHandler = Box<dyn Fn(Value) -> Result<Value> + Send + Sync>;
32
33/// Method definition with handler and documentation
34#[derive(Clone)]
35pub struct Method {
36    pub name: String,
37    pub handler: Arc<MethodHandler>,
38    pub docs: Vec<String>,
39}
40
41/// WebSocket message types
42#[derive(Deserialize, Debug)]
43#[serde(tag = "type")]
44enum IncomingMessage {
45    #[serde(rename = "task")]
46    Task {
47        #[serde(rename = "taskId")]
48        task_id: String,
49        method: String,
50        params: Value,
51    },
52    #[serde(rename = "encrypted_task")]
53    EncryptedTask {
54        #[serde(rename = "taskId")]
55        task_id: String,
56        method: String,
57        params: Value,
58        key: String,
59        crypto: String,
60    },
61    #[serde(rename = "ping")]
62    Ping,
63}
64
65#[derive(Serialize, Debug)]
66#[serde(tag = "type")]
67enum OutgoingMessage {
68    #[serde(rename = "result")]
69    Result {
70        #[serde(rename = "taskId")]
71        task_id: String,
72        #[serde(skip_serializing_if = "Option::is_none")]
73        result: Option<Value>,
74        #[serde(skip_serializing_if = "Option::is_none")]
75        error: Option<String>,
76    },
77    #[serde(rename = "pong")]
78    Pong,
79}
80
81#[derive(Serialize, Debug)]
82struct RegistrationMessage {
83    group: String,
84    methods: Vec<MethodInfo>,
85}
86
87#[derive(Serialize, Debug)]
88struct MethodInfo {
89    name: String,
90    docs: Vec<String>,
91}
92
93/// Distributed task worker
94pub struct Worker {
95    config: Config,
96    methods: Arc<RwLock<HashMap<String, Method>>>,
97    running: Arc<RwLock<bool>>,
98    shutdown_tx: Option<mpsc::Sender<()>>,
99}
100
101impl Worker {
102    /// Creates a new worker with the given configuration
103    /// 
104    /// # Arguments
105    /// 
106    /// * `config` - Worker configuration
107    /// 
108    /// # Example
109    /// 
110    /// ```rust
111    /// use go_server_rust_sdk::worker::{Worker, Config};
112    /// 
113    /// let config = Config {
114    ///     scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
115    ///     worker_group: "math".to_string(),
116    ///     max_retry: 5,
117    ///     ping_interval: 5,
118    /// };
119    /// 
120    /// let worker = Worker::new(config);
121    /// ```
122    pub fn new(config: Config) -> Self {
123        Self {
124            config,
125            methods: Arc::new(RwLock::new(HashMap::new())),
126            running: Arc::new(RwLock::new(false)),
127            shutdown_tx: None,
128        }
129    }
130
131    /// Registers a method handler with the worker
132    /// 
133    /// # Arguments
134    /// 
135    /// * `name` - Method name
136    /// * `handler` - Function that handles the method call
137    /// * `docs` - Documentation strings for the method
138    /// 
139    /// # Example
140    /// 
141    /// ```rust
142    /// use go_server_rust_sdk::worker::{Worker, Config};
143    /// use serde_json::{json, Value};
144    /// 
145    /// # let config = Config {
146    /// #     scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
147    /// #     worker_group: "math".to_string(),
148    /// #     max_retry: 5,
149    /// #     ping_interval: 5,
150    /// # };
151    /// let mut worker = Worker::new(config);
152    /// 
153    /// worker.register_method("add", |params: Value| {
154    ///     let a = params["a"].as_f64().unwrap_or(0.0);
155    ///     let b = params["b"].as_f64().unwrap_or(0.0);
156    ///     Ok(json!(a + b))
157    /// }, vec!["Add two numbers".to_string()]);
158    /// ```
159    pub fn register_method<F>(&mut self, name: impl Into<String>, handler: F, docs: Vec<String>)
160    where
161        F: Fn(Value) -> Result<Value> + Send + Sync + 'static,
162    {
163        let method = Method {
164            name: name.into(),
165            handler: Arc::new(Box::new(handler)),
166            docs,
167        };
168        
169        // We need to use a blocking approach here since this is a sync method
170        let methods = self.methods.clone();
171        tokio::spawn(async move {
172            let mut methods_guard = methods.write().await;
173            methods_guard.insert(method.name.clone(), method);
174        });
175    }
176
177    /// Starts the worker with automatic reconnection support
178    /// 
179    /// This method will block until the worker is stopped.
180    /// 
181    /// # Returns
182    /// 
183    /// A `Result` indicating success or failure
184    /// 
185    /// # Example
186    /// 
187    /// ```rust,no_run
188    /// use go_server_rust_sdk::worker::{Worker, Config};
189    /// 
190    /// # #[tokio::main]
191    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
192    /// # let config = Config {
193    /// #     scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
194    /// #     worker_group: "math".to_string(),
195    /// #     max_retry: 5,
196    /// #     ping_interval: 5,
197    /// # };
198    /// let mut worker = Worker::new(config);
199    /// // Register methods...
200    /// worker.start().await?;
201    /// # Ok(())
202    /// # }
203    /// ```
204    pub async fn start(&mut self) -> Result<()> {
205        *self.running.write().await = true;
206        
207        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
208        self.shutdown_tx = Some(shutdown_tx);
209
210        log::info!("Worker {} starting", self.config.worker_group);
211
212        loop {
213            // Check if we should stop
214            if !*self.running.read().await {
215                break;
216            }
217
218            // Try to connect and run
219            match self.connect_and_run(&mut shutdown_rx).await {
220                Ok(_) => {
221                    log::info!("Worker connection closed normally");
222                    break;
223                }
224                Err(e) => {
225                    log::error!("Worker connection failed: {}", e);
226                    if *self.running.read().await {
227                        log::info!("Retrying connection in 5 seconds...");
228                        sleep(Duration::from_secs(5)).await;
229                    }
230                }
231            }
232        }
233
234        log::info!("Worker {} stopped", self.config.worker_group);
235        Ok(())
236    }
237
238    /// Stops the worker
239    /// 
240    /// This method will signal the worker to stop and close all connections.
241    /// 
242    /// # Example
243    /// 
244    /// ```rust,no_run
245    /// use go_server_rust_sdk::worker::{Worker, Config};
246    /// 
247    /// # #[tokio::main]
248    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
249    /// # let config = Config {
250    /// #     scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
251    /// #     worker_group: "math".to_string(),
252    /// #     max_retry: 5,
253    /// #     ping_interval: 5,
254    /// # };
255    /// let mut worker = Worker::new(config);
256    /// 
257    /// // In another task or signal handler:
258    /// worker.stop().await;
259    /// # Ok(())
260    /// # }
261    /// ```
262    pub async fn stop(&mut self) {
263        *self.running.write().await = false;
264        
265        if let Some(tx) = &self.shutdown_tx {
266            let _ = tx.send(()).await;
267        }
268    }
269
270    async fn connect_and_run(&self, shutdown_rx: &mut mpsc::Receiver<()>) -> Result<()> {
271        let url = Url::parse(&self.config.scheduler_url)?;
272        let (ws_stream, _) = connect_async(url).await?;
273        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
274
275        // Send registration message
276        let methods = self.get_methods_info().await;
277        let registration = RegistrationMessage {
278            group: self.config.worker_group.clone(),
279            methods,
280        };
281        
282        let registration_msg = serde_json::to_string(&registration)?;
283        ws_sender.send(Message::Text(registration_msg)).await?;
284
285        log::info!("Worker {} connected and registered", self.config.worker_group);
286
287        // Start ping interval
288        let mut ping_interval = interval(Duration::from_secs(self.config.ping_interval));
289        
290        loop {
291            tokio::select! {
292                // Handle shutdown signal
293                _ = shutdown_rx.recv() => {
294                    log::info!("Received shutdown signal");
295                    let _ = ws_sender.close().await;
296                    break;
297                }
298                
299                // Handle ping interval
300                _ = ping_interval.tick() => {
301                    let ping_msg = serde_json::to_string(&OutgoingMessage::Pong)?;
302                    if let Err(e) = ws_sender.send(Message::Text(ping_msg)).await {
303                        log::error!("Failed to send ping: {}", e);
304                        break;
305                    }
306                }
307                
308                // Handle incoming messages
309                msg = ws_receiver.next() => {
310                    match msg {
311                        Some(Ok(Message::Text(text))) => {
312                            if let Err(e) = self.handle_message(&text, &mut ws_sender).await {
313                                log::error!("Error handling message: {}", e);
314                            }
315                        }
316                        Some(Ok(Message::Close(_))) => {
317                            log::info!("WebSocket connection closed by server");
318                            break;
319                        }
320                        Some(Err(e)) => {
321                            log::error!("WebSocket error: {}", e);
322                            break;
323                        }
324                        None => {
325                            log::info!("WebSocket stream ended");
326                            break;
327                        }
328                        _ => {}
329                    }
330                }
331            }
332        }
333
334        Ok(())
335    }
336
337    async fn handle_message(
338        &self,
339        text: &str,
340        ws_sender: &mut futures_util::stream::SplitSink<
341            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
342            Message,
343        >,
344    ) -> Result<()> {
345        let message: IncomingMessage = serde_json::from_str(text)?;
346
347        match message {
348            IncomingMessage::Task { task_id, method, params } => {
349                self.handle_task(task_id, method, params, ws_sender).await?
350            }
351            IncomingMessage::EncryptedTask { task_id, method, params, key, crypto } => {
352                self.handle_encrypted_task(task_id, method, params, key, crypto, ws_sender).await?
353            }
354            IncomingMessage::Ping => {
355                let pong_msg = serde_json::to_string(&OutgoingMessage::Pong)?;
356                ws_sender.send(Message::Text(pong_msg)).await?;
357            }
358        }
359
360        Ok(())
361    }
362
363    async fn handle_task(
364        &self,
365        task_id: String,
366        method: String,
367        params: Value,
368        ws_sender: &mut futures_util::stream::SplitSink<
369            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
370            Message,
371        >,
372    ) -> Result<()> {
373        let methods = self.methods.read().await;
374        let method_handler = methods.get(&method).cloned();
375        drop(methods);
376
377        let response = match method_handler {
378            Some(handler) => {
379                match (handler.handler)(params) {
380                    Ok(result) => OutgoingMessage::Result {
381                        task_id,
382                        result: Some(result),
383                        error: None,
384                    },
385                    Err(e) => OutgoingMessage::Result {
386                        task_id,
387                        result: None,
388                        error: Some(e.to_string()),
389                    },
390                }
391            }
392            None => OutgoingMessage::Result {
393                task_id,
394                result: None,
395                error: Some(format!("Method '{}' not found", method)),
396            },
397        };
398
399        let response_text = serde_json::to_string(&response)?;
400        ws_sender.send(Message::Text(response_text)).await?;
401
402        Ok(())
403    }
404
405    async fn handle_encrypted_task(
406        &self,
407        task_id: String,
408        method: String,
409        encrypted_params: Value,
410        salted_key: String,
411        crypto: String,
412        ws_sender: &mut futures_util::stream::SplitSink<
413            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
414            Message,
415        >,
416    ) -> Result<()> {
417        let methods = self.methods.read().await;
418        let method_handler = methods.get(&method).cloned();
419        drop(methods);
420
421        let response = match method_handler {
422            Some(handler) => {
423                // Decrypt parameters
424                match self.decrypt_task_params(encrypted_params, &salted_key, &crypto).await {
425                    Ok(params) => {
426                        // Execute method
427                        match (handler.handler)(params) {
428                            Ok(result) => {
429                                // Encrypt result
430                                match self.encrypt_task_result(result, &salted_key, &crypto).await {
431                                    Ok(encrypted_result) => OutgoingMessage::Result {
432                                        task_id,
433                                        result: Some(encrypted_result),
434                                        error: None,
435                                    },
436                                    Err(e) => OutgoingMessage::Result {
437                                        task_id,
438                                        result: None,
439                                        error: Some(format!("Failed to encrypt result: {}", e)),
440                                    },
441                                }
442                            }
443                            Err(e) => OutgoingMessage::Result {
444                                task_id,
445                                result: None,
446                                error: Some(e.to_string()),
447                            },
448                        }
449                    }
450                    Err(e) => OutgoingMessage::Result {
451                        task_id,
452                        result: None,
453                        error: Some(format!("Failed to decrypt params: {}", e)),
454                    },
455                }
456            }
457            None => OutgoingMessage::Result {
458                task_id,
459                result: None,
460                error: Some(format!("Method '{}' not found", method)),
461            },
462        };
463
464        let response_text = serde_json::to_string(&response)?;
465        ws_sender.send(Message::Text(response_text)).await?;
466
467        Ok(())
468    }
469
470    async fn decrypt_task_params(
471        &self,
472        encrypted_params: Value,
473        salted_key: &str,
474        crypto: &str,
475    ) -> Result<Value> {
476        // Extract encrypted string from JSON
477        let encrypted_str = encrypted_params
478            .as_str()
479            .ok_or_else(|| SdkError::Crypto("Invalid encrypted params format".to_string()))?;
480
481        // Parse salt from crypto string
482        let salt: i32 = crypto.parse()
483            .map_err(|_| SdkError::Crypto("Invalid crypto salt format".to_string()))?;
484
485        // Unsalt the key
486        let original_key = unsalt_key(salted_key, salt)?;
487
488        // Decrypt the parameters
489        decrypt_data(encrypted_str, &original_key)
490    }
491
492    async fn encrypt_task_result(
493        &self,
494        result: Value,
495        salted_key: &str,
496        crypto: &str,
497    ) -> Result<Value> {
498        // Parse salt from crypto string
499        let salt: i32 = crypto.parse()
500            .map_err(|_| SdkError::Crypto("Invalid crypto salt format".to_string()))?;
501
502        // Unsalt the key
503        let original_key = unsalt_key(salted_key, salt)?;
504
505        // Serialize result to JSON string
506        let result_str = serde_json::to_string(&result)?;
507
508        // Encrypt the result using the original key
509        let encrypted_result = encrypt_data(&Value::String(result_str), &original_key)?;
510
511        Ok(Value::String(encrypted_result))
512    }
513
514    async fn get_methods_info(&self) -> Vec<MethodInfo> {
515        let methods = self.methods.read().await;
516        methods
517            .values()
518            .map(|method| MethodInfo {
519                name: method.name.clone(),
520                docs: method.docs.clone(),
521            })
522            .collect()
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use serde_json::json;
530
531    #[test]
532    fn test_worker_creation() {
533        let config = Config {
534            scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
535            worker_group: "test".to_string(),
536            max_retry: 3,
537            ping_interval: 5,
538        };
539
540        let worker = Worker::new(config.clone());
541        assert_eq!(worker.config.worker_group, "test");
542        assert_eq!(worker.config.max_retry, 3);
543    }
544
545    #[tokio::test]
546    async fn test_method_registration() {
547        let config = Config {
548            scheduler_url: "ws://localhost:8080/api/worker/connect/123456".to_string(),
549            worker_group: "test".to_string(),
550            max_retry: 3,
551            ping_interval: 5,
552        };
553
554        let mut worker = Worker::new(config);
555        
556        worker.register_method("test_method", |params: Value| {
557            Ok(json!({"received": params}))
558        }, vec!["Test method".to_string()]);
559
560        // Give some time for the async registration to complete
561        tokio::time::sleep(Duration::from_millis(10)).await;
562
563        let methods = worker.methods.read().await;
564        assert!(methods.contains_key("test_method"));
565    }
566}