1use super::traits::{Channel, ChannelMessage, SendMessage};
2use anyhow::{Result, bail};
3use async_trait::async_trait;
4use std::collections::HashSet;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8const NOTION_API_BASE: &str = "https://api.notion.com/v1";
9const NOTION_VERSION: &str = "2022-06-28";
10const MAX_RESULT_LENGTH: usize = 2000;
11const MAX_RETRIES: u32 = 3;
12const RETRY_BASE_DELAY_MS: u64 = 2000;
13const MAX_ERROR_BODY_CHARS: usize = 500;
15
16fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
18 if max_bytes >= s.len() {
19 return s.len();
20 }
21 let mut idx = max_bytes;
22 while idx > 0 && !s.is_char_boundary(idx) {
23 idx -= 1;
24 }
25 idx
26}
27
28pub struct NotionChannel {
34 api_key: String,
35 database_id: String,
36 poll_interval_secs: u64,
37 status_property: String,
38 input_property: String,
39 result_property: String,
40 max_concurrent: usize,
41 status_type: Arc<RwLock<String>>,
42 inflight: Arc<RwLock<HashSet<String>>>,
43 http: reqwest::Client,
44 recover_stale: bool,
45}
46
47impl NotionChannel {
48 pub fn new(
50 api_key: String,
51 database_id: String,
52 poll_interval_secs: u64,
53 status_property: String,
54 input_property: String,
55 result_property: String,
56 max_concurrent: usize,
57 recover_stale: bool,
58 ) -> Self {
59 Self {
60 api_key,
61 database_id,
62 poll_interval_secs,
63 status_property,
64 input_property,
65 result_property,
66 max_concurrent,
67 status_type: Arc::new(RwLock::new("select".to_string())),
68 inflight: Arc::new(RwLock::new(HashSet::new())),
69 http: reqwest::Client::new(),
70 recover_stale,
71 }
72 }
73
74 fn headers(&self) -> Result<reqwest::header::HeaderMap> {
76 let mut headers = reqwest::header::HeaderMap::new();
77 headers.insert(
78 "Authorization",
79 format!("Bearer {}", self.api_key)
80 .parse()
81 .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
82 );
83 headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
84 headers.insert("Content-Type", "application/json".parse().unwrap());
85 Ok(headers)
86 }
87
88 async fn api_call(
90 &self,
91 method: reqwest::Method,
92 url: &str,
93 body: Option<serde_json::Value>,
94 ) -> Result<serde_json::Value> {
95 let mut last_err = None;
96 for attempt in 0..MAX_RETRIES {
97 let mut req = self
98 .http
99 .request(method.clone(), url)
100 .headers(self.headers()?);
101 if let Some(ref b) = body {
102 req = req.json(b);
103 }
104 match req.send().await {
105 Ok(resp) => {
106 let status = resp.status();
107 if status.is_success() {
108 return resp
109 .json()
110 .await
111 .map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
112 }
113 let status_code = status.as_u16();
114 if status_code != 429 && (400..500).contains(&status_code) {
116 let body_text = resp.text().await.unwrap_or_default();
117 let truncated =
118 crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
119 bail!("Notion API error {status_code}: {truncated}");
120 }
121 last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
122 }
123 Err(e) => {
124 last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
125 }
126 }
127 let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
128 tracing::warn!(
129 "Notion API call failed (attempt {}/{}), retrying in {}ms",
130 attempt + 1,
131 MAX_RETRIES,
132 delay
133 );
134 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
135 }
136 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
137 }
138
139 async fn detect_status_type(&self) -> Result<String> {
141 let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
142 let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
143 let status_type = resp
144 .get("properties")
145 .and_then(|p| p.get(&self.status_property))
146 .and_then(|s| s.get("type"))
147 .and_then(|t| t.as_str())
148 .unwrap_or("select")
149 .to_string();
150 Ok(status_type)
151 }
152
153 async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
155 let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
156 let status_type = self.status_type.read().await.clone();
157 let filter = build_status_filter(&self.status_property, &status_type, "pending");
158 let resp = self
159 .api_call(
160 reqwest::Method::POST,
161 &url,
162 Some(serde_json::json!({ "filter": filter })),
163 )
164 .await?;
165 Ok(resp
166 .get("results")
167 .and_then(|r| r.as_array())
168 .cloned()
169 .unwrap_or_default())
170 }
171
172 async fn claim_task(&self, page_id: &str) -> bool {
174 let mut inflight = self.inflight.write().await;
175 if inflight.contains(page_id) {
176 return false;
177 }
178 if inflight.len() >= self.max_concurrent {
179 return false;
180 }
181 inflight.insert(page_id.to_string());
182 true
183 }
184
185 async fn release_task(&self, page_id: &str) {
187 let mut inflight = self.inflight.write().await;
188 inflight.remove(page_id);
189 }
190
191 async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
193 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
194 let status_type = self.status_type.read().await.clone();
195 let payload = serde_json::json!({
196 "properties": {
197 &self.status_property: build_status_payload(&status_type, status_value),
198 }
199 });
200 self.api_call(reqwest::Method::PATCH, &url, Some(payload))
201 .await?;
202 Ok(())
203 }
204
205 async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
207 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
208 let payload = serde_json::json!({
209 "properties": {
210 &self.result_property: build_rich_text_payload(result_text),
211 }
212 });
213 self.api_call(reqwest::Method::PATCH, &url, Some(payload))
214 .await?;
215 Ok(())
216 }
217
218 async fn recover_stale(&self) -> Result<()> {
220 let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
221 let status_type = self.status_type.read().await.clone();
222 let filter = build_status_filter(&self.status_property, &status_type, "running");
223 let resp = self
224 .api_call(
225 reqwest::Method::POST,
226 &url,
227 Some(serde_json::json!({ "filter": filter })),
228 )
229 .await?;
230 let stale = resp
231 .get("results")
232 .and_then(|r| r.as_array())
233 .cloned()
234 .unwrap_or_default();
235 if stale.is_empty() {
236 return Ok(());
237 }
238 tracing::warn!(
239 "Found {} stale task(s) in 'running' state, resetting to 'pending'",
240 stale.len()
241 );
242 for task in &stale {
243 if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
244 let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
245 let payload = serde_json::json!({
246 "properties": {
247 &self.status_property: build_status_payload(&status_type, "pending"),
248 &self.result_property: build_rich_text_payload(
249 "Reset: poller restarted while task was running"
250 ),
251 }
252 });
253 let short_id_end = floor_utf8_char_boundary(page_id, 8);
254 let short_id = &page_id[..short_id_end];
255 if let Err(e) = self
256 .api_call(reqwest::Method::PATCH, &page_url, Some(payload))
257 .await
258 {
259 tracing::error!("Could not reset stale task {short_id}: {e}");
260 } else {
261 tracing::info!("Reset stale task {short_id} to pending");
262 }
263 }
264 }
265 Ok(())
266 }
267}
268
269#[async_trait]
270impl Channel for NotionChannel {
271 fn name(&self) -> &str {
272 "notion"
273 }
274
275 async fn send(&self, message: &SendMessage) -> Result<()> {
276 let page_id = &message.recipient;
278 let status_type = self.status_type.read().await.clone();
279 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
280 let payload = serde_json::json!({
281 "properties": {
282 &self.status_property: build_status_payload(&status_type, "done"),
283 &self.result_property: build_rich_text_payload(&message.content),
284 }
285 });
286 self.api_call(reqwest::Method::PATCH, &url, Some(payload))
287 .await?;
288 self.release_task(page_id).await;
289 Ok(())
290 }
291
292 async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
293 match self.detect_status_type().await {
295 Ok(st) => {
296 tracing::info!("Notion status property type: {st}");
297 *self.status_type.write().await = st;
298 }
299 Err(e) => {
300 bail!("Failed to detect Notion database schema: {e}");
301 }
302 }
303
304 if self.recover_stale {
306 if let Err(e) = self.recover_stale().await {
307 tracing::error!("Notion stale task recovery failed: {e}");
308 }
309 }
310
311 loop {
313 match self.query_pending().await {
314 Ok(tasks) => {
315 if !tasks.is_empty() {
316 tracing::info!("Notion: found {} pending task(s)", tasks.len());
317 }
318 for task in tasks {
319 let page_id = match task.get("id").and_then(|v| v.as_str()) {
320 Some(id) => id.to_string(),
321 None => continue,
322 };
323
324 let input_text = extract_text_from_property(
325 task.get("properties")
326 .and_then(|p| p.get(&self.input_property)),
327 );
328
329 if input_text.trim().is_empty() {
330 let short_end = floor_utf8_char_boundary(&page_id, 8);
331 tracing::warn!(
332 "Notion: empty input for task {}, skipping",
333 &page_id[..short_end]
334 );
335 continue;
336 }
337
338 if !self.claim_task(&page_id).await {
339 continue;
340 }
341
342 if let Err(e) = self.set_status(&page_id, "running").await {
344 tracing::error!("Notion: failed to set running status: {e}");
345 self.release_task(&page_id).await;
346 continue;
347 }
348
349 let timestamp = std::time::SystemTime::now()
350 .duration_since(std::time::UNIX_EPOCH)
351 .unwrap_or_default()
352 .as_secs();
353
354 if tx
355 .send(ChannelMessage {
356 id: page_id.clone(),
357 sender: "notion".into(),
358 reply_target: page_id,
359 content: input_text,
360 channel: "notion".into(),
361 timestamp,
362 thread_ts: None,
363 interruption_scope_id: None,
364 attachments: vec![],
365 })
366 .await
367 .is_err()
368 {
369 tracing::info!("Notion channel shutting down");
370 return Ok(());
371 }
372 }
373 }
374 Err(e) => {
375 tracing::error!("Notion poll error: {e}");
376 }
377 }
378
379 tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
380 }
381 }
382
383 async fn health_check(&self) -> bool {
384 let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
385 self.api_call(reqwest::Method::GET, &url, None)
386 .await
387 .is_ok()
388 }
389}
390
391fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
395 if status_type == "status" {
396 serde_json::json!({
397 "property": property,
398 "status": { "equals": value }
399 })
400 } else {
401 serde_json::json!({
402 "property": property,
403 "select": { "equals": value }
404 })
405 }
406}
407
408fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
410 if status_type == "status" {
411 serde_json::json!({ "status": { "name": value } })
412 } else {
413 serde_json::json!({ "select": { "name": value } })
414 }
415}
416
417fn build_rich_text_payload(value: &str) -> serde_json::Value {
419 let truncated = truncate_result(value);
420 serde_json::json!({
421 "rich_text": [{
422 "text": { "content": truncated }
423 }]
424 })
425}
426
427fn truncate_result(value: &str) -> String {
429 if value.len() <= MAX_RESULT_LENGTH {
430 return value.to_string();
431 }
432 let cut = MAX_RESULT_LENGTH.saturating_sub(30);
433 let end = floor_utf8_char_boundary(value, cut);
435 format!("{}\n\n... [output truncated]", &value[..end])
436}
437
438fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
440 let Some(prop) = prop else {
441 return String::new();
442 };
443 let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
444 let array_key = match ptype {
445 "title" => "title",
446 "rich_text" => "rich_text",
447 _ => return String::new(),
448 };
449 prop.get(array_key)
450 .and_then(|arr| arr.as_array())
451 .map(|items| {
452 items
453 .iter()
454 .filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
455 .collect::<Vec<_>>()
456 .join("")
457 })
458 .unwrap_or_default()
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[tokio::test]
466 async fn claim_task_deduplication() {
467 let channel = NotionChannel::new(
468 "test-key".into(),
469 "test-db".into(),
470 5,
471 "Status".into(),
472 "Input".into(),
473 "Result".into(),
474 4,
475 false,
476 );
477
478 assert!(channel.claim_task("page-1").await);
479 assert!(!channel.claim_task("page-1").await);
481 assert!(channel.claim_task("page-2").await);
483
484 channel.release_task("page-1").await;
486 assert!(channel.claim_task("page-1").await);
487 }
488
489 #[test]
490 fn result_truncation_within_limit() {
491 let short = "hello world";
492 assert_eq!(truncate_result(short), short);
493 }
494
495 #[test]
496 fn result_truncation_over_limit() {
497 let long = "a".repeat(MAX_RESULT_LENGTH + 100);
498 let truncated = truncate_result(&long);
499 assert!(truncated.len() <= MAX_RESULT_LENGTH);
500 assert!(truncated.ends_with("... [output truncated]"));
501 }
502
503 #[test]
504 fn result_truncation_multibyte_safe() {
505 let mut s = String::new();
507 for _ in 0..700 {
508 s.push('\u{6E2C}'); }
510 let truncated = truncate_result(&s);
511 assert!(truncated.len() <= MAX_RESULT_LENGTH);
513 assert!(truncated.ends_with("... [output truncated]"));
514 }
515
516 #[test]
517 fn status_payload_select_type() {
518 let payload = build_status_payload("select", "pending");
519 assert_eq!(
520 payload,
521 serde_json::json!({ "select": { "name": "pending" } })
522 );
523 }
524
525 #[test]
526 fn status_payload_status_type() {
527 let payload = build_status_payload("status", "done");
528 assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
529 }
530
531 #[test]
532 fn rich_text_payload_construction() {
533 let payload = build_rich_text_payload("test output");
534 let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
535 assert_eq!(text, "test output");
536 }
537
538 #[test]
539 fn status_filter_select_type() {
540 let filter = build_status_filter("Status", "select", "pending");
541 assert_eq!(
542 filter,
543 serde_json::json!({
544 "property": "Status",
545 "select": { "equals": "pending" }
546 })
547 );
548 }
549
550 #[test]
551 fn status_filter_status_type() {
552 let filter = build_status_filter("Status", "status", "running");
553 assert_eq!(
554 filter,
555 serde_json::json!({
556 "property": "Status",
557 "status": { "equals": "running" }
558 })
559 );
560 }
561
562 #[test]
563 fn extract_text_from_title_property() {
564 let prop = serde_json::json!({
565 "type": "title",
566 "title": [
567 { "plain_text": "Hello " },
568 { "plain_text": "World" }
569 ]
570 });
571 assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
572 }
573
574 #[test]
575 fn extract_text_from_rich_text_property() {
576 let prop = serde_json::json!({
577 "type": "rich_text",
578 "rich_text": [{ "plain_text": "task content" }]
579 });
580 assert_eq!(extract_text_from_property(Some(&prop)), "task content");
581 }
582
583 #[test]
584 fn extract_text_from_none() {
585 assert_eq!(extract_text_from_property(None), "");
586 }
587
588 #[test]
589 fn extract_text_from_unknown_type() {
590 let prop = serde_json::json!({ "type": "number", "number": 42 });
591 assert_eq!(extract_text_from_property(Some(&prop)), "");
592 }
593
594 #[tokio::test]
595 async fn claim_task_respects_max_concurrent() {
596 let channel = NotionChannel::new(
597 "test-key".into(),
598 "test-db".into(),
599 5,
600 "Status".into(),
601 "Input".into(),
602 "Result".into(),
603 2, false,
605 );
606
607 assert!(channel.claim_task("page-1").await);
608 assert!(channel.claim_task("page-2").await);
609 assert!(!channel.claim_task("page-3").await);
611
612 channel.release_task("page-1").await;
614 assert!(channel.claim_task("page-3").await);
615 }
616}