use std::time::{Duration, Instant};
use crate::markdown::markdown_to_telegram;
use teloxide::prelude::*;
use teloxide::types::{BotCommand, ChatAction, MessageId, ParseMode};
use tokio::sync::mpsc;
use zeph_common::TaskSupervisor;
use zeph_core::channel::{
Attachment, AttachmentKind, Channel, ChannelError, ChannelMessage, ElicitationField,
ElicitationFieldType, ElicitationRequest, ElicitationResponse,
};
const MAX_MESSAGE_LEN: usize = 4096;
const MAX_IMAGE_BYTES: u32 = 20 * 1024 * 1024;
pub struct TelegramChannel {
bot: Bot,
chat_id: Option<ChatId>,
rx: mpsc::Receiver<IncomingMessage>,
allowed_users: Vec<String>,
accumulated: String,
last_edit: Option<Instant>,
message_id: Option<MessageId>,
supervisor: Option<TaskSupervisor>,
}
impl std::fmt::Debug for TelegramChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TelegramChannel")
.field("chat_id", &self.chat_id)
.field("allowed_users", &self.allowed_users)
.field("accumulated_len", &self.accumulated.len())
.field("supervisor", &self.supervisor.is_some())
.finish_non_exhaustive()
}
}
#[derive(Debug)]
struct IncomingMessage {
chat_id: ChatId,
text: String,
attachments: Vec<Attachment>,
}
impl TelegramChannel {
#[must_use]
pub fn new(token: String, allowed_users: Vec<String>) -> Self {
let bot = Bot::new(token);
let (_, rx) = mpsc::channel(64);
Self {
bot,
chat_id: None,
rx,
allowed_users,
accumulated: String::new(),
last_edit: None,
message_id: None,
supervisor: None,
}
}
#[must_use]
pub fn with_supervisor(mut self, supervisor: TaskSupervisor) -> Self {
self.supervisor = Some(supervisor);
self
}
fn register_commands(bot: Bot) {
tokio::spawn(async move {
let commands = vec![
BotCommand::new("start", "Start a new conversation"),
BotCommand::new("reset", "Reset conversation history"),
BotCommand::new("skills", "List loaded skills"),
BotCommand::new("agent", "Manage sub-agents (list/spawn/status/cancel)"),
];
if let Err(e) = bot.set_my_commands(commands).await {
tracing::warn!("failed to register bot commands: {e}");
}
});
}
pub fn start(mut self) -> Result<Self, ChannelError> {
if self.allowed_users.is_empty() {
tracing::error!("telegram.allowed_users is empty; refusing to start an open bot");
return Err(ChannelError::Other(
"telegram.allowed_users must not be empty".into(),
));
}
let (tx, rx) = mpsc::channel::<IncomingMessage>(64);
self.rx = rx;
let bot = self.bot.clone();
let allowed = self.allowed_users.clone();
Self::register_commands(bot.clone());
let bot_for_factory = bot.clone();
let allowed_for_factory = allowed.clone();
let tx_for_factory = tx.clone();
let listener_factory = move || {
let bot = bot_for_factory.clone();
let allowed = allowed_for_factory.clone();
let tx = tx_for_factory.clone();
async move {
let handler = Update::filter_message().endpoint(move |msg: Message, bot: Bot| {
let tx = tx.clone();
let allowed = allowed.clone();
async move { handle_telegram_message(bot, msg, tx, allowed).await }
});
Dispatcher::builder(bot, handler)
.enable_ctrlc_handler()
.build()
.dispatch()
.await;
}
};
if let Some(sup) = &self.supervisor {
sup.spawn(zeph_common::TaskDescriptor {
name: "telegram_listener",
restart: zeph_common::RestartPolicy::Restart {
max: 5,
base_delay: Duration::from_secs(2),
},
factory: listener_factory,
});
} else {
tokio::spawn(listener_factory());
}
tracing::info!("telegram bot listener started");
Ok(self)
}
#[cfg(test)]
fn new_test(allowed_users: Vec<String>) -> (Self, mpsc::Sender<IncomingMessage>) {
let (tx, rx) = mpsc::channel(64);
let channel = Self {
bot: Bot::new("test_token"),
chat_id: None,
rx,
allowed_users,
accumulated: String::new(),
last_edit: None,
message_id: None,
supervisor: None,
};
(channel, tx)
}
fn is_command(text: &str) -> Option<&str> {
let cmd = text.split_whitespace().next()?;
if cmd.starts_with('/') {
Some(cmd)
} else {
None
}
}
fn should_send_update(&self) -> bool {
match self.last_edit {
None => true,
Some(last) => last.elapsed() > Duration::from_secs(3),
}
}
async fn send_or_edit(&mut self) -> Result<(), ChannelError> {
let Some(chat_id) = self.chat_id else {
return Err(ChannelError::NoActiveSession);
};
let text = if self.accumulated.is_empty() {
"..."
} else {
&self.accumulated
};
let formatted_text = markdown_to_telegram(text);
if formatted_text.is_empty() {
tracing::debug!("skipping send: formatted text is empty");
return Ok(());
}
tracing::debug!("formatted_text (full): {}", formatted_text);
match self.message_id {
None => {
tracing::debug!("sending new message (length: {})", formatted_text.len());
let chunks = crate::markdown::utf8_chunks(&formatted_text, MAX_MESSAGE_LEN);
for chunk in chunks {
let msg = self
.bot
.send_message(chat_id, chunk)
.parse_mode(ParseMode::MarkdownV2)
.await
.map_err(ChannelError::telegram)?;
self.message_id = Some(msg.id);
tracing::debug!("new message sent with id: {:?}", msg.id);
}
}
Some(msg_id) => {
tracing::debug!(
"editing message {:?} (length: {})",
msg_id,
formatted_text.len()
);
if formatted_text.len() <= MAX_MESSAGE_LEN {
let edit_result = self
.bot
.edit_message_text(chat_id, msg_id, &formatted_text)
.parse_mode(ParseMode::MarkdownV2)
.await;
if let Err(e) = edit_result {
let error_msg = e.to_string();
if error_msg.contains("message is not modified") {
tracing::debug!("message content unchanged, skipping edit");
} else if error_msg.contains("message to edit not found")
|| error_msg.contains("MESSAGE_ID_INVALID")
{
tracing::warn!(
"Telegram edit failed (message_id stale?): {e}, sending new message"
);
self.message_id = None;
self.last_edit = None;
let msg = self
.bot
.send_message(chat_id, &formatted_text)
.parse_mode(ParseMode::MarkdownV2)
.await
.map_err(ChannelError::telegram)?;
self.message_id = Some(msg.id);
} else {
return Err(ChannelError::telegram(e));
}
} else {
tracing::debug!("message edited successfully");
}
} else {
let chunks = crate::markdown::utf8_chunks(&formatted_text, MAX_MESSAGE_LEN);
let mut iter = chunks.into_iter();
if let Some(first) = iter.next() {
let edit_result = self
.bot
.edit_message_text(chat_id, msg_id, first)
.parse_mode(ParseMode::MarkdownV2)
.await;
if let Err(e) = edit_result {
let error_msg = e.to_string();
if !error_msg.contains("message is not modified") {
tracing::warn!("Telegram edit failed during split: {e}");
}
}
}
for chunk in iter {
let msg = self
.bot
.send_message(chat_id, chunk)
.parse_mode(ParseMode::MarkdownV2)
.await
.map_err(ChannelError::telegram)?;
self.message_id = Some(msg.id);
tracing::debug!("overflow chunk sent with id: {:?}", msg.id);
}
}
}
}
self.last_edit = Some(Instant::now());
Ok(())
}
}
fn is_user_authorized(username: Option<&str>, allowed: &[String]) -> bool {
allowed.is_empty() || username.is_some_and(|u| allowed.iter().any(|a| a == u))
}
fn extract_audio_attachment(msg: &Message) -> Option<(String, u32)> {
msg.voice()
.map(|v| (v.file.id.0.clone(), v.file.size))
.or_else(|| msg.audio().map(|a| (a.file.id.0.clone(), a.file.size)))
}
fn extract_photo_attachment(msg: &Message) -> Option<(String, u32)> {
let photos = msg.photo()?;
let photo = photos.iter().max_by_key(|p| p.file.size)?;
if photo.file.size > MAX_IMAGE_BYTES {
tracing::warn!(
size = photo.file.size,
max = MAX_IMAGE_BYTES,
"photo exceeds size limit, skipping"
);
return None;
}
Some((photo.file.id.0.clone(), photo.file.size))
}
async fn handle_telegram_message(
bot: Bot,
msg: Message,
tx: mpsc::Sender<IncomingMessage>,
allowed: Vec<String>,
) -> Result<(), teloxide::RequestError> {
let username = msg.from.as_ref().and_then(|u| u.username.as_deref());
if !is_user_authorized(username, &allowed) {
tracing::warn!("rejected message from unauthorized user: {:?}", username);
return respond(());
}
let text = msg.text().unwrap_or_default().to_string();
let mut attachments = Vec::new();
if let Some((file_id, file_size)) = extract_audio_attachment(&msg) {
match download_file(&bot, file_id, file_size).await {
Ok(data) => {
attachments.push(Attachment {
kind: AttachmentKind::Audio,
data,
filename: msg.audio().and_then(|a| a.file_name.clone()),
});
}
Err(e) => {
tracing::warn!("failed to download audio attachment: {e}");
}
}
}
if let Some((file_id, file_size)) = extract_photo_attachment(&msg) {
match download_file(&bot, file_id, file_size).await {
Ok(data) => {
attachments.push(Attachment {
kind: AttachmentKind::Image,
data,
filename: None,
});
}
Err(e) => {
tracing::warn!("failed to download photo attachment: {e}");
}
}
}
if text.is_empty() && attachments.is_empty() {
return respond(());
}
let _ = tx
.send(IncomingMessage {
chat_id: msg.chat.id,
text,
attachments,
})
.await;
respond(())
}
async fn download_file(bot: &Bot, file_id: String, capacity: u32) -> Result<Vec<u8>, String> {
use teloxide::net::Download;
let file = bot
.get_file(file_id.into())
.await
.map_err(|e| format!("get_file: {e}"))?;
let mut buf: Vec<u8> = Vec::with_capacity(capacity as usize);
bot.download_file(&file.path, &mut buf)
.await
.map_err(|e| format!("download_file: {e}"))?;
Ok(buf)
}
impl Channel for TelegramChannel {
fn supports_exit(&self) -> bool {
false
}
fn try_recv(&mut self) -> Option<ChannelMessage> {
self.rx.try_recv().ok().map(|incoming| {
self.chat_id = Some(incoming.chat_id);
ChannelMessage {
text: incoming.text,
attachments: incoming.attachments,
}
})
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "channel.telegram.recv", skip_all, fields(msg_len = tracing::field::Empty))
)]
async fn recv(&mut self) -> Result<Option<ChannelMessage>, ChannelError> {
loop {
let Some(incoming) = self.rx.recv().await else {
return Ok(None);
};
self.chat_id = Some(incoming.chat_id);
self.accumulated.clear();
self.last_edit = None;
self.message_id = None;
if let Some(cmd) = Self::is_command(&incoming.text) {
match cmd {
"/start" => {
self.send("Welcome to Zeph! Send me a message to get started.")
.await?;
continue;
}
"/reset" => {
return Ok(Some(ChannelMessage {
text: "/reset".to_string(),
attachments: vec![],
}));
}
"/skills" => {
return Ok(Some(ChannelMessage {
text: "/skills".to_string(),
attachments: vec![],
}));
}
_ => {}
}
}
return Ok(Some(ChannelMessage {
text: incoming.text,
attachments: incoming.attachments,
}));
}
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "channel.telegram.send", skip_all, fields(msg_len = %text.len()))
)]
async fn send(&mut self, text: &str) -> Result<(), ChannelError> {
let Some(chat_id) = self.chat_id else {
return Err(ChannelError::NoActiveSession);
};
let formatted_text = markdown_to_telegram(text);
if formatted_text.is_empty() {
tracing::debug!("skipping send: formatted text is empty");
return Ok(());
}
if formatted_text.len() <= MAX_MESSAGE_LEN {
self.bot
.send_message(chat_id, &formatted_text)
.parse_mode(ParseMode::MarkdownV2)
.await
.map_err(ChannelError::telegram)?;
} else {
let chunks = crate::markdown::utf8_chunks(&formatted_text, MAX_MESSAGE_LEN);
for chunk in chunks {
self.bot
.send_message(chat_id, chunk)
.parse_mode(ParseMode::MarkdownV2)
.await
.map_err(ChannelError::telegram)?;
}
}
Ok(())
}
async fn send_chunk(&mut self, chunk: &str) -> Result<(), ChannelError> {
self.accumulated.push_str(chunk);
tracing::debug!(
"received chunk (size: {}, total: {})",
chunk.len(),
self.accumulated.len()
);
if self.should_send_update() {
tracing::debug!("sending update (should_send_update returned true)");
self.send_or_edit().await?;
}
Ok(())
}
async fn flush_chunks(&mut self) -> Result<(), ChannelError> {
tracing::debug!(
"flushing chunks (message_id: {:?}, accumulated: {} bytes)",
self.message_id,
self.accumulated.len()
);
if self.message_id.is_some() {
self.send_or_edit().await?;
}
self.accumulated.clear();
self.last_edit = None;
self.message_id = None;
Ok(())
}
async fn send_typing(&mut self) -> Result<(), ChannelError> {
let Some(chat_id) = self.chat_id else {
return Ok(());
};
self.bot
.send_chat_action(chat_id, ChatAction::Typing)
.await
.map_err(ChannelError::telegram)?;
Ok(())
}
async fn confirm(&mut self, prompt: &str) -> Result<bool, ChannelError> {
self.send(&format!(
"{prompt}\nReply 'yes' to confirm (timeout: {}s).",
crate::CONFIRM_TIMEOUT.as_secs()
))
.await?;
match tokio::time::timeout(crate::CONFIRM_TIMEOUT, self.rx.recv()).await {
Ok(Some(incoming)) => Ok(incoming.text.trim().eq_ignore_ascii_case("yes")),
Ok(None) => {
tracing::warn!("confirm channel closed — denying secret request");
Ok(false)
}
Err(_) => {
tracing::warn!("confirm timed out after 30s — denied");
Ok(false)
}
}
}
async fn elicit(
&mut self,
request: ElicitationRequest,
) -> Result<ElicitationResponse, ChannelError> {
let timeout = crate::ELICITATION_TIMEOUT;
self.send(&format!(
"*[MCP server '{}' is requesting input]*\n{}\n\n_Reply /cancel to cancel. \
Timeout: {}s._",
sanitize_markdown(&request.server_name),
sanitize_markdown(&request.message),
timeout.as_secs(),
))
.await?;
let mut values = serde_json::Map::new();
for field in &request.fields {
let prompt = build_telegram_field_prompt(field);
self.send(&prompt).await?;
let incoming = match tokio::time::timeout(timeout, self.rx.recv()).await {
Ok(Some(msg)) => msg,
Ok(None) => {
tracing::warn!(server = request.server_name, "elicitation channel closed");
return Ok(ElicitationResponse::Declined);
}
Err(_) => {
tracing::warn!(server = request.server_name, "elicitation timed out");
let _ = self
.send("Elicitation timed out — request cancelled.")
.await;
return Ok(ElicitationResponse::Cancelled);
}
};
let text = incoming.text.trim().to_owned();
if text.eq_ignore_ascii_case("/cancel") {
let _ = self.send("Elicitation cancelled.").await;
return Ok(ElicitationResponse::Cancelled);
}
let Some(value) = coerce_telegram_field(&text, &field.field_type) else {
let _ = self
.send(&format!(
"Invalid value for '{}'. Declining.",
sanitize_markdown(&field.name)
))
.await;
return Ok(ElicitationResponse::Declined);
};
values.insert(field.name.clone(), value);
}
Ok(ElicitationResponse::Accepted(serde_json::Value::Object(
values,
)))
}
}
fn sanitize_markdown(s: &str) -> String {
s.chars()
.filter(|c| !matches!(c, '*' | '_' | '[' | ']' | '`' | '\x1b'))
.collect()
}
fn build_telegram_field_prompt(field: &ElicitationField) -> String {
let req = if field.required { " (required)" } else { "" };
let name = sanitize_markdown(&field.name);
match &field.field_type {
ElicitationFieldType::Boolean => {
format!("*{name}*{req}: Reply *yes* or *no*")
}
ElicitationFieldType::Enum(opts) => {
let list: String = opts
.iter()
.enumerate()
.map(|(i, o)| format!("{}: {}", i + 1, sanitize_markdown(o)))
.collect::<Vec<_>>()
.join("\n");
format!("*{name}*{req}: Reply with the number:\n{list}")
}
ElicitationFieldType::Integer => {
format!("*{name}*{req}: Reply with an integer")
}
ElicitationFieldType::Number => {
format!("*{name}*{req}: Reply with a number")
}
ElicitationFieldType::String => {
format!("*{name}*{req}: Reply with text")
}
}
}
fn coerce_telegram_field(text: &str, kind: &ElicitationFieldType) -> Option<serde_json::Value> {
match kind {
ElicitationFieldType::String => Some(serde_json::Value::String(text.to_owned())),
ElicitationFieldType::Boolean => {
if text.eq_ignore_ascii_case("yes") || text == "1" {
Some(serde_json::Value::Bool(true))
} else if text.eq_ignore_ascii_case("no") || text == "0" {
Some(serde_json::Value::Bool(false))
} else {
None
}
}
ElicitationFieldType::Integer => text
.parse::<i64>()
.ok()
.map(|n| serde_json::Value::Number(n.into())),
ElicitationFieldType::Number => text
.parse::<f64>()
.ok()
.and_then(|n| serde_json::Number::from_f64(n).map(serde_json::Value::Number)),
ElicitationFieldType::Enum(opts) => {
if let Ok(idx) = text.parse::<usize>()
&& idx >= 1
&& idx <= opts.len()
{
return Some(serde_json::Value::String(opts[idx - 1].clone()));
}
opts.iter()
.find(|o| o.eq_ignore_ascii_case(text))
.map(|o| serde_json::Value::String(o.clone()))
}
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use wiremock::matchers::any;
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
fn tg_ok_message() -> serde_json::Value {
serde_json::json!({
"ok": true,
"result": {
"message_id": 42,
"date": 1_700_000_000_i64,
"chat": {"id": 1, "type": "private"}
}
})
}
async fn make_mocked_channel(
server: &MockServer,
allowed_users: Vec<String>,
) -> (TelegramChannel, mpsc::Sender<IncomingMessage>) {
Mock::given(any())
.respond_with(ResponseTemplate::new(200).set_body_json(tg_ok_message()))
.mount(server)
.await;
let api_url = reqwest::Url::parse(&server.uri()).unwrap();
let bot = Bot::new("test_token").set_api_url(api_url);
let (tx, rx) = mpsc::channel(64);
let channel = TelegramChannel {
bot,
chat_id: Some(ChatId(1)),
rx,
allowed_users,
accumulated: String::new(),
last_edit: None,
message_id: None,
supervisor: None,
};
(channel, tx)
}
fn plain_message(text: &str) -> IncomingMessage {
IncomingMessage {
chat_id: ChatId(1),
text: text.to_string(),
attachments: vec![],
}
}
#[test]
fn is_user_authorized_empty_allowed_permits_all() {
assert!(is_user_authorized(None, &[]));
assert!(is_user_authorized(Some("anyone"), &[]));
}
#[test]
fn is_user_authorized_known_user_is_permitted() {
let allowed = vec!["alice".to_string(), "bob".to_string()];
assert!(is_user_authorized(Some("alice"), &allowed));
assert!(is_user_authorized(Some("bob"), &allowed));
}
#[test]
fn is_user_authorized_unknown_user_is_rejected() {
let allowed = vec!["alice".to_string()];
assert!(!is_user_authorized(Some("eve"), &allowed));
assert!(!is_user_authorized(None, &allowed));
}
#[test]
fn is_command_detection() {
assert_eq!(TelegramChannel::is_command("/start"), Some("/start"));
assert_eq!(TelegramChannel::is_command("/reset now"), Some("/reset"));
assert_eq!(TelegramChannel::is_command("hello"), None);
assert_eq!(TelegramChannel::is_command(""), None);
}
#[test]
fn should_send_update_first_chunk() {
let channel = TelegramChannel::new("test_token".to_string(), Vec::new());
assert!(channel.should_send_update());
}
#[test]
fn should_send_update_time_threshold() {
let mut channel = TelegramChannel::new("test_token".to_string(), Vec::new());
channel.accumulated = "test".to_string();
channel.last_edit = Some(Instant::now().checked_sub(Duration::from_secs(4)).unwrap());
assert!(channel.should_send_update());
}
#[test]
fn should_not_send_update_within_threshold() {
let mut channel = TelegramChannel::new("test_token".to_string(), Vec::new());
channel.last_edit = Some(
Instant::now()
.checked_sub(Duration::from_millis(500))
.unwrap(),
);
assert!(!channel.should_send_update());
}
#[test]
fn max_image_bytes_is_20_mib() {
assert_eq!(MAX_IMAGE_BYTES, 20 * 1024 * 1024);
}
#[test]
fn photo_size_limit_enforcement() {
const { assert!(MAX_IMAGE_BYTES - 1 <= MAX_IMAGE_BYTES) };
const { assert!(MAX_IMAGE_BYTES <= MAX_IMAGE_BYTES) };
const { assert!(MAX_IMAGE_BYTES + 1 > MAX_IMAGE_BYTES) };
}
#[test]
fn start_rejects_empty_allowed_users() {
let result = TelegramChannel::new("test_token".to_string(), Vec::new()).start();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ChannelError::Other(_)));
}
#[tokio::test]
async fn recv_returns_channel_message_when_injected() {
let (mut channel, tx) = TelegramChannel::new_test(vec![]);
tx.send(plain_message("hello world")).await.unwrap();
let msg = channel.recv().await.unwrap().unwrap();
assert_eq!(msg.text, "hello world");
assert!(msg.attachments.is_empty());
}
#[tokio::test]
async fn recv_reset_command_routed_correctly() {
let (mut channel, tx) = TelegramChannel::new_test(vec![]);
tx.send(plain_message("/reset")).await.unwrap();
let msg = channel.recv().await.unwrap().unwrap();
assert_eq!(msg.text, "/reset");
}
#[tokio::test]
async fn recv_skills_command_routed_correctly() {
let (mut channel, tx) = TelegramChannel::new_test(vec![]);
tx.send(plain_message("/skills")).await.unwrap();
let msg = channel.recv().await.unwrap().unwrap();
assert_eq!(msg.text, "/skills");
}
#[tokio::test]
async fn recv_unknown_command_passed_through() {
let (mut channel, tx) = TelegramChannel::new_test(vec![]);
tx.send(plain_message("/unknown_cmd arg")).await.unwrap();
let msg = channel.recv().await.unwrap().unwrap();
assert_eq!(msg.text, "/unknown_cmd arg");
}
#[tokio::test]
async fn recv_returns_none_when_sender_dropped() {
let (mut channel, tx) = TelegramChannel::new_test(vec![]);
drop(tx);
let result = channel.recv().await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn send_chunk_accumulates_text_without_api_call() {
let (mut channel, _tx) = TelegramChannel::new_test(vec![]);
channel.last_edit = Some(Instant::now());
channel.send_chunk("hello").await.unwrap();
channel.send_chunk(" world").await.unwrap();
assert_eq!(channel.accumulated, "hello world");
}
#[tokio::test]
async fn flush_chunks_clears_state_when_no_message_id() {
let (mut channel, _tx) = TelegramChannel::new_test(vec![]);
channel.accumulated = "some text".to_string();
channel.last_edit = Some(Instant::now());
channel.flush_chunks().await.unwrap();
assert!(channel.accumulated.is_empty());
assert!(channel.last_edit.is_none());
assert!(channel.message_id.is_none());
}
#[tokio::test]
async fn recv_start_consumed_internally_without_returning_to_caller() {
let server = MockServer::start().await;
let (mut channel, tx) = make_mocked_channel(&server, vec![]).await;
tx.send(plain_message("/start")).await.unwrap();
tx.send(plain_message("hello after start")).await.unwrap();
let msg = channel.recv().await.unwrap().unwrap();
assert_eq!(msg.text, "hello after start");
}
#[tokio::test]
async fn flush_chunks_calls_edit_and_clears_state_when_message_id_set() {
let server = MockServer::start().await;
let (mut channel, _tx) = make_mocked_channel(&server, vec![]).await;
channel.accumulated = "partial response".to_string();
channel.last_edit = Some(Instant::now());
channel.message_id = Some(teloxide::types::MessageId(42));
channel.flush_chunks().await.unwrap();
assert!(channel.accumulated.is_empty());
assert!(channel.last_edit.is_none());
assert!(channel.message_id.is_none());
}
#[tokio::test]
async fn confirm_timeout_logic_denies_on_timeout() {
tokio::time::pause();
let (_tx, mut rx) = mpsc::channel::<IncomingMessage>(1);
let timeout_fut = tokio::time::timeout(crate::CONFIRM_TIMEOUT, rx.recv());
tokio::time::advance(crate::CONFIRM_TIMEOUT + Duration::from_millis(1)).await;
let result = timeout_fut.await;
assert!(result.is_err(), "expected timeout Err, got recv result");
}
#[tokio::test]
async fn confirm_close_logic_denies_on_channel_close() {
let (tx, mut rx) = mpsc::channel::<IncomingMessage>(1);
drop(tx);
let result = tokio::time::timeout(crate::CONFIRM_TIMEOUT, rx.recv()).await;
assert!(result.is_ok(), "should not time out");
assert!(
result.unwrap().is_none(),
"closed channel should yield None"
);
}
#[tokio::test]
async fn confirm_yes_logic_accepts_yes_response() {
let (tx, mut rx) = mpsc::channel::<IncomingMessage>(1);
tx.send(plain_message("yes")).await.unwrap();
let result = tokio::time::timeout(crate::CONFIRM_TIMEOUT, rx.recv())
.await
.unwrap()
.unwrap();
assert!(result.text.trim().eq_ignore_ascii_case("yes"));
}
#[tokio::test]
async fn confirm_no_logic_denies_non_yes_response() {
let (tx, mut rx) = mpsc::channel::<IncomingMessage>(1);
tx.send(plain_message("no")).await.unwrap();
let result = tokio::time::timeout(crate::CONFIRM_TIMEOUT, rx.recv())
.await
.unwrap()
.unwrap();
assert!(!result.text.trim().eq_ignore_ascii_case("yes"));
}
#[tokio::test]
async fn send_or_edit_splits_long_message_into_multiple_sends() {
let server = MockServer::start().await;
let (mut channel, _tx) = make_mocked_channel(&server, vec![]).await;
let long_text = "a".repeat(MAX_MESSAGE_LEN + 1);
channel.accumulated = long_text;
channel.send_or_edit().await.unwrap();
let requests = server.received_requests().await.unwrap();
assert!(
requests.len() >= 2,
"expected ≥2 API calls for oversized message, got {}",
requests.len()
);
}
#[tokio::test]
async fn send_or_edit_single_message_when_within_limit() {
let server = MockServer::start().await;
let (mut channel, _tx) = make_mocked_channel(&server, vec![]).await;
channel.accumulated = "short text".to_string();
channel.send_or_edit().await.unwrap();
let requests = server.received_requests().await.unwrap();
assert_eq!(
requests.len(),
1,
"expected exactly 1 API call for short message"
);
assert!(channel.message_id.is_some());
}
#[tokio::test]
async fn send_or_edit_splits_when_edit_overflows() {
let server = MockServer::start().await;
let (mut channel, _tx) = make_mocked_channel(&server, vec![]).await;
channel.message_id = Some(teloxide::types::MessageId(42));
let long_text = "b".repeat(MAX_MESSAGE_LEN + 1);
channel.accumulated = long_text;
channel.send_or_edit().await.unwrap();
let requests = server.received_requests().await.unwrap();
assert!(
requests.len() >= 2,
"expected edit + at least 1 overflow send, got {}",
requests.len()
);
}
#[tokio::test]
async fn elicit_happy_path_string_field_returns_accepted() {
let server = MockServer::start().await;
let (mut channel, tx) = make_mocked_channel(&server, vec![]).await;
let request = ElicitationRequest {
server_name: "test-server".to_owned(),
message: "Please provide your name".to_owned(),
fields: vec![ElicitationField {
name: "username".to_owned(),
description: None,
field_type: ElicitationFieldType::String,
required: true,
}],
};
tx.send(plain_message("alice")).await.unwrap();
let response = channel.elicit(request).await.unwrap();
match response {
ElicitationResponse::Accepted(val) => {
assert_eq!(val["username"], "alice");
}
other => panic!("expected Accepted, got {other:?}"),
}
}
#[tokio::test]
async fn elicit_field_key_uses_raw_name_not_sanitized() {
let server = MockServer::start().await;
let (mut channel, tx) = make_mocked_channel(&server, vec![]).await;
let request = ElicitationRequest {
server_name: "test-server".to_owned(),
message: "Provide credentials".to_owned(),
fields: vec![ElicitationField {
name: "pass phrase".to_owned(),
description: None,
field_type: ElicitationFieldType::String,
required: true,
}],
};
tx.send(plain_message("hunter2")).await.unwrap();
let response = channel.elicit(request).await.unwrap();
match response {
ElicitationResponse::Accepted(val) => {
assert_eq!(
val["pass phrase"], "hunter2",
"raw field name must be the map key"
);
assert!(
val.get("passphrase").is_none(),
"sanitized key must not appear in response"
);
}
other => panic!("expected Accepted, got {other:?}"),
}
}
#[tokio::test]
async fn elicit_cancel_command_returns_cancelled() {
let server = MockServer::start().await;
let (mut channel, tx) = make_mocked_channel(&server, vec![]).await;
let request = ElicitationRequest {
server_name: "test-server".to_owned(),
message: "Provide a value".to_owned(),
fields: vec![ElicitationField {
name: "token".to_owned(),
description: None,
field_type: ElicitationFieldType::String,
required: true,
}],
};
tx.send(plain_message("/cancel")).await.unwrap();
let response = channel.elicit(request).await.unwrap();
assert!(
matches!(response, ElicitationResponse::Cancelled),
"expected Cancelled, got {response:?}"
);
}
#[tokio::test]
async fn elicit_timeout_logic_cancels_on_timeout() {
tokio::time::pause();
let (_tx, mut rx) = mpsc::channel::<IncomingMessage>(1);
let timeout_fut = tokio::time::timeout(crate::ELICITATION_TIMEOUT, rx.recv());
tokio::time::advance(crate::ELICITATION_TIMEOUT + Duration::from_millis(1)).await;
let result = timeout_fut.await;
assert!(
result.is_err(),
"expected Err(Elapsed) for elicitation timeout, got recv result"
);
}
#[tokio::test]
async fn with_supervisor_registers_listener_task() {
use tokio_util::sync::CancellationToken;
let cancel = CancellationToken::new();
let sup = zeph_common::TaskSupervisor::new(cancel.clone());
let (channel, _tx) = TelegramChannel::new_test(vec!["user".to_string()]);
let channel = channel.with_supervisor(sup.clone());
channel
.start()
.expect("start() must succeed with non-empty allowed_users");
tokio::task::yield_now().await;
let snapshot = sup.snapshot();
let names: Vec<&str> = snapshot.iter().map(|s| s.name.as_ref()).collect();
assert!(
names.contains(&"telegram_listener"),
"expected 'telegram_listener' in supervisor snapshot, got: {names:?}"
);
cancel.cancel();
}
}