use crate::types::{Artifact, Message, PushNotificationConfig, Task, TaskState, TaskStatus};
use chrono::Utc;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
pub struct TaskRecord {
pub task: Task,
pub push_configs: HashMap<String, PushNotificationConfig>,
pub cancel_requested: bool,
}
impl TaskRecord {
pub fn new(task: Task) -> Self {
Self {
task,
push_configs: HashMap::new(),
cancel_requested: false,
}
}
}
#[async_trait::async_trait]
pub trait TaskStore: Send + Sync {
async fn create(&self, task: Task) -> TaskRecord;
async fn get(&self, task_id: &str) -> Option<TaskRecord>;
async fn update_state(&self, task_id: &str, state: TaskState) -> Option<TaskRecord>;
async fn append_artifacts(&self, task_id: &str, artifacts: Vec<Artifact>) -> Option<TaskRecord>;
async fn append_history(&self, task_id: &str, message: Message) -> Option<TaskRecord>;
async fn list(&self, filter: ListFilter) -> Vec<TaskRecord>;
async fn request_cancel(&self, task_id: &str) -> Option<TaskRecord>;
async fn add_push_config(
&self,
task_id: &str,
config_id: String,
config: PushNotificationConfig,
) -> Option<TaskRecord>;
async fn remove_push_config(&self, task_id: &str, config_id: &str) -> Option<TaskRecord>;
async fn get_push_config(
&self,
task_id: &str,
config_id: &str,
) -> Option<PushNotificationConfig>;
async fn list_push_configs(&self, task_id: &str) -> Vec<(String, PushNotificationConfig)>;
}
#[derive(Debug, Clone, Default)]
pub struct ListFilter {
pub context_id: Option<String>,
pub state: Option<TaskState>,
pub limit: Option<u32>,
}
#[derive(Default)]
pub struct InMemoryTaskStore {
inner: RwLock<HashMap<String, TaskRecord>>,
}
impl InMemoryTaskStore {
pub fn new() -> Self {
Self::default()
}
pub fn shared() -> Arc<Self> {
Arc::new(Self::new())
}
}
#[async_trait::async_trait]
impl TaskStore for InMemoryTaskStore {
async fn create(&self, task: Task) -> TaskRecord {
let record = TaskRecord::new(task);
self.inner
.write()
.await
.insert(record.task.id.clone(), record.clone());
record
}
async fn get(&self, task_id: &str) -> Option<TaskRecord> {
self.inner.read().await.get(task_id).cloned()
}
async fn update_state(&self, task_id: &str, state: TaskState) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
if record.task.status.state.is_terminal() {
return Some(record.clone());
}
record.task.status = TaskStatus {
state,
message: None,
timestamp: Utc::now(),
};
Some(record.clone())
}
async fn append_artifacts(
&self,
task_id: &str,
artifacts: Vec<Artifact>,
) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
record.task.artifacts.extend(artifacts);
Some(record.clone())
}
async fn append_history(&self, task_id: &str, message: Message) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
record.task.history.push(message);
Some(record.clone())
}
async fn list(&self, filter: ListFilter) -> Vec<TaskRecord> {
let guard = self.inner.read().await;
let mut out: Vec<TaskRecord> = guard
.values()
.filter(|r| match &filter.context_id {
Some(ctx) => &r.task.context_id == ctx,
None => true,
})
.filter(|r| match filter.state {
Some(s) => r.task.status.state == s,
None => true,
})
.cloned()
.collect();
out.sort_by(|a, b| b.task.status.timestamp.cmp(&a.task.status.timestamp));
if let Some(limit) = filter.limit {
out.truncate(limit as usize);
}
out
}
async fn request_cancel(&self, task_id: &str) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
if !record.task.status.state.is_terminal() {
record.cancel_requested = true;
record.task.status = TaskStatus {
state: TaskState::Canceled,
message: None,
timestamp: Utc::now(),
};
}
Some(record.clone())
}
async fn add_push_config(
&self,
task_id: &str,
config_id: String,
config: PushNotificationConfig,
) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
record.push_configs.insert(config_id, config);
Some(record.clone())
}
async fn remove_push_config(&self, task_id: &str, config_id: &str) -> Option<TaskRecord> {
let mut guard = self.inner.write().await;
let record = guard.get_mut(task_id)?;
record.push_configs.remove(config_id);
Some(record.clone())
}
async fn get_push_config(
&self,
task_id: &str,
config_id: &str,
) -> Option<PushNotificationConfig> {
self.inner
.read()
.await
.get(task_id)
.and_then(|r| r.push_configs.get(config_id).cloned())
}
async fn list_push_configs(&self, task_id: &str) -> Vec<(String, PushNotificationConfig)> {
self.inner
.read()
.await
.get(task_id)
.map(|r| {
r.push_configs
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
})
.unwrap_or_default()
}
}
#[derive(Default)]
pub struct AbortRegistry {
inner: Mutex<HashMap<String, CancellationToken>>,
}
impl AbortRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(&self, task_id: String, token: CancellationToken) {
self.inner.lock().await.insert(task_id, token);
}
pub async fn clear(&self, task_id: &str) {
self.inner.lock().await.remove(task_id);
}
pub async fn abort(&self, task_id: &str) -> bool {
match self.inner.lock().await.remove(task_id) {
Some(token) => {
token.cancel();
true
}
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::task_with_status;
#[tokio::test]
async fn terminal_state_guard_blocks_clobber() {
let store = InMemoryTaskStore::new();
store
.create(task_with_status(
"t-guard".into(),
"ctx".into(),
TaskState::Submitted,
vec![],
vec![],
))
.await;
store
.update_state("t-guard", TaskState::Canceled)
.await
.unwrap();
let after = store
.update_state("t-guard", TaskState::Completed)
.await
.unwrap();
assert_eq!(after.task.status.state, TaskState::Canceled);
}
#[tokio::test]
async fn abort_registry_signals_token() {
let registry = AbortRegistry::new();
let token = CancellationToken::new();
registry
.register("t-1".into(), token.clone())
.await;
let aborted = registry.abort("t-1").await;
assert!(aborted);
assert!(token.is_cancelled());
assert!(!registry.abort("t-1").await);
}
#[tokio::test]
async fn create_and_update_lifecycle() {
let store = InMemoryTaskStore::new();
let task = task_with_status(
"t-1".into(),
"ctx-1".into(),
TaskState::Submitted,
vec![],
vec![],
);
store.create(task).await;
let updated = store
.update_state("t-1", TaskState::Working)
.await
.expect("found");
assert_eq!(updated.task.status.state, TaskState::Working);
let canceled = store.request_cancel("t-1").await.expect("found");
assert_eq!(canceled.task.status.state, TaskState::Canceled);
assert!(canceled.cancel_requested);
}
#[tokio::test]
async fn list_filters_by_state_and_limit() {
let store = InMemoryTaskStore::new();
for i in 0..5 {
let task = task_with_status(
format!("t-{}", i),
"ctx-a".into(),
if i % 2 == 0 {
TaskState::Working
} else {
TaskState::Completed
},
vec![],
vec![],
);
store.create(task).await;
}
let working = store
.list(ListFilter {
state: Some(TaskState::Working),
..Default::default()
})
.await;
assert_eq!(working.len(), 3);
let limited = store
.list(ListFilter {
limit: Some(2),
..Default::default()
})
.await;
assert_eq!(limited.len(), 2);
}
#[tokio::test]
async fn push_config_crud() {
let store = InMemoryTaskStore::new();
let task = task_with_status(
"t-9".into(),
"ctx-9".into(),
TaskState::Submitted,
vec![],
vec![],
);
store.create(task).await;
store
.add_push_config(
"t-9",
"cfg-1".into(),
PushNotificationConfig {
url: "https://example.com/hook".into(),
token: None,
authentication: None,
},
)
.await
.expect("task");
let cfgs = store.list_push_configs("t-9").await;
assert_eq!(cfgs.len(), 1);
assert_eq!(cfgs[0].0, "cfg-1");
store
.remove_push_config("t-9", "cfg-1")
.await
.expect("task");
assert!(store.list_push_configs("t-9").await.is_empty());
}
}