Skip to main content

brainwires_mcp_server/
tasks.rs

1//! MCP Tasks primitive (SEP-1686).
2//!
3//! Provides a standardised lifecycle for long-running asynchronous tool calls:
4//!
5//! ```text
6//! Working → Completed
7//!          ↘ Failed
8//!          ↘ Cancelled
9//! Working → InputRequired → Working (loop)
10//! ```
11//!
12//! [`McpTaskStore`] is a thread-safe in-memory store with optional TTL-based
13//! expiry. Wire it into [`McpServer`](crate::McpServer) to expose
14//! `tasks/create`, `tasks/get`, and `tasks/cancel` JSON-RPC methods.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use tokio::sync::RwLock;
21use uuid::Uuid;
22
23/// Default maximum number of retries before a task transitions to `Failed`.
24pub const DEFAULT_MAX_RETRIES: u32 = 3;
25
26/// All possible states in the MCP task lifecycle.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum McpTaskState {
29    /// Task is actively running.
30    Working,
31    /// Task is paused waiting for additional input from the caller.
32    InputRequired,
33    /// Task finished successfully.
34    Completed,
35    /// Task finished with an error.
36    Failed,
37    /// Task was explicitly cancelled by the caller.
38    Cancelled,
39}
40
41impl std::fmt::Display for McpTaskState {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        let s = match self {
44            Self::Working => "working",
45            Self::InputRequired => "input_required",
46            Self::Completed => "completed",
47            Self::Failed => "failed",
48            Self::Cancelled => "cancelled",
49        };
50        write!(f, "{}", s)
51    }
52}
53
54/// A single MCP task entry.
55#[derive(Debug, Clone)]
56pub struct McpTask {
57    /// Unique task identifier (UUID v4).
58    pub id: String,
59    /// Current lifecycle state.
60    pub state: McpTaskState,
61    /// Wall-clock creation time.
62    pub created_at: Instant,
63    /// When this task entry expires and will be evicted (if set).
64    pub expires_at: Option<Instant>,
65    /// Result payload for completed tasks.
66    pub result: Option<serde_json::Value>,
67    /// Error message for failed tasks.
68    pub error: Option<String>,
69    /// Number of times execution has been retried.
70    pub retry_count: u32,
71    /// Maximum allowed retries before the task automatically fails.
72    pub max_retries: u32,
73}
74
75impl McpTask {
76    /// Create a new task in the `Working` state.
77    pub fn new() -> Self {
78        Self {
79            id: Uuid::new_v4().to_string(),
80            state: McpTaskState::Working,
81            created_at: Instant::now(),
82            expires_at: None,
83            result: None,
84            error: None,
85            retry_count: 0,
86            max_retries: DEFAULT_MAX_RETRIES,
87        }
88    }
89
90    /// Set a TTL so the store evicts this task after the given duration.
91    pub fn with_ttl(mut self, ttl: Duration) -> Self {
92        self.expires_at = Some(Instant::now() + ttl);
93        self
94    }
95
96    /// Whether this task entry has expired.
97    pub fn is_expired(&self) -> bool {
98        self.expires_at
99            .map(|exp| Instant::now() >= exp)
100            .unwrap_or(false)
101    }
102}
103
104impl Default for McpTask {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110/// Thread-safe in-memory store for [`McpTask`] entries.
111///
112/// Spawn with [`McpTaskStore::new`]; wire JSON-RPC dispatch into
113/// [`McpServer`](crate::McpServer) by calling [`insert`], [`get`], and
114/// [`cancel`] from your handler's `tasks/*` method implementations.
115///
116/// [`insert`]: McpTaskStore::insert
117/// [`get`]: McpTaskStore::get
118/// [`cancel`]: McpTaskStore::cancel
119#[derive(Clone)]
120pub struct McpTaskStore {
121    inner: Arc<RwLock<HashMap<String, McpTask>>>,
122}
123
124impl McpTaskStore {
125    /// Create a new empty store.
126    pub fn new() -> Self {
127        Self {
128            inner: Arc::new(RwLock::new(HashMap::new())),
129        }
130    }
131
132    /// Insert a task and return its ID.
133    pub async fn insert(&self, task: McpTask) -> String {
134        let id = task.id.clone();
135        self.inner.write().await.insert(id.clone(), task);
136        id
137    }
138
139    /// Retrieve a task by ID, returning `None` if not found or expired.
140    pub async fn get(&self, id: &str) -> Option<McpTask> {
141        let map = self.inner.read().await;
142        let task = map.get(id)?;
143        if task.is_expired() {
144            None
145        } else {
146            Some(task.clone())
147        }
148    }
149
150    /// Transition a task to `Cancelled`. Returns `false` if the task is not
151    /// found, expired, or already in a terminal state.
152    pub async fn cancel(&self, id: &str) -> bool {
153        let mut map = self.inner.write().await;
154        match map.get_mut(id) {
155            Some(task)
156                if !task.is_expired()
157                    && !matches!(
158                        task.state,
159                        McpTaskState::Completed | McpTaskState::Failed | McpTaskState::Cancelled
160                    ) =>
161            {
162                task.state = McpTaskState::Cancelled;
163                true
164            }
165            _ => false,
166        }
167    }
168
169    /// Update the state of a task. Returns `false` if the task is not found or expired.
170    pub async fn update_state(&self, id: &str, state: McpTaskState) -> bool {
171        let mut map = self.inner.write().await;
172        match map.get_mut(id) {
173            Some(task) if !task.is_expired() => {
174                task.state = state;
175                true
176            }
177            _ => false,
178        }
179    }
180
181    /// Set a completed result on a task, transitioning it to `Completed`.
182    pub async fn complete(&self, id: &str, result: serde_json::Value) -> bool {
183        let mut map = self.inner.write().await;
184        match map.get_mut(id) {
185            Some(task) if !task.is_expired() => {
186                task.state = McpTaskState::Completed;
187                task.result = Some(result);
188                true
189            }
190            _ => false,
191        }
192    }
193
194    /// Fail a task with an error message, transitioning it to `Failed`.
195    pub async fn fail(&self, id: &str, error: impl Into<String>) -> bool {
196        let mut map = self.inner.write().await;
197        match map.get_mut(id) {
198            Some(task) if !task.is_expired() => {
199                task.state = McpTaskState::Failed;
200                task.error = Some(error.into());
201                true
202            }
203            _ => false,
204        }
205    }
206
207    /// Evict all expired task entries.
208    pub async fn evict_expired(&self) -> usize {
209        let mut map = self.inner.write().await;
210        let before = map.len();
211        map.retain(|_, task| !task.is_expired());
212        before - map.len()
213    }
214
215    /// Return the number of tasks currently in the store (including expired).
216    pub async fn len(&self) -> usize {
217        self.inner.read().await.len()
218    }
219
220    /// Return `true` if the store has no tasks.
221    pub async fn is_empty(&self) -> bool {
222        self.inner.read().await.is_empty()
223    }
224}
225
226impl Default for McpTaskStore {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[tokio::test]
237    async fn test_task_lifecycle_working_to_completed() {
238        let store = McpTaskStore::new();
239        let task = McpTask::new();
240        let id = store.insert(task).await;
241
242        assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Working);
243        store.complete(&id, serde_json::json!({"ok": true})).await;
244        assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Completed);
245    }
246
247    #[tokio::test]
248    async fn test_task_lifecycle_working_to_failed() {
249        let store = McpTaskStore::new();
250        let id = store.insert(McpTask::new()).await;
251        store.fail(&id, "timeout").await;
252        let task = store.get(&id).await.unwrap();
253        assert_eq!(task.state, McpTaskState::Failed);
254        assert_eq!(task.error.as_deref(), Some("timeout"));
255    }
256
257    #[tokio::test]
258    async fn test_task_lifecycle_working_to_cancelled() {
259        let store = McpTaskStore::new();
260        let id = store.insert(McpTask::new()).await;
261        assert!(store.cancel(&id).await);
262        assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Cancelled);
263    }
264
265    #[tokio::test]
266    async fn test_cancel_terminal_task_returns_false() {
267        let store = McpTaskStore::new();
268        let id = store.insert(McpTask::new()).await;
269        store.complete(&id, serde_json::json!({})).await;
270        // Already completed — cancel should return false
271        assert!(!store.cancel(&id).await);
272    }
273
274    #[tokio::test]
275    async fn test_input_required_state() {
276        let store = McpTaskStore::new();
277        let id = store.insert(McpTask::new()).await;
278        store.update_state(&id, McpTaskState::InputRequired).await;
279        assert_eq!(
280            store.get(&id).await.unwrap().state,
281            McpTaskState::InputRequired
282        );
283    }
284
285    #[tokio::test]
286    async fn test_ttl_expiry_eviction() {
287        let store = McpTaskStore::new();
288        let task = McpTask::new().with_ttl(Duration::from_millis(1));
289        let id = store.insert(task).await;
290
291        // Wait for TTL to elapse
292        tokio::time::sleep(Duration::from_millis(5)).await;
293
294        // get() returns None for expired task
295        assert!(store.get(&id).await.is_none());
296
297        // evict_expired cleans up the backing map
298        let evicted = store.evict_expired().await;
299        assert_eq!(evicted, 1);
300        assert_eq!(store.len().await, 0);
301    }
302}