1use crate::{DatabasePool, Result, StorageError};
4use oxify_model::{ExecutionContext, ExecutionState, WorkflowId};
5use sqlx::Row;
6use uuid::Uuid;
7
8const MAX_VARIABLES: usize = 1000;
10
11#[derive(Clone)]
13pub struct ExecutionStore {
14 pool: DatabasePool,
15}
16
17impl ExecutionStore {
18 pub fn new(pool: DatabasePool) -> Self {
20 Self { pool }
21 }
22
23 #[tracing::instrument(skip(self, ctx), fields(execution_id = %ctx.execution_id, workflow_id = %ctx.workflow_id))]
25 pub async fn create(&self, ctx: &ExecutionContext) -> Result<Uuid> {
26 if ctx.variables.len() > MAX_VARIABLES {
28 return Err(StorageError::ValidationError(format!(
29 "Execution has {} variables, which exceeds the maximum of {}",
30 ctx.variables.len(),
31 MAX_VARIABLES
32 )));
33 }
34
35 let id = ctx.execution_id.to_string();
36 let workflow_id = ctx.workflow_id.to_string();
37 let started_at = ctx.started_at.to_rfc3339();
38 let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
39 let state = format!("{:?}", ctx.state);
40 let context_json = serde_json::to_string(ctx)?;
41 let node_results = serde_json::to_string(&ctx.node_results)?;
42 let variables = serde_json::to_string(&ctx.variables)?;
43
44 sqlx::query(
45 r#"
46 INSERT INTO executions (id, workflow_id, started_at, completed_at, state, context, node_results, variables)
47 VALUES (?, ?, ?, ?, ?, ?, ?, ?)
48 "#,
49 )
50 .bind(&id)
51 .bind(&workflow_id)
52 .bind(Some(&started_at))
53 .bind(&completed_at)
54 .bind(&state)
55 .bind(&context_json)
56 .bind(&node_results)
57 .bind(&variables)
58 .execute(self.pool.pool())
59 .await?;
60
61 Ok(ctx.execution_id)
62 }
63
64 #[tracing::instrument(skip(self, contexts), fields(batch_size = contexts.len()))]
71 pub async fn batch_create(&self, contexts: &[ExecutionContext]) -> Result<u64> {
72 if contexts.is_empty() {
73 return Ok(0);
74 }
75
76 for ctx in contexts {
78 if ctx.variables.len() > MAX_VARIABLES {
79 return Err(StorageError::ValidationError(format!(
80 "Execution {} has {} variables, which exceeds the maximum of {}",
81 ctx.execution_id,
82 ctx.variables.len(),
83 MAX_VARIABLES
84 )));
85 }
86 }
87
88 let mut tx = self.pool.pool().begin().await?;
89
90 for ctx in contexts {
91 let id = ctx.execution_id.to_string();
92 let workflow_id = ctx.workflow_id.to_string();
93 let started_at = ctx.started_at.to_rfc3339();
94 let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
95 let state = format!("{:?}", ctx.state);
96 let context_json = serde_json::to_string(ctx)?;
97 let node_results = serde_json::to_string(&ctx.node_results)?;
98 let variables = serde_json::to_string(&ctx.variables)?;
99
100 sqlx::query(
101 r#"
102 INSERT INTO executions (id, workflow_id, started_at, completed_at, state, context, node_results, variables)
103 VALUES (?, ?, ?, ?, ?, ?, ?, ?)
104 "#,
105 )
106 .bind(&id)
107 .bind(&workflow_id)
108 .bind(Some(&started_at))
109 .bind(&completed_at)
110 .bind(&state)
111 .bind(&context_json)
112 .bind(&node_results)
113 .bind(&variables)
114 .execute(&mut *tx)
115 .await?;
116 }
117
118 tx.commit().await?;
119
120 Ok(contexts.len() as u64)
121 }
122
123 #[tracing::instrument(skip(self), fields(execution_id = %id))]
125 pub async fn get(&self, id: &Uuid) -> Result<Option<ExecutionContext>> {
126 let id_str = id.to_string();
127 let row = sqlx::query(
128 r#"
129 SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
130 FROM executions
131 WHERE id = ?
132 "#,
133 )
134 .bind(&id_str)
135 .fetch_optional(self.pool.pool())
136 .await?;
137
138 match row {
139 Some(row) => {
140 let context_str: String = row.get("context");
141 let ctx: ExecutionContext = serde_json::from_str(&context_str)?;
142 Ok(Some(ctx))
143 }
144 None => Ok(None),
145 }
146 }
147
148 pub async fn list(&self) -> Result<Vec<(Uuid, ExecutionContext)>> {
150 let rows = sqlx::query(
151 r#"
152 SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
153 FROM executions
154 ORDER BY started_at DESC
155 "#,
156 )
157 .fetch_all(self.pool.pool())
158 .await?;
159
160 let executions: Vec<(Uuid, ExecutionContext)> = rows
161 .into_iter()
162 .filter_map(|row| {
163 let id_str: String = row.get("id");
164 let context_str: String = row.get("context");
165 let id = Uuid::parse_str(&id_str).ok()?;
166 let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
167 Some((id, ctx))
168 })
169 .collect();
170
171 Ok(executions)
172 }
173
174 pub async fn list_by_workflow(
176 &self,
177 workflow_id: &WorkflowId,
178 ) -> Result<Vec<(Uuid, ExecutionContext)>> {
179 let workflow_id_str = workflow_id.to_string();
180 let rows = sqlx::query(
181 r#"
182 SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
183 FROM executions
184 WHERE workflow_id = ?
185 ORDER BY started_at DESC
186 "#,
187 )
188 .bind(&workflow_id_str)
189 .fetch_all(self.pool.pool())
190 .await?;
191
192 let executions: Vec<(Uuid, ExecutionContext)> = rows
193 .into_iter()
194 .filter_map(|row| {
195 let id_str: String = row.get("id");
196 let context_str: String = row.get("context");
197 let id = Uuid::parse_str(&id_str).ok()?;
198 let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
199 Some((id, ctx))
200 })
201 .collect();
202
203 Ok(executions)
204 }
205
206 pub async fn list_paginated(
208 &self,
209 limit: i64,
210 offset: i64,
211 ) -> Result<Vec<(Uuid, ExecutionContext)>> {
212 let rows = sqlx::query(
213 r#"
214 SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
215 FROM executions
216 ORDER BY started_at DESC
217 LIMIT ? OFFSET ?
218 "#,
219 )
220 .bind(limit)
221 .bind(offset)
222 .fetch_all(self.pool.pool())
223 .await?;
224
225 let executions: Vec<(Uuid, ExecutionContext)> = rows
226 .into_iter()
227 .filter_map(|row| {
228 let id_str: String = row.get("id");
229 let context_str: String = row.get("context");
230 let id = Uuid::parse_str(&id_str).ok()?;
231 let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
232 Some((id, ctx))
233 })
234 .collect();
235
236 Ok(executions)
237 }
238
239 #[tracing::instrument(skip(self, ctx), fields(execution_id = %id, new_state = ?ctx.state))]
241 pub async fn update(&self, id: &Uuid, ctx: &ExecutionContext) -> Result<bool> {
242 if ctx.variables.len() > MAX_VARIABLES {
244 return Err(StorageError::ValidationError(format!(
245 "Execution has {} variables, which exceeds the maximum of {}",
246 ctx.variables.len(),
247 MAX_VARIABLES
248 )));
249 }
250
251 let id_str = id.to_string();
252 let state = format!("{:?}", ctx.state);
253 let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
254 let context_json = serde_json::to_string(ctx)?;
255 let node_results = serde_json::to_string(&ctx.node_results)?;
256 let variables = serde_json::to_string(&ctx.variables)?;
257
258 let error_message = match &ctx.state {
260 ExecutionState::Failed(msg) => Some(msg.clone()),
261 _ => None,
262 };
263
264 let result = sqlx::query(
265 r#"
266 UPDATE executions
267 SET completed_at = ?, state = ?, context = ?, node_results = ?, variables = ?, error_message = ?
268 WHERE id = ?
269 "#,
270 )
271 .bind(&completed_at)
272 .bind(&state)
273 .bind(&context_json)
274 .bind(&node_results)
275 .bind(&variables)
276 .bind(&error_message)
277 .bind(&id_str)
278 .execute(self.pool.pool())
279 .await?;
280
281 Ok(result.rows_affected() > 0)
282 }
283
284 #[tracing::instrument(skip(self), fields(execution_id = %id))]
286 pub async fn delete(&self, id: &Uuid) -> Result<bool> {
287 let id_str = id.to_string();
288 let result = sqlx::query(
289 r#"
290 DELETE FROM executions
291 WHERE id = ?
292 "#,
293 )
294 .bind(&id_str)
295 .execute(self.pool.pool())
296 .await?;
297
298 Ok(result.rows_affected() > 0)
299 }
300
301 pub async fn count_by_state(&self, state: &str) -> Result<i64> {
303 let row = sqlx::query(
304 r#"
305 SELECT COUNT(*) as count
306 FROM executions
307 WHERE state = ?
308 "#,
309 )
310 .bind(state)
311 .fetch_one(self.pool.pool())
312 .await?;
313
314 let count: i64 = row.get("count");
315 Ok(count)
316 }
317
318 pub async fn get_active(&self) -> Result<Vec<(Uuid, ExecutionContext)>> {
320 let rows = sqlx::query(
321 r#"
322 SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
323 FROM executions
324 WHERE state IN ('Running', 'Paused')
325 ORDER BY started_at DESC
326 "#,
327 )
328 .fetch_all(self.pool.pool())
329 .await?;
330
331 let executions: Vec<(Uuid, ExecutionContext)> = rows
332 .into_iter()
333 .filter_map(|row| {
334 let id_str: String = row.get("id");
335 let context_str: String = row.get("context");
336 let id = Uuid::parse_str(&id_str).ok()?;
337 let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
338 Some((id, ctx))
339 })
340 .collect();
341
342 Ok(executions)
343 }
344
345 #[tracing::instrument(skip(self), fields(workflow_id = %workflow_id))]
348 pub async fn delete_by_workflow(&self, workflow_id: &WorkflowId) -> Result<u64> {
349 let workflow_id_str = workflow_id.to_string();
350 let result = sqlx::query(
351 r#"
352 DELETE FROM executions WHERE workflow_id = ?
353 "#,
354 )
355 .bind(&workflow_id_str)
356 .execute(self.pool.pool())
357 .await?;
358
359 Ok(result.rows_affected())
360 }
361
362 #[tracing::instrument(skip(self), fields(before = %before))]
365 pub async fn archive_completed(&self, before: chrono::DateTime<chrono::Utc>) -> Result<u64> {
366 let before_str = before.to_rfc3339();
367 let result = sqlx::query(
368 r#"
369 DELETE FROM executions
370 WHERE completed_at IS NOT NULL
371 AND completed_at < ?
372 "#,
373 )
374 .bind(&before_str)
375 .execute(self.pool.pool())
376 .await?;
377
378 Ok(result.rows_affected())
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use oxify_model::ExecutionContext;
386
387 async fn setup_test_pool() -> Result<DatabasePool> {
388 let config = crate::DatabaseConfig {
389 database_url: std::env::var("DATABASE_URL")
390 .unwrap_or_else(|_| "sqlite::memory:".to_string()),
391 ..Default::default()
392 };
393 DatabasePool::new(config).await
394 }
395
396 #[tokio::test]
397 #[ignore] async fn test_execution_crud() -> Result<()> {
399 let pool = setup_test_pool().await?;
400 pool.migrate().await?;
401
402 let store = ExecutionStore::new(pool);
403
404 let workflow_id = Uuid::new_v4();
406 let mut ctx = ExecutionContext::new(workflow_id);
407
408 let id = store.create(&ctx).await?;
410 assert_eq!(id, ctx.execution_id);
411
412 let fetched = store.get(&id).await?;
414 assert!(fetched.is_some());
415
416 ctx.state = ExecutionState::Completed;
418 ctx.mark_completed();
419 let result = store.update(&id, &ctx).await?;
420 assert!(result);
421
422 let result = store.delete(&id).await?;
424 assert!(result);
425
426 Ok(())
427 }
428}