Skip to main content

aster/scheduler/
delivery.rs

1//! 结果投递模块
2//!
3//! 本模块定义结果投递的 trait 和实现,包括:
4//! - `DeliveryChannel`: 投递渠道 trait
5//! - `DeliveryRouter`: 投递路由器
6//! - `DeliveryResult`: 投递结果
7//!
8//! ## 需求映射
9//!
10//! - **Requirement 5.5**: 投递渠道 trait 定义
11//! - **Requirement 5.6**: 投递路由器实现
12//! - **Requirement 5.7**: best_effort 模式支持
13
14use anyhow::Result;
15use async_trait::async_trait;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use super::executor::ExecutionResult;
20use super::types::DeliveryConfig;
21
22// ============================================================================
23// DeliveryResult 结构体
24// ============================================================================
25
26/// 投递结果
27///
28/// 记录投递操作的结果信息。
29#[derive(Clone, Debug)]
30pub struct DeliveryResult {
31    /// 是否成功
32    pub success: bool,
33
34    /// 投递渠道
35    pub channel: String,
36
37    /// 投递目标
38    pub to: String,
39
40    /// 错误信息(如果失败)
41    pub error: Option<String>,
42}
43
44impl DeliveryResult {
45    /// 创建成功的投递结果
46    pub fn success(channel: impl Into<String>, to: impl Into<String>) -> Self {
47        Self {
48            success: true,
49            channel: channel.into(),
50            to: to.into(),
51            error: None,
52        }
53    }
54
55    /// 创建失败的投递结果
56    pub fn failure(
57        channel: impl Into<String>,
58        to: impl Into<String>,
59        error: impl Into<String>,
60    ) -> Self {
61        Self {
62            success: false,
63            channel: channel.into(),
64            to: to.into(),
65            error: Some(error.into()),
66        }
67    }
68}
69
70// ============================================================================
71// DeliveryChannel Trait (Task 8.1)
72// ============================================================================
73
74/// 投递渠道 trait
75///
76/// 定义结果投递的标准接口,支持不同的投递渠道(Slack、Telegram、Email 等)。
77///
78/// # 需求映射
79///
80/// - **Requirement 5.5**: 投递渠道 trait 定义
81///
82/// # 实现者
83///
84/// 各种投递渠道实现,如:
85/// - `SlackChannel`: Slack 投递
86/// - `TelegramChannel`: Telegram 投递
87/// - `EmailChannel`: 邮件投递
88#[async_trait]
89pub trait DeliveryChannel: Send + Sync {
90    /// 获取渠道 ID
91    fn channel_id(&self) -> &str;
92
93    /// 发送消息
94    ///
95    /// # 参数
96    /// - `to`: 投递目标
97    /// - `message`: 消息内容
98    ///
99    /// # 返回值
100    /// - `Ok(())`: 发送成功
101    /// - `Err`: 发送失败
102    async fn send(&self, to: &str, message: &str) -> Result<()>;
103
104    /// 检查渠道是否可用
105    async fn is_available(&self) -> bool {
106        true
107    }
108}
109
110// ============================================================================
111// DeliveryRouter (Task 8.1)
112// ============================================================================
113
114/// 投递路由器
115///
116/// 管理多个投递渠道,根据配置将结果投递到指定渠道。
117///
118/// # 需求映射
119///
120/// - **Requirement 5.6**: 投递路由器实现
121/// - **Requirement 5.7**: best_effort 模式支持
122pub struct DeliveryRouter {
123    /// 注册的投递渠道
124    channels: HashMap<String, Arc<dyn DeliveryChannel>>,
125
126    /// 默认渠道 ID
127    default_channel: Option<String>,
128}
129
130impl DeliveryRouter {
131    /// 创建新的投递路由器
132    pub fn new() -> Self {
133        Self {
134            channels: HashMap::new(),
135            default_channel: None,
136        }
137    }
138
139    /// 注册投递渠道
140    ///
141    /// # 参数
142    /// - `channel`: 投递渠道实例
143    pub fn register(&mut self, channel: Arc<dyn DeliveryChannel>) {
144        let id = channel.channel_id().to_string();
145        self.channels.insert(id, channel);
146    }
147
148    /// 设置默认渠道
149    ///
150    /// # 参数
151    /// - `channel_id`: 默认渠道 ID
152    pub fn set_default(&mut self, channel_id: impl Into<String>) {
153        self.default_channel = Some(channel_id.into());
154    }
155
156    /// 获取渠道
157    ///
158    /// # 参数
159    /// - `channel_id`: 渠道 ID(如果为 None,使用默认渠道)
160    pub fn get_channel(&self, channel_id: Option<&str>) -> Option<&Arc<dyn DeliveryChannel>> {
161        let id = channel_id.or(self.default_channel.as_deref())?;
162        self.channels.get(id)
163    }
164
165    /// 列出所有注册的渠道
166    pub fn list_channels(&self) -> Vec<&str> {
167        self.channels.keys().map(|s| s.as_str()).collect()
168    }
169
170    /// 投递执行结果
171    ///
172    /// # 参数
173    /// - `config`: 投递配置
174    /// - `result`: 执行结果
175    ///
176    /// # 返回值
177    /// - `Ok(DeliveryResult)`: 投递结果
178    /// - `Err`: 投递失败(仅当 best_effort 为 false 时)
179    ///
180    /// # 行为说明
181    ///
182    /// - 如果 `config.enabled` 为 false,直接返回成功
183    /// - 如果 `config.best_effort` 为 true,投递失败时记录警告但不返回错误
184    /// - 如果 `config.best_effort` 为 false,投递失败时返回错误
185    pub async fn deliver(
186        &self,
187        config: &DeliveryConfig,
188        result: &ExecutionResult,
189    ) -> Result<DeliveryResult> {
190        // 检查是否启用投递
191        if !config.enabled {
192            return Ok(DeliveryResult::success("none", "none"));
193        }
194
195        // 获取渠道和目标
196        let channel_id = config.channel.as_deref().unwrap_or("default");
197        let to = config.to.as_deref().unwrap_or("default");
198
199        // 获取渠道
200        let channel = match self.get_channel(Some(channel_id)) {
201            Some(ch) => ch,
202            None => {
203                let err_msg = format!("渠道未找到: {}", channel_id);
204                if config.best_effort {
205                    tracing::warn!("投递失败 (best effort): {}", err_msg);
206                    return Ok(DeliveryResult::failure(channel_id, to, err_msg));
207                }
208                return Err(anyhow::anyhow!(err_msg));
209            }
210        };
211
212        // 构建消息
213        let message = result.output.as_deref().unwrap_or("任务执行完成");
214
215        // 发送消息
216        match channel.send(to, message).await {
217            Ok(()) => {
218                tracing::info!("投递成功: {} -> {}", channel_id, to);
219                Ok(DeliveryResult::success(channel_id, to))
220            }
221            Err(e) => {
222                let err_msg = e.to_string();
223                if config.best_effort {
224                    tracing::warn!("投递失败 (best effort): {}", err_msg);
225                    Ok(DeliveryResult::failure(channel_id, to, err_msg))
226                } else {
227                    Err(e)
228                }
229            }
230        }
231    }
232}
233
234impl Default for DeliveryRouter {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240// ============================================================================
241// 示例渠道实现
242// ============================================================================
243
244/// 日志渠道(用于测试和调试)
245///
246/// 将消息输出到日志,不实际发送。
247pub struct LogChannel {
248    id: String,
249}
250
251impl LogChannel {
252    /// 创建新的日志渠道
253    pub fn new(id: impl Into<String>) -> Self {
254        Self { id: id.into() }
255    }
256}
257
258#[async_trait]
259impl DeliveryChannel for LogChannel {
260    fn channel_id(&self) -> &str {
261        &self.id
262    }
263
264    async fn send(&self, to: &str, message: &str) -> Result<()> {
265        tracing::info!("[LogChannel:{}] 发送到 {}: {}", self.id, to, message);
266        Ok(())
267    }
268}
269
270/// 模拟渠道(用于测试)
271///
272/// 可配置成功或失败的模拟渠道。
273#[cfg(test)]
274pub struct MockChannel {
275    id: String,
276    should_fail: bool,
277    fail_message: String,
278}
279
280#[cfg(test)]
281impl MockChannel {
282    /// 创建成功的模拟渠道
283    pub fn success(id: impl Into<String>) -> Self {
284        Self {
285            id: id.into(),
286            should_fail: false,
287            fail_message: String::new(),
288        }
289    }
290
291    /// 创建失败的模拟渠道
292    pub fn failure(id: impl Into<String>, message: impl Into<String>) -> Self {
293        Self {
294            id: id.into(),
295            should_fail: true,
296            fail_message: message.into(),
297        }
298    }
299}
300
301#[cfg(test)]
302#[async_trait]
303impl DeliveryChannel for MockChannel {
304    fn channel_id(&self) -> &str {
305        &self.id
306    }
307
308    async fn send(&self, _to: &str, _message: &str) -> Result<()> {
309        if self.should_fail {
310            Err(anyhow::anyhow!("{}", self.fail_message))
311        } else {
312            Ok(())
313        }
314    }
315}
316
317// ============================================================================
318// 单元测试 (Task 8.2)
319// ============================================================================
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::scheduler::types::JobStatus;
325
326    // 创建测试用的 ExecutionResult
327    fn create_test_result(output: Option<&str>) -> ExecutionResult {
328        ExecutionResult {
329            session_id: "test-session".to_string(),
330            output: output.map(|s| s.to_string()),
331            duration_ms: 100,
332            status: JobStatus::Ok,
333            error: None,
334        }
335    }
336
337    // ------------------------------------------------------------------------
338    // DeliveryResult 测试
339    // ------------------------------------------------------------------------
340
341    #[test]
342    fn test_delivery_result_success() {
343        let result = DeliveryResult::success("slack", "#general");
344
345        assert!(result.success);
346        assert_eq!(result.channel, "slack");
347        assert_eq!(result.to, "#general");
348        assert!(result.error.is_none());
349    }
350
351    #[test]
352    fn test_delivery_result_failure() {
353        let result = DeliveryResult::failure("email", "user@example.com", "SMTP error");
354
355        assert!(!result.success);
356        assert_eq!(result.channel, "email");
357        assert_eq!(result.to, "user@example.com");
358        assert_eq!(result.error, Some("SMTP error".to_string()));
359    }
360
361    // ------------------------------------------------------------------------
362    // DeliveryRouter 测试
363    // ------------------------------------------------------------------------
364
365    #[test]
366    fn test_router_new() {
367        let router = DeliveryRouter::new();
368
369        assert!(router.channels.is_empty());
370        assert!(router.default_channel.is_none());
371    }
372
373    #[test]
374    fn test_router_register() {
375        let mut router = DeliveryRouter::new();
376        let channel = Arc::new(MockChannel::success("test"));
377
378        router.register(channel);
379
380        assert_eq!(router.channels.len(), 1);
381        assert!(router.channels.contains_key("test"));
382    }
383
384    #[test]
385    fn test_router_set_default() {
386        let mut router = DeliveryRouter::new();
387        router.set_default("slack");
388
389        assert_eq!(router.default_channel, Some("slack".to_string()));
390    }
391
392    #[test]
393    fn test_router_get_channel() {
394        let mut router = DeliveryRouter::new();
395        let channel = Arc::new(MockChannel::success("test"));
396        router.register(channel);
397
398        assert!(router.get_channel(Some("test")).is_some());
399        assert!(router.get_channel(Some("nonexistent")).is_none());
400    }
401
402    #[test]
403    fn test_router_get_channel_default() {
404        let mut router = DeliveryRouter::new();
405        let channel = Arc::new(MockChannel::success("default"));
406        router.register(channel);
407        router.set_default("default");
408
409        // 不指定渠道时使用默认渠道
410        assert!(router.get_channel(None).is_some());
411    }
412
413    #[test]
414    fn test_router_list_channels() {
415        let mut router = DeliveryRouter::new();
416        router.register(Arc::new(MockChannel::success("slack")));
417        router.register(Arc::new(MockChannel::success("email")));
418
419        let channels = router.list_channels();
420        assert_eq!(channels.len(), 2);
421        assert!(channels.contains(&"slack"));
422        assert!(channels.contains(&"email"));
423    }
424
425    // ------------------------------------------------------------------------
426    // deliver 测试
427    // ------------------------------------------------------------------------
428
429    #[tokio::test]
430    async fn test_deliver_disabled() {
431        let router = DeliveryRouter::new();
432        let config = DeliveryConfig::default(); // enabled = false
433        let result = create_test_result(Some("output"));
434
435        let delivery_result = router.deliver(&config, &result).await.unwrap();
436
437        assert!(delivery_result.success);
438        assert_eq!(delivery_result.channel, "none");
439    }
440
441    #[tokio::test]
442    async fn test_deliver_success() {
443        let mut router = DeliveryRouter::new();
444        router.register(Arc::new(MockChannel::success("slack")));
445
446        let config = DeliveryConfig::enabled("slack", "#general");
447        let result = create_test_result(Some("Task completed"));
448
449        let delivery_result = router.deliver(&config, &result).await.unwrap();
450
451        assert!(delivery_result.success);
452        assert_eq!(delivery_result.channel, "slack");
453        assert_eq!(delivery_result.to, "#general");
454    }
455
456    #[tokio::test]
457    async fn test_deliver_channel_not_found_best_effort() {
458        let router = DeliveryRouter::new();
459        let config = DeliveryConfig {
460            enabled: true,
461            channel: Some("nonexistent".to_string()),
462            to: Some("target".to_string()),
463            best_effort: true,
464        };
465        let result = create_test_result(Some("output"));
466
467        let delivery_result = router.deliver(&config, &result).await.unwrap();
468
469        assert!(!delivery_result.success);
470        assert!(delivery_result.error.is_some());
471    }
472
473    #[tokio::test]
474    async fn test_deliver_channel_not_found_strict() {
475        let router = DeliveryRouter::new();
476        let config = DeliveryConfig {
477            enabled: true,
478            channel: Some("nonexistent".to_string()),
479            to: Some("target".to_string()),
480            best_effort: false,
481        };
482        let result = create_test_result(Some("output"));
483
484        let delivery_result = router.deliver(&config, &result).await;
485
486        assert!(delivery_result.is_err());
487    }
488
489    #[tokio::test]
490    async fn test_deliver_send_failure_best_effort() {
491        let mut router = DeliveryRouter::new();
492        router.register(Arc::new(MockChannel::failure("slack", "Network error")));
493
494        let config = DeliveryConfig {
495            enabled: true,
496            channel: Some("slack".to_string()),
497            to: Some("#general".to_string()),
498            best_effort: true,
499        };
500        let result = create_test_result(Some("output"));
501
502        let delivery_result = router.deliver(&config, &result).await.unwrap();
503
504        assert!(!delivery_result.success);
505        assert!(delivery_result.error.unwrap().contains("Network error"));
506    }
507
508    #[tokio::test]
509    async fn test_deliver_send_failure_strict() {
510        let mut router = DeliveryRouter::new();
511        router.register(Arc::new(MockChannel::failure("slack", "Network error")));
512
513        let config = DeliveryConfig {
514            enabled: true,
515            channel: Some("slack".to_string()),
516            to: Some("#general".to_string()),
517            best_effort: false,
518        };
519        let result = create_test_result(Some("output"));
520
521        let delivery_result = router.deliver(&config, &result).await;
522
523        assert!(delivery_result.is_err());
524    }
525
526    #[tokio::test]
527    async fn test_deliver_no_output() {
528        let mut router = DeliveryRouter::new();
529        router.register(Arc::new(MockChannel::success("slack")));
530
531        let config = DeliveryConfig::enabled("slack", "#general");
532        let result = create_test_result(None);
533
534        let delivery_result = router.deliver(&config, &result).await.unwrap();
535
536        // 应该使用默认消息
537        assert!(delivery_result.success);
538    }
539
540    // ------------------------------------------------------------------------
541    // LogChannel 测试
542    // ------------------------------------------------------------------------
543
544    #[test]
545    fn test_log_channel_id() {
546        let channel = LogChannel::new("test-log");
547        assert_eq!(channel.channel_id(), "test-log");
548    }
549
550    #[tokio::test]
551    async fn test_log_channel_send() {
552        let channel = LogChannel::new("test-log");
553        let result = channel.send("target", "Hello").await;
554        assert!(result.is_ok());
555    }
556}