Skip to main content

brainwires_stores/
task_store.rs

1//! Task Store - Persists tasks via a backend-agnostic storage layer.
2//!
3//! Also includes agent state persistence for background task agents.
4
5use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_core::{Task, TaskPriority, TaskStatus};
9use brainwires_storage::databases::{
10    FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
11};
12
13const TASK_TABLE: &str = "tasks";
14const AGENT_STATE_TABLE: &str = "agent_states";
15
16// ── Schema helpers ──────────────────────────────────────────────────────
17
18fn tasks_field_defs() -> Vec<FieldDef> {
19    vec![
20        FieldDef::required("task_id", FieldType::Utf8),
21        FieldDef::required("conversation_id", FieldType::Utf8),
22        FieldDef::optional("plan_id", FieldType::Utf8),
23        FieldDef::required("description", FieldType::Utf8),
24        FieldDef::required("status", FieldType::Utf8),
25        FieldDef::optional("parent_id", FieldType::Utf8),
26        FieldDef::required("children", FieldType::Utf8), // JSON array
27        FieldDef::required("depends_on", FieldType::Utf8), // JSON array
28        FieldDef::required("priority", FieldType::Utf8),
29        FieldDef::optional("assigned_to", FieldType::Utf8),
30        FieldDef::required("iterations", FieldType::Int32),
31        FieldDef::optional("summary", FieldType::Utf8),
32        FieldDef::required("created_at", FieldType::Int64),
33        FieldDef::required("updated_at", FieldType::Int64),
34        FieldDef::optional("started_at", FieldType::Int64),
35        FieldDef::optional("completed_at", FieldType::Int64),
36    ]
37}
38
39fn agent_states_field_defs() -> Vec<FieldDef> {
40    vec![
41        FieldDef::required("agent_id", FieldType::Utf8),
42        FieldDef::required("task_id", FieldType::Utf8),
43        FieldDef::required("conversation_id", FieldType::Utf8),
44        FieldDef::required("status", FieldType::Utf8),
45        FieldDef::required("iteration", FieldType::Int32),
46        FieldDef::required("context_json", FieldType::Utf8),
47        FieldDef::required("created_at", FieldType::Int64),
48        FieldDef::required("updated_at", FieldType::Int64),
49    ]
50}
51
52// ── Record conversion helpers ───────────────────────────────────────────
53
54fn task_to_record(m: &TaskMetadata) -> Record {
55    vec![
56        ("task_id".into(), FieldValue::Utf8(Some(m.task_id.clone()))),
57        (
58            "conversation_id".into(),
59            FieldValue::Utf8(Some(m.conversation_id.clone())),
60        ),
61        ("plan_id".into(), FieldValue::Utf8(m.plan_id.clone())),
62        (
63            "description".into(),
64            FieldValue::Utf8(Some(m.description.clone())),
65        ),
66        ("status".into(), FieldValue::Utf8(Some(m.status.clone()))),
67        ("parent_id".into(), FieldValue::Utf8(m.parent_id.clone())),
68        (
69            "children".into(),
70            FieldValue::Utf8(Some(m.children.clone())),
71        ),
72        (
73            "depends_on".into(),
74            FieldValue::Utf8(Some(m.depends_on.clone())),
75        ),
76        (
77            "priority".into(),
78            FieldValue::Utf8(Some(m.priority.clone())),
79        ),
80        (
81            "assigned_to".into(),
82            FieldValue::Utf8(m.assigned_to.clone()),
83        ),
84        ("iterations".into(), FieldValue::Int32(Some(m.iterations))),
85        ("summary".into(), FieldValue::Utf8(m.summary.clone())),
86        ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
87        ("updated_at".into(), FieldValue::Int64(Some(m.updated_at))),
88        ("started_at".into(), FieldValue::Int64(m.started_at)),
89        ("completed_at".into(), FieldValue::Int64(m.completed_at)),
90    ]
91}
92
93fn task_from_record(r: &Record) -> Result<TaskMetadata> {
94    Ok(TaskMetadata {
95        task_id: record_get(r, "task_id")
96            .and_then(|v| v.as_str())
97            .context("missing task_id")?
98            .to_string(),
99        conversation_id: record_get(r, "conversation_id")
100            .and_then(|v| v.as_str())
101            .context("missing conversation_id")?
102            .to_string(),
103        plan_id: record_get(r, "plan_id")
104            .and_then(|v| v.as_str())
105            .map(String::from),
106        description: record_get(r, "description")
107            .and_then(|v| v.as_str())
108            .context("missing description")?
109            .to_string(),
110        status: record_get(r, "status")
111            .and_then(|v| v.as_str())
112            .context("missing status")?
113            .to_string(),
114        parent_id: record_get(r, "parent_id")
115            .and_then(|v| v.as_str())
116            .map(String::from),
117        children: record_get(r, "children")
118            .and_then(|v| v.as_str())
119            .context("missing children")?
120            .to_string(),
121        depends_on: record_get(r, "depends_on")
122            .and_then(|v| v.as_str())
123            .context("missing depends_on")?
124            .to_string(),
125        priority: record_get(r, "priority")
126            .and_then(|v| v.as_str())
127            .context("missing priority")?
128            .to_string(),
129        assigned_to: record_get(r, "assigned_to")
130            .and_then(|v| v.as_str())
131            .map(String::from),
132        iterations: record_get(r, "iterations")
133            .and_then(|v| v.as_i32())
134            .context("missing iterations")?,
135        summary: record_get(r, "summary")
136            .and_then(|v| v.as_str())
137            .map(String::from),
138        created_at: record_get(r, "created_at")
139            .and_then(|v| v.as_i64())
140            .context("missing created_at")?,
141        updated_at: record_get(r, "updated_at")
142            .and_then(|v| v.as_i64())
143            .context("missing updated_at")?,
144        started_at: record_get(r, "started_at").and_then(|v| v.as_i64()),
145        completed_at: record_get(r, "completed_at").and_then(|v| v.as_i64()),
146    })
147}
148
149fn state_to_record(s: &AgentStateMetadata) -> Record {
150    vec![
151        (
152            "agent_id".into(),
153            FieldValue::Utf8(Some(s.agent_id.clone())),
154        ),
155        ("task_id".into(), FieldValue::Utf8(Some(s.task_id.clone()))),
156        (
157            "conversation_id".into(),
158            FieldValue::Utf8(Some(s.conversation_id.clone())),
159        ),
160        ("status".into(), FieldValue::Utf8(Some(s.status.clone()))),
161        ("iteration".into(), FieldValue::Int32(Some(s.iteration))),
162        (
163            "context_json".into(),
164            FieldValue::Utf8(Some(s.context_json.clone())),
165        ),
166        ("created_at".into(), FieldValue::Int64(Some(s.created_at))),
167        ("updated_at".into(), FieldValue::Int64(Some(s.updated_at))),
168    ]
169}
170
171fn state_from_record(r: &Record) -> Result<AgentStateMetadata> {
172    Ok(AgentStateMetadata {
173        agent_id: record_get(r, "agent_id")
174            .and_then(|v| v.as_str())
175            .context("missing agent_id")?
176            .to_string(),
177        task_id: record_get(r, "task_id")
178            .and_then(|v| v.as_str())
179            .context("missing task_id")?
180            .to_string(),
181        conversation_id: record_get(r, "conversation_id")
182            .and_then(|v| v.as_str())
183            .context("missing conversation_id")?
184            .to_string(),
185        status: record_get(r, "status")
186            .and_then(|v| v.as_str())
187            .context("missing status")?
188            .to_string(),
189        iteration: record_get(r, "iteration")
190            .and_then(|v| v.as_i32())
191            .context("missing iteration")?,
192        context_json: record_get(r, "context_json")
193            .and_then(|v| v.as_str())
194            .context("missing context_json")?
195            .to_string(),
196        created_at: record_get(r, "created_at")
197            .and_then(|v| v.as_i64())
198            .context("missing created_at")?,
199        updated_at: record_get(r, "updated_at")
200            .and_then(|v| v.as_i64())
201            .context("missing updated_at")?,
202    })
203}
204
205// ── TaskMetadata ────────────────────────────────────────────────────────
206
207/// Metadata for storing tasks
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct TaskMetadata {
210    /// Unique task identifier.
211    pub task_id: String,
212    /// Conversation this task belongs to.
213    pub conversation_id: String,
214    /// Plan this task belongs to.
215    pub plan_id: Option<String>,
216    /// Task description.
217    pub description: String,
218    /// Current task status.
219    pub status: String,
220    /// Parent task identifier.
221    pub parent_id: Option<String>,
222    /// Child task IDs (JSON array).
223    pub children: String, // JSON array
224    /// Task dependency IDs (JSON array).
225    pub depends_on: String, // JSON array
226    /// Task priority level.
227    pub priority: String,
228    /// Agent assigned to this task.
229    pub assigned_to: Option<String>,
230    /// Number of iterations completed.
231    pub iterations: i32,
232    /// Task completion summary.
233    pub summary: Option<String>,
234    /// Creation timestamp (Unix seconds).
235    pub created_at: i64,
236    /// Last update timestamp (Unix seconds).
237    pub updated_at: i64,
238    /// Start timestamp (Unix seconds).
239    pub started_at: Option<i64>,
240    /// Completion timestamp (Unix seconds).
241    pub completed_at: Option<i64>,
242}
243
244impl TaskMetadata {
245    /// Convert from Task
246    pub fn from_task(task: &Task, conversation_id: &str) -> Self {
247        Self {
248            task_id: task.id.clone(),
249            conversation_id: conversation_id.to_string(),
250            plan_id: task.plan_id.clone(),
251            description: task.description.clone(),
252            status: format!("{:?}", task.status).to_lowercase(),
253            parent_id: task.parent_id.clone(),
254            children: serde_json::to_string(&task.children).unwrap_or_default(),
255            depends_on: serde_json::to_string(&task.depends_on).unwrap_or_default(),
256            priority: format!("{:?}", task.priority).to_lowercase(),
257            assigned_to: task.assigned_to.clone(),
258            iterations: task.iterations as i32,
259            summary: task.summary.clone(),
260            created_at: task.created_at,
261            updated_at: task.updated_at,
262            started_at: task.started_at,
263            completed_at: task.completed_at,
264        }
265    }
266
267    /// Convert to Task
268    pub fn to_task(&self) -> Task {
269        let status = match self.status.as_str() {
270            "pending" => TaskStatus::Pending,
271            "inprogress" => TaskStatus::InProgress,
272            "completed" => TaskStatus::Completed,
273            "failed" => TaskStatus::Failed,
274            "blocked" => TaskStatus::Blocked,
275            _ => TaskStatus::Pending,
276        };
277
278        let priority = match self.priority.as_str() {
279            "low" => TaskPriority::Low,
280            "normal" => TaskPriority::Normal,
281            "high" => TaskPriority::High,
282            "urgent" => TaskPriority::Urgent,
283            _ => TaskPriority::Normal,
284        };
285
286        let children: Vec<String> = serde_json::from_str(&self.children).unwrap_or_default();
287        let depends_on: Vec<String> = serde_json::from_str(&self.depends_on).unwrap_or_default();
288
289        Task {
290            id: self.task_id.clone(),
291            description: self.description.clone(),
292            status,
293            plan_id: self.plan_id.clone(),
294            parent_id: self.parent_id.clone(),
295            children,
296            depends_on,
297            priority,
298            assigned_to: self.assigned_to.clone(),
299            iterations: self.iterations as u32,
300            summary: self.summary.clone(),
301            created_at: self.created_at,
302            updated_at: self.updated_at,
303            started_at: self.started_at,
304            completed_at: self.completed_at,
305        }
306    }
307}
308
309// ── TaskStore ───────────────────────────────────────────────────────────
310
311/// Store for managing tasks
312pub struct TaskStore<
313    B: StorageBackend + 'static = brainwires_storage::databases::lance::LanceDatabase,
314> {
315    backend: Arc<B>,
316}
317
318// Manual Clone impl: Arc<B> is always Clone regardless of B
319impl<B: StorageBackend + 'static> Clone for TaskStore<B> {
320    fn clone(&self) -> Self {
321        Self {
322            backend: Arc::clone(&self.backend),
323        }
324    }
325}
326
327impl<B: StorageBackend + 'static> TaskStore<B> {
328    /// Create a new task store
329    pub fn new(backend: Arc<B>) -> Self {
330        Self { backend }
331    }
332
333    /// Ensure the underlying table exists.
334    pub async fn ensure_table(&self) -> Result<()> {
335        self.backend
336            .ensure_table(TASK_TABLE, &tasks_field_defs())
337            .await
338    }
339
340    /// Save a task
341    pub async fn save(&self, task: &Task, conversation_id: &str) -> Result<()> {
342        let metadata = TaskMetadata::from_task(task, conversation_id);
343
344        // First try to delete existing task with same ID
345        let _ = self.delete(&task.id).await;
346
347        self.backend
348            .insert(TASK_TABLE, vec![task_to_record(&metadata)])
349            .await
350            .context("Failed to save task")?;
351
352        Ok(())
353    }
354
355    /// Get a task by ID
356    pub async fn get(&self, task_id: &str) -> Result<Option<Task>> {
357        let filter = Filter::Eq(
358            "task_id".into(),
359            FieldValue::Utf8(Some(task_id.to_string())),
360        );
361        let records = self
362            .backend
363            .query(TASK_TABLE, Some(&filter), Some(1))
364            .await?;
365
366        match records.first() {
367            Some(r) => Ok(Some(task_from_record(r)?.to_task())),
368            None => Ok(None),
369        }
370    }
371
372    /// Get all tasks for a conversation
373    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<Task>> {
374        let filter = Filter::Eq(
375            "conversation_id".into(),
376            FieldValue::Utf8(Some(conversation_id.to_string())),
377        );
378        let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
379
380        records
381            .iter()
382            .map(|r| task_from_record(r).map(|m| m.to_task()))
383            .collect()
384    }
385
386    /// Get all tasks for a plan
387    pub async fn get_by_plan(&self, plan_id: &str) -> Result<Vec<Task>> {
388        let filter = Filter::Eq(
389            "plan_id".into(),
390            FieldValue::Utf8(Some(plan_id.to_string())),
391        );
392        let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
393
394        records
395            .iter()
396            .map(|r| task_from_record(r).map(|m| m.to_task()))
397            .collect()
398    }
399
400    /// Delete a task
401    pub async fn delete(&self, task_id: &str) -> Result<()> {
402        let filter = Filter::Eq(
403            "task_id".into(),
404            FieldValue::Utf8(Some(task_id.to_string())),
405        );
406        self.backend
407            .delete(TASK_TABLE, &filter)
408            .await
409            .context("Failed to delete task")?;
410        Ok(())
411    }
412
413    /// Delete all tasks for a conversation
414    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
415        let filter = Filter::Eq(
416            "conversation_id".into(),
417            FieldValue::Utf8(Some(conversation_id.to_string())),
418        );
419        self.backend
420            .delete(TASK_TABLE, &filter)
421            .await
422            .context("Failed to delete tasks for conversation")?;
423        Ok(())
424    }
425
426    /// Delete all tasks for a plan
427    pub async fn delete_by_plan(&self, plan_id: &str) -> Result<()> {
428        let filter = Filter::Eq(
429            "plan_id".into(),
430            FieldValue::Utf8(Some(plan_id.to_string())),
431        );
432        self.backend
433            .delete(TASK_TABLE, &filter)
434            .await
435            .context("Failed to delete tasks for plan")?;
436        Ok(())
437    }
438
439    /// Schema for the tasks table as backend-agnostic field definitions.
440    pub fn tasks_schema() -> Vec<FieldDef> {
441        tasks_field_defs()
442    }
443
444    /// Arrow schema for the tasks table, used by `LanceDatabase` table creation.
445    pub fn tasks_arrow_schema() -> Arc<arrow_schema::Schema> {
446        use arrow_schema::{DataType, Field, Schema};
447        Arc::new(Schema::new(vec![
448            Field::new("task_id", DataType::Utf8, false),
449            Field::new("conversation_id", DataType::Utf8, false),
450            Field::new("plan_id", DataType::Utf8, true),
451            Field::new("description", DataType::Utf8, false),
452            Field::new("status", DataType::Utf8, false),
453            Field::new("parent_id", DataType::Utf8, true),
454            Field::new("children", DataType::Utf8, false),
455            Field::new("depends_on", DataType::Utf8, false),
456            Field::new("priority", DataType::Utf8, false),
457            Field::new("assigned_to", DataType::Utf8, true),
458            Field::new("iterations", DataType::Int32, false),
459            Field::new("summary", DataType::Utf8, true),
460            Field::new("created_at", DataType::Int64, false),
461            Field::new("updated_at", DataType::Int64, false),
462            Field::new("started_at", DataType::Int64, true),
463            Field::new("completed_at", DataType::Int64, true),
464        ]))
465    }
466}
467
468// ── AgentStateMetadata ──────────────────────────────────────────────────
469
470/// Metadata for storing agent state
471#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
472pub struct AgentStateMetadata {
473    /// Unique agent identifier.
474    pub agent_id: String,
475    /// Task the agent is working on.
476    pub task_id: String,
477    /// Conversation context.
478    pub conversation_id: String,
479    /// Current agent status.
480    pub status: String,
481    /// Current iteration number.
482    pub iteration: i32,
483    /// Serialized agent context (JSON).
484    pub context_json: String, // Serialized AgentContext
485    /// Creation timestamp (Unix seconds).
486    pub created_at: i64,
487    /// Last update timestamp (Unix seconds).
488    pub updated_at: i64,
489}
490
491// ── AgentStateStore ─────────────────────────────────────────────────────
492
493/// Store for managing agent state persistence
494pub struct AgentStateStore<
495    B: StorageBackend + 'static = brainwires_storage::databases::lance::LanceDatabase,
496> {
497    backend: Arc<B>,
498}
499
500impl<B: StorageBackend + 'static> AgentStateStore<B> {
501    /// Create a new agent state store
502    pub fn new(backend: Arc<B>) -> Self {
503        Self { backend }
504    }
505
506    /// Ensure the underlying table exists.
507    pub async fn ensure_table(&self) -> Result<()> {
508        self.backend
509            .ensure_table(AGENT_STATE_TABLE, &agent_states_field_defs())
510            .await
511    }
512
513    /// Save agent state
514    pub async fn save(&self, state: &AgentStateMetadata) -> Result<()> {
515        // First try to delete existing state with same agent ID
516        let _ = self.delete(&state.agent_id).await;
517
518        self.backend
519            .insert(AGENT_STATE_TABLE, vec![state_to_record(state)])
520            .await
521            .context("Failed to save agent state")?;
522
523        Ok(())
524    }
525
526    /// Get agent state by ID
527    pub async fn get(&self, agent_id: &str) -> Result<Option<AgentStateMetadata>> {
528        let filter = Filter::Eq(
529            "agent_id".into(),
530            FieldValue::Utf8(Some(agent_id.to_string())),
531        );
532        let records = self
533            .backend
534            .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
535            .await?;
536
537        match records.first() {
538            Some(r) => Ok(Some(state_from_record(r)?)),
539            None => Ok(None),
540        }
541    }
542
543    /// Get all agent states for a conversation
544    pub async fn get_by_conversation(
545        &self,
546        conversation_id: &str,
547    ) -> Result<Vec<AgentStateMetadata>> {
548        let filter = Filter::Eq(
549            "conversation_id".into(),
550            FieldValue::Utf8(Some(conversation_id.to_string())),
551        );
552        let records = self
553            .backend
554            .query(AGENT_STATE_TABLE, Some(&filter), None)
555            .await?;
556
557        records.iter().map(state_from_record).collect()
558    }
559
560    /// Get agent state by task ID
561    pub async fn get_by_task(&self, task_id: &str) -> Result<Option<AgentStateMetadata>> {
562        let filter = Filter::Eq(
563            "task_id".into(),
564            FieldValue::Utf8(Some(task_id.to_string())),
565        );
566        let records = self
567            .backend
568            .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
569            .await?;
570
571        match records.first() {
572            Some(r) => Ok(Some(state_from_record(r)?)),
573            None => Ok(None),
574        }
575    }
576
577    /// Delete agent state
578    pub async fn delete(&self, agent_id: &str) -> Result<()> {
579        let filter = Filter::Eq(
580            "agent_id".into(),
581            FieldValue::Utf8(Some(agent_id.to_string())),
582        );
583        self.backend
584            .delete(AGENT_STATE_TABLE, &filter)
585            .await
586            .context("Failed to delete agent state")?;
587        Ok(())
588    }
589
590    /// Delete all agent states for a conversation
591    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
592        let filter = Filter::Eq(
593            "conversation_id".into(),
594            FieldValue::Utf8(Some(conversation_id.to_string())),
595        );
596        self.backend
597            .delete(AGENT_STATE_TABLE, &filter)
598            .await
599            .context("Failed to delete agent states for conversation")?;
600        Ok(())
601    }
602
603    /// Schema for the agent_states table as backend-agnostic field definitions.
604    pub fn agent_states_schema() -> Vec<FieldDef> {
605        agent_states_field_defs()
606    }
607
608    /// Arrow schema for the agent_states table, used by `LanceDatabase` table creation.
609    pub fn agent_states_arrow_schema() -> Arc<arrow_schema::Schema> {
610        use arrow_schema::{DataType, Field, Schema};
611        Arc::new(Schema::new(vec![
612            Field::new("agent_id", DataType::Utf8, false),
613            Field::new("task_id", DataType::Utf8, false),
614            Field::new("conversation_id", DataType::Utf8, false),
615            Field::new("status", DataType::Utf8, false),
616            Field::new("iteration", DataType::Int32, false),
617            Field::new("context_json", DataType::Utf8, false),
618            Field::new("created_at", DataType::Int64, false),
619            Field::new("updated_at", DataType::Int64, false),
620        ]))
621    }
622}