1use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_core::{Task, TaskPriority, TaskStatus};
9use brainwires_storage::databases::{
10 FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
11};
12
13const TASK_TABLE: &str = "tasks";
14const AGENT_STATE_TABLE: &str = "agent_states";
15
16fn tasks_field_defs() -> Vec<FieldDef> {
19 vec![
20 FieldDef::required("task_id", FieldType::Utf8),
21 FieldDef::required("conversation_id", FieldType::Utf8),
22 FieldDef::optional("plan_id", FieldType::Utf8),
23 FieldDef::required("description", FieldType::Utf8),
24 FieldDef::required("status", FieldType::Utf8),
25 FieldDef::optional("parent_id", FieldType::Utf8),
26 FieldDef::required("children", FieldType::Utf8), FieldDef::required("depends_on", FieldType::Utf8), FieldDef::required("priority", FieldType::Utf8),
29 FieldDef::optional("assigned_to", FieldType::Utf8),
30 FieldDef::required("iterations", FieldType::Int32),
31 FieldDef::optional("summary", FieldType::Utf8),
32 FieldDef::required("created_at", FieldType::Int64),
33 FieldDef::required("updated_at", FieldType::Int64),
34 FieldDef::optional("started_at", FieldType::Int64),
35 FieldDef::optional("completed_at", FieldType::Int64),
36 ]
37}
38
39fn agent_states_field_defs() -> Vec<FieldDef> {
40 vec![
41 FieldDef::required("agent_id", FieldType::Utf8),
42 FieldDef::required("task_id", FieldType::Utf8),
43 FieldDef::required("conversation_id", FieldType::Utf8),
44 FieldDef::required("status", FieldType::Utf8),
45 FieldDef::required("iteration", FieldType::Int32),
46 FieldDef::required("context_json", FieldType::Utf8),
47 FieldDef::required("created_at", FieldType::Int64),
48 FieldDef::required("updated_at", FieldType::Int64),
49 ]
50}
51
52fn task_to_record(m: &TaskMetadata) -> Record {
55 vec![
56 ("task_id".into(), FieldValue::Utf8(Some(m.task_id.clone()))),
57 (
58 "conversation_id".into(),
59 FieldValue::Utf8(Some(m.conversation_id.clone())),
60 ),
61 ("plan_id".into(), FieldValue::Utf8(m.plan_id.clone())),
62 (
63 "description".into(),
64 FieldValue::Utf8(Some(m.description.clone())),
65 ),
66 ("status".into(), FieldValue::Utf8(Some(m.status.clone()))),
67 ("parent_id".into(), FieldValue::Utf8(m.parent_id.clone())),
68 (
69 "children".into(),
70 FieldValue::Utf8(Some(m.children.clone())),
71 ),
72 (
73 "depends_on".into(),
74 FieldValue::Utf8(Some(m.depends_on.clone())),
75 ),
76 (
77 "priority".into(),
78 FieldValue::Utf8(Some(m.priority.clone())),
79 ),
80 (
81 "assigned_to".into(),
82 FieldValue::Utf8(m.assigned_to.clone()),
83 ),
84 ("iterations".into(), FieldValue::Int32(Some(m.iterations))),
85 ("summary".into(), FieldValue::Utf8(m.summary.clone())),
86 ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
87 ("updated_at".into(), FieldValue::Int64(Some(m.updated_at))),
88 ("started_at".into(), FieldValue::Int64(m.started_at)),
89 ("completed_at".into(), FieldValue::Int64(m.completed_at)),
90 ]
91}
92
93fn task_from_record(r: &Record) -> Result<TaskMetadata> {
94 Ok(TaskMetadata {
95 task_id: record_get(r, "task_id")
96 .and_then(|v| v.as_str())
97 .context("missing task_id")?
98 .to_string(),
99 conversation_id: record_get(r, "conversation_id")
100 .and_then(|v| v.as_str())
101 .context("missing conversation_id")?
102 .to_string(),
103 plan_id: record_get(r, "plan_id")
104 .and_then(|v| v.as_str())
105 .map(String::from),
106 description: record_get(r, "description")
107 .and_then(|v| v.as_str())
108 .context("missing description")?
109 .to_string(),
110 status: record_get(r, "status")
111 .and_then(|v| v.as_str())
112 .context("missing status")?
113 .to_string(),
114 parent_id: record_get(r, "parent_id")
115 .and_then(|v| v.as_str())
116 .map(String::from),
117 children: record_get(r, "children")
118 .and_then(|v| v.as_str())
119 .context("missing children")?
120 .to_string(),
121 depends_on: record_get(r, "depends_on")
122 .and_then(|v| v.as_str())
123 .context("missing depends_on")?
124 .to_string(),
125 priority: record_get(r, "priority")
126 .and_then(|v| v.as_str())
127 .context("missing priority")?
128 .to_string(),
129 assigned_to: record_get(r, "assigned_to")
130 .and_then(|v| v.as_str())
131 .map(String::from),
132 iterations: record_get(r, "iterations")
133 .and_then(|v| v.as_i32())
134 .context("missing iterations")?,
135 summary: record_get(r, "summary")
136 .and_then(|v| v.as_str())
137 .map(String::from),
138 created_at: record_get(r, "created_at")
139 .and_then(|v| v.as_i64())
140 .context("missing created_at")?,
141 updated_at: record_get(r, "updated_at")
142 .and_then(|v| v.as_i64())
143 .context("missing updated_at")?,
144 started_at: record_get(r, "started_at").and_then(|v| v.as_i64()),
145 completed_at: record_get(r, "completed_at").and_then(|v| v.as_i64()),
146 })
147}
148
149fn state_to_record(s: &AgentStateMetadata) -> Record {
150 vec![
151 (
152 "agent_id".into(),
153 FieldValue::Utf8(Some(s.agent_id.clone())),
154 ),
155 ("task_id".into(), FieldValue::Utf8(Some(s.task_id.clone()))),
156 (
157 "conversation_id".into(),
158 FieldValue::Utf8(Some(s.conversation_id.clone())),
159 ),
160 ("status".into(), FieldValue::Utf8(Some(s.status.clone()))),
161 ("iteration".into(), FieldValue::Int32(Some(s.iteration))),
162 (
163 "context_json".into(),
164 FieldValue::Utf8(Some(s.context_json.clone())),
165 ),
166 ("created_at".into(), FieldValue::Int64(Some(s.created_at))),
167 ("updated_at".into(), FieldValue::Int64(Some(s.updated_at))),
168 ]
169}
170
171fn state_from_record(r: &Record) -> Result<AgentStateMetadata> {
172 Ok(AgentStateMetadata {
173 agent_id: record_get(r, "agent_id")
174 .and_then(|v| v.as_str())
175 .context("missing agent_id")?
176 .to_string(),
177 task_id: record_get(r, "task_id")
178 .and_then(|v| v.as_str())
179 .context("missing task_id")?
180 .to_string(),
181 conversation_id: record_get(r, "conversation_id")
182 .and_then(|v| v.as_str())
183 .context("missing conversation_id")?
184 .to_string(),
185 status: record_get(r, "status")
186 .and_then(|v| v.as_str())
187 .context("missing status")?
188 .to_string(),
189 iteration: record_get(r, "iteration")
190 .and_then(|v| v.as_i32())
191 .context("missing iteration")?,
192 context_json: record_get(r, "context_json")
193 .and_then(|v| v.as_str())
194 .context("missing context_json")?
195 .to_string(),
196 created_at: record_get(r, "created_at")
197 .and_then(|v| v.as_i64())
198 .context("missing created_at")?,
199 updated_at: record_get(r, "updated_at")
200 .and_then(|v| v.as_i64())
201 .context("missing updated_at")?,
202 })
203}
204
205#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct TaskMetadata {
210 pub task_id: String,
212 pub conversation_id: String,
214 pub plan_id: Option<String>,
216 pub description: String,
218 pub status: String,
220 pub parent_id: Option<String>,
222 pub children: String, pub depends_on: String, pub priority: String,
228 pub assigned_to: Option<String>,
230 pub iterations: i32,
232 pub summary: Option<String>,
234 pub created_at: i64,
236 pub updated_at: i64,
238 pub started_at: Option<i64>,
240 pub completed_at: Option<i64>,
242}
243
244impl TaskMetadata {
245 pub fn from_task(task: &Task, conversation_id: &str) -> Self {
247 Self {
248 task_id: task.id.clone(),
249 conversation_id: conversation_id.to_string(),
250 plan_id: task.plan_id.clone(),
251 description: task.description.clone(),
252 status: format!("{:?}", task.status).to_lowercase(),
253 parent_id: task.parent_id.clone(),
254 children: serde_json::to_string(&task.children).unwrap_or_default(),
255 depends_on: serde_json::to_string(&task.depends_on).unwrap_or_default(),
256 priority: format!("{:?}", task.priority).to_lowercase(),
257 assigned_to: task.assigned_to.clone(),
258 iterations: task.iterations as i32,
259 summary: task.summary.clone(),
260 created_at: task.created_at,
261 updated_at: task.updated_at,
262 started_at: task.started_at,
263 completed_at: task.completed_at,
264 }
265 }
266
267 pub fn to_task(&self) -> Task {
269 let status = match self.status.as_str() {
270 "pending" => TaskStatus::Pending,
271 "inprogress" => TaskStatus::InProgress,
272 "completed" => TaskStatus::Completed,
273 "failed" => TaskStatus::Failed,
274 "blocked" => TaskStatus::Blocked,
275 _ => TaskStatus::Pending,
276 };
277
278 let priority = match self.priority.as_str() {
279 "low" => TaskPriority::Low,
280 "normal" => TaskPriority::Normal,
281 "high" => TaskPriority::High,
282 "urgent" => TaskPriority::Urgent,
283 _ => TaskPriority::Normal,
284 };
285
286 let children: Vec<String> = serde_json::from_str(&self.children).unwrap_or_default();
287 let depends_on: Vec<String> = serde_json::from_str(&self.depends_on).unwrap_or_default();
288
289 Task {
290 id: self.task_id.clone(),
291 description: self.description.clone(),
292 status,
293 plan_id: self.plan_id.clone(),
294 parent_id: self.parent_id.clone(),
295 children,
296 depends_on,
297 priority,
298 assigned_to: self.assigned_to.clone(),
299 iterations: self.iterations as u32,
300 summary: self.summary.clone(),
301 created_at: self.created_at,
302 updated_at: self.updated_at,
303 started_at: self.started_at,
304 completed_at: self.completed_at,
305 }
306 }
307}
308
309pub struct TaskStore<
313 B: StorageBackend + 'static = brainwires_storage::databases::lance::LanceDatabase,
314> {
315 backend: Arc<B>,
316}
317
318impl<B: StorageBackend + 'static> Clone for TaskStore<B> {
320 fn clone(&self) -> Self {
321 Self {
322 backend: Arc::clone(&self.backend),
323 }
324 }
325}
326
327impl<B: StorageBackend + 'static> TaskStore<B> {
328 pub fn new(backend: Arc<B>) -> Self {
330 Self { backend }
331 }
332
333 pub async fn ensure_table(&self) -> Result<()> {
335 self.backend
336 .ensure_table(TASK_TABLE, &tasks_field_defs())
337 .await
338 }
339
340 pub async fn save(&self, task: &Task, conversation_id: &str) -> Result<()> {
342 let metadata = TaskMetadata::from_task(task, conversation_id);
343
344 let _ = self.delete(&task.id).await;
346
347 self.backend
348 .insert(TASK_TABLE, vec![task_to_record(&metadata)])
349 .await
350 .context("Failed to save task")?;
351
352 Ok(())
353 }
354
355 pub async fn get(&self, task_id: &str) -> Result<Option<Task>> {
357 let filter = Filter::Eq(
358 "task_id".into(),
359 FieldValue::Utf8(Some(task_id.to_string())),
360 );
361 let records = self
362 .backend
363 .query(TASK_TABLE, Some(&filter), Some(1))
364 .await?;
365
366 match records.first() {
367 Some(r) => Ok(Some(task_from_record(r)?.to_task())),
368 None => Ok(None),
369 }
370 }
371
372 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<Task>> {
374 let filter = Filter::Eq(
375 "conversation_id".into(),
376 FieldValue::Utf8(Some(conversation_id.to_string())),
377 );
378 let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
379
380 records
381 .iter()
382 .map(|r| task_from_record(r).map(|m| m.to_task()))
383 .collect()
384 }
385
386 pub async fn get_by_plan(&self, plan_id: &str) -> Result<Vec<Task>> {
388 let filter = Filter::Eq(
389 "plan_id".into(),
390 FieldValue::Utf8(Some(plan_id.to_string())),
391 );
392 let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
393
394 records
395 .iter()
396 .map(|r| task_from_record(r).map(|m| m.to_task()))
397 .collect()
398 }
399
400 pub async fn delete(&self, task_id: &str) -> Result<()> {
402 let filter = Filter::Eq(
403 "task_id".into(),
404 FieldValue::Utf8(Some(task_id.to_string())),
405 );
406 self.backend
407 .delete(TASK_TABLE, &filter)
408 .await
409 .context("Failed to delete task")?;
410 Ok(())
411 }
412
413 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
415 let filter = Filter::Eq(
416 "conversation_id".into(),
417 FieldValue::Utf8(Some(conversation_id.to_string())),
418 );
419 self.backend
420 .delete(TASK_TABLE, &filter)
421 .await
422 .context("Failed to delete tasks for conversation")?;
423 Ok(())
424 }
425
426 pub async fn delete_by_plan(&self, plan_id: &str) -> Result<()> {
428 let filter = Filter::Eq(
429 "plan_id".into(),
430 FieldValue::Utf8(Some(plan_id.to_string())),
431 );
432 self.backend
433 .delete(TASK_TABLE, &filter)
434 .await
435 .context("Failed to delete tasks for plan")?;
436 Ok(())
437 }
438
439 pub fn tasks_schema() -> Vec<FieldDef> {
441 tasks_field_defs()
442 }
443
444 pub fn tasks_arrow_schema() -> Arc<arrow_schema::Schema> {
446 use arrow_schema::{DataType, Field, Schema};
447 Arc::new(Schema::new(vec![
448 Field::new("task_id", DataType::Utf8, false),
449 Field::new("conversation_id", DataType::Utf8, false),
450 Field::new("plan_id", DataType::Utf8, true),
451 Field::new("description", DataType::Utf8, false),
452 Field::new("status", DataType::Utf8, false),
453 Field::new("parent_id", DataType::Utf8, true),
454 Field::new("children", DataType::Utf8, false),
455 Field::new("depends_on", DataType::Utf8, false),
456 Field::new("priority", DataType::Utf8, false),
457 Field::new("assigned_to", DataType::Utf8, true),
458 Field::new("iterations", DataType::Int32, false),
459 Field::new("summary", DataType::Utf8, true),
460 Field::new("created_at", DataType::Int64, false),
461 Field::new("updated_at", DataType::Int64, false),
462 Field::new("started_at", DataType::Int64, true),
463 Field::new("completed_at", DataType::Int64, true),
464 ]))
465 }
466}
467
468#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
472pub struct AgentStateMetadata {
473 pub agent_id: String,
475 pub task_id: String,
477 pub conversation_id: String,
479 pub status: String,
481 pub iteration: i32,
483 pub context_json: String, pub created_at: i64,
487 pub updated_at: i64,
489}
490
491pub struct AgentStateStore<
495 B: StorageBackend + 'static = brainwires_storage::databases::lance::LanceDatabase,
496> {
497 backend: Arc<B>,
498}
499
500impl<B: StorageBackend + 'static> AgentStateStore<B> {
501 pub fn new(backend: Arc<B>) -> Self {
503 Self { backend }
504 }
505
506 pub async fn ensure_table(&self) -> Result<()> {
508 self.backend
509 .ensure_table(AGENT_STATE_TABLE, &agent_states_field_defs())
510 .await
511 }
512
513 pub async fn save(&self, state: &AgentStateMetadata) -> Result<()> {
515 let _ = self.delete(&state.agent_id).await;
517
518 self.backend
519 .insert(AGENT_STATE_TABLE, vec![state_to_record(state)])
520 .await
521 .context("Failed to save agent state")?;
522
523 Ok(())
524 }
525
526 pub async fn get(&self, agent_id: &str) -> Result<Option<AgentStateMetadata>> {
528 let filter = Filter::Eq(
529 "agent_id".into(),
530 FieldValue::Utf8(Some(agent_id.to_string())),
531 );
532 let records = self
533 .backend
534 .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
535 .await?;
536
537 match records.first() {
538 Some(r) => Ok(Some(state_from_record(r)?)),
539 None => Ok(None),
540 }
541 }
542
543 pub async fn get_by_conversation(
545 &self,
546 conversation_id: &str,
547 ) -> Result<Vec<AgentStateMetadata>> {
548 let filter = Filter::Eq(
549 "conversation_id".into(),
550 FieldValue::Utf8(Some(conversation_id.to_string())),
551 );
552 let records = self
553 .backend
554 .query(AGENT_STATE_TABLE, Some(&filter), None)
555 .await?;
556
557 records.iter().map(state_from_record).collect()
558 }
559
560 pub async fn get_by_task(&self, task_id: &str) -> Result<Option<AgentStateMetadata>> {
562 let filter = Filter::Eq(
563 "task_id".into(),
564 FieldValue::Utf8(Some(task_id.to_string())),
565 );
566 let records = self
567 .backend
568 .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
569 .await?;
570
571 match records.first() {
572 Some(r) => Ok(Some(state_from_record(r)?)),
573 None => Ok(None),
574 }
575 }
576
577 pub async fn delete(&self, agent_id: &str) -> Result<()> {
579 let filter = Filter::Eq(
580 "agent_id".into(),
581 FieldValue::Utf8(Some(agent_id.to_string())),
582 );
583 self.backend
584 .delete(AGENT_STATE_TABLE, &filter)
585 .await
586 .context("Failed to delete agent state")?;
587 Ok(())
588 }
589
590 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
592 let filter = Filter::Eq(
593 "conversation_id".into(),
594 FieldValue::Utf8(Some(conversation_id.to_string())),
595 );
596 self.backend
597 .delete(AGENT_STATE_TABLE, &filter)
598 .await
599 .context("Failed to delete agent states for conversation")?;
600 Ok(())
601 }
602
603 pub fn agent_states_schema() -> Vec<FieldDef> {
605 agent_states_field_defs()
606 }
607
608 pub fn agent_states_arrow_schema() -> Arc<arrow_schema::Schema> {
610 use arrow_schema::{DataType, Field, Schema};
611 Arc::new(Schema::new(vec![
612 Field::new("agent_id", DataType::Utf8, false),
613 Field::new("task_id", DataType::Utf8, false),
614 Field::new("conversation_id", DataType::Utf8, false),
615 Field::new("status", DataType::Utf8, false),
616 Field::new("iteration", DataType::Int32, false),
617 Field::new("context_json", DataType::Utf8, false),
618 Field::new("created_at", DataType::Int64, false),
619 Field::new("updated_at", DataType::Int64, false),
620 ]))
621 }
622}