Skip to main content

construct/tools/
notion_tool.rs

1use super::traits::{Tool, ToolResult};
2use crate::security::{SecurityPolicy, policy::ToolOperation};
3use async_trait::async_trait;
4use serde_json::json;
5use std::sync::Arc;
6
7const NOTION_API_BASE: &str = "https://api.notion.com/v1";
8const NOTION_VERSION: &str = "2022-06-28";
9const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
10/// Maximum number of characters to include from an error response body.
11const MAX_ERROR_BODY_CHARS: usize = 500;
12
13/// Tool for interacting with the Notion API — query databases, read/create/update pages,
14/// and search the workspace. Each action is gated by the appropriate security operation
15/// (Read for queries, Act for mutations).
16pub struct NotionTool {
17    api_key: String,
18    http: reqwest::Client,
19    security: Arc<SecurityPolicy>,
20}
21
22impl NotionTool {
23    /// Create a new Notion tool with the given API key and security policy.
24    pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
25        Self {
26            api_key,
27            http: reqwest::Client::new(),
28            security,
29        }
30    }
31
32    /// Build the standard Notion API headers (Authorization, version, content-type).
33    fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
34        let mut headers = reqwest::header::HeaderMap::new();
35        headers.insert(
36            "Authorization",
37            format!("Bearer {}", self.api_key)
38                .parse()
39                .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
40        );
41        headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
42        headers.insert("Content-Type", "application/json".parse().unwrap());
43        Ok(headers)
44    }
45
46    /// Query a Notion database with an optional filter.
47    async fn query_database(
48        &self,
49        database_id: &str,
50        filter: Option<&serde_json::Value>,
51    ) -> anyhow::Result<serde_json::Value> {
52        let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
53        let mut body = json!({});
54        if let Some(f) = filter {
55            body["filter"] = f.clone();
56        }
57        let resp = self
58            .http
59            .post(&url)
60            .headers(self.headers()?)
61            .json(&body)
62            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
63            .send()
64            .await?;
65        let status = resp.status();
66        if !status.is_success() {
67            let text = resp.text().await.unwrap_or_default();
68            let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
69            anyhow::bail!("Notion query_database failed ({status}): {truncated}");
70        }
71        resp.json().await.map_err(Into::into)
72    }
73
74    /// Read a single Notion page by ID.
75    async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
76        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
77        let resp = self
78            .http
79            .get(&url)
80            .headers(self.headers()?)
81            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
82            .send()
83            .await?;
84        let status = resp.status();
85        if !status.is_success() {
86            let text = resp.text().await.unwrap_or_default();
87            let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
88            anyhow::bail!("Notion read_page failed ({status}): {truncated}");
89        }
90        resp.json().await.map_err(Into::into)
91    }
92
93    /// Create a new Notion page, optionally within a database.
94    async fn create_page(
95        &self,
96        properties: &serde_json::Value,
97        database_id: Option<&str>,
98    ) -> anyhow::Result<serde_json::Value> {
99        let url = format!("{NOTION_API_BASE}/pages");
100        let mut body = json!({ "properties": properties });
101        if let Some(db_id) = database_id {
102            body["parent"] = json!({ "database_id": db_id });
103        }
104        let resp = self
105            .http
106            .post(&url)
107            .headers(self.headers()?)
108            .json(&body)
109            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
110            .send()
111            .await?;
112        let status = resp.status();
113        if !status.is_success() {
114            let text = resp.text().await.unwrap_or_default();
115            let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
116            anyhow::bail!("Notion create_page failed ({status}): {truncated}");
117        }
118        resp.json().await.map_err(Into::into)
119    }
120
121    /// Update an existing Notion page's properties.
122    async fn update_page(
123        &self,
124        page_id: &str,
125        properties: &serde_json::Value,
126    ) -> anyhow::Result<serde_json::Value> {
127        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
128        let body = json!({ "properties": properties });
129        let resp = self
130            .http
131            .patch(&url)
132            .headers(self.headers()?)
133            .json(&body)
134            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
135            .send()
136            .await?;
137        let status = resp.status();
138        if !status.is_success() {
139            let text = resp.text().await.unwrap_or_default();
140            let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
141            anyhow::bail!("Notion update_page failed ({status}): {truncated}");
142        }
143        resp.json().await.map_err(Into::into)
144    }
145
146    /// Search the Notion workspace by query string.
147    async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
148        let url = format!("{NOTION_API_BASE}/search");
149        let body = json!({ "query": query });
150        let resp = self
151            .http
152            .post(&url)
153            .headers(self.headers()?)
154            .json(&body)
155            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
156            .send()
157            .await?;
158        let status = resp.status();
159        if !status.is_success() {
160            let text = resp.text().await.unwrap_or_default();
161            let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
162            anyhow::bail!("Notion search failed ({status}): {truncated}");
163        }
164        resp.json().await.map_err(Into::into)
165    }
166}
167
168#[async_trait]
169impl Tool for NotionTool {
170    fn name(&self) -> &str {
171        "notion"
172    }
173
174    fn description(&self) -> &str {
175        "Interact with Notion: query databases, read/create/update pages, and search the workspace."
176    }
177
178    fn parameters_schema(&self) -> serde_json::Value {
179        json!({
180            "type": "object",
181            "properties": {
182                "action": {
183                    "type": "string",
184                    "enum": ["query_database", "read_page", "create_page", "update_page", "search"],
185                    "description": "The Notion API action to perform"
186                },
187                "database_id": {
188                    "type": "string",
189                    "description": "Database ID (required for query_database, optional for create_page)"
190                },
191                "page_id": {
192                    "type": "string",
193                    "description": "Page ID (required for read_page and update_page)"
194                },
195                "filter": {
196                    "type": "object",
197                    "description": "Notion filter object for query_database"
198                },
199                "properties": {
200                    "type": "object",
201                    "description": "Properties object for create_page and update_page"
202                },
203                "query": {
204                    "type": "string",
205                    "description": "Search query string for the search action"
206                }
207            },
208            "required": ["action"]
209        })
210    }
211
212    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
213        let action = match args.get("action").and_then(|v| v.as_str()) {
214            Some(a) => a,
215            None => {
216                return Ok(ToolResult {
217                    success: false,
218                    output: String::new(),
219                    error: Some("Missing required parameter: action".into()),
220                });
221            }
222        };
223
224        // Enforce granular security: Read for queries, Act for mutations
225        let operation = match action {
226            "query_database" | "read_page" | "search" => ToolOperation::Read,
227            "create_page" | "update_page" => ToolOperation::Act,
228            _ => {
229                return Ok(ToolResult {
230                    success: false,
231                    output: String::new(),
232                    error: Some(format!(
233                        "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
234                    )),
235                });
236            }
237        };
238
239        if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
240            return Ok(ToolResult {
241                success: false,
242                output: String::new(),
243                error: Some(error),
244            });
245        }
246
247        let result = match action {
248            "query_database" => {
249                let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
250                    Some(id) => id,
251                    None => {
252                        return Ok(ToolResult {
253                            success: false,
254                            output: String::new(),
255                            error: Some("query_database requires database_id parameter".into()),
256                        });
257                    }
258                };
259                let filter = args.get("filter");
260                self.query_database(database_id, filter).await
261            }
262            "read_page" => {
263                let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
264                    Some(id) => id,
265                    None => {
266                        return Ok(ToolResult {
267                            success: false,
268                            output: String::new(),
269                            error: Some("read_page requires page_id parameter".into()),
270                        });
271                    }
272                };
273                self.read_page(page_id).await
274            }
275            "create_page" => {
276                let properties = match args.get("properties") {
277                    Some(p) => p,
278                    None => {
279                        return Ok(ToolResult {
280                            success: false,
281                            output: String::new(),
282                            error: Some("create_page requires properties parameter".into()),
283                        });
284                    }
285                };
286                let database_id = args.get("database_id").and_then(|v| v.as_str());
287                self.create_page(properties, database_id).await
288            }
289            "update_page" => {
290                let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
291                    Some(id) => id,
292                    None => {
293                        return Ok(ToolResult {
294                            success: false,
295                            output: String::new(),
296                            error: Some("update_page requires page_id parameter".into()),
297                        });
298                    }
299                };
300                let properties = match args.get("properties") {
301                    Some(p) => p,
302                    None => {
303                        return Ok(ToolResult {
304                            success: false,
305                            output: String::new(),
306                            error: Some("update_page requires properties parameter".into()),
307                        });
308                    }
309                };
310                self.update_page(page_id, properties).await
311            }
312            "search" => {
313                let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
314                self.search(query).await
315            }
316            _ => unreachable!(), // Already handled above
317        };
318
319        match result {
320            Ok(value) => Ok(ToolResult {
321                success: true,
322                output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
323                error: None,
324            }),
325            Err(e) => Ok(ToolResult {
326                success: false,
327                output: String::new(),
328                error: Some(e.to_string()),
329            }),
330        }
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use crate::security::SecurityPolicy;
338
339    fn test_tool() -> NotionTool {
340        let security = Arc::new(SecurityPolicy::default());
341        NotionTool::new("test-key".into(), security)
342    }
343
344    #[test]
345    fn tool_name_is_notion() {
346        let tool = test_tool();
347        assert_eq!(tool.name(), "notion");
348    }
349
350    #[test]
351    fn parameters_schema_has_required_action() {
352        let tool = test_tool();
353        let schema = tool.parameters_schema();
354        let required = schema["required"].as_array().unwrap();
355        assert!(required.iter().any(|v| v.as_str() == Some("action")));
356    }
357
358    #[test]
359    fn parameters_schema_defines_all_actions() {
360        let tool = test_tool();
361        let schema = tool.parameters_schema();
362        let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
363        let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
364        assert!(action_strs.contains(&"query_database"));
365        assert!(action_strs.contains(&"read_page"));
366        assert!(action_strs.contains(&"create_page"));
367        assert!(action_strs.contains(&"update_page"));
368        assert!(action_strs.contains(&"search"));
369    }
370
371    #[tokio::test]
372    async fn execute_missing_action_returns_error() {
373        let tool = test_tool();
374        let result = tool.execute(json!({})).await.unwrap();
375        assert!(!result.success);
376        assert!(result.error.as_deref().unwrap().contains("action"));
377    }
378
379    #[tokio::test]
380    async fn execute_unknown_action_returns_error() {
381        let tool = test_tool();
382        let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
383        assert!(!result.success);
384        assert!(result.error.as_deref().unwrap().contains("Unknown action"));
385    }
386
387    #[tokio::test]
388    async fn execute_query_database_missing_id_returns_error() {
389        let tool = test_tool();
390        let result = tool
391            .execute(json!({"action": "query_database"}))
392            .await
393            .unwrap();
394        assert!(!result.success);
395        assert!(result.error.as_deref().unwrap().contains("database_id"));
396    }
397
398    #[tokio::test]
399    async fn execute_read_page_missing_id_returns_error() {
400        let tool = test_tool();
401        let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
402        assert!(!result.success);
403        assert!(result.error.as_deref().unwrap().contains("page_id"));
404    }
405
406    #[tokio::test]
407    async fn execute_create_page_missing_properties_returns_error() {
408        let tool = test_tool();
409        let result = tool
410            .execute(json!({"action": "create_page"}))
411            .await
412            .unwrap();
413        assert!(!result.success);
414        assert!(result.error.as_deref().unwrap().contains("properties"));
415    }
416
417    #[tokio::test]
418    async fn execute_update_page_missing_page_id_returns_error() {
419        let tool = test_tool();
420        let result = tool
421            .execute(json!({"action": "update_page", "properties": {}}))
422            .await
423            .unwrap();
424        assert!(!result.success);
425        assert!(result.error.as_deref().unwrap().contains("page_id"));
426    }
427
428    #[tokio::test]
429    async fn execute_update_page_missing_properties_returns_error() {
430        let tool = test_tool();
431        let result = tool
432            .execute(json!({"action": "update_page", "page_id": "test-id"}))
433            .await
434            .unwrap();
435        assert!(!result.success);
436        assert!(result.error.as_deref().unwrap().contains("properties"));
437    }
438}