Skip to main content

construct/channels/
notion.rs

1use super::traits::{Channel, ChannelMessage, SendMessage};
2use anyhow::{Result, bail};
3use async_trait::async_trait;
4use std::collections::HashSet;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8const NOTION_API_BASE: &str = "https://api.notion.com/v1";
9const NOTION_VERSION: &str = "2022-06-28";
10const MAX_RESULT_LENGTH: usize = 2000;
11const MAX_RETRIES: u32 = 3;
12const RETRY_BASE_DELAY_MS: u64 = 2000;
13/// Maximum number of characters to include from an error response body.
14const MAX_ERROR_BODY_CHARS: usize = 500;
15
16/// Find the largest byte index <= `max_bytes` that falls on a UTF-8 char boundary.
17fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
18    if max_bytes >= s.len() {
19        return s.len();
20    }
21    let mut idx = max_bytes;
22    while idx > 0 && !s.is_char_boundary(idx) {
23        idx -= 1;
24    }
25    idx
26}
27
28/// Notion channel — polls a Notion database for pending tasks and writes results back.
29///
30/// The channel connects to the Notion API, queries a database for rows with a "pending"
31/// status, dispatches them as channel messages, and writes results back when processing
32/// completes. It supports crash recovery by resetting stale "running" tasks on startup.
33pub struct NotionChannel {
34    api_key: String,
35    database_id: String,
36    poll_interval_secs: u64,
37    status_property: String,
38    input_property: String,
39    result_property: String,
40    max_concurrent: usize,
41    status_type: Arc<RwLock<String>>,
42    inflight: Arc<RwLock<HashSet<String>>>,
43    http: reqwest::Client,
44    recover_stale: bool,
45}
46
47impl NotionChannel {
48    /// Create a new Notion channel with the given configuration.
49    pub fn new(
50        api_key: String,
51        database_id: String,
52        poll_interval_secs: u64,
53        status_property: String,
54        input_property: String,
55        result_property: String,
56        max_concurrent: usize,
57        recover_stale: bool,
58    ) -> Self {
59        Self {
60            api_key,
61            database_id,
62            poll_interval_secs,
63            status_property,
64            input_property,
65            result_property,
66            max_concurrent,
67            status_type: Arc::new(RwLock::new("select".to_string())),
68            inflight: Arc::new(RwLock::new(HashSet::new())),
69            http: reqwest::Client::new(),
70            recover_stale,
71        }
72    }
73
74    /// Build the standard Notion API headers (Authorization, version, content-type).
75    fn headers(&self) -> Result<reqwest::header::HeaderMap> {
76        let mut headers = reqwest::header::HeaderMap::new();
77        headers.insert(
78            "Authorization",
79            format!("Bearer {}", self.api_key)
80                .parse()
81                .map_err(|e| anyhow::anyhow!("Invalid Notion API key header value: {e}"))?,
82        );
83        headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
84        headers.insert("Content-Type", "application/json".parse().unwrap());
85        Ok(headers)
86    }
87
88    /// Make a Notion API call with automatic retry on rate-limit (429) and server errors (5xx).
89    async fn api_call(
90        &self,
91        method: reqwest::Method,
92        url: &str,
93        body: Option<serde_json::Value>,
94    ) -> Result<serde_json::Value> {
95        let mut last_err = None;
96        for attempt in 0..MAX_RETRIES {
97            let mut req = self
98                .http
99                .request(method.clone(), url)
100                .headers(self.headers()?);
101            if let Some(ref b) = body {
102                req = req.json(b);
103            }
104            match req.send().await {
105                Ok(resp) => {
106                    let status = resp.status();
107                    if status.is_success() {
108                        return resp
109                            .json()
110                            .await
111                            .map_err(|e| anyhow::anyhow!("Failed to parse response: {e}"));
112                    }
113                    let status_code = status.as_u16();
114                    // Only retry on 429 (rate limit) or 5xx (server errors)
115                    if status_code != 429 && (400..500).contains(&status_code) {
116                        let body_text = resp.text().await.unwrap_or_default();
117                        let truncated =
118                            crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
119                        bail!("Notion API error {status_code}: {truncated}");
120                    }
121                    last_err = Some(anyhow::anyhow!("Notion API error: {status_code}"));
122                }
123                Err(e) => {
124                    last_err = Some(anyhow::anyhow!("HTTP request failed: {e}"));
125                }
126            }
127            let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
128            tracing::warn!(
129                "Notion API call failed (attempt {}/{}), retrying in {}ms",
130                attempt + 1,
131                MAX_RETRIES,
132                delay
133            );
134            tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
135        }
136        Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Notion API call failed after retries")))
137    }
138
139    /// Query the database schema and detect whether Status uses "select" or "status" type.
140    async fn detect_status_type(&self) -> Result<String> {
141        let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
142        let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
143        let status_type = resp
144            .get("properties")
145            .and_then(|p| p.get(&self.status_property))
146            .and_then(|s| s.get("type"))
147            .and_then(|t| t.as_str())
148            .unwrap_or("select")
149            .to_string();
150        Ok(status_type)
151    }
152
153    /// Query for rows where Status = "pending".
154    async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
155        let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
156        let status_type = self.status_type.read().await.clone();
157        let filter = build_status_filter(&self.status_property, &status_type, "pending");
158        let resp = self
159            .api_call(
160                reqwest::Method::POST,
161                &url,
162                Some(serde_json::json!({ "filter": filter })),
163            )
164            .await?;
165        Ok(resp
166            .get("results")
167            .and_then(|r| r.as_array())
168            .cloned()
169            .unwrap_or_default())
170    }
171
172    /// Atomically claim a task. Returns true if this caller got it.
173    async fn claim_task(&self, page_id: &str) -> bool {
174        let mut inflight = self.inflight.write().await;
175        if inflight.contains(page_id) {
176            return false;
177        }
178        if inflight.len() >= self.max_concurrent {
179            return false;
180        }
181        inflight.insert(page_id.to_string());
182        true
183    }
184
185    /// Release a task from the inflight set.
186    async fn release_task(&self, page_id: &str) {
187        let mut inflight = self.inflight.write().await;
188        inflight.remove(page_id);
189    }
190
191    /// Update a row's status.
192    async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
193        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
194        let status_type = self.status_type.read().await.clone();
195        let payload = serde_json::json!({
196            "properties": {
197                &self.status_property: build_status_payload(&status_type, status_value),
198            }
199        });
200        self.api_call(reqwest::Method::PATCH, &url, Some(payload))
201            .await?;
202        Ok(())
203    }
204
205    /// Write result text to the Result column.
206    async fn set_result(&self, page_id: &str, result_text: &str) -> Result<()> {
207        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
208        let payload = serde_json::json!({
209            "properties": {
210                &self.result_property: build_rich_text_payload(result_text),
211            }
212        });
213        self.api_call(reqwest::Method::PATCH, &url, Some(payload))
214            .await?;
215        Ok(())
216    }
217
218    /// On startup, reset "running" tasks back to "pending" for crash recovery.
219    async fn recover_stale(&self) -> Result<()> {
220        let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
221        let status_type = self.status_type.read().await.clone();
222        let filter = build_status_filter(&self.status_property, &status_type, "running");
223        let resp = self
224            .api_call(
225                reqwest::Method::POST,
226                &url,
227                Some(serde_json::json!({ "filter": filter })),
228            )
229            .await?;
230        let stale = resp
231            .get("results")
232            .and_then(|r| r.as_array())
233            .cloned()
234            .unwrap_or_default();
235        if stale.is_empty() {
236            return Ok(());
237        }
238        tracing::warn!(
239            "Found {} stale task(s) in 'running' state, resetting to 'pending'",
240            stale.len()
241        );
242        for task in &stale {
243            if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
244                let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
245                let payload = serde_json::json!({
246                    "properties": {
247                        &self.status_property: build_status_payload(&status_type, "pending"),
248                        &self.result_property: build_rich_text_payload(
249                            "Reset: poller restarted while task was running"
250                        ),
251                    }
252                });
253                let short_id_end = floor_utf8_char_boundary(page_id, 8);
254                let short_id = &page_id[..short_id_end];
255                if let Err(e) = self
256                    .api_call(reqwest::Method::PATCH, &page_url, Some(payload))
257                    .await
258                {
259                    tracing::error!("Could not reset stale task {short_id}: {e}");
260                } else {
261                    tracing::info!("Reset stale task {short_id} to pending");
262                }
263            }
264        }
265        Ok(())
266    }
267}
268
269#[async_trait]
270impl Channel for NotionChannel {
271    fn name(&self) -> &str {
272        "notion"
273    }
274
275    async fn send(&self, message: &SendMessage) -> Result<()> {
276        // recipient is the page_id for Notion
277        let page_id = &message.recipient;
278        let status_type = self.status_type.read().await.clone();
279        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
280        let payload = serde_json::json!({
281            "properties": {
282                &self.status_property: build_status_payload(&status_type, "done"),
283                &self.result_property: build_rich_text_payload(&message.content),
284            }
285        });
286        self.api_call(reqwest::Method::PATCH, &url, Some(payload))
287            .await?;
288        self.release_task(page_id).await;
289        Ok(())
290    }
291
292    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
293        // Detect status property type
294        match self.detect_status_type().await {
295            Ok(st) => {
296                tracing::info!("Notion status property type: {st}");
297                *self.status_type.write().await = st;
298            }
299            Err(e) => {
300                bail!("Failed to detect Notion database schema: {e}");
301            }
302        }
303
304        // Crash recovery
305        if self.recover_stale {
306            if let Err(e) = self.recover_stale().await {
307                tracing::error!("Notion stale task recovery failed: {e}");
308            }
309        }
310
311        // Polling loop
312        loop {
313            match self.query_pending().await {
314                Ok(tasks) => {
315                    if !tasks.is_empty() {
316                        tracing::info!("Notion: found {} pending task(s)", tasks.len());
317                    }
318                    for task in tasks {
319                        let page_id = match task.get("id").and_then(|v| v.as_str()) {
320                            Some(id) => id.to_string(),
321                            None => continue,
322                        };
323
324                        let input_text = extract_text_from_property(
325                            task.get("properties")
326                                .and_then(|p| p.get(&self.input_property)),
327                        );
328
329                        if input_text.trim().is_empty() {
330                            let short_end = floor_utf8_char_boundary(&page_id, 8);
331                            tracing::warn!(
332                                "Notion: empty input for task {}, skipping",
333                                &page_id[..short_end]
334                            );
335                            continue;
336                        }
337
338                        if !self.claim_task(&page_id).await {
339                            continue;
340                        }
341
342                        // Set status to running
343                        if let Err(e) = self.set_status(&page_id, "running").await {
344                            tracing::error!("Notion: failed to set running status: {e}");
345                            self.release_task(&page_id).await;
346                            continue;
347                        }
348
349                        let timestamp = std::time::SystemTime::now()
350                            .duration_since(std::time::UNIX_EPOCH)
351                            .unwrap_or_default()
352                            .as_secs();
353
354                        if tx
355                            .send(ChannelMessage {
356                                id: page_id.clone(),
357                                sender: "notion".into(),
358                                reply_target: page_id,
359                                content: input_text,
360                                channel: "notion".into(),
361                                timestamp,
362                                thread_ts: None,
363                                interruption_scope_id: None,
364                                attachments: vec![],
365                            })
366                            .await
367                            .is_err()
368                        {
369                            tracing::info!("Notion channel shutting down");
370                            return Ok(());
371                        }
372                    }
373                }
374                Err(e) => {
375                    tracing::error!("Notion poll error: {e}");
376                }
377            }
378
379            tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
380        }
381    }
382
383    async fn health_check(&self) -> bool {
384        let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
385        self.api_call(reqwest::Method::GET, &url, None)
386            .await
387            .is_ok()
388    }
389}
390
391// ── Helper functions ──────────────────────────────────────────────
392
393/// Build a Notion API filter object for the given status property.
394fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
395    if status_type == "status" {
396        serde_json::json!({
397            "property": property,
398            "status": { "equals": value }
399        })
400    } else {
401        serde_json::json!({
402            "property": property,
403            "select": { "equals": value }
404        })
405    }
406}
407
408/// Build a Notion API property-update payload for a status field.
409fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
410    if status_type == "status" {
411        serde_json::json!({ "status": { "name": value } })
412    } else {
413        serde_json::json!({ "select": { "name": value } })
414    }
415}
416
417/// Build a Notion API rich-text property payload, truncating if necessary.
418fn build_rich_text_payload(value: &str) -> serde_json::Value {
419    let truncated = truncate_result(value);
420    serde_json::json!({
421        "rich_text": [{
422            "text": { "content": truncated }
423        }]
424    })
425}
426
427/// Truncate result text to fit within the Notion rich-text content limit.
428fn truncate_result(value: &str) -> String {
429    if value.len() <= MAX_RESULT_LENGTH {
430        return value.to_string();
431    }
432    let cut = MAX_RESULT_LENGTH.saturating_sub(30);
433    // Ensure we cut on a char boundary
434    let end = floor_utf8_char_boundary(value, cut);
435    format!("{}\n\n... [output truncated]", &value[..end])
436}
437
438/// Extract plain text from a Notion property (title or rich_text type).
439fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
440    let Some(prop) = prop else {
441        return String::new();
442    };
443    let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
444    let array_key = match ptype {
445        "title" => "title",
446        "rich_text" => "rich_text",
447        _ => return String::new(),
448    };
449    prop.get(array_key)
450        .and_then(|arr| arr.as_array())
451        .map(|items| {
452            items
453                .iter()
454                .filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
455                .collect::<Vec<_>>()
456                .join("")
457        })
458        .unwrap_or_default()
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[tokio::test]
466    async fn claim_task_deduplication() {
467        let channel = NotionChannel::new(
468            "test-key".into(),
469            "test-db".into(),
470            5,
471            "Status".into(),
472            "Input".into(),
473            "Result".into(),
474            4,
475            false,
476        );
477
478        assert!(channel.claim_task("page-1").await);
479        // Second claim for same page should fail
480        assert!(!channel.claim_task("page-1").await);
481        // Different page should succeed
482        assert!(channel.claim_task("page-2").await);
483
484        // After release, can claim again
485        channel.release_task("page-1").await;
486        assert!(channel.claim_task("page-1").await);
487    }
488
489    #[test]
490    fn result_truncation_within_limit() {
491        let short = "hello world";
492        assert_eq!(truncate_result(short), short);
493    }
494
495    #[test]
496    fn result_truncation_over_limit() {
497        let long = "a".repeat(MAX_RESULT_LENGTH + 100);
498        let truncated = truncate_result(&long);
499        assert!(truncated.len() <= MAX_RESULT_LENGTH);
500        assert!(truncated.ends_with("... [output truncated]"));
501    }
502
503    #[test]
504    fn result_truncation_multibyte_safe() {
505        // Build a string that would cut in the middle of a multibyte char
506        let mut s = String::new();
507        for _ in 0..700 {
508            s.push('\u{6E2C}'); // 3-byte UTF-8 char
509        }
510        let truncated = truncate_result(&s);
511        // Should not panic and should be valid UTF-8
512        assert!(truncated.len() <= MAX_RESULT_LENGTH);
513        assert!(truncated.ends_with("... [output truncated]"));
514    }
515
516    #[test]
517    fn status_payload_select_type() {
518        let payload = build_status_payload("select", "pending");
519        assert_eq!(
520            payload,
521            serde_json::json!({ "select": { "name": "pending" } })
522        );
523    }
524
525    #[test]
526    fn status_payload_status_type() {
527        let payload = build_status_payload("status", "done");
528        assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
529    }
530
531    #[test]
532    fn rich_text_payload_construction() {
533        let payload = build_rich_text_payload("test output");
534        let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
535        assert_eq!(text, "test output");
536    }
537
538    #[test]
539    fn status_filter_select_type() {
540        let filter = build_status_filter("Status", "select", "pending");
541        assert_eq!(
542            filter,
543            serde_json::json!({
544                "property": "Status",
545                "select": { "equals": "pending" }
546            })
547        );
548    }
549
550    #[test]
551    fn status_filter_status_type() {
552        let filter = build_status_filter("Status", "status", "running");
553        assert_eq!(
554            filter,
555            serde_json::json!({
556                "property": "Status",
557                "status": { "equals": "running" }
558            })
559        );
560    }
561
562    #[test]
563    fn extract_text_from_title_property() {
564        let prop = serde_json::json!({
565            "type": "title",
566            "title": [
567                { "plain_text": "Hello " },
568                { "plain_text": "World" }
569            ]
570        });
571        assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
572    }
573
574    #[test]
575    fn extract_text_from_rich_text_property() {
576        let prop = serde_json::json!({
577            "type": "rich_text",
578            "rich_text": [{ "plain_text": "task content" }]
579        });
580        assert_eq!(extract_text_from_property(Some(&prop)), "task content");
581    }
582
583    #[test]
584    fn extract_text_from_none() {
585        assert_eq!(extract_text_from_property(None), "");
586    }
587
588    #[test]
589    fn extract_text_from_unknown_type() {
590        let prop = serde_json::json!({ "type": "number", "number": 42 });
591        assert_eq!(extract_text_from_property(Some(&prop)), "");
592    }
593
594    #[tokio::test]
595    async fn claim_task_respects_max_concurrent() {
596        let channel = NotionChannel::new(
597            "test-key".into(),
598            "test-db".into(),
599            5,
600            "Status".into(),
601            "Input".into(),
602            "Result".into(),
603            2, // max_concurrent = 2
604            false,
605        );
606
607        assert!(channel.claim_task("page-1").await);
608        assert!(channel.claim_task("page-2").await);
609        // Third claim should be rejected (at capacity)
610        assert!(!channel.claim_task("page-3").await);
611
612        // After releasing one, can claim again
613        channel.release_task("page-1").await;
614        assert!(channel.claim_task("page-3").await);
615    }
616}