Skip to main content

brainwires_storage/stores/
task_store.rs

1//! Task Store - Persists tasks via a backend-agnostic storage layer.
2//!
3//! Also includes agent state persistence for background task agents.
4
5use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use crate::databases::{
9    FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
10};
11use brainwires_core::{Task, TaskPriority, TaskStatus};
12
13const TASK_TABLE: &str = "tasks";
14const AGENT_STATE_TABLE: &str = "agent_states";
15
16// ── Schema helpers ──────────────────────────────────────────────────────
17
18fn tasks_field_defs() -> Vec<FieldDef> {
19    vec![
20        FieldDef::required("task_id", FieldType::Utf8),
21        FieldDef::required("conversation_id", FieldType::Utf8),
22        FieldDef::optional("plan_id", FieldType::Utf8),
23        FieldDef::required("description", FieldType::Utf8),
24        FieldDef::required("status", FieldType::Utf8),
25        FieldDef::optional("parent_id", FieldType::Utf8),
26        FieldDef::required("children", FieldType::Utf8), // JSON array
27        FieldDef::required("depends_on", FieldType::Utf8), // JSON array
28        FieldDef::required("priority", FieldType::Utf8),
29        FieldDef::optional("assigned_to", FieldType::Utf8),
30        FieldDef::required("iterations", FieldType::Int32),
31        FieldDef::optional("summary", FieldType::Utf8),
32        FieldDef::required("created_at", FieldType::Int64),
33        FieldDef::required("updated_at", FieldType::Int64),
34        FieldDef::optional("started_at", FieldType::Int64),
35        FieldDef::optional("completed_at", FieldType::Int64),
36    ]
37}
38
39fn agent_states_field_defs() -> Vec<FieldDef> {
40    vec![
41        FieldDef::required("agent_id", FieldType::Utf8),
42        FieldDef::required("task_id", FieldType::Utf8),
43        FieldDef::required("conversation_id", FieldType::Utf8),
44        FieldDef::required("status", FieldType::Utf8),
45        FieldDef::required("iteration", FieldType::Int32),
46        FieldDef::required("context_json", FieldType::Utf8),
47        FieldDef::required("created_at", FieldType::Int64),
48        FieldDef::required("updated_at", FieldType::Int64),
49    ]
50}
51
52// ── Record conversion helpers ───────────────────────────────────────────
53
54fn task_to_record(m: &TaskMetadata) -> Record {
55    vec![
56        ("task_id".into(), FieldValue::Utf8(Some(m.task_id.clone()))),
57        (
58            "conversation_id".into(),
59            FieldValue::Utf8(Some(m.conversation_id.clone())),
60        ),
61        ("plan_id".into(), FieldValue::Utf8(m.plan_id.clone())),
62        (
63            "description".into(),
64            FieldValue::Utf8(Some(m.description.clone())),
65        ),
66        ("status".into(), FieldValue::Utf8(Some(m.status.clone()))),
67        ("parent_id".into(), FieldValue::Utf8(m.parent_id.clone())),
68        (
69            "children".into(),
70            FieldValue::Utf8(Some(m.children.clone())),
71        ),
72        (
73            "depends_on".into(),
74            FieldValue::Utf8(Some(m.depends_on.clone())),
75        ),
76        (
77            "priority".into(),
78            FieldValue::Utf8(Some(m.priority.clone())),
79        ),
80        (
81            "assigned_to".into(),
82            FieldValue::Utf8(m.assigned_to.clone()),
83        ),
84        ("iterations".into(), FieldValue::Int32(Some(m.iterations))),
85        ("summary".into(), FieldValue::Utf8(m.summary.clone())),
86        ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
87        ("updated_at".into(), FieldValue::Int64(Some(m.updated_at))),
88        ("started_at".into(), FieldValue::Int64(m.started_at)),
89        ("completed_at".into(), FieldValue::Int64(m.completed_at)),
90    ]
91}
92
93fn task_from_record(r: &Record) -> Result<TaskMetadata> {
94    Ok(TaskMetadata {
95        task_id: record_get(r, "task_id")
96            .and_then(|v| v.as_str())
97            .context("missing task_id")?
98            .to_string(),
99        conversation_id: record_get(r, "conversation_id")
100            .and_then(|v| v.as_str())
101            .context("missing conversation_id")?
102            .to_string(),
103        plan_id: record_get(r, "plan_id")
104            .and_then(|v| v.as_str())
105            .map(String::from),
106        description: record_get(r, "description")
107            .and_then(|v| v.as_str())
108            .context("missing description")?
109            .to_string(),
110        status: record_get(r, "status")
111            .and_then(|v| v.as_str())
112            .context("missing status")?
113            .to_string(),
114        parent_id: record_get(r, "parent_id")
115            .and_then(|v| v.as_str())
116            .map(String::from),
117        children: record_get(r, "children")
118            .and_then(|v| v.as_str())
119            .context("missing children")?
120            .to_string(),
121        depends_on: record_get(r, "depends_on")
122            .and_then(|v| v.as_str())
123            .context("missing depends_on")?
124            .to_string(),
125        priority: record_get(r, "priority")
126            .and_then(|v| v.as_str())
127            .context("missing priority")?
128            .to_string(),
129        assigned_to: record_get(r, "assigned_to")
130            .and_then(|v| v.as_str())
131            .map(String::from),
132        iterations: record_get(r, "iterations")
133            .and_then(|v| v.as_i32())
134            .context("missing iterations")?,
135        summary: record_get(r, "summary")
136            .and_then(|v| v.as_str())
137            .map(String::from),
138        created_at: record_get(r, "created_at")
139            .and_then(|v| v.as_i64())
140            .context("missing created_at")?,
141        updated_at: record_get(r, "updated_at")
142            .and_then(|v| v.as_i64())
143            .context("missing updated_at")?,
144        started_at: record_get(r, "started_at").and_then(|v| v.as_i64()),
145        completed_at: record_get(r, "completed_at").and_then(|v| v.as_i64()),
146    })
147}
148
149fn state_to_record(s: &AgentStateMetadata) -> Record {
150    vec![
151        (
152            "agent_id".into(),
153            FieldValue::Utf8(Some(s.agent_id.clone())),
154        ),
155        ("task_id".into(), FieldValue::Utf8(Some(s.task_id.clone()))),
156        (
157            "conversation_id".into(),
158            FieldValue::Utf8(Some(s.conversation_id.clone())),
159        ),
160        ("status".into(), FieldValue::Utf8(Some(s.status.clone()))),
161        ("iteration".into(), FieldValue::Int32(Some(s.iteration))),
162        (
163            "context_json".into(),
164            FieldValue::Utf8(Some(s.context_json.clone())),
165        ),
166        ("created_at".into(), FieldValue::Int64(Some(s.created_at))),
167        ("updated_at".into(), FieldValue::Int64(Some(s.updated_at))),
168    ]
169}
170
171fn state_from_record(r: &Record) -> Result<AgentStateMetadata> {
172    Ok(AgentStateMetadata {
173        agent_id: record_get(r, "agent_id")
174            .and_then(|v| v.as_str())
175            .context("missing agent_id")?
176            .to_string(),
177        task_id: record_get(r, "task_id")
178            .and_then(|v| v.as_str())
179            .context("missing task_id")?
180            .to_string(),
181        conversation_id: record_get(r, "conversation_id")
182            .and_then(|v| v.as_str())
183            .context("missing conversation_id")?
184            .to_string(),
185        status: record_get(r, "status")
186            .and_then(|v| v.as_str())
187            .context("missing status")?
188            .to_string(),
189        iteration: record_get(r, "iteration")
190            .and_then(|v| v.as_i32())
191            .context("missing iteration")?,
192        context_json: record_get(r, "context_json")
193            .and_then(|v| v.as_str())
194            .context("missing context_json")?
195            .to_string(),
196        created_at: record_get(r, "created_at")
197            .and_then(|v| v.as_i64())
198            .context("missing created_at")?,
199        updated_at: record_get(r, "updated_at")
200            .and_then(|v| v.as_i64())
201            .context("missing updated_at")?,
202    })
203}
204
205// ── TaskMetadata ────────────────────────────────────────────────────────
206
207/// Metadata for storing tasks
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct TaskMetadata {
210    /// Unique task identifier.
211    pub task_id: String,
212    /// Conversation this task belongs to.
213    pub conversation_id: String,
214    /// Plan this task belongs to.
215    pub plan_id: Option<String>,
216    /// Task description.
217    pub description: String,
218    /// Current task status.
219    pub status: String,
220    /// Parent task identifier.
221    pub parent_id: Option<String>,
222    /// Child task IDs (JSON array).
223    pub children: String, // JSON array
224    /// Task dependency IDs (JSON array).
225    pub depends_on: String, // JSON array
226    /// Task priority level.
227    pub priority: String,
228    /// Agent assigned to this task.
229    pub assigned_to: Option<String>,
230    /// Number of iterations completed.
231    pub iterations: i32,
232    /// Task completion summary.
233    pub summary: Option<String>,
234    /// Creation timestamp (Unix seconds).
235    pub created_at: i64,
236    /// Last update timestamp (Unix seconds).
237    pub updated_at: i64,
238    /// Start timestamp (Unix seconds).
239    pub started_at: Option<i64>,
240    /// Completion timestamp (Unix seconds).
241    pub completed_at: Option<i64>,
242}
243
244impl TaskMetadata {
245    /// Convert from Task
246    pub fn from_task(task: &Task, conversation_id: &str) -> Self {
247        Self {
248            task_id: task.id.clone(),
249            conversation_id: conversation_id.to_string(),
250            plan_id: task.plan_id.clone(),
251            description: task.description.clone(),
252            status: format!("{:?}", task.status).to_lowercase(),
253            parent_id: task.parent_id.clone(),
254            children: serde_json::to_string(&task.children).unwrap_or_default(),
255            depends_on: serde_json::to_string(&task.depends_on).unwrap_or_default(),
256            priority: format!("{:?}", task.priority).to_lowercase(),
257            assigned_to: task.assigned_to.clone(),
258            iterations: task.iterations as i32,
259            summary: task.summary.clone(),
260            created_at: task.created_at,
261            updated_at: task.updated_at,
262            started_at: task.started_at,
263            completed_at: task.completed_at,
264        }
265    }
266
267    /// Convert to Task
268    pub fn to_task(&self) -> Task {
269        let status = match self.status.as_str() {
270            "pending" => TaskStatus::Pending,
271            "inprogress" => TaskStatus::InProgress,
272            "completed" => TaskStatus::Completed,
273            "failed" => TaskStatus::Failed,
274            "blocked" => TaskStatus::Blocked,
275            _ => TaskStatus::Pending,
276        };
277
278        let priority = match self.priority.as_str() {
279            "low" => TaskPriority::Low,
280            "normal" => TaskPriority::Normal,
281            "high" => TaskPriority::High,
282            "urgent" => TaskPriority::Urgent,
283            _ => TaskPriority::Normal,
284        };
285
286        let children: Vec<String> = serde_json::from_str(&self.children).unwrap_or_default();
287        let depends_on: Vec<String> = serde_json::from_str(&self.depends_on).unwrap_or_default();
288
289        Task {
290            id: self.task_id.clone(),
291            description: self.description.clone(),
292            status,
293            plan_id: self.plan_id.clone(),
294            parent_id: self.parent_id.clone(),
295            children,
296            depends_on,
297            priority,
298            assigned_to: self.assigned_to.clone(),
299            iterations: self.iterations as u32,
300            summary: self.summary.clone(),
301            created_at: self.created_at,
302            updated_at: self.updated_at,
303            started_at: self.started_at,
304            completed_at: self.completed_at,
305        }
306    }
307}
308
309// ── TaskStore ───────────────────────────────────────────────────────────
310
311/// Store for managing tasks
312#[derive(Clone)]
313pub struct TaskStore<B: StorageBackend + 'static = crate::databases::lance::LanceDatabase> {
314    backend: Arc<B>,
315}
316
317impl<B: StorageBackend + 'static> TaskStore<B> {
318    /// Create a new task store
319    pub fn new(backend: Arc<B>) -> Self {
320        Self { backend }
321    }
322
323    /// Ensure the underlying table exists.
324    pub async fn ensure_table(&self) -> Result<()> {
325        self.backend
326            .ensure_table(TASK_TABLE, &tasks_field_defs())
327            .await
328    }
329
330    /// Save a task
331    pub async fn save(&self, task: &Task, conversation_id: &str) -> Result<()> {
332        let metadata = TaskMetadata::from_task(task, conversation_id);
333
334        // First try to delete existing task with same ID
335        let _ = self.delete(&task.id).await;
336
337        self.backend
338            .insert(TASK_TABLE, vec![task_to_record(&metadata)])
339            .await
340            .context("Failed to save task")?;
341
342        Ok(())
343    }
344
345    /// Get a task by ID
346    pub async fn get(&self, task_id: &str) -> Result<Option<Task>> {
347        let filter = Filter::Eq(
348            "task_id".into(),
349            FieldValue::Utf8(Some(task_id.to_string())),
350        );
351        let records = self
352            .backend
353            .query(TASK_TABLE, Some(&filter), Some(1))
354            .await?;
355
356        match records.first() {
357            Some(r) => Ok(Some(task_from_record(r)?.to_task())),
358            None => Ok(None),
359        }
360    }
361
362    /// Get all tasks for a conversation
363    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<Task>> {
364        let filter = Filter::Eq(
365            "conversation_id".into(),
366            FieldValue::Utf8(Some(conversation_id.to_string())),
367        );
368        let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
369
370        records
371            .iter()
372            .map(|r| task_from_record(r).map(|m| m.to_task()))
373            .collect()
374    }
375
376    /// Get all tasks for a plan
377    pub async fn get_by_plan(&self, plan_id: &str) -> Result<Vec<Task>> {
378        let filter = Filter::Eq(
379            "plan_id".into(),
380            FieldValue::Utf8(Some(plan_id.to_string())),
381        );
382        let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
383
384        records
385            .iter()
386            .map(|r| task_from_record(r).map(|m| m.to_task()))
387            .collect()
388    }
389
390    /// Delete a task
391    pub async fn delete(&self, task_id: &str) -> Result<()> {
392        let filter = Filter::Eq(
393            "task_id".into(),
394            FieldValue::Utf8(Some(task_id.to_string())),
395        );
396        self.backend
397            .delete(TASK_TABLE, &filter)
398            .await
399            .context("Failed to delete task")?;
400        Ok(())
401    }
402
403    /// Delete all tasks for a conversation
404    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
405        let filter = Filter::Eq(
406            "conversation_id".into(),
407            FieldValue::Utf8(Some(conversation_id.to_string())),
408        );
409        self.backend
410            .delete(TASK_TABLE, &filter)
411            .await
412            .context("Failed to delete tasks for conversation")?;
413        Ok(())
414    }
415
416    /// Delete all tasks for a plan
417    pub async fn delete_by_plan(&self, plan_id: &str) -> Result<()> {
418        let filter = Filter::Eq(
419            "plan_id".into(),
420            FieldValue::Utf8(Some(plan_id.to_string())),
421        );
422        self.backend
423            .delete(TASK_TABLE, &filter)
424            .await
425            .context("Failed to delete tasks for plan")?;
426        Ok(())
427    }
428
429    /// Schema for the tasks table as backend-agnostic field definitions.
430    pub fn tasks_schema() -> Vec<FieldDef> {
431        tasks_field_defs()
432    }
433
434    /// Schema for the tasks table as an Arrow schema (for backward compat with LanceDatabase).
435    #[cfg(feature = "native")]
436    pub fn tasks_arrow_schema() -> Arc<arrow_schema::Schema> {
437        use arrow_schema::{DataType, Field, Schema};
438        Arc::new(Schema::new(vec![
439            Field::new("task_id", DataType::Utf8, false),
440            Field::new("conversation_id", DataType::Utf8, false),
441            Field::new("plan_id", DataType::Utf8, true),
442            Field::new("description", DataType::Utf8, false),
443            Field::new("status", DataType::Utf8, false),
444            Field::new("parent_id", DataType::Utf8, true),
445            Field::new("children", DataType::Utf8, false),
446            Field::new("depends_on", DataType::Utf8, false),
447            Field::new("priority", DataType::Utf8, false),
448            Field::new("assigned_to", DataType::Utf8, true),
449            Field::new("iterations", DataType::Int32, false),
450            Field::new("summary", DataType::Utf8, true),
451            Field::new("created_at", DataType::Int64, false),
452            Field::new("updated_at", DataType::Int64, false),
453            Field::new("started_at", DataType::Int64, true),
454            Field::new("completed_at", DataType::Int64, true),
455        ]))
456    }
457}
458
459// ── AgentStateMetadata ──────────────────────────────────────────────────
460
461/// Metadata for storing agent state
462#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
463pub struct AgentStateMetadata {
464    /// Unique agent identifier.
465    pub agent_id: String,
466    /// Task the agent is working on.
467    pub task_id: String,
468    /// Conversation context.
469    pub conversation_id: String,
470    /// Current agent status.
471    pub status: String,
472    /// Current iteration number.
473    pub iteration: i32,
474    /// Serialized agent context (JSON).
475    pub context_json: String, // Serialized AgentContext
476    /// Creation timestamp (Unix seconds).
477    pub created_at: i64,
478    /// Last update timestamp (Unix seconds).
479    pub updated_at: i64,
480}
481
482// ── AgentStateStore ─────────────────────────────────────────────────────
483
484/// Store for managing agent state persistence
485pub struct AgentStateStore<B: StorageBackend + 'static = crate::databases::lance::LanceDatabase> {
486    backend: Arc<B>,
487}
488
489impl<B: StorageBackend + 'static> AgentStateStore<B> {
490    /// Create a new agent state store
491    pub fn new(backend: Arc<B>) -> Self {
492        Self { backend }
493    }
494
495    /// Ensure the underlying table exists.
496    pub async fn ensure_table(&self) -> Result<()> {
497        self.backend
498            .ensure_table(AGENT_STATE_TABLE, &agent_states_field_defs())
499            .await
500    }
501
502    /// Save agent state
503    pub async fn save(&self, state: &AgentStateMetadata) -> Result<()> {
504        // First try to delete existing state with same agent ID
505        let _ = self.delete(&state.agent_id).await;
506
507        self.backend
508            .insert(AGENT_STATE_TABLE, vec![state_to_record(state)])
509            .await
510            .context("Failed to save agent state")?;
511
512        Ok(())
513    }
514
515    /// Get agent state by ID
516    pub async fn get(&self, agent_id: &str) -> Result<Option<AgentStateMetadata>> {
517        let filter = Filter::Eq(
518            "agent_id".into(),
519            FieldValue::Utf8(Some(agent_id.to_string())),
520        );
521        let records = self
522            .backend
523            .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
524            .await?;
525
526        match records.first() {
527            Some(r) => Ok(Some(state_from_record(r)?)),
528            None => Ok(None),
529        }
530    }
531
532    /// Get all agent states for a conversation
533    pub async fn get_by_conversation(
534        &self,
535        conversation_id: &str,
536    ) -> Result<Vec<AgentStateMetadata>> {
537        let filter = Filter::Eq(
538            "conversation_id".into(),
539            FieldValue::Utf8(Some(conversation_id.to_string())),
540        );
541        let records = self
542            .backend
543            .query(AGENT_STATE_TABLE, Some(&filter), None)
544            .await?;
545
546        records.iter().map(state_from_record).collect()
547    }
548
549    /// Get agent state by task ID
550    pub async fn get_by_task(&self, task_id: &str) -> Result<Option<AgentStateMetadata>> {
551        let filter = Filter::Eq(
552            "task_id".into(),
553            FieldValue::Utf8(Some(task_id.to_string())),
554        );
555        let records = self
556            .backend
557            .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
558            .await?;
559
560        match records.first() {
561            Some(r) => Ok(Some(state_from_record(r)?)),
562            None => Ok(None),
563        }
564    }
565
566    /// Delete agent state
567    pub async fn delete(&self, agent_id: &str) -> Result<()> {
568        let filter = Filter::Eq(
569            "agent_id".into(),
570            FieldValue::Utf8(Some(agent_id.to_string())),
571        );
572        self.backend
573            .delete(AGENT_STATE_TABLE, &filter)
574            .await
575            .context("Failed to delete agent state")?;
576        Ok(())
577    }
578
579    /// Delete all agent states for a conversation
580    pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
581        let filter = Filter::Eq(
582            "conversation_id".into(),
583            FieldValue::Utf8(Some(conversation_id.to_string())),
584        );
585        self.backend
586            .delete(AGENT_STATE_TABLE, &filter)
587            .await
588            .context("Failed to delete agent states for conversation")?;
589        Ok(())
590    }
591
592    /// Schema for the agent_states table as backend-agnostic field definitions.
593    pub fn agent_states_schema() -> Vec<FieldDef> {
594        agent_states_field_defs()
595    }
596
597    /// Schema for the agent_states table as an Arrow schema (for backward compat).
598    #[cfg(feature = "native")]
599    pub fn agent_states_arrow_schema() -> Arc<arrow_schema::Schema> {
600        use arrow_schema::{DataType, Field, Schema};
601        Arc::new(Schema::new(vec![
602            Field::new("agent_id", DataType::Utf8, false),
603            Field::new("task_id", DataType::Utf8, false),
604            Field::new("conversation_id", DataType::Utf8, false),
605            Field::new("status", DataType::Utf8, false),
606            Field::new("iteration", DataType::Int32, false),
607            Field::new("context_json", DataType::Utf8, false),
608            Field::new("created_at", DataType::Int64, false),
609            Field::new("updated_at", DataType::Int64, false),
610        ]))
611    }
612}