Skip to main content

construct/tools/
reaction.rs

1//! Emoji reaction tool for cross-channel message reactions.
2//!
3//! Exposes `add_reaction` and `remove_reaction` from the [`Channel`] trait as an
4//! agent-callable tool. The tool holds a late-binding channel map handle that is
5//! populated once channels are initialized (after tool construction). This mirrors
6//! the pattern used by [`DelegateTool`] for its parent-tools handle.
7
8use super::traits::{Tool, ToolResult};
9use crate::channels::traits::Channel;
10use crate::security::SecurityPolicy;
11use crate::security::policy::ToolOperation;
12use async_trait::async_trait;
13use parking_lot::RwLock;
14use serde_json::json;
15use std::collections::HashMap;
16use std::sync::Arc;
17
18/// Shared handle to the channel map. Starts empty; populated once channels boot.
19pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
20
21/// Agent-callable tool for adding or removing emoji reactions on messages.
22pub struct ReactionTool {
23    channels: ChannelMapHandle,
24    security: Arc<SecurityPolicy>,
25}
26
27impl ReactionTool {
28    /// Create a new reaction tool with an empty channel map.
29    /// Call [`populate`] or write to the returned [`ChannelMapHandle`] once channels
30    /// are available.
31    pub fn new(security: Arc<SecurityPolicy>) -> Self {
32        Self {
33            channels: Arc::new(RwLock::new(HashMap::new())),
34            security,
35        }
36    }
37
38    /// Return the shared handle so callers can populate it after channel init.
39    pub fn channel_map_handle(&self) -> ChannelMapHandle {
40        Arc::clone(&self.channels)
41    }
42
43    /// Convenience: populate the channel map from a pre-built map.
44    pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
45        *self.channels.write() = map;
46    }
47}
48
49#[async_trait]
50impl Tool for ReactionTool {
51    fn name(&self) -> &str {
52        "reaction"
53    }
54
55    fn description(&self) -> &str {
56        "Add or remove an emoji reaction on a message in any active channel. \
57         Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
58         the platform message ID, and the emoji (Unicode character or platform shortcode)."
59    }
60
61    fn parameters_schema(&self) -> serde_json::Value {
62        json!({
63            "type": "object",
64            "properties": {
65                "channel": {
66                    "type": "string",
67                    "description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
68                },
69                "channel_id": {
70                    "type": "string",
71                    "description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
72                },
73                "message_id": {
74                    "type": "string",
75                    "description": "Platform-scoped message identifier to react to"
76                },
77                "emoji": {
78                    "type": "string",
79                    "description": "Emoji to react with (Unicode character or platform shortcode)"
80                },
81                "action": {
82                    "type": "string",
83                    "enum": ["add", "remove"],
84                    "description": "Whether to add or remove the reaction (default: 'add')"
85                }
86            },
87            "required": ["channel", "channel_id", "message_id", "emoji"]
88        })
89    }
90
91    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
92        // Security gate
93        if let Err(error) = self
94            .security
95            .enforce_tool_operation(ToolOperation::Act, "reaction")
96        {
97            return Ok(ToolResult {
98                success: false,
99                output: String::new(),
100                error: Some(error),
101            });
102        }
103
104        let channel_name = args
105            .get("channel")
106            .and_then(|v| v.as_str())
107            .ok_or_else(|| anyhow::anyhow!("Missing 'channel' parameter"))?;
108
109        let channel_id = args
110            .get("channel_id")
111            .and_then(|v| v.as_str())
112            .ok_or_else(|| anyhow::anyhow!("Missing 'channel_id' parameter"))?;
113
114        let message_id = args
115            .get("message_id")
116            .and_then(|v| v.as_str())
117            .ok_or_else(|| anyhow::anyhow!("Missing 'message_id' parameter"))?;
118
119        let emoji = args
120            .get("emoji")
121            .and_then(|v| v.as_str())
122            .ok_or_else(|| anyhow::anyhow!("Missing 'emoji' parameter"))?;
123
124        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
125
126        if action != "add" && action != "remove" {
127            return Ok(ToolResult {
128                success: false,
129                output: String::new(),
130                error: Some(format!(
131                    "Invalid action '{action}': must be 'add' or 'remove'"
132                )),
133            });
134        }
135
136        // Read-lock the channel map to find the target channel.
137        let channel = {
138            let map = self.channels.read();
139            if map.is_empty() {
140                return Ok(ToolResult {
141                    success: false,
142                    output: String::new(),
143                    error: Some("No channels available yet (channels not initialized)".to_string()),
144                });
145            }
146            match map.get(channel_name) {
147                Some(ch) => Arc::clone(ch),
148                None => {
149                    let available: Vec<String> = map.keys().cloned().collect();
150                    return Ok(ToolResult {
151                        success: false,
152                        output: String::new(),
153                        error: Some(format!(
154                            "Channel '{channel_name}' not found. Available channels: {}",
155                            available.join(", ")
156                        )),
157                    });
158                }
159            }
160        };
161
162        let result = if action == "add" {
163            channel.add_reaction(channel_id, message_id, emoji).await
164        } else {
165            channel.remove_reaction(channel_id, message_id, emoji).await
166        };
167
168        let past_tense = if action == "remove" {
169            "removed"
170        } else {
171            "added"
172        };
173
174        match result {
175            Ok(()) => Ok(ToolResult {
176                success: true,
177                output: format!(
178                    "Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
179                ),
180                error: None,
181            }),
182            Err(e) => Ok(ToolResult {
183                success: false,
184                output: String::new(),
185                error: Some(format!("Failed to {action} reaction: {e}")),
186            }),
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::channels::traits::{ChannelMessage, SendMessage};
195    use std::sync::atomic::{AtomicBool, Ordering};
196
197    struct MockChannel {
198        reaction_added: AtomicBool,
199        reaction_removed: AtomicBool,
200        last_channel_id: parking_lot::Mutex<Option<String>>,
201        fail_on_add: bool,
202    }
203
204    impl MockChannel {
205        fn new() -> Self {
206            Self {
207                reaction_added: AtomicBool::new(false),
208                reaction_removed: AtomicBool::new(false),
209                last_channel_id: parking_lot::Mutex::new(None),
210                fail_on_add: false,
211            }
212        }
213
214        fn failing() -> Self {
215            Self {
216                reaction_added: AtomicBool::new(false),
217                reaction_removed: AtomicBool::new(false),
218                last_channel_id: parking_lot::Mutex::new(None),
219                fail_on_add: true,
220            }
221        }
222    }
223
224    #[async_trait]
225    impl Channel for MockChannel {
226        fn name(&self) -> &str {
227            "mock"
228        }
229
230        async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
231            Ok(())
232        }
233
234        async fn listen(
235            &self,
236            _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
237        ) -> anyhow::Result<()> {
238            Ok(())
239        }
240
241        async fn add_reaction(
242            &self,
243            channel_id: &str,
244            _message_id: &str,
245            _emoji: &str,
246        ) -> anyhow::Result<()> {
247            if self.fail_on_add {
248                return Err(anyhow::anyhow!("API error: rate limited"));
249            }
250            *self.last_channel_id.lock() = Some(channel_id.to_string());
251            self.reaction_added.store(true, Ordering::SeqCst);
252            Ok(())
253        }
254
255        async fn remove_reaction(
256            &self,
257            channel_id: &str,
258            _message_id: &str,
259            _emoji: &str,
260        ) -> anyhow::Result<()> {
261            *self.last_channel_id.lock() = Some(channel_id.to_string());
262            self.reaction_removed.store(true, Ordering::SeqCst);
263            Ok(())
264        }
265    }
266
267    fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
268        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
269        let map: HashMap<String, Arc<dyn Channel>> = channels
270            .into_iter()
271            .map(|(name, ch)| (name.to_string(), ch))
272            .collect();
273        tool.populate(map);
274        tool
275    }
276
277    #[test]
278    fn tool_metadata() {
279        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
280        assert_eq!(tool.name(), "reaction");
281        assert!(!tool.description().is_empty());
282        let schema = tool.parameters_schema();
283        assert_eq!(schema["type"], "object");
284        assert!(schema["properties"]["channel"].is_object());
285        assert!(schema["properties"]["channel_id"].is_object());
286        assert!(schema["properties"]["message_id"].is_object());
287        assert!(schema["properties"]["emoji"].is_object());
288        assert!(schema["properties"]["action"].is_object());
289        let required = schema["required"].as_array().unwrap();
290        assert!(required.iter().any(|v| v == "channel"));
291        assert!(required.iter().any(|v| v == "channel_id"));
292        assert!(required.iter().any(|v| v == "message_id"));
293        assert!(required.iter().any(|v| v == "emoji"));
294        // action is optional (defaults to "add")
295        assert!(!required.iter().any(|v| v == "action"));
296    }
297
298    #[tokio::test]
299    async fn add_reaction_success() {
300        let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
301        let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
302
303        let result = tool
304            .execute(json!({
305                "channel": "discord",
306                "channel_id": "ch_001",
307                "message_id": "msg_123",
308                "emoji": "\u{2705}"
309            }))
310            .await
311            .unwrap();
312
313        assert!(result.success);
314        assert!(result.output.contains("added"));
315        assert!(result.error.is_none());
316    }
317
318    #[tokio::test]
319    async fn remove_reaction_success() {
320        let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
321        let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
322
323        let result = tool
324            .execute(json!({
325                "channel": "slack",
326                "channel_id": "C0123SLACK",
327                "message_id": "msg_456",
328                "emoji": "\u{1F440}",
329                "action": "remove"
330            }))
331            .await
332            .unwrap();
333
334        assert!(result.success);
335        assert!(result.output.contains("removed"));
336    }
337
338    #[tokio::test]
339    async fn unknown_channel_returns_error() {
340        let tool = make_tool_with_channels(vec![(
341            "discord",
342            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
343        )]);
344
345        let result = tool
346            .execute(json!({
347                "channel": "nonexistent",
348                "channel_id": "ch_x",
349                "message_id": "msg_1",
350                "emoji": "\u{2705}"
351            }))
352            .await
353            .unwrap();
354
355        assert!(!result.success);
356        let err = result.error.as_deref().unwrap();
357        assert!(err.contains("not found"));
358        assert!(err.contains("discord"));
359    }
360
361    #[tokio::test]
362    async fn invalid_action_returns_error() {
363        let tool = make_tool_with_channels(vec![(
364            "discord",
365            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
366        )]);
367
368        let result = tool
369            .execute(json!({
370                "channel": "discord",
371                "channel_id": "ch_001",
372                "message_id": "msg_1",
373                "emoji": "\u{2705}",
374                "action": "toggle"
375            }))
376            .await
377            .unwrap();
378
379        assert!(!result.success);
380        assert!(result.error.as_deref().unwrap().contains("toggle"));
381    }
382
383    #[tokio::test]
384    async fn channel_error_propagated() {
385        let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
386        let tool = make_tool_with_channels(vec![("discord", mock)]);
387
388        let result = tool
389            .execute(json!({
390                "channel": "discord",
391                "channel_id": "ch_001",
392                "message_id": "msg_1",
393                "emoji": "\u{2705}"
394            }))
395            .await
396            .unwrap();
397
398        assert!(!result.success);
399        assert!(result.error.as_deref().unwrap().contains("rate limited"));
400    }
401
402    #[tokio::test]
403    async fn missing_required_params() {
404        let tool = make_tool_with_channels(vec![(
405            "test",
406            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
407        )]);
408
409        // Missing channel
410        let result = tool
411            .execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
412            .await;
413        assert!(result.is_err());
414
415        // Missing channel_id
416        let result = tool
417            .execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
418            .await;
419        assert!(result.is_err());
420
421        // Missing message_id
422        let result = tool
423            .execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
424            .await;
425        assert!(result.is_err());
426
427        // Missing emoji
428        let result = tool
429            .execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
430            .await;
431        assert!(result.is_err());
432    }
433
434    #[tokio::test]
435    async fn empty_channels_returns_not_initialized() {
436        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
437        // No channels populated
438
439        let result = tool
440            .execute(json!({
441                "channel": "discord",
442                "channel_id": "ch_001",
443                "message_id": "msg_1",
444                "emoji": "\u{2705}"
445            }))
446            .await
447            .unwrap();
448
449        assert!(!result.success);
450        assert!(result.error.as_deref().unwrap().contains("not initialized"));
451    }
452
453    #[tokio::test]
454    async fn default_action_is_add() {
455        let mock = Arc::new(MockChannel::new());
456        let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
457        let tool = make_tool_with_channels(vec![("test", mock_ch)]);
458
459        let result = tool
460            .execute(json!({
461                "channel": "test",
462                "channel_id": "ch_test",
463                "message_id": "msg_1",
464                "emoji": "\u{2705}"
465            }))
466            .await
467            .unwrap();
468
469        assert!(result.success);
470        assert!(mock.reaction_added.load(Ordering::SeqCst));
471        assert!(!mock.reaction_removed.load(Ordering::SeqCst));
472    }
473
474    #[tokio::test]
475    async fn channel_id_passed_to_trait_not_channel_name() {
476        let mock = Arc::new(MockChannel::new());
477        let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
478        let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
479
480        let result = tool
481            .execute(json!({
482                "channel": "discord",
483                "channel_id": "123456789",
484                "message_id": "msg_1",
485                "emoji": "\u{2705}"
486            }))
487            .await
488            .unwrap();
489
490        assert!(result.success);
491        // The trait must receive the platform channel_id, not the channel name
492        assert_eq!(
493            mock.last_channel_id.lock().as_deref(),
494            Some("123456789"),
495            "add_reaction must receive channel_id, not channel name"
496        );
497    }
498
499    #[tokio::test]
500    async fn channel_map_handle_allows_late_binding() {
501        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
502        let handle = tool.channel_map_handle();
503
504        // Initially empty — tool reports not initialized
505        let result = tool
506            .execute(json!({
507                "channel": "slack",
508                "channel_id": "C0123",
509                "message_id": "msg_1",
510                "emoji": "\u{2705}"
511            }))
512            .await
513            .unwrap();
514        assert!(!result.success);
515
516        // Populate via the handle
517        {
518            let mut map = handle.write();
519            map.insert(
520                "slack".to_string(),
521                Arc::new(MockChannel::new()) as Arc<dyn Channel>,
522            );
523        }
524
525        // Now the tool can route to the channel
526        let result = tool
527            .execute(json!({
528                "channel": "slack",
529                "channel_id": "C0123",
530                "message_id": "msg_1",
531                "emoji": "\u{2705}"
532            }))
533            .await
534            .unwrap();
535        assert!(result.success);
536    }
537
538    #[test]
539    fn spec_matches_metadata() {
540        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
541        let spec = tool.spec();
542        assert_eq!(spec.name, "reaction");
543        assert_eq!(spec.description, tool.description());
544        assert!(spec.parameters["required"].is_array());
545    }
546}