1use super::traits::{Tool, ToolResult};
9use crate::channels::traits::Channel;
10use crate::security::SecurityPolicy;
11use crate::security::policy::ToolOperation;
12use async_trait::async_trait;
13use parking_lot::RwLock;
14use serde_json::json;
15use std::collections::HashMap;
16use std::sync::Arc;
17
18pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
20
21pub struct ReactionTool {
23 channels: ChannelMapHandle,
24 security: Arc<SecurityPolicy>,
25}
26
27impl ReactionTool {
28 pub fn new(security: Arc<SecurityPolicy>) -> Self {
32 Self {
33 channels: Arc::new(RwLock::new(HashMap::new())),
34 security,
35 }
36 }
37
38 pub fn channel_map_handle(&self) -> ChannelMapHandle {
40 Arc::clone(&self.channels)
41 }
42
43 pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
45 *self.channels.write() = map;
46 }
47}
48
49#[async_trait]
50impl Tool for ReactionTool {
51 fn name(&self) -> &str {
52 "reaction"
53 }
54
55 fn description(&self) -> &str {
56 "Add or remove an emoji reaction on a message in any active channel. \
57 Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
58 the platform message ID, and the emoji (Unicode character or platform shortcode)."
59 }
60
61 fn parameters_schema(&self) -> serde_json::Value {
62 json!({
63 "type": "object",
64 "properties": {
65 "channel": {
66 "type": "string",
67 "description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
68 },
69 "channel_id": {
70 "type": "string",
71 "description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
72 },
73 "message_id": {
74 "type": "string",
75 "description": "Platform-scoped message identifier to react to"
76 },
77 "emoji": {
78 "type": "string",
79 "description": "Emoji to react with (Unicode character or platform shortcode)"
80 },
81 "action": {
82 "type": "string",
83 "enum": ["add", "remove"],
84 "description": "Whether to add or remove the reaction (default: 'add')"
85 }
86 },
87 "required": ["channel", "channel_id", "message_id", "emoji"]
88 })
89 }
90
91 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
92 if let Err(error) = self
94 .security
95 .enforce_tool_operation(ToolOperation::Act, "reaction")
96 {
97 return Ok(ToolResult {
98 success: false,
99 output: String::new(),
100 error: Some(error),
101 });
102 }
103
104 let channel_name = args
105 .get("channel")
106 .and_then(|v| v.as_str())
107 .ok_or_else(|| anyhow::anyhow!("Missing 'channel' parameter"))?;
108
109 let channel_id = args
110 .get("channel_id")
111 .and_then(|v| v.as_str())
112 .ok_or_else(|| anyhow::anyhow!("Missing 'channel_id' parameter"))?;
113
114 let message_id = args
115 .get("message_id")
116 .and_then(|v| v.as_str())
117 .ok_or_else(|| anyhow::anyhow!("Missing 'message_id' parameter"))?;
118
119 let emoji = args
120 .get("emoji")
121 .and_then(|v| v.as_str())
122 .ok_or_else(|| anyhow::anyhow!("Missing 'emoji' parameter"))?;
123
124 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
125
126 if action != "add" && action != "remove" {
127 return Ok(ToolResult {
128 success: false,
129 output: String::new(),
130 error: Some(format!(
131 "Invalid action '{action}': must be 'add' or 'remove'"
132 )),
133 });
134 }
135
136 let channel = {
138 let map = self.channels.read();
139 if map.is_empty() {
140 return Ok(ToolResult {
141 success: false,
142 output: String::new(),
143 error: Some("No channels available yet (channels not initialized)".to_string()),
144 });
145 }
146 match map.get(channel_name) {
147 Some(ch) => Arc::clone(ch),
148 None => {
149 let available: Vec<String> = map.keys().cloned().collect();
150 return Ok(ToolResult {
151 success: false,
152 output: String::new(),
153 error: Some(format!(
154 "Channel '{channel_name}' not found. Available channels: {}",
155 available.join(", ")
156 )),
157 });
158 }
159 }
160 };
161
162 let result = if action == "add" {
163 channel.add_reaction(channel_id, message_id, emoji).await
164 } else {
165 channel.remove_reaction(channel_id, message_id, emoji).await
166 };
167
168 let past_tense = if action == "remove" {
169 "removed"
170 } else {
171 "added"
172 };
173
174 match result {
175 Ok(()) => Ok(ToolResult {
176 success: true,
177 output: format!(
178 "Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
179 ),
180 error: None,
181 }),
182 Err(e) => Ok(ToolResult {
183 success: false,
184 output: String::new(),
185 error: Some(format!("Failed to {action} reaction: {e}")),
186 }),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::channels::traits::{ChannelMessage, SendMessage};
195 use std::sync::atomic::{AtomicBool, Ordering};
196
197 struct MockChannel {
198 reaction_added: AtomicBool,
199 reaction_removed: AtomicBool,
200 last_channel_id: parking_lot::Mutex<Option<String>>,
201 fail_on_add: bool,
202 }
203
204 impl MockChannel {
205 fn new() -> Self {
206 Self {
207 reaction_added: AtomicBool::new(false),
208 reaction_removed: AtomicBool::new(false),
209 last_channel_id: parking_lot::Mutex::new(None),
210 fail_on_add: false,
211 }
212 }
213
214 fn failing() -> Self {
215 Self {
216 reaction_added: AtomicBool::new(false),
217 reaction_removed: AtomicBool::new(false),
218 last_channel_id: parking_lot::Mutex::new(None),
219 fail_on_add: true,
220 }
221 }
222 }
223
224 #[async_trait]
225 impl Channel for MockChannel {
226 fn name(&self) -> &str {
227 "mock"
228 }
229
230 async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
231 Ok(())
232 }
233
234 async fn listen(
235 &self,
236 _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
237 ) -> anyhow::Result<()> {
238 Ok(())
239 }
240
241 async fn add_reaction(
242 &self,
243 channel_id: &str,
244 _message_id: &str,
245 _emoji: &str,
246 ) -> anyhow::Result<()> {
247 if self.fail_on_add {
248 return Err(anyhow::anyhow!("API error: rate limited"));
249 }
250 *self.last_channel_id.lock() = Some(channel_id.to_string());
251 self.reaction_added.store(true, Ordering::SeqCst);
252 Ok(())
253 }
254
255 async fn remove_reaction(
256 &self,
257 channel_id: &str,
258 _message_id: &str,
259 _emoji: &str,
260 ) -> anyhow::Result<()> {
261 *self.last_channel_id.lock() = Some(channel_id.to_string());
262 self.reaction_removed.store(true, Ordering::SeqCst);
263 Ok(())
264 }
265 }
266
267 fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
268 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
269 let map: HashMap<String, Arc<dyn Channel>> = channels
270 .into_iter()
271 .map(|(name, ch)| (name.to_string(), ch))
272 .collect();
273 tool.populate(map);
274 tool
275 }
276
277 #[test]
278 fn tool_metadata() {
279 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
280 assert_eq!(tool.name(), "reaction");
281 assert!(!tool.description().is_empty());
282 let schema = tool.parameters_schema();
283 assert_eq!(schema["type"], "object");
284 assert!(schema["properties"]["channel"].is_object());
285 assert!(schema["properties"]["channel_id"].is_object());
286 assert!(schema["properties"]["message_id"].is_object());
287 assert!(schema["properties"]["emoji"].is_object());
288 assert!(schema["properties"]["action"].is_object());
289 let required = schema["required"].as_array().unwrap();
290 assert!(required.iter().any(|v| v == "channel"));
291 assert!(required.iter().any(|v| v == "channel_id"));
292 assert!(required.iter().any(|v| v == "message_id"));
293 assert!(required.iter().any(|v| v == "emoji"));
294 assert!(!required.iter().any(|v| v == "action"));
296 }
297
298 #[tokio::test]
299 async fn add_reaction_success() {
300 let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
301 let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
302
303 let result = tool
304 .execute(json!({
305 "channel": "discord",
306 "channel_id": "ch_001",
307 "message_id": "msg_123",
308 "emoji": "\u{2705}"
309 }))
310 .await
311 .unwrap();
312
313 assert!(result.success);
314 assert!(result.output.contains("added"));
315 assert!(result.error.is_none());
316 }
317
318 #[tokio::test]
319 async fn remove_reaction_success() {
320 let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
321 let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
322
323 let result = tool
324 .execute(json!({
325 "channel": "slack",
326 "channel_id": "C0123SLACK",
327 "message_id": "msg_456",
328 "emoji": "\u{1F440}",
329 "action": "remove"
330 }))
331 .await
332 .unwrap();
333
334 assert!(result.success);
335 assert!(result.output.contains("removed"));
336 }
337
338 #[tokio::test]
339 async fn unknown_channel_returns_error() {
340 let tool = make_tool_with_channels(vec![(
341 "discord",
342 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
343 )]);
344
345 let result = tool
346 .execute(json!({
347 "channel": "nonexistent",
348 "channel_id": "ch_x",
349 "message_id": "msg_1",
350 "emoji": "\u{2705}"
351 }))
352 .await
353 .unwrap();
354
355 assert!(!result.success);
356 let err = result.error.as_deref().unwrap();
357 assert!(err.contains("not found"));
358 assert!(err.contains("discord"));
359 }
360
361 #[tokio::test]
362 async fn invalid_action_returns_error() {
363 let tool = make_tool_with_channels(vec![(
364 "discord",
365 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
366 )]);
367
368 let result = tool
369 .execute(json!({
370 "channel": "discord",
371 "channel_id": "ch_001",
372 "message_id": "msg_1",
373 "emoji": "\u{2705}",
374 "action": "toggle"
375 }))
376 .await
377 .unwrap();
378
379 assert!(!result.success);
380 assert!(result.error.as_deref().unwrap().contains("toggle"));
381 }
382
383 #[tokio::test]
384 async fn channel_error_propagated() {
385 let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
386 let tool = make_tool_with_channels(vec![("discord", mock)]);
387
388 let result = tool
389 .execute(json!({
390 "channel": "discord",
391 "channel_id": "ch_001",
392 "message_id": "msg_1",
393 "emoji": "\u{2705}"
394 }))
395 .await
396 .unwrap();
397
398 assert!(!result.success);
399 assert!(result.error.as_deref().unwrap().contains("rate limited"));
400 }
401
402 #[tokio::test]
403 async fn missing_required_params() {
404 let tool = make_tool_with_channels(vec![(
405 "test",
406 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
407 )]);
408
409 let result = tool
411 .execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
412 .await;
413 assert!(result.is_err());
414
415 let result = tool
417 .execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
418 .await;
419 assert!(result.is_err());
420
421 let result = tool
423 .execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
424 .await;
425 assert!(result.is_err());
426
427 let result = tool
429 .execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
430 .await;
431 assert!(result.is_err());
432 }
433
434 #[tokio::test]
435 async fn empty_channels_returns_not_initialized() {
436 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
437 let result = tool
440 .execute(json!({
441 "channel": "discord",
442 "channel_id": "ch_001",
443 "message_id": "msg_1",
444 "emoji": "\u{2705}"
445 }))
446 .await
447 .unwrap();
448
449 assert!(!result.success);
450 assert!(result.error.as_deref().unwrap().contains("not initialized"));
451 }
452
453 #[tokio::test]
454 async fn default_action_is_add() {
455 let mock = Arc::new(MockChannel::new());
456 let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
457 let tool = make_tool_with_channels(vec![("test", mock_ch)]);
458
459 let result = tool
460 .execute(json!({
461 "channel": "test",
462 "channel_id": "ch_test",
463 "message_id": "msg_1",
464 "emoji": "\u{2705}"
465 }))
466 .await
467 .unwrap();
468
469 assert!(result.success);
470 assert!(mock.reaction_added.load(Ordering::SeqCst));
471 assert!(!mock.reaction_removed.load(Ordering::SeqCst));
472 }
473
474 #[tokio::test]
475 async fn channel_id_passed_to_trait_not_channel_name() {
476 let mock = Arc::new(MockChannel::new());
477 let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
478 let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
479
480 let result = tool
481 .execute(json!({
482 "channel": "discord",
483 "channel_id": "123456789",
484 "message_id": "msg_1",
485 "emoji": "\u{2705}"
486 }))
487 .await
488 .unwrap();
489
490 assert!(result.success);
491 assert_eq!(
493 mock.last_channel_id.lock().as_deref(),
494 Some("123456789"),
495 "add_reaction must receive channel_id, not channel name"
496 );
497 }
498
499 #[tokio::test]
500 async fn channel_map_handle_allows_late_binding() {
501 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
502 let handle = tool.channel_map_handle();
503
504 let result = tool
506 .execute(json!({
507 "channel": "slack",
508 "channel_id": "C0123",
509 "message_id": "msg_1",
510 "emoji": "\u{2705}"
511 }))
512 .await
513 .unwrap();
514 assert!(!result.success);
515
516 {
518 let mut map = handle.write();
519 map.insert(
520 "slack".to_string(),
521 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
522 );
523 }
524
525 let result = tool
527 .execute(json!({
528 "channel": "slack",
529 "channel_id": "C0123",
530 "message_id": "msg_1",
531 "emoji": "\u{2705}"
532 }))
533 .await
534 .unwrap();
535 assert!(result.success);
536 }
537
538 #[test]
539 fn spec_matches_metadata() {
540 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
541 let spec = tool.spec();
542 assert_eq!(spec.name, "reaction");
543 assert_eq!(spec.description, tool.description());
544 assert!(spec.parameters["required"].is_array());
545 }
546}