Skip to main content

oxios_telegram/
lib.rs

1pub mod format;
2pub mod plugin;
3
4pub use format::TelegramFormatter;
5pub use plugin::TelegramPlugin;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use oxios_gateway::channel::Channel;
11use oxios_gateway::format::ChannelFormatter;
12use oxios_gateway::message::{IncomingMessage, OutgoingMessage};
13use oxios_gateway::GatewayInbox;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{mpsc, watch, RwLock};
17
18/// Per-chat session state for Telegram.
19#[derive(Debug, Clone)]
20struct ChatSession {
21    /// Current session ID used for multi-turn conversations.
22    session_id: String,
23    /// When the session was created.
24    created_at: DateTime<Utc>,
25    /// When the last message was sent/received in this session.
26    last_active_at: DateTime<Utc>,
27    /// Number of messages exchanged in this session.
28    message_count: usize,
29}
30
31impl ChatSession {
32    fn new() -> Self {
33        let now = Utc::now();
34        Self {
35            session_id: uuid::Uuid::new_v4().to_string(),
36            created_at: now,
37            last_active_at: now,
38            message_count: 0,
39        }
40    }
41
42    /// Check if this session should be rotated based on configuration.
43    fn should_rotate(&self, rotation_hours: u64, max_messages: usize) -> bool {
44        // Time-based rotation
45        if rotation_hours > 0 {
46            let elapsed = Utc::now() - self.last_active_at;
47            if elapsed > chrono::Duration::hours(rotation_hours as i64) {
48                return true;
49            }
50        }
51        // Message-count based rotation
52        if max_messages > 0 && self.message_count >= max_messages {
53            return true;
54        }
55        false
56    }
57
58    /// Touch the session (update last_active_at and increment message count).
59    fn touch(&mut self) {
60        self.last_active_at = Utc::now();
61        self.message_count += 1;
62    }
63
64    /// Rotate to a new session, returning the old session ID.
65    fn rotate(&mut self) -> String {
66        let old_id = self.session_id.clone();
67        let now = Utc::now();
68        self.session_id = uuid::Uuid::new_v4().to_string();
69        self.created_at = now;
70        self.last_active_at = now;
71        self.message_count = 0;
72        old_id
73    }
74}
75
76/// Telegram session configuration.
77#[derive(Debug, Clone)]
78pub struct TelegramSessionSettings {
79    /// Automatically rotate sessions after this many hours of inactivity.
80    pub rotation_hours: u64,
81    /// Rotate after this many messages (0 = unlimited).
82    pub max_messages_per_session: usize,
83}
84
85impl Default for TelegramSessionSettings {
86    fn default() -> Self {
87        Self {
88            rotation_hours: 2,
89            max_messages_per_session: 0,
90        }
91    }
92}
93
94/// Telegram channel adapter.
95///
96/// Uses long polling (getUpdates) to receive messages
97/// and the Bot API to send responses.
98///
99/// Session management:
100/// - Each `chat_id` gets its own session tracked in memory.
101/// - Sessions auto-rotate after a configurable period of inactivity.
102/// - Users can force a new session with the `/new` command.
103#[derive(Clone)]
104pub struct TelegramChannel {
105    bot_token: String,
106    api_base: String,
107    allowed_users: Vec<i64>,
108    client: reqwest::Client,
109    offset: Arc<RwLock<i64>>,
110    /// Maps chat_id β†’ per-chat session state
111    chat_sessions: Arc<RwLock<HashMap<i64, ChatSession>>>,
112    /// Session rotation settings
113    session_settings: TelegramSessionSettings,
114}
115
116impl TelegramChannel {
117    /// Create a new Telegram channel.
118    ///
119    /// # Arguments
120    /// * `bot_token` - Telegram Bot API token from @BotFather
121    /// * `allowed_users` - List of allowed Telegram user IDs (empty = allow all)
122    pub fn new(bot_token: String, allowed_users: Vec<i64>) -> Self {
123        Self {
124            bot_token,
125            api_base: "https://api.telegram.org".to_string(),
126            allowed_users,
127            client: reqwest::Client::builder()
128                .timeout(std::time::Duration::from_secs(60))
129                .build()
130                .unwrap_or_else(|_| reqwest::Client::new()),
131            offset: Arc::new(RwLock::new(0)),
132            chat_sessions: Arc::new(RwLock::new(HashMap::new())),
133            session_settings: TelegramSessionSettings::default(),
134        }
135    }
136
137    /// Override API base URL (for local Bot API servers).
138    pub fn with_api_base(mut self, base: String) -> Self {
139        self.api_base = base;
140        self
141    }
142
143    /// Set session management settings.
144    pub fn with_session_settings(mut self, settings: TelegramSessionSettings) -> Self {
145        self.session_settings = settings;
146        self
147    }
148
149    fn api_url(&self, method: &str) -> String {
150        format!("{}/bot{}/{}", self.api_base, self.bot_token, method)
151    }
152
153    /// Check if user is allowed.
154    fn is_user_allowed(&self, user_id: i64) -> bool {
155        self.allowed_users.is_empty() || self.allowed_users.contains(&user_id)
156    }
157
158    /// Get or create a session for a chat, auto-rotating if needed.
159    async fn get_or_create_session(&self, chat_id: i64) -> String {
160        let mut sessions = self.chat_sessions.write().await;
161        let session = sessions.entry(chat_id).or_insert_with(ChatSession::new);
162
163        // Check if rotation is needed
164        if session.should_rotate(
165            self.session_settings.rotation_hours,
166            self.session_settings.max_messages_per_session,
167        ) {
168            session.rotate();
169            tracing::info!(
170                chat_id = chat_id,
171                new_session = %session.session_id,
172                "Telegram session auto-rotated"
173            );
174        }
175
176        session.touch();
177        session.session_id.clone()
178    }
179
180    /// Force-rotate a chat's session (used for /new command).
181    async fn force_new_session(&self, chat_id: i64) -> String {
182        let mut sessions = self.chat_sessions.write().await;
183        let session = sessions.entry(chat_id).or_insert_with(ChatSession::new);
184        let old_id = session.rotate();
185        tracing::info!(
186            chat_id = chat_id,
187            old_session = %old_id,
188            new_session = %session.session_id,
189            "Telegram session force-rotated via /new command"
190        );
191        session.session_id.clone()
192    }
193
194    /// Poll for updates using getUpdates (long polling).
195    async fn poll_updates(&self) -> Result<Vec<serde_json::Value>> {
196        let offset = *self.offset.read().await;
197        let mut body = serde_json::json!({
198            "timeout": 30,
199            "limit": 100,
200        });
201        if offset > 0 {
202            body["offset"] = serde_json::Value::Number(offset.into());
203        }
204
205        let resp = self
206            .client
207            .post(self.api_url("getUpdates"))
208            .json(&body)
209            .send()
210            .await?;
211
212        if !resp.status().is_success() {
213            let err = resp.text().await.unwrap_or_default();
214            anyhow::bail!("Telegram getUpdates failed: {err}");
215        }
216
217        let json: serde_json::Value = resp.json().await?;
218        let updates = json
219            .get("result")
220            .and_then(|r| r.as_array())
221            .cloned()
222            .unwrap_or_default();
223
224        // Update offset
225        if let Some(last) = updates.last() {
226            if let Some(id) = last.get("update_id").and_then(|id| id.as_i64()) {
227                *self.offset.write().await = id + 1;
228            }
229        }
230
231        Ok(updates)
232    }
233
234    /// Send a chat action indicator (e.g. "typing").
235    async fn send_chat_action(&self, chat_id: i64, action: &str) -> Result<()> {
236        self.client
237            .post(self.api_url("sendChatAction"))
238            .json(&serde_json::json!({ "chat_id": chat_id, "action": action }))
239            .send()
240            .await?;
241        Ok(())
242    }
243
244    /// Send a text message to a chat.
245    async fn send_text(&self, chat_id: i64, text: &str, reply_to: Option<i64>) -> Result<()> {
246        for chunk in split_message(text, 4000) {
247            let mut body = serde_json::json!({
248                "chat_id": chat_id,
249                "text": &chunk,
250                "parse_mode": "Markdown",
251            });
252            if let Some(msg_id) = reply_to {
253                body["reply_to_message_id"] = serde_json::Value::Number(msg_id.into());
254            }
255            let resp = self
256                .client
257                .post(self.api_url("sendMessage"))
258                .json(&body)
259                .send()
260                .await?;
261            if !resp.status().is_success() {
262                // Fallback: send without parse_mode
263                body["parse_mode"] = serde_json::Value::Null;
264                self.client
265                    .post(self.api_url("sendMessage"))
266                    .json(&body)
267                    .send()
268                    .await?;
269            }
270        }
271        Ok(())
272    }
273}
274
275#[async_trait]
276impl Channel for TelegramChannel {
277    fn name(&self) -> &str {
278        "telegram"
279    }
280
281    async fn start(
282        &self,
283        tx: mpsc::Sender<GatewayInbox>,
284        mut shutdown: watch::Receiver<bool>,
285    ) -> Result<tokio::task::JoinHandle<()>> {
286        let this = Arc::new(self.clone());
287        let channel_name = this.name().to_owned();
288
289        let handle = tokio::spawn(async move {
290            let mut retry_count: u32 = 0;
291            loop {
292                tokio::select! {
293                    updates_result = this.poll_updates() => {
294                        match updates_result {
295                            Ok(updates) => {
296                                retry_count = 0;
297                                for update in updates {
298                                    let message = update
299                                        .get("message")
300                                        .or_else(|| update.get("channel_post"))
301                                        .or_else(|| update.get("edited_message"));
302                                    let Some(msg) = message else { continue };
303
304                                    let chat_id = msg
305                                        .get("chat")
306                                        .and_then(|c| c.get("id"))
307                                        .and_then(|id| id.as_i64());
308                                    let user_id = msg
309                                        .get("from")
310                                        .and_then(|f| f.get("id"))
311                                        .and_then(|id| id.as_i64());
312                                    let text = msg
313                                        .get("text")
314                                        .and_then(|t| t.as_str())
315                                        .unwrap_or("");
316                                    let message_id = msg
317                                        .get("message_id")
318                                        .and_then(|id| id.as_i64())
319                                        .unwrap_or(0);
320
321                                    if text.is_empty() {
322                                        continue;
323                                    }
324
325                                    // Permission check
326                                    if let Some(uid) = user_id {
327                                        if !this.is_user_allowed(uid) {
328                                            tracing::warn!(user_id = uid, "Unauthorized Telegram user");
329                                            if let Some(cid) = chat_id {
330                                                let _ = this
331                                                    .send_text(
332                                                        cid,
333                                                        "Unauthorized. Your user ID is not in the allowed list.",
334                                                        None,
335                                                    )
336                                                    .await;
337                                            }
338                                            continue;
339                                        }
340                                    }
341
342                                    let Some(cid) = chat_id else { continue };
343                                    let user_id_str = user_id
344                                        .map(|id| id.to_string())
345                                        .unwrap_or_else(|| "unknown".to_string());
346
347                                    // /new command β€” start a new session
348                                    let trimmed = text.trim();
349                                    if trimmed == "/new" || trimmed == "/new@me" {
350                                        let new_session_id = this.force_new_session(cid).await;
351                                        let _ = this
352                                            .send_text(
353                                                cid,
354                                                &format!("πŸ”„ μƒˆ μ„Έμ…˜μ„ μ‹œμž‘ν•©λ‹ˆλ‹€.\\n`{}`", &new_session_id[..8]),
355                                                Some(message_id),
356                                            )
357                                            .await;
358                                        continue;
359                                    }
360
361                                    // /session command β€” show current session info
362                                    if trimmed == "/session" || trimmed == "/session@me" {
363                                        let sessions = this.chat_sessions.read().await;
364                                        if let Some(session) = sessions.get(&cid) {
365                                            let info = format!(
366                                                "πŸ“‹ ν˜„μž¬ μ„Έμ…˜\\nβ€’ ID: `{}`\\nβ€’ λ©”μ‹œμ§€: {}개\\nβ€’ μ‹œμž‘: {}\\nβ€’ λ§ˆμ§€λ§‰ ν™œλ™: {}",
367                                                &session.session_id[..8],
368                                                session.message_count,
369                                                session.created_at.format("%m/%d %H:%M"),
370                                                session.last_active_at.format("%m/%d %H:%M"),
371                                            );
372                                            drop(sessions);
373                                            let _ = this.send_text(cid, &info, Some(message_id)).await;
374                                        } else {
375                                            drop(sessions);
376                                            let _ = this
377                                                .send_text(cid, "πŸ“‹ ν™œμ„± μ„Έμ…˜μ΄ μ—†μŠ΅λ‹ˆλ‹€.", Some(message_id))
378                                                .await;
379                                        }
380                                        continue;
381                                    }
382
383                                    // /spaces command β€” channels don't have kernel access
384                                    if trimmed == "/spaces" || trimmed.starts_with("/spaces@") {
385                                        let _ = this.send_text(cid, "Space κ΄€λ¦¬λŠ” Web λŒ€μ‹œλ³΄λ“œμ—μ„œ μ‚¬μš© κ°€λŠ₯ν•©λ‹ˆλ‹€.", Some(message_id)).await;
386                                        continue;
387                                    }
388
389                                    // /space command β€” channels don't have kernel access
390                                    if trimmed.starts_with("/space") && !trimmed.starts_with("/spaces") {
391                                        let _ = this.send_text(cid, "Space κ΄€λ¦¬λŠ” Web λŒ€μ‹œλ³΄λ“œμ—μ„œ μ‚¬μš© κ°€λŠ₯ν•©λ‹ˆλ‹€.", Some(message_id)).await;
392                                        continue;
393                                    }
394
395                                    // Skip other /command messages
396                                    if text.starts_with('/') {
397                                        continue;
398                                    }
399
400                                    // Get or auto-rotate session
401                                    let session_id = this.get_or_create_session(cid).await;
402
403                                    let mut metadata = HashMap::new();
404                                    metadata.insert("chat_id".to_string(), cid.to_string());
405                                    metadata.insert("message_id".to_string(), message_id.to_string());
406                                    metadata.insert("session_id".to_string(), session_id);
407
408                                    let incoming = IncomingMessage {
409                                        channel: "telegram".to_string(),
410                                        user_id: user_id_str,
411                                        content: text.to_string(),
412                                        metadata,
413                                        ..Default::default()
414                                    };
415
416                                    tracing::info!(
417                                        chat_id = cid,
418                                        text = %text.chars().take(50).collect::<String>(),
419                                        "Telegram message received"
420                                    );
421
422                                    if tx.send((channel_name.clone(), incoming)).await.is_err() {
423                                        break; // Gateway receiver closed
424                                    }
425                                    // Send typing indicator
426                                    let _ = this.send_chat_action(cid, "typing").await;
427                                }
428                            }
429                            Err(e) => {
430                                tracing::warn!(error = %e, "Telegram poll error");
431                                let delay = std::time::Duration::from_secs(5 * 2u64.pow(retry_count.min(4)));
432                                tokio::time::sleep(delay).await;
433                                retry_count += 1;
434                            }
435                        }
436                    }
437
438                    _ = shutdown.changed() => {
439                        tracing::info!(channel = %channel_name, "Telegram channel stopped");
440                        break;
441                    }
442                }
443            }
444        });
445
446        Ok(handle)
447    }
448
449    async fn send(&self, msg: OutgoingMessage) -> Result<()> {
450        let chat_id: i64 = msg
451            .metadata
452            .get("chat_id")
453            .and_then(|id| id.parse().ok())
454            .or_else(|| msg.user_id.parse().ok())
455            .ok_or_else(|| anyhow::anyhow!("No chat_id for Telegram message"))?;
456
457        let reply_to = msg
458            .metadata
459            .get("message_id")
460            .and_then(|id| id.parse().ok());
461
462        let formatter = crate::TelegramFormatter;
463        let raw = match &msg.meta {
464            Some(meta) if meta.error.is_some() => formatter.format_error(&msg),
465            Some(_) => formatter.format_success(&msg),
466            None => msg.content.clone(),
467        };
468
469        for chunk in split_message(&raw, 4000) {
470            self.send_text(chat_id, &chunk, reply_to).await?;
471        }
472
473        tracing::debug!(chat_id = chat_id, "Telegram response sent");
474        Ok(())
475    }
476}
477
478/// Split a message into chunks of at most `max_chars` Unicode characters.
479///
480/// Unlike byte-based splitting, this is safe for multi-byte UTF-8
481/// (Korean, Chinese, emoji, etc.).
482fn split_message(text: &str, max_chars: usize) -> Vec<String> {
483    if text.chars().count() <= max_chars {
484        return vec![text.to_string()];
485    }
486    let mut chunks = Vec::new();
487    let mut current = String::new();
488    for ch in text.chars() {
489        if current.chars().count() >= max_chars {
490            chunks.push(std::mem::take(&mut current));
491        }
492        current.push(ch);
493    }
494    if !current.is_empty() {
495        chunks.push(current);
496    }
497    chunks
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_telegram_channel_new() {
506        let channel = TelegramChannel::new("test-token".to_string(), vec![12345]);
507        assert_eq!(channel.name(), "telegram");
508        assert!(channel.is_user_allowed(12345));
509        assert!(!channel.is_user_allowed(99999));
510    }
511
512    #[test]
513    fn test_telegram_channel_allow_all() {
514        let channel = TelegramChannel::new("test-token".to_string(), vec![]);
515        assert!(channel.is_user_allowed(12345));
516        assert!(channel.is_user_allowed(99999));
517    }
518
519    #[test]
520    fn test_api_url() {
521        let channel = TelegramChannel::new("123:ABC".to_string(), vec![]);
522        assert_eq!(
523            channel.api_url("getMe"),
524            "https://api.telegram.org/bot123:ABC/getMe"
525        );
526    }
527
528    #[test]
529    fn test_chat_session_rotation_by_time() {
530        let mut session = ChatSession::new();
531        assert!(!session.should_rotate(2, 0)); // Just created, should not rotate
532
533        // Simulate 3 hours of inactivity
534        session.last_active_at = Utc::now() - chrono::Duration::hours(3);
535        assert!(session.should_rotate(2, 0)); // Should rotate
536        assert!(!session.should_rotate(0, 0)); // Disabled, should not rotate
537    }
538
539    #[test]
540    fn test_chat_session_rotation_by_message_count() {
541        let mut session = ChatSession::new();
542        session.message_count = 50;
543        assert!(session.should_rotate(0, 50)); // At limit
544        assert!(session.should_rotate(0, 49)); // Over limit
545        assert!(!session.should_rotate(0, 51)); // Under limit
546        assert!(!session.should_rotate(0, 0)); // Disabled
547    }
548
549    #[test]
550    fn test_chat_session_rotate_resets_state() {
551        let mut session = ChatSession::new();
552        let original_id = session.session_id.clone();
553        session.message_count = 100;
554
555        let old_id = session.rotate();
556        assert_eq!(old_id, original_id);
557        assert_ne!(session.session_id, original_id);
558        assert_eq!(session.message_count, 0);
559    }
560
561    #[test]
562    fn test_chat_session_touch() {
563        let mut session = ChatSession::new();
564        assert_eq!(session.message_count, 0);
565        session.touch();
566        assert_eq!(session.message_count, 1);
567        session.touch();
568        assert_eq!(session.message_count, 2);
569    }
570
571    #[test]
572    fn test_split_message_ascii() {
573        let text = "hello world";
574        let chunks = split_message(text, 5);
575        assert_eq!(chunks, vec!["hello", " worl", "d"]);
576    }
577
578    #[test]
579    fn test_split_message_utf8() {
580        let text = "μ•ˆλ…•ν•˜μ„Έμš”μ„Έκ³„"; // 7 Korean chars
581        let chunks = split_message(text, 3);
582        assert_eq!(chunks, vec!["μ•ˆλ…•ν•˜", "μ„Έμš”μ„Έ", "계"]);
583    }
584
585    #[test]
586    fn test_split_message_short() {
587        let text = "hello";
588        let chunks = split_message(text, 10);
589        assert_eq!(chunks, vec!["hello"]);
590    }
591}