Skip to main content

punch_channels/
lib.rs

1//! # punch-channels
2//!
3//! Channel adapters for messaging platforms in the Punch Agent Combat System.
4//!
5//! Provides a unified [`ChannelAdapter`] trait that abstracts over different
6//! messaging platforms (Telegram, Discord, Slack, etc.), a [`ChannelRouter`]
7//! that maps platform users to fighters, and a [`ChannelBridge`] that manages
8//! adapters and dispatches messages through the Ring.
9
10pub mod adapters;
11pub mod bridge;
12pub mod onboarding;
13pub mod router;
14pub mod security;
15
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use chrono::{DateTime, Utc};
21use serde::{Deserialize, Serialize};
22use tokio::sync::RwLock;
23use tracing::{info, warn};
24
25use punch_types::{PunchError, PunchResult};
26
27// ---------------------------------------------------------------------------
28// Core types
29// ---------------------------------------------------------------------------
30
31/// Identifies the messaging platform an adapter connects to.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum ChannelPlatform {
35    Telegram,
36    Discord,
37    Slack,
38    WhatsApp,
39    Signal,
40    Matrix,
41    Email,
42    Teams,
43    Irc,
44    Mastodon,
45    Reddit,
46    Twitch,
47    GitHub,
48    Line,
49    WebChat,
50    GoogleChat,
51    Bluesky,
52    LinkedIn,
53    Sms,
54    DingTalk,
55    Feishu,
56    Nostr,
57    Mattermost,
58    Zulip,
59    RocketChat,
60    Custom(String),
61}
62
63impl std::fmt::Display for ChannelPlatform {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Telegram => write!(f, "telegram"),
67            Self::Discord => write!(f, "discord"),
68            Self::Slack => write!(f, "slack"),
69            Self::WhatsApp => write!(f, "whatsapp"),
70            Self::Signal => write!(f, "signal"),
71            Self::Matrix => write!(f, "matrix"),
72            Self::Email => write!(f, "email"),
73            Self::Teams => write!(f, "teams"),
74            Self::Irc => write!(f, "irc"),
75            Self::Mastodon => write!(f, "mastodon"),
76            Self::Reddit => write!(f, "reddit"),
77            Self::Twitch => write!(f, "twitch"),
78            Self::GitHub => write!(f, "github"),
79            Self::Line => write!(f, "line"),
80            Self::WebChat => write!(f, "webchat"),
81            Self::GoogleChat => write!(f, "google_chat"),
82            Self::Bluesky => write!(f, "bluesky"),
83            Self::LinkedIn => write!(f, "linkedin"),
84            Self::Sms => write!(f, "sms"),
85            Self::DingTalk => write!(f, "dingtalk"),
86            Self::Feishu => write!(f, "feishu"),
87            Self::Nostr => write!(f, "nostr"),
88            Self::Mattermost => write!(f, "mattermost"),
89            Self::Zulip => write!(f, "zulip"),
90            Self::RocketChat => write!(f, "rocketchat"),
91            Self::Custom(name) => write!(f, "custom({})", name),
92        }
93    }
94}
95
96/// A message received from an external messaging platform.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct IncomingMessage {
99    /// The channel or conversation identifier on the platform.
100    pub channel_id: String,
101    /// The user identifier on the platform.
102    pub user_id: String,
103    /// The display name of the user.
104    pub display_name: String,
105    /// The text content of the message.
106    pub text: String,
107    /// When the message was sent.
108    pub timestamp: DateTime<Utc>,
109    /// Which platform the message originated from.
110    pub platform: ChannelPlatform,
111    /// Platform-specific message ID.
112    pub platform_message_id: String,
113    /// Whether this is from a group chat.
114    #[serde(default)]
115    pub is_group: bool,
116    /// Arbitrary platform metadata.
117    #[serde(default)]
118    pub metadata: HashMap<String, serde_json::Value>,
119}
120
121/// Status of a channel adapter.
122#[derive(Debug, Clone, Default, Serialize, Deserialize)]
123pub struct ChannelStatus {
124    /// Whether the adapter is currently running.
125    pub connected: bool,
126    /// When the adapter was started.
127    pub started_at: Option<DateTime<Utc>>,
128    /// Total messages received since start.
129    pub messages_received: u64,
130    /// Total messages sent since start.
131    pub messages_sent: u64,
132    /// Last error message (if any).
133    pub last_error: Option<String>,
134}
135
136// ---------------------------------------------------------------------------
137// Trait
138// ---------------------------------------------------------------------------
139
140/// Abstraction over a messaging platform connection.
141///
142/// Each adapter receives incoming messages and can send responses back.
143/// The lifecycle is: start() -> process messages -> stop().
144#[async_trait]
145pub trait ChannelAdapter: Send + Sync + 'static {
146    /// Human-readable name for this adapter (e.g. "telegram", "discord").
147    fn name(&self) -> &str;
148
149    /// The platform this adapter connects to.
150    fn platform(&self) -> ChannelPlatform;
151
152    /// Start the adapter and begin listening for messages.
153    async fn start(&self) -> PunchResult<()>;
154
155    /// Stop the adapter and clean up resources.
156    async fn stop(&self) -> PunchResult<()>;
157
158    /// Send a text response to a specific channel/conversation.
159    async fn send_response(&self, channel_id: &str, message: &str) -> PunchResult<()>;
160
161    /// Get the current status of this adapter.
162    fn status(&self) -> ChannelStatus {
163        ChannelStatus::default()
164    }
165
166    /// Validate that configured credentials are valid by calling the platform API.
167    /// Returns Ok(()) if valid. Default implementation assumes valid.
168    async fn validate_credentials(&self) -> PunchResult<()> {
169        Ok(())
170    }
171}
172
173// ---------------------------------------------------------------------------
174// ChannelBridge
175// ---------------------------------------------------------------------------
176
177/// Manages multiple [`ChannelAdapter`]s and routes messages between them.
178pub struct ChannelBridge {
179    adapters: RwLock<HashMap<String, Arc<dyn ChannelAdapter>>>,
180}
181
182impl ChannelBridge {
183    /// Create a new, empty bridge.
184    pub fn new() -> Self {
185        Self {
186            adapters: RwLock::new(HashMap::new()),
187        }
188    }
189
190    /// Register an adapter with the bridge.
191    pub async fn register(&self, adapter: Arc<dyn ChannelAdapter>) {
192        let name = adapter.name().to_string();
193        info!(adapter = %name, "registering channel adapter");
194        self.adapters.write().await.insert(name, adapter);
195    }
196
197    /// Start all registered adapters.
198    pub async fn start_all(&self) -> PunchResult<()> {
199        let adapters = self.adapters.read().await;
200        for (name, adapter) in adapters.iter() {
201            info!(adapter = %name, "starting channel adapter");
202            adapter.start().await.map_err(|e| PunchError::Channel {
203                channel: name.clone(),
204                message: format!("failed to start: {e}"),
205            })?;
206        }
207        Ok(())
208    }
209
210    /// Stop all registered adapters.
211    pub async fn stop_all(&self) -> PunchResult<()> {
212        let adapters = self.adapters.read().await;
213        for (name, adapter) in adapters.iter() {
214            info!(adapter = %name, "stopping channel adapter");
215            if let Err(e) = adapter.stop().await {
216                warn!(adapter = %name, error = %e, "failed to stop adapter");
217            }
218        }
219        Ok(())
220    }
221
222    /// Send a message through a specific adapter by name.
223    pub async fn send_message(
224        &self,
225        adapter_name: &str,
226        channel_id: &str,
227        text: &str,
228    ) -> PunchResult<()> {
229        let adapters = self.adapters.read().await;
230        let adapter = adapters
231            .get(adapter_name)
232            .ok_or_else(|| PunchError::Channel {
233                channel: adapter_name.to_string(),
234                message: "adapter not found".to_string(),
235            })?;
236        adapter.send_response(channel_id, text).await
237    }
238
239    /// List the names of all registered adapters.
240    pub async fn list_adapters(&self) -> Vec<String> {
241        self.adapters.read().await.keys().cloned().collect()
242    }
243
244    /// Get the status of all adapters.
245    pub async fn adapter_statuses(&self) -> Vec<(String, ChannelPlatform, ChannelStatus)> {
246        let adapters = self.adapters.read().await;
247        adapters
248            .iter()
249            .map(|(name, adapter)| (name.clone(), adapter.platform(), adapter.status()))
250            .collect()
251    }
252}
253
254impl Default for ChannelBridge {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260/// Split a message into chunks of at most `max_len` characters,
261/// preferring to split at newline boundaries.
262pub fn split_message(text: &str, max_len: usize) -> Vec<&str> {
263    if text.len() <= max_len {
264        return vec![text];
265    }
266    let mut chunks = Vec::new();
267    let mut remaining = text;
268    while !remaining.is_empty() {
269        if remaining.len() <= max_len {
270            chunks.push(remaining);
271            break;
272        }
273        let split_at = remaining[..max_len].rfind('\n').unwrap_or(max_len);
274        let (chunk, rest) = remaining.split_at(split_at);
275        chunks.push(chunk);
276        remaining = rest
277            .strip_prefix("\r\n")
278            .or_else(|| rest.strip_prefix('\n'))
279            .unwrap_or(rest);
280    }
281    chunks
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_channel_platform_display() {
290        assert_eq!(ChannelPlatform::Telegram.to_string(), "telegram");
291        assert_eq!(ChannelPlatform::Discord.to_string(), "discord");
292        assert_eq!(ChannelPlatform::Slack.to_string(), "slack");
293        assert_eq!(
294            ChannelPlatform::Custom("irc".to_string()).to_string(),
295            "custom(irc)"
296        );
297    }
298
299    #[test]
300    fn test_split_message_short() {
301        assert_eq!(split_message("hello", 100), vec!["hello"]);
302    }
303
304    #[test]
305    fn test_split_message_at_newlines() {
306        let text = "line1\nline2\nline3";
307        let chunks = split_message(text, 10);
308        assert_eq!(chunks, vec!["line1", "line2", "line3"]);
309    }
310
311    #[test]
312    fn test_incoming_message_serde() {
313        let msg = IncomingMessage {
314            channel_id: "ch1".to_string(),
315            user_id: "user1".to_string(),
316            display_name: "Alice".to_string(),
317            text: "Hello!".to_string(),
318            timestamp: Utc::now(),
319            platform: ChannelPlatform::Telegram,
320            platform_message_id: "123".to_string(),
321            is_group: false,
322            metadata: HashMap::new(),
323        };
324
325        let json = serde_json::to_string(&msg).unwrap();
326        let deserialized: IncomingMessage = serde_json::from_str(&json).unwrap();
327        assert_eq!(deserialized.platform, ChannelPlatform::Telegram);
328        assert_eq!(deserialized.user_id, "user1");
329    }
330
331    // --- NEW: split_message edge cases ---
332
333    #[test]
334    fn test_split_message_empty_string() {
335        let chunks = split_message("", 100);
336        assert_eq!(chunks, vec![""]);
337    }
338
339    #[test]
340    fn test_split_message_exact_boundary() {
341        let text = "12345";
342        let chunks = split_message(text, 5);
343        assert_eq!(chunks, vec!["12345"]);
344    }
345
346    #[test]
347    fn test_split_message_one_over_boundary() {
348        let text = "123456";
349        let chunks = split_message(text, 5);
350        assert_eq!(chunks.len(), 2);
351        assert_eq!(chunks[0].len() + chunks[1].len(), 6);
352    }
353
354    #[test]
355    fn test_split_message_no_newlines() {
356        let text = "abcdefghijklmnopqrstuvwxyz";
357        let chunks = split_message(text, 10);
358        // Should split at max_len boundaries since no newlines
359        assert!(chunks.len() > 1);
360        for chunk in &chunks {
361            assert!(chunk.len() <= 10);
362        }
363    }
364
365    #[test]
366    fn test_split_message_unicode() {
367        let text = "Hello \u{1F600} World \u{1F600} Test";
368        let chunks = split_message(text, 100);
369        assert_eq!(chunks, vec![text]);
370    }
371
372    #[test]
373    fn test_split_message_crlf_newlines() {
374        // split_message splits on \n, so \r remains attached to each line
375        let text = "line1\r\nline2\r\nline3";
376        let chunks = split_message(text, 10);
377        assert_eq!(chunks, vec!["line1\r", "line2\r", "line3"]);
378    }
379
380    #[test]
381    fn test_split_message_consecutive_newlines() {
382        let text = "line1\n\nline3";
383        let chunks = split_message(text, 8);
384        // Should handle the empty line between
385        assert!(chunks.len() >= 2);
386    }
387
388    // --- NEW: IncomingMessage field access ---
389
390    #[test]
391    fn test_incoming_message_field_access() {
392        let ts = Utc::now();
393        let mut meta = HashMap::new();
394        meta.insert("key".to_string(), serde_json::json!("value"));
395
396        let msg = IncomingMessage {
397            channel_id: "ch42".to_string(),
398            user_id: "u99".to_string(),
399            display_name: "Bob".to_string(),
400            text: "Test message".to_string(),
401            timestamp: ts,
402            platform: ChannelPlatform::Discord,
403            platform_message_id: "msg-555".to_string(),
404            is_group: true,
405            metadata: meta,
406        };
407
408        assert_eq!(msg.channel_id, "ch42");
409        assert_eq!(msg.user_id, "u99");
410        assert_eq!(msg.display_name, "Bob");
411        assert_eq!(msg.text, "Test message");
412        assert_eq!(msg.platform, ChannelPlatform::Discord);
413        assert_eq!(msg.platform_message_id, "msg-555");
414        assert!(msg.is_group);
415        assert_eq!(
416            msg.metadata.get("key").unwrap(),
417            &serde_json::json!("value")
418        );
419    }
420
421    #[test]
422    fn test_incoming_message_default_is_group() {
423        // is_group defaults to false with serde
424        let json = r#"{
425            "channel_id":"c","user_id":"u","display_name":"n",
426            "text":"t","timestamp":"2024-01-01T00:00:00Z",
427            "platform":"telegram","platform_message_id":"1"
428        }"#;
429        let msg: IncomingMessage = serde_json::from_str(json).unwrap();
430        assert!(!msg.is_group);
431    }
432
433    #[test]
434    fn test_incoming_message_default_metadata() {
435        let json = r#"{
436            "channel_id":"c","user_id":"u","display_name":"n",
437            "text":"t","timestamp":"2024-01-01T00:00:00Z",
438            "platform":"discord","platform_message_id":"1"
439        }"#;
440        let msg: IncomingMessage = serde_json::from_str(json).unwrap();
441        assert!(msg.metadata.is_empty());
442    }
443
444    // --- NEW: ChannelStatus defaults ---
445
446    #[test]
447    fn test_channel_status_defaults() {
448        let status = ChannelStatus::default();
449        assert!(!status.connected);
450        assert!(status.started_at.is_none());
451        assert_eq!(status.messages_received, 0);
452        assert_eq!(status.messages_sent, 0);
453        assert!(status.last_error.is_none());
454    }
455
456    // --- NEW: ChannelPlatform display for all variants ---
457
458    #[test]
459    fn test_channel_platform_display_all() {
460        assert_eq!(ChannelPlatform::WhatsApp.to_string(), "whatsapp");
461        assert_eq!(ChannelPlatform::Signal.to_string(), "signal");
462        assert_eq!(ChannelPlatform::Matrix.to_string(), "matrix");
463        assert_eq!(ChannelPlatform::Email.to_string(), "email");
464        assert_eq!(ChannelPlatform::Teams.to_string(), "teams");
465        assert_eq!(ChannelPlatform::Irc.to_string(), "irc");
466        assert_eq!(ChannelPlatform::Mastodon.to_string(), "mastodon");
467        assert_eq!(ChannelPlatform::Reddit.to_string(), "reddit");
468        assert_eq!(ChannelPlatform::Twitch.to_string(), "twitch");
469        assert_eq!(ChannelPlatform::GitHub.to_string(), "github");
470        assert_eq!(ChannelPlatform::Line.to_string(), "line");
471        assert_eq!(ChannelPlatform::WebChat.to_string(), "webchat");
472        assert_eq!(ChannelPlatform::GoogleChat.to_string(), "google_chat");
473        assert_eq!(ChannelPlatform::Bluesky.to_string(), "bluesky");
474        assert_eq!(ChannelPlatform::LinkedIn.to_string(), "linkedin");
475        assert_eq!(ChannelPlatform::Sms.to_string(), "sms");
476        assert_eq!(ChannelPlatform::DingTalk.to_string(), "dingtalk");
477        assert_eq!(ChannelPlatform::Feishu.to_string(), "feishu");
478        assert_eq!(ChannelPlatform::Nostr.to_string(), "nostr");
479        assert_eq!(ChannelPlatform::Mattermost.to_string(), "mattermost");
480        assert_eq!(ChannelPlatform::Zulip.to_string(), "zulip");
481        assert_eq!(ChannelPlatform::RocketChat.to_string(), "rocketchat");
482    }
483
484    // --- NEW: ChannelPlatform serde ---
485
486    #[test]
487    fn test_channel_platform_serde_roundtrip() {
488        let platforms = vec![
489            ChannelPlatform::Telegram,
490            ChannelPlatform::Discord,
491            ChannelPlatform::Custom("test".to_string()),
492        ];
493        for p in platforms {
494            let json = serde_json::to_string(&p).unwrap();
495            let deserialized: ChannelPlatform = serde_json::from_str(&json).unwrap();
496            assert_eq!(p, deserialized);
497        }
498    }
499
500    // --- NEW: ChannelBridge tests ---
501
502    #[tokio::test]
503    async fn test_channel_bridge_new_has_no_adapters() {
504        let bridge = ChannelBridge::new();
505        let adapters = bridge.list_adapters().await;
506        assert!(adapters.is_empty());
507    }
508
509    #[tokio::test]
510    async fn test_channel_bridge_default() {
511        let bridge = ChannelBridge::default();
512        let adapters = bridge.list_adapters().await;
513        assert!(adapters.is_empty());
514    }
515
516    #[tokio::test]
517    async fn test_channel_bridge_send_message_unknown_adapter() {
518        let bridge = ChannelBridge::new();
519        let result = bridge.send_message("nonexistent", "ch1", "hello").await;
520        assert!(result.is_err());
521    }
522
523    #[test]
524    fn test_channel_status_serde() {
525        let status = ChannelStatus {
526            connected: true,
527            started_at: Some(Utc::now()),
528            messages_received: 42,
529            messages_sent: 10,
530            last_error: Some("test error".to_string()),
531        };
532        let json = serde_json::to_string(&status).unwrap();
533        let restored: ChannelStatus = serde_json::from_str(&json).unwrap();
534        assert!(restored.connected);
535        assert_eq!(restored.messages_received, 42);
536        assert_eq!(restored.messages_sent, 10);
537        assert_eq!(restored.last_error, Some("test error".to_string()));
538    }
539}