brainwires_mcp_server/
tasks.rs1use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use tokio::sync::RwLock;
21use uuid::Uuid;
22
23pub const DEFAULT_MAX_RETRIES: u32 = 3;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum McpTaskState {
29 Working,
31 InputRequired,
33 Completed,
35 Failed,
37 Cancelled,
39}
40
41impl std::fmt::Display for McpTaskState {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 let s = match self {
44 Self::Working => "working",
45 Self::InputRequired => "input_required",
46 Self::Completed => "completed",
47 Self::Failed => "failed",
48 Self::Cancelled => "cancelled",
49 };
50 write!(f, "{}", s)
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct McpTask {
57 pub id: String,
59 pub state: McpTaskState,
61 pub created_at: Instant,
63 pub expires_at: Option<Instant>,
65 pub result: Option<serde_json::Value>,
67 pub error: Option<String>,
69 pub retry_count: u32,
71 pub max_retries: u32,
73}
74
75impl McpTask {
76 pub fn new() -> Self {
78 Self {
79 id: Uuid::new_v4().to_string(),
80 state: McpTaskState::Working,
81 created_at: Instant::now(),
82 expires_at: None,
83 result: None,
84 error: None,
85 retry_count: 0,
86 max_retries: DEFAULT_MAX_RETRIES,
87 }
88 }
89
90 pub fn with_ttl(mut self, ttl: Duration) -> Self {
92 self.expires_at = Some(Instant::now() + ttl);
93 self
94 }
95
96 pub fn is_expired(&self) -> bool {
98 self.expires_at
99 .map(|exp| Instant::now() >= exp)
100 .unwrap_or(false)
101 }
102}
103
104impl Default for McpTask {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[derive(Clone)]
120pub struct McpTaskStore {
121 inner: Arc<RwLock<HashMap<String, McpTask>>>,
122}
123
124impl McpTaskStore {
125 pub fn new() -> Self {
127 Self {
128 inner: Arc::new(RwLock::new(HashMap::new())),
129 }
130 }
131
132 pub async fn insert(&self, task: McpTask) -> String {
134 let id = task.id.clone();
135 self.inner.write().await.insert(id.clone(), task);
136 id
137 }
138
139 pub async fn get(&self, id: &str) -> Option<McpTask> {
141 let map = self.inner.read().await;
142 let task = map.get(id)?;
143 if task.is_expired() {
144 None
145 } else {
146 Some(task.clone())
147 }
148 }
149
150 pub async fn cancel(&self, id: &str) -> bool {
153 let mut map = self.inner.write().await;
154 match map.get_mut(id) {
155 Some(task)
156 if !task.is_expired()
157 && !matches!(
158 task.state,
159 McpTaskState::Completed | McpTaskState::Failed | McpTaskState::Cancelled
160 ) =>
161 {
162 task.state = McpTaskState::Cancelled;
163 true
164 }
165 _ => false,
166 }
167 }
168
169 pub async fn update_state(&self, id: &str, state: McpTaskState) -> bool {
171 let mut map = self.inner.write().await;
172 match map.get_mut(id) {
173 Some(task) if !task.is_expired() => {
174 task.state = state;
175 true
176 }
177 _ => false,
178 }
179 }
180
181 pub async fn complete(&self, id: &str, result: serde_json::Value) -> bool {
183 let mut map = self.inner.write().await;
184 match map.get_mut(id) {
185 Some(task) if !task.is_expired() => {
186 task.state = McpTaskState::Completed;
187 task.result = Some(result);
188 true
189 }
190 _ => false,
191 }
192 }
193
194 pub async fn fail(&self, id: &str, error: impl Into<String>) -> bool {
196 let mut map = self.inner.write().await;
197 match map.get_mut(id) {
198 Some(task) if !task.is_expired() => {
199 task.state = McpTaskState::Failed;
200 task.error = Some(error.into());
201 true
202 }
203 _ => false,
204 }
205 }
206
207 pub async fn evict_expired(&self) -> usize {
209 let mut map = self.inner.write().await;
210 let before = map.len();
211 map.retain(|_, task| !task.is_expired());
212 before - map.len()
213 }
214
215 pub async fn len(&self) -> usize {
217 self.inner.read().await.len()
218 }
219
220 pub async fn is_empty(&self) -> bool {
222 self.inner.read().await.is_empty()
223 }
224}
225
226impl Default for McpTaskStore {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[tokio::test]
237 async fn test_task_lifecycle_working_to_completed() {
238 let store = McpTaskStore::new();
239 let task = McpTask::new();
240 let id = store.insert(task).await;
241
242 assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Working);
243 store.complete(&id, serde_json::json!({"ok": true})).await;
244 assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Completed);
245 }
246
247 #[tokio::test]
248 async fn test_task_lifecycle_working_to_failed() {
249 let store = McpTaskStore::new();
250 let id = store.insert(McpTask::new()).await;
251 store.fail(&id, "timeout").await;
252 let task = store.get(&id).await.unwrap();
253 assert_eq!(task.state, McpTaskState::Failed);
254 assert_eq!(task.error.as_deref(), Some("timeout"));
255 }
256
257 #[tokio::test]
258 async fn test_task_lifecycle_working_to_cancelled() {
259 let store = McpTaskStore::new();
260 let id = store.insert(McpTask::new()).await;
261 assert!(store.cancel(&id).await);
262 assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Cancelled);
263 }
264
265 #[tokio::test]
266 async fn test_cancel_terminal_task_returns_false() {
267 let store = McpTaskStore::new();
268 let id = store.insert(McpTask::new()).await;
269 store.complete(&id, serde_json::json!({})).await;
270 assert!(!store.cancel(&id).await);
272 }
273
274 #[tokio::test]
275 async fn test_input_required_state() {
276 let store = McpTaskStore::new();
277 let id = store.insert(McpTask::new()).await;
278 store.update_state(&id, McpTaskState::InputRequired).await;
279 assert_eq!(
280 store.get(&id).await.unwrap().state,
281 McpTaskState::InputRequired
282 );
283 }
284
285 #[tokio::test]
286 async fn test_ttl_expiry_eviction() {
287 let store = McpTaskStore::new();
288 let task = McpTask::new().with_ttl(Duration::from_millis(1));
289 let id = store.insert(task).await;
290
291 tokio::time::sleep(Duration::from_millis(5)).await;
293
294 assert!(store.get(&id).await.is_none());
296
297 let evicted = store.evict_expired().await;
299 assert_eq!(evicted, 1);
300 assert_eq!(store.len().await, 0);
301 }
302}