1use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use crate::databases::{
9 FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
10};
11use brainwires_core::{Task, TaskPriority, TaskStatus};
12
13const TASK_TABLE: &str = "tasks";
14const AGENT_STATE_TABLE: &str = "agent_states";
15
16fn tasks_field_defs() -> Vec<FieldDef> {
19 vec![
20 FieldDef::required("task_id", FieldType::Utf8),
21 FieldDef::required("conversation_id", FieldType::Utf8),
22 FieldDef::optional("plan_id", FieldType::Utf8),
23 FieldDef::required("description", FieldType::Utf8),
24 FieldDef::required("status", FieldType::Utf8),
25 FieldDef::optional("parent_id", FieldType::Utf8),
26 FieldDef::required("children", FieldType::Utf8), FieldDef::required("depends_on", FieldType::Utf8), FieldDef::required("priority", FieldType::Utf8),
29 FieldDef::optional("assigned_to", FieldType::Utf8),
30 FieldDef::required("iterations", FieldType::Int32),
31 FieldDef::optional("summary", FieldType::Utf8),
32 FieldDef::required("created_at", FieldType::Int64),
33 FieldDef::required("updated_at", FieldType::Int64),
34 FieldDef::optional("started_at", FieldType::Int64),
35 FieldDef::optional("completed_at", FieldType::Int64),
36 ]
37}
38
39fn agent_states_field_defs() -> Vec<FieldDef> {
40 vec![
41 FieldDef::required("agent_id", FieldType::Utf8),
42 FieldDef::required("task_id", FieldType::Utf8),
43 FieldDef::required("conversation_id", FieldType::Utf8),
44 FieldDef::required("status", FieldType::Utf8),
45 FieldDef::required("iteration", FieldType::Int32),
46 FieldDef::required("context_json", FieldType::Utf8),
47 FieldDef::required("created_at", FieldType::Int64),
48 FieldDef::required("updated_at", FieldType::Int64),
49 ]
50}
51
52fn task_to_record(m: &TaskMetadata) -> Record {
55 vec![
56 ("task_id".into(), FieldValue::Utf8(Some(m.task_id.clone()))),
57 (
58 "conversation_id".into(),
59 FieldValue::Utf8(Some(m.conversation_id.clone())),
60 ),
61 ("plan_id".into(), FieldValue::Utf8(m.plan_id.clone())),
62 (
63 "description".into(),
64 FieldValue::Utf8(Some(m.description.clone())),
65 ),
66 ("status".into(), FieldValue::Utf8(Some(m.status.clone()))),
67 ("parent_id".into(), FieldValue::Utf8(m.parent_id.clone())),
68 (
69 "children".into(),
70 FieldValue::Utf8(Some(m.children.clone())),
71 ),
72 (
73 "depends_on".into(),
74 FieldValue::Utf8(Some(m.depends_on.clone())),
75 ),
76 (
77 "priority".into(),
78 FieldValue::Utf8(Some(m.priority.clone())),
79 ),
80 (
81 "assigned_to".into(),
82 FieldValue::Utf8(m.assigned_to.clone()),
83 ),
84 ("iterations".into(), FieldValue::Int32(Some(m.iterations))),
85 ("summary".into(), FieldValue::Utf8(m.summary.clone())),
86 ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
87 ("updated_at".into(), FieldValue::Int64(Some(m.updated_at))),
88 ("started_at".into(), FieldValue::Int64(m.started_at)),
89 ("completed_at".into(), FieldValue::Int64(m.completed_at)),
90 ]
91}
92
93fn task_from_record(r: &Record) -> Result<TaskMetadata> {
94 Ok(TaskMetadata {
95 task_id: record_get(r, "task_id")
96 .and_then(|v| v.as_str())
97 .context("missing task_id")?
98 .to_string(),
99 conversation_id: record_get(r, "conversation_id")
100 .and_then(|v| v.as_str())
101 .context("missing conversation_id")?
102 .to_string(),
103 plan_id: record_get(r, "plan_id")
104 .and_then(|v| v.as_str())
105 .map(String::from),
106 description: record_get(r, "description")
107 .and_then(|v| v.as_str())
108 .context("missing description")?
109 .to_string(),
110 status: record_get(r, "status")
111 .and_then(|v| v.as_str())
112 .context("missing status")?
113 .to_string(),
114 parent_id: record_get(r, "parent_id")
115 .and_then(|v| v.as_str())
116 .map(String::from),
117 children: record_get(r, "children")
118 .and_then(|v| v.as_str())
119 .context("missing children")?
120 .to_string(),
121 depends_on: record_get(r, "depends_on")
122 .and_then(|v| v.as_str())
123 .context("missing depends_on")?
124 .to_string(),
125 priority: record_get(r, "priority")
126 .and_then(|v| v.as_str())
127 .context("missing priority")?
128 .to_string(),
129 assigned_to: record_get(r, "assigned_to")
130 .and_then(|v| v.as_str())
131 .map(String::from),
132 iterations: record_get(r, "iterations")
133 .and_then(|v| v.as_i32())
134 .context("missing iterations")?,
135 summary: record_get(r, "summary")
136 .and_then(|v| v.as_str())
137 .map(String::from),
138 created_at: record_get(r, "created_at")
139 .and_then(|v| v.as_i64())
140 .context("missing created_at")?,
141 updated_at: record_get(r, "updated_at")
142 .and_then(|v| v.as_i64())
143 .context("missing updated_at")?,
144 started_at: record_get(r, "started_at").and_then(|v| v.as_i64()),
145 completed_at: record_get(r, "completed_at").and_then(|v| v.as_i64()),
146 })
147}
148
149fn state_to_record(s: &AgentStateMetadata) -> Record {
150 vec![
151 (
152 "agent_id".into(),
153 FieldValue::Utf8(Some(s.agent_id.clone())),
154 ),
155 ("task_id".into(), FieldValue::Utf8(Some(s.task_id.clone()))),
156 (
157 "conversation_id".into(),
158 FieldValue::Utf8(Some(s.conversation_id.clone())),
159 ),
160 ("status".into(), FieldValue::Utf8(Some(s.status.clone()))),
161 ("iteration".into(), FieldValue::Int32(Some(s.iteration))),
162 (
163 "context_json".into(),
164 FieldValue::Utf8(Some(s.context_json.clone())),
165 ),
166 ("created_at".into(), FieldValue::Int64(Some(s.created_at))),
167 ("updated_at".into(), FieldValue::Int64(Some(s.updated_at))),
168 ]
169}
170
171fn state_from_record(r: &Record) -> Result<AgentStateMetadata> {
172 Ok(AgentStateMetadata {
173 agent_id: record_get(r, "agent_id")
174 .and_then(|v| v.as_str())
175 .context("missing agent_id")?
176 .to_string(),
177 task_id: record_get(r, "task_id")
178 .and_then(|v| v.as_str())
179 .context("missing task_id")?
180 .to_string(),
181 conversation_id: record_get(r, "conversation_id")
182 .and_then(|v| v.as_str())
183 .context("missing conversation_id")?
184 .to_string(),
185 status: record_get(r, "status")
186 .and_then(|v| v.as_str())
187 .context("missing status")?
188 .to_string(),
189 iteration: record_get(r, "iteration")
190 .and_then(|v| v.as_i32())
191 .context("missing iteration")?,
192 context_json: record_get(r, "context_json")
193 .and_then(|v| v.as_str())
194 .context("missing context_json")?
195 .to_string(),
196 created_at: record_get(r, "created_at")
197 .and_then(|v| v.as_i64())
198 .context("missing created_at")?,
199 updated_at: record_get(r, "updated_at")
200 .and_then(|v| v.as_i64())
201 .context("missing updated_at")?,
202 })
203}
204
205#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct TaskMetadata {
210 pub task_id: String,
212 pub conversation_id: String,
214 pub plan_id: Option<String>,
216 pub description: String,
218 pub status: String,
220 pub parent_id: Option<String>,
222 pub children: String, pub depends_on: String, pub priority: String,
228 pub assigned_to: Option<String>,
230 pub iterations: i32,
232 pub summary: Option<String>,
234 pub created_at: i64,
236 pub updated_at: i64,
238 pub started_at: Option<i64>,
240 pub completed_at: Option<i64>,
242}
243
244impl TaskMetadata {
245 pub fn from_task(task: &Task, conversation_id: &str) -> Self {
247 Self {
248 task_id: task.id.clone(),
249 conversation_id: conversation_id.to_string(),
250 plan_id: task.plan_id.clone(),
251 description: task.description.clone(),
252 status: format!("{:?}", task.status).to_lowercase(),
253 parent_id: task.parent_id.clone(),
254 children: serde_json::to_string(&task.children).unwrap_or_default(),
255 depends_on: serde_json::to_string(&task.depends_on).unwrap_or_default(),
256 priority: format!("{:?}", task.priority).to_lowercase(),
257 assigned_to: task.assigned_to.clone(),
258 iterations: task.iterations as i32,
259 summary: task.summary.clone(),
260 created_at: task.created_at,
261 updated_at: task.updated_at,
262 started_at: task.started_at,
263 completed_at: task.completed_at,
264 }
265 }
266
267 pub fn to_task(&self) -> Task {
269 let status = match self.status.as_str() {
270 "pending" => TaskStatus::Pending,
271 "inprogress" => TaskStatus::InProgress,
272 "completed" => TaskStatus::Completed,
273 "failed" => TaskStatus::Failed,
274 "blocked" => TaskStatus::Blocked,
275 _ => TaskStatus::Pending,
276 };
277
278 let priority = match self.priority.as_str() {
279 "low" => TaskPriority::Low,
280 "normal" => TaskPriority::Normal,
281 "high" => TaskPriority::High,
282 "urgent" => TaskPriority::Urgent,
283 _ => TaskPriority::Normal,
284 };
285
286 let children: Vec<String> = serde_json::from_str(&self.children).unwrap_or_default();
287 let depends_on: Vec<String> = serde_json::from_str(&self.depends_on).unwrap_or_default();
288
289 Task {
290 id: self.task_id.clone(),
291 description: self.description.clone(),
292 status,
293 plan_id: self.plan_id.clone(),
294 parent_id: self.parent_id.clone(),
295 children,
296 depends_on,
297 priority,
298 assigned_to: self.assigned_to.clone(),
299 iterations: self.iterations as u32,
300 summary: self.summary.clone(),
301 created_at: self.created_at,
302 updated_at: self.updated_at,
303 started_at: self.started_at,
304 completed_at: self.completed_at,
305 }
306 }
307}
308
309#[derive(Clone)]
313pub struct TaskStore<B: StorageBackend + 'static = crate::databases::lance::LanceDatabase> {
314 backend: Arc<B>,
315}
316
317impl<B: StorageBackend + 'static> TaskStore<B> {
318 pub fn new(backend: Arc<B>) -> Self {
320 Self { backend }
321 }
322
323 pub async fn ensure_table(&self) -> Result<()> {
325 self.backend
326 .ensure_table(TASK_TABLE, &tasks_field_defs())
327 .await
328 }
329
330 pub async fn save(&self, task: &Task, conversation_id: &str) -> Result<()> {
332 let metadata = TaskMetadata::from_task(task, conversation_id);
333
334 let _ = self.delete(&task.id).await;
336
337 self.backend
338 .insert(TASK_TABLE, vec![task_to_record(&metadata)])
339 .await
340 .context("Failed to save task")?;
341
342 Ok(())
343 }
344
345 pub async fn get(&self, task_id: &str) -> Result<Option<Task>> {
347 let filter = Filter::Eq(
348 "task_id".into(),
349 FieldValue::Utf8(Some(task_id.to_string())),
350 );
351 let records = self
352 .backend
353 .query(TASK_TABLE, Some(&filter), Some(1))
354 .await?;
355
356 match records.first() {
357 Some(r) => Ok(Some(task_from_record(r)?.to_task())),
358 None => Ok(None),
359 }
360 }
361
362 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<Task>> {
364 let filter = Filter::Eq(
365 "conversation_id".into(),
366 FieldValue::Utf8(Some(conversation_id.to_string())),
367 );
368 let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
369
370 records
371 .iter()
372 .map(|r| task_from_record(r).map(|m| m.to_task()))
373 .collect()
374 }
375
376 pub async fn get_by_plan(&self, plan_id: &str) -> Result<Vec<Task>> {
378 let filter = Filter::Eq(
379 "plan_id".into(),
380 FieldValue::Utf8(Some(plan_id.to_string())),
381 );
382 let records = self.backend.query(TASK_TABLE, Some(&filter), None).await?;
383
384 records
385 .iter()
386 .map(|r| task_from_record(r).map(|m| m.to_task()))
387 .collect()
388 }
389
390 pub async fn delete(&self, task_id: &str) -> Result<()> {
392 let filter = Filter::Eq(
393 "task_id".into(),
394 FieldValue::Utf8(Some(task_id.to_string())),
395 );
396 self.backend
397 .delete(TASK_TABLE, &filter)
398 .await
399 .context("Failed to delete task")?;
400 Ok(())
401 }
402
403 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
405 let filter = Filter::Eq(
406 "conversation_id".into(),
407 FieldValue::Utf8(Some(conversation_id.to_string())),
408 );
409 self.backend
410 .delete(TASK_TABLE, &filter)
411 .await
412 .context("Failed to delete tasks for conversation")?;
413 Ok(())
414 }
415
416 pub async fn delete_by_plan(&self, plan_id: &str) -> Result<()> {
418 let filter = Filter::Eq(
419 "plan_id".into(),
420 FieldValue::Utf8(Some(plan_id.to_string())),
421 );
422 self.backend
423 .delete(TASK_TABLE, &filter)
424 .await
425 .context("Failed to delete tasks for plan")?;
426 Ok(())
427 }
428
429 pub fn tasks_schema() -> Vec<FieldDef> {
431 tasks_field_defs()
432 }
433
434 #[cfg(feature = "native")]
436 pub fn tasks_arrow_schema() -> Arc<arrow_schema::Schema> {
437 use arrow_schema::{DataType, Field, Schema};
438 Arc::new(Schema::new(vec![
439 Field::new("task_id", DataType::Utf8, false),
440 Field::new("conversation_id", DataType::Utf8, false),
441 Field::new("plan_id", DataType::Utf8, true),
442 Field::new("description", DataType::Utf8, false),
443 Field::new("status", DataType::Utf8, false),
444 Field::new("parent_id", DataType::Utf8, true),
445 Field::new("children", DataType::Utf8, false),
446 Field::new("depends_on", DataType::Utf8, false),
447 Field::new("priority", DataType::Utf8, false),
448 Field::new("assigned_to", DataType::Utf8, true),
449 Field::new("iterations", DataType::Int32, false),
450 Field::new("summary", DataType::Utf8, true),
451 Field::new("created_at", DataType::Int64, false),
452 Field::new("updated_at", DataType::Int64, false),
453 Field::new("started_at", DataType::Int64, true),
454 Field::new("completed_at", DataType::Int64, true),
455 ]))
456 }
457}
458
459#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
463pub struct AgentStateMetadata {
464 pub agent_id: String,
466 pub task_id: String,
468 pub conversation_id: String,
470 pub status: String,
472 pub iteration: i32,
474 pub context_json: String, pub created_at: i64,
478 pub updated_at: i64,
480}
481
482pub struct AgentStateStore<B: StorageBackend + 'static = crate::databases::lance::LanceDatabase> {
486 backend: Arc<B>,
487}
488
489impl<B: StorageBackend + 'static> AgentStateStore<B> {
490 pub fn new(backend: Arc<B>) -> Self {
492 Self { backend }
493 }
494
495 pub async fn ensure_table(&self) -> Result<()> {
497 self.backend
498 .ensure_table(AGENT_STATE_TABLE, &agent_states_field_defs())
499 .await
500 }
501
502 pub async fn save(&self, state: &AgentStateMetadata) -> Result<()> {
504 let _ = self.delete(&state.agent_id).await;
506
507 self.backend
508 .insert(AGENT_STATE_TABLE, vec![state_to_record(state)])
509 .await
510 .context("Failed to save agent state")?;
511
512 Ok(())
513 }
514
515 pub async fn get(&self, agent_id: &str) -> Result<Option<AgentStateMetadata>> {
517 let filter = Filter::Eq(
518 "agent_id".into(),
519 FieldValue::Utf8(Some(agent_id.to_string())),
520 );
521 let records = self
522 .backend
523 .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
524 .await?;
525
526 match records.first() {
527 Some(r) => Ok(Some(state_from_record(r)?)),
528 None => Ok(None),
529 }
530 }
531
532 pub async fn get_by_conversation(
534 &self,
535 conversation_id: &str,
536 ) -> Result<Vec<AgentStateMetadata>> {
537 let filter = Filter::Eq(
538 "conversation_id".into(),
539 FieldValue::Utf8(Some(conversation_id.to_string())),
540 );
541 let records = self
542 .backend
543 .query(AGENT_STATE_TABLE, Some(&filter), None)
544 .await?;
545
546 records.iter().map(state_from_record).collect()
547 }
548
549 pub async fn get_by_task(&self, task_id: &str) -> Result<Option<AgentStateMetadata>> {
551 let filter = Filter::Eq(
552 "task_id".into(),
553 FieldValue::Utf8(Some(task_id.to_string())),
554 );
555 let records = self
556 .backend
557 .query(AGENT_STATE_TABLE, Some(&filter), Some(1))
558 .await?;
559
560 match records.first() {
561 Some(r) => Ok(Some(state_from_record(r)?)),
562 None => Ok(None),
563 }
564 }
565
566 pub async fn delete(&self, agent_id: &str) -> Result<()> {
568 let filter = Filter::Eq(
569 "agent_id".into(),
570 FieldValue::Utf8(Some(agent_id.to_string())),
571 );
572 self.backend
573 .delete(AGENT_STATE_TABLE, &filter)
574 .await
575 .context("Failed to delete agent state")?;
576 Ok(())
577 }
578
579 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
581 let filter = Filter::Eq(
582 "conversation_id".into(),
583 FieldValue::Utf8(Some(conversation_id.to_string())),
584 );
585 self.backend
586 .delete(AGENT_STATE_TABLE, &filter)
587 .await
588 .context("Failed to delete agent states for conversation")?;
589 Ok(())
590 }
591
592 pub fn agent_states_schema() -> Vec<FieldDef> {
594 agent_states_field_defs()
595 }
596
597 #[cfg(feature = "native")]
599 pub fn agent_states_arrow_schema() -> Arc<arrow_schema::Schema> {
600 use arrow_schema::{DataType, Field, Schema};
601 Arc::new(Schema::new(vec![
602 Field::new("agent_id", DataType::Utf8, false),
603 Field::new("task_id", DataType::Utf8, false),
604 Field::new("conversation_id", DataType::Utf8, false),
605 Field::new("status", DataType::Utf8, false),
606 Field::new("iteration", DataType::Int32, false),
607 Field::new("context_json", DataType::Utf8, false),
608 Field::new("created_at", DataType::Int64, false),
609 Field::new("updated_at", DataType::Int64, false),
610 ]))
611 }
612}