1use super::traits::{Tool, ToolResult};
2use crate::security::{SecurityPolicy, policy::ToolOperation};
3use async_trait::async_trait;
4use serde_json::json;
5use std::sync::Arc;
6
7const NOTION_API_BASE: &str = "https://api.notion.com/v1";
8const NOTION_VERSION: &str = "2022-06-28";
9const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
10const MAX_ERROR_BODY_CHARS: usize = 500;
12
13pub struct NotionTool {
17 api_key: String,
18 http: reqwest::Client,
19 security: Arc<SecurityPolicy>,
20}
21
22impl NotionTool {
23 pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
25 Self {
26 api_key,
27 http: reqwest::Client::new(),
28 security,
29 }
30 }
31
32 fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
34 let mut headers = reqwest::header::HeaderMap::new();
35 headers.insert(
36 "Authorization",
37 format!("Bearer {}", self.api_key)
38 .parse()
39 .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
40 );
41 headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
42 headers.insert("Content-Type", "application/json".parse().unwrap());
43 Ok(headers)
44 }
45
46 async fn query_database(
48 &self,
49 database_id: &str,
50 filter: Option<&serde_json::Value>,
51 ) -> anyhow::Result<serde_json::Value> {
52 let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
53 let mut body = json!({});
54 if let Some(f) = filter {
55 body["filter"] = f.clone();
56 }
57 let resp = self
58 .http
59 .post(&url)
60 .headers(self.headers()?)
61 .json(&body)
62 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
63 .send()
64 .await?;
65 let status = resp.status();
66 if !status.is_success() {
67 let text = resp.text().await.unwrap_or_default();
68 let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
69 anyhow::bail!("Notion query_database failed ({status}): {truncated}");
70 }
71 resp.json().await.map_err(Into::into)
72 }
73
74 async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
76 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
77 let resp = self
78 .http
79 .get(&url)
80 .headers(self.headers()?)
81 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
82 .send()
83 .await?;
84 let status = resp.status();
85 if !status.is_success() {
86 let text = resp.text().await.unwrap_or_default();
87 let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
88 anyhow::bail!("Notion read_page failed ({status}): {truncated}");
89 }
90 resp.json().await.map_err(Into::into)
91 }
92
93 async fn create_page(
95 &self,
96 properties: &serde_json::Value,
97 database_id: Option<&str>,
98 ) -> anyhow::Result<serde_json::Value> {
99 let url = format!("{NOTION_API_BASE}/pages");
100 let mut body = json!({ "properties": properties });
101 if let Some(db_id) = database_id {
102 body["parent"] = json!({ "database_id": db_id });
103 }
104 let resp = self
105 .http
106 .post(&url)
107 .headers(self.headers()?)
108 .json(&body)
109 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
110 .send()
111 .await?;
112 let status = resp.status();
113 if !status.is_success() {
114 let text = resp.text().await.unwrap_or_default();
115 let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
116 anyhow::bail!("Notion create_page failed ({status}): {truncated}");
117 }
118 resp.json().await.map_err(Into::into)
119 }
120
121 async fn update_page(
123 &self,
124 page_id: &str,
125 properties: &serde_json::Value,
126 ) -> anyhow::Result<serde_json::Value> {
127 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
128 let body = json!({ "properties": properties });
129 let resp = self
130 .http
131 .patch(&url)
132 .headers(self.headers()?)
133 .json(&body)
134 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
135 .send()
136 .await?;
137 let status = resp.status();
138 if !status.is_success() {
139 let text = resp.text().await.unwrap_or_default();
140 let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
141 anyhow::bail!("Notion update_page failed ({status}): {truncated}");
142 }
143 resp.json().await.map_err(Into::into)
144 }
145
146 async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
148 let url = format!("{NOTION_API_BASE}/search");
149 let body = json!({ "query": query });
150 let resp = self
151 .http
152 .post(&url)
153 .headers(self.headers()?)
154 .json(&body)
155 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
156 .send()
157 .await?;
158 let status = resp.status();
159 if !status.is_success() {
160 let text = resp.text().await.unwrap_or_default();
161 let truncated = crate::util::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
162 anyhow::bail!("Notion search failed ({status}): {truncated}");
163 }
164 resp.json().await.map_err(Into::into)
165 }
166}
167
168#[async_trait]
169impl Tool for NotionTool {
170 fn name(&self) -> &str {
171 "notion"
172 }
173
174 fn description(&self) -> &str {
175 "Interact with Notion: query databases, read/create/update pages, and search the workspace."
176 }
177
178 fn parameters_schema(&self) -> serde_json::Value {
179 json!({
180 "type": "object",
181 "properties": {
182 "action": {
183 "type": "string",
184 "enum": ["query_database", "read_page", "create_page", "update_page", "search"],
185 "description": "The Notion API action to perform"
186 },
187 "database_id": {
188 "type": "string",
189 "description": "Database ID (required for query_database, optional for create_page)"
190 },
191 "page_id": {
192 "type": "string",
193 "description": "Page ID (required for read_page and update_page)"
194 },
195 "filter": {
196 "type": "object",
197 "description": "Notion filter object for query_database"
198 },
199 "properties": {
200 "type": "object",
201 "description": "Properties object for create_page and update_page"
202 },
203 "query": {
204 "type": "string",
205 "description": "Search query string for the search action"
206 }
207 },
208 "required": ["action"]
209 })
210 }
211
212 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
213 let action = match args.get("action").and_then(|v| v.as_str()) {
214 Some(a) => a,
215 None => {
216 return Ok(ToolResult {
217 success: false,
218 output: String::new(),
219 error: Some("Missing required parameter: action".into()),
220 });
221 }
222 };
223
224 let operation = match action {
226 "query_database" | "read_page" | "search" => ToolOperation::Read,
227 "create_page" | "update_page" => ToolOperation::Act,
228 _ => {
229 return Ok(ToolResult {
230 success: false,
231 output: String::new(),
232 error: Some(format!(
233 "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
234 )),
235 });
236 }
237 };
238
239 if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
240 return Ok(ToolResult {
241 success: false,
242 output: String::new(),
243 error: Some(error),
244 });
245 }
246
247 let result = match action {
248 "query_database" => {
249 let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
250 Some(id) => id,
251 None => {
252 return Ok(ToolResult {
253 success: false,
254 output: String::new(),
255 error: Some("query_database requires database_id parameter".into()),
256 });
257 }
258 };
259 let filter = args.get("filter");
260 self.query_database(database_id, filter).await
261 }
262 "read_page" => {
263 let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
264 Some(id) => id,
265 None => {
266 return Ok(ToolResult {
267 success: false,
268 output: String::new(),
269 error: Some("read_page requires page_id parameter".into()),
270 });
271 }
272 };
273 self.read_page(page_id).await
274 }
275 "create_page" => {
276 let properties = match args.get("properties") {
277 Some(p) => p,
278 None => {
279 return Ok(ToolResult {
280 success: false,
281 output: String::new(),
282 error: Some("create_page requires properties parameter".into()),
283 });
284 }
285 };
286 let database_id = args.get("database_id").and_then(|v| v.as_str());
287 self.create_page(properties, database_id).await
288 }
289 "update_page" => {
290 let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
291 Some(id) => id,
292 None => {
293 return Ok(ToolResult {
294 success: false,
295 output: String::new(),
296 error: Some("update_page requires page_id parameter".into()),
297 });
298 }
299 };
300 let properties = match args.get("properties") {
301 Some(p) => p,
302 None => {
303 return Ok(ToolResult {
304 success: false,
305 output: String::new(),
306 error: Some("update_page requires properties parameter".into()),
307 });
308 }
309 };
310 self.update_page(page_id, properties).await
311 }
312 "search" => {
313 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
314 self.search(query).await
315 }
316 _ => unreachable!(), };
318
319 match result {
320 Ok(value) => Ok(ToolResult {
321 success: true,
322 output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
323 error: None,
324 }),
325 Err(e) => Ok(ToolResult {
326 success: false,
327 output: String::new(),
328 error: Some(e.to_string()),
329 }),
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::security::SecurityPolicy;
338
339 fn test_tool() -> NotionTool {
340 let security = Arc::new(SecurityPolicy::default());
341 NotionTool::new("test-key".into(), security)
342 }
343
344 #[test]
345 fn tool_name_is_notion() {
346 let tool = test_tool();
347 assert_eq!(tool.name(), "notion");
348 }
349
350 #[test]
351 fn parameters_schema_has_required_action() {
352 let tool = test_tool();
353 let schema = tool.parameters_schema();
354 let required = schema["required"].as_array().unwrap();
355 assert!(required.iter().any(|v| v.as_str() == Some("action")));
356 }
357
358 #[test]
359 fn parameters_schema_defines_all_actions() {
360 let tool = test_tool();
361 let schema = tool.parameters_schema();
362 let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
363 let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
364 assert!(action_strs.contains(&"query_database"));
365 assert!(action_strs.contains(&"read_page"));
366 assert!(action_strs.contains(&"create_page"));
367 assert!(action_strs.contains(&"update_page"));
368 assert!(action_strs.contains(&"search"));
369 }
370
371 #[tokio::test]
372 async fn execute_missing_action_returns_error() {
373 let tool = test_tool();
374 let result = tool.execute(json!({})).await.unwrap();
375 assert!(!result.success);
376 assert!(result.error.as_deref().unwrap().contains("action"));
377 }
378
379 #[tokio::test]
380 async fn execute_unknown_action_returns_error() {
381 let tool = test_tool();
382 let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
383 assert!(!result.success);
384 assert!(result.error.as_deref().unwrap().contains("Unknown action"));
385 }
386
387 #[tokio::test]
388 async fn execute_query_database_missing_id_returns_error() {
389 let tool = test_tool();
390 let result = tool
391 .execute(json!({"action": "query_database"}))
392 .await
393 .unwrap();
394 assert!(!result.success);
395 assert!(result.error.as_deref().unwrap().contains("database_id"));
396 }
397
398 #[tokio::test]
399 async fn execute_read_page_missing_id_returns_error() {
400 let tool = test_tool();
401 let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
402 assert!(!result.success);
403 assert!(result.error.as_deref().unwrap().contains("page_id"));
404 }
405
406 #[tokio::test]
407 async fn execute_create_page_missing_properties_returns_error() {
408 let tool = test_tool();
409 let result = tool
410 .execute(json!({"action": "create_page"}))
411 .await
412 .unwrap();
413 assert!(!result.success);
414 assert!(result.error.as_deref().unwrap().contains("properties"));
415 }
416
417 #[tokio::test]
418 async fn execute_update_page_missing_page_id_returns_error() {
419 let tool = test_tool();
420 let result = tool
421 .execute(json!({"action": "update_page", "properties": {}}))
422 .await
423 .unwrap();
424 assert!(!result.success);
425 assert!(result.error.as_deref().unwrap().contains("page_id"));
426 }
427
428 #[tokio::test]
429 async fn execute_update_page_missing_properties_returns_error() {
430 let tool = test_tool();
431 let result = tool
432 .execute(json!({"action": "update_page", "page_id": "test-id"}))
433 .await
434 .unwrap();
435 assert!(!result.success);
436 assert!(result.error.as_deref().unwrap().contains("properties"));
437 }
438}