Skip to main content

construct/tools/
discord_search.rs

1use super::traits::{Tool, ToolResult};
2use crate::memory::Memory;
3use async_trait::async_trait;
4use serde_json::json;
5use std::fmt::Write;
6use std::sync::Arc;
7
8/// Search Discord message history stored in discord.db.
9pub struct DiscordSearchTool {
10    discord_memory: Arc<dyn Memory>,
11}
12
13impl DiscordSearchTool {
14    pub fn new(discord_memory: Arc<dyn Memory>) -> Self {
15        Self { discord_memory }
16    }
17}
18
19#[async_trait]
20impl Tool for DiscordSearchTool {
21    fn name(&self) -> &str {
22        "discord_search"
23    }
24
25    fn description(&self) -> &str {
26        "Search Discord message history. Returns messages matching a keyword query, optionally filtered by channel_id, author_id, or time range."
27    }
28
29    fn parameters_schema(&self) -> serde_json::Value {
30        json!({
31            "type": "object",
32            "properties": {
33                "query": {
34                    "type": "string",
35                    "description": "Keywords or phrase to search for in Discord messages (optional if since/until provided)"
36                },
37                "limit": {
38                    "type": "integer",
39                    "description": "Max results to return (default: 10)"
40                },
41                "channel_id": {
42                    "type": "string",
43                    "description": "Filter results to a specific Discord channel ID"
44                },
45                "since": {
46                    "type": "string",
47                    "description": "Filter messages at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
48                },
49                "until": {
50                    "type": "string",
51                    "description": "Filter messages at or before this time (RFC 3339)"
52                }
53            }
54        })
55    }
56
57    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
58        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
59        let channel_id = args.get("channel_id").and_then(|v| v.as_str());
60        let since = args.get("since").and_then(|v| v.as_str());
61        let until = args.get("until").and_then(|v| v.as_str());
62
63        if query.trim().is_empty() && since.is_none() && until.is_none() {
64            return Ok(ToolResult {
65                success: false,
66                output: String::new(),
67                error: Some(
68                    "Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
69                ),
70            });
71        }
72
73        if let Some(s) = since {
74            if chrono::DateTime::parse_from_rfc3339(s).is_err() {
75                return Ok(ToolResult {
76                    success: false,
77                    output: String::new(),
78                    error: Some(format!(
79                        "Invalid 'since' date: {s}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
80                    )),
81                });
82            }
83        }
84        if let Some(u) = until {
85            if chrono::DateTime::parse_from_rfc3339(u).is_err() {
86                return Ok(ToolResult {
87                    success: false,
88                    output: String::new(),
89                    error: Some(format!(
90                        "Invalid 'until' date: {u}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
91                    )),
92                });
93            }
94        }
95        if let (Some(s), Some(u)) = (since, until) {
96            if let (Ok(s_dt), Ok(u_dt)) = (
97                chrono::DateTime::parse_from_rfc3339(s),
98                chrono::DateTime::parse_from_rfc3339(u),
99            ) {
100                if s_dt >= u_dt {
101                    return Ok(ToolResult {
102                        success: false,
103                        output: String::new(),
104                        error: Some("'since' must be before 'until'".into()),
105                    });
106                }
107            }
108        }
109
110        #[allow(clippy::cast_possible_truncation)]
111        let limit = args
112            .get("limit")
113            .and_then(serde_json::Value::as_u64)
114            .map_or(10, |v| v as usize);
115
116        match self
117            .discord_memory
118            .recall(query, limit, channel_id, since, until)
119            .await
120        {
121            Ok(entries) if entries.is_empty() => Ok(ToolResult {
122                success: true,
123                output: "No Discord messages found.".into(),
124                error: None,
125            }),
126            Ok(entries) => {
127                let mut output = format!("Found {} Discord messages:\n", entries.len());
128                for entry in &entries {
129                    let score = entry
130                        .score
131                        .map_or_else(String::new, |s| format!(" [{s:.0}%]"));
132                    let _ = writeln!(output, "- {}{score}", entry.content);
133                }
134                Ok(ToolResult {
135                    success: true,
136                    output,
137                    error: None,
138                })
139            }
140            Err(e) => Ok(ToolResult {
141                success: false,
142                output: String::new(),
143                error: Some(format!("Discord search failed: {e}")),
144            }),
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::memory::NoneMemory;
153
154    fn noop_mem() -> Arc<dyn Memory> {
155        Arc::new(NoneMemory::new())
156    }
157
158    #[tokio::test]
159    async fn search_empty() {
160        let tool = DiscordSearchTool::new(noop_mem());
161        let result = tool.execute(json!({"query": "hello"})).await.unwrap();
162        assert!(result.success);
163        assert!(result.output.contains("No Discord messages found"));
164    }
165
166    #[tokio::test]
167    async fn search_requires_query_or_time() {
168        let tool = DiscordSearchTool::new(noop_mem());
169        let result = tool.execute(json!({})).await.unwrap();
170        assert!(!result.success);
171        assert!(result.error.as_ref().unwrap().contains("at least"));
172    }
173
174    #[test]
175    fn name_and_schema() {
176        let tool = DiscordSearchTool::new(noop_mem());
177        assert_eq!(tool.name(), "discord_search");
178        assert!(tool.parameters_schema()["properties"]["query"].is_object());
179    }
180}