oxify_storage/
execution_store.rs

1//! Execution storage implementation for SQLite
2
3use crate::{DatabasePool, Result, StorageError};
4use oxify_model::{ExecutionContext, ExecutionState, WorkflowId};
5use sqlx::Row;
6use uuid::Uuid;
7
8/// Maximum number of variables allowed in an execution context
9const MAX_VARIABLES: usize = 1000;
10
11/// Execution storage layer
12#[derive(Clone)]
13pub struct ExecutionStore {
14    pool: DatabasePool,
15}
16
17impl ExecutionStore {
18    /// Create a new execution store
19    pub fn new(pool: DatabasePool) -> Self {
20        Self { pool }
21    }
22
23    /// Create a new execution record
24    #[tracing::instrument(skip(self, ctx), fields(execution_id = %ctx.execution_id, workflow_id = %ctx.workflow_id))]
25    pub async fn create(&self, ctx: &ExecutionContext) -> Result<Uuid> {
26        // Validate variable count
27        if ctx.variables.len() > MAX_VARIABLES {
28            return Err(StorageError::ValidationError(format!(
29                "Execution has {} variables, which exceeds the maximum of {}",
30                ctx.variables.len(),
31                MAX_VARIABLES
32            )));
33        }
34
35        let id = ctx.execution_id.to_string();
36        let workflow_id = ctx.workflow_id.to_string();
37        let started_at = ctx.started_at.to_rfc3339();
38        let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
39        let state = format!("{:?}", ctx.state);
40        let context_json = serde_json::to_string(ctx)?;
41        let node_results = serde_json::to_string(&ctx.node_results)?;
42        let variables = serde_json::to_string(&ctx.variables)?;
43
44        sqlx::query(
45            r#"
46            INSERT INTO executions (id, workflow_id, started_at, completed_at, state, context, node_results, variables)
47            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
48            "#,
49        )
50        .bind(&id)
51        .bind(&workflow_id)
52        .bind(Some(&started_at))
53        .bind(&completed_at)
54        .bind(&state)
55        .bind(&context_json)
56        .bind(&node_results)
57        .bind(&variables)
58        .execute(self.pool.pool())
59        .await?;
60
61        Ok(ctx.execution_id)
62    }
63
64    /// Batch create multiple execution records
65    ///
66    /// This is more efficient than calling `create()` multiple times
67    /// as it uses a single database transaction.
68    ///
69    /// Returns the number of executions created.
70    #[tracing::instrument(skip(self, contexts), fields(batch_size = contexts.len()))]
71    pub async fn batch_create(&self, contexts: &[ExecutionContext]) -> Result<u64> {
72        if contexts.is_empty() {
73            return Ok(0);
74        }
75
76        // Validate all contexts first
77        for ctx in contexts {
78            if ctx.variables.len() > MAX_VARIABLES {
79                return Err(StorageError::ValidationError(format!(
80                    "Execution {} has {} variables, which exceeds the maximum of {}",
81                    ctx.execution_id,
82                    ctx.variables.len(),
83                    MAX_VARIABLES
84                )));
85            }
86        }
87
88        let mut tx = self.pool.pool().begin().await?;
89
90        for ctx in contexts {
91            let id = ctx.execution_id.to_string();
92            let workflow_id = ctx.workflow_id.to_string();
93            let started_at = ctx.started_at.to_rfc3339();
94            let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
95            let state = format!("{:?}", ctx.state);
96            let context_json = serde_json::to_string(ctx)?;
97            let node_results = serde_json::to_string(&ctx.node_results)?;
98            let variables = serde_json::to_string(&ctx.variables)?;
99
100            sqlx::query(
101                r#"
102                INSERT INTO executions (id, workflow_id, started_at, completed_at, state, context, node_results, variables)
103                VALUES (?, ?, ?, ?, ?, ?, ?, ?)
104                "#,
105            )
106            .bind(&id)
107            .bind(&workflow_id)
108            .bind(Some(&started_at))
109            .bind(&completed_at)
110            .bind(&state)
111            .bind(&context_json)
112            .bind(&node_results)
113            .bind(&variables)
114            .execute(&mut *tx)
115            .await?;
116        }
117
118        tx.commit().await?;
119
120        Ok(contexts.len() as u64)
121    }
122
123    /// Get an execution by ID
124    #[tracing::instrument(skip(self), fields(execution_id = %id))]
125    pub async fn get(&self, id: &Uuid) -> Result<Option<ExecutionContext>> {
126        let id_str = id.to_string();
127        let row = sqlx::query(
128            r#"
129            SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
130            FROM executions
131            WHERE id = ?
132            "#,
133        )
134        .bind(&id_str)
135        .fetch_optional(self.pool.pool())
136        .await?;
137
138        match row {
139            Some(row) => {
140                let context_str: String = row.get("context");
141                let ctx: ExecutionContext = serde_json::from_str(&context_str)?;
142                Ok(Some(ctx))
143            }
144            None => Ok(None),
145        }
146    }
147
148    /// List all executions
149    pub async fn list(&self) -> Result<Vec<(Uuid, ExecutionContext)>> {
150        let rows = sqlx::query(
151            r#"
152            SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
153            FROM executions
154            ORDER BY started_at DESC
155            "#,
156        )
157        .fetch_all(self.pool.pool())
158        .await?;
159
160        let executions: Vec<(Uuid, ExecutionContext)> = rows
161            .into_iter()
162            .filter_map(|row| {
163                let id_str: String = row.get("id");
164                let context_str: String = row.get("context");
165                let id = Uuid::parse_str(&id_str).ok()?;
166                let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
167                Some((id, ctx))
168            })
169            .collect();
170
171        Ok(executions)
172    }
173
174    /// List executions for a specific workflow
175    pub async fn list_by_workflow(
176        &self,
177        workflow_id: &WorkflowId,
178    ) -> Result<Vec<(Uuid, ExecutionContext)>> {
179        let workflow_id_str = workflow_id.to_string();
180        let rows = sqlx::query(
181            r#"
182            SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
183            FROM executions
184            WHERE workflow_id = ?
185            ORDER BY started_at DESC
186            "#,
187        )
188        .bind(&workflow_id_str)
189        .fetch_all(self.pool.pool())
190        .await?;
191
192        let executions: Vec<(Uuid, ExecutionContext)> = rows
193            .into_iter()
194            .filter_map(|row| {
195                let id_str: String = row.get("id");
196                let context_str: String = row.get("context");
197                let id = Uuid::parse_str(&id_str).ok()?;
198                let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
199                Some((id, ctx))
200            })
201            .collect();
202
203        Ok(executions)
204    }
205
206    /// List executions with pagination
207    pub async fn list_paginated(
208        &self,
209        limit: i64,
210        offset: i64,
211    ) -> Result<Vec<(Uuid, ExecutionContext)>> {
212        let rows = sqlx::query(
213            r#"
214            SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
215            FROM executions
216            ORDER BY started_at DESC
217            LIMIT ? OFFSET ?
218            "#,
219        )
220        .bind(limit)
221        .bind(offset)
222        .fetch_all(self.pool.pool())
223        .await?;
224
225        let executions: Vec<(Uuid, ExecutionContext)> = rows
226            .into_iter()
227            .filter_map(|row| {
228                let id_str: String = row.get("id");
229                let context_str: String = row.get("context");
230                let id = Uuid::parse_str(&id_str).ok()?;
231                let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
232                Some((id, ctx))
233            })
234            .collect();
235
236        Ok(executions)
237    }
238
239    /// Update an execution
240    #[tracing::instrument(skip(self, ctx), fields(execution_id = %id, new_state = ?ctx.state))]
241    pub async fn update(&self, id: &Uuid, ctx: &ExecutionContext) -> Result<bool> {
242        // Validate variable count
243        if ctx.variables.len() > MAX_VARIABLES {
244            return Err(StorageError::ValidationError(format!(
245                "Execution has {} variables, which exceeds the maximum of {}",
246                ctx.variables.len(),
247                MAX_VARIABLES
248            )));
249        }
250
251        let id_str = id.to_string();
252        let state = format!("{:?}", ctx.state);
253        let completed_at = ctx.completed_at.map(|t| t.to_rfc3339());
254        let context_json = serde_json::to_string(ctx)?;
255        let node_results = serde_json::to_string(&ctx.node_results)?;
256        let variables = serde_json::to_string(&ctx.variables)?;
257
258        // Extract error message if state is Failed
259        let error_message = match &ctx.state {
260            ExecutionState::Failed(msg) => Some(msg.clone()),
261            _ => None,
262        };
263
264        let result = sqlx::query(
265            r#"
266            UPDATE executions
267            SET completed_at = ?, state = ?, context = ?, node_results = ?, variables = ?, error_message = ?
268            WHERE id = ?
269            "#,
270        )
271        .bind(&completed_at)
272        .bind(&state)
273        .bind(&context_json)
274        .bind(&node_results)
275        .bind(&variables)
276        .bind(&error_message)
277        .bind(&id_str)
278        .execute(self.pool.pool())
279        .await?;
280
281        Ok(result.rows_affected() > 0)
282    }
283
284    /// Delete an execution
285    #[tracing::instrument(skip(self), fields(execution_id = %id))]
286    pub async fn delete(&self, id: &Uuid) -> Result<bool> {
287        let id_str = id.to_string();
288        let result = sqlx::query(
289            r#"
290            DELETE FROM executions
291            WHERE id = ?
292            "#,
293        )
294        .bind(&id_str)
295        .execute(self.pool.pool())
296        .await?;
297
298        Ok(result.rows_affected() > 0)
299    }
300
301    /// Count executions by state
302    pub async fn count_by_state(&self, state: &str) -> Result<i64> {
303        let row = sqlx::query(
304            r#"
305            SELECT COUNT(*) as count
306            FROM executions
307            WHERE state = ?
308            "#,
309        )
310        .bind(state)
311        .fetch_one(self.pool.pool())
312        .await?;
313
314        let count: i64 = row.get("count");
315        Ok(count)
316    }
317
318    /// Get active executions (Running or Paused)
319    pub async fn get_active(&self) -> Result<Vec<(Uuid, ExecutionContext)>> {
320        let rows = sqlx::query(
321            r#"
322            SELECT id, workflow_id, started_at, completed_at, state, context, node_results, variables, error_message
323            FROM executions
324            WHERE state IN ('Running', 'Paused')
325            ORDER BY started_at DESC
326            "#,
327        )
328        .fetch_all(self.pool.pool())
329        .await?;
330
331        let executions: Vec<(Uuid, ExecutionContext)> = rows
332            .into_iter()
333            .filter_map(|row| {
334                let id_str: String = row.get("id");
335                let context_str: String = row.get("context");
336                let id = Uuid::parse_str(&id_str).ok()?;
337                let ctx: ExecutionContext = serde_json::from_str(&context_str).ok()?;
338                Some((id, ctx))
339            })
340            .collect();
341
342        Ok(executions)
343    }
344
345    /// Delete all executions for a specific workflow
346    /// Returns the number of executions deleted
347    #[tracing::instrument(skip(self), fields(workflow_id = %workflow_id))]
348    pub async fn delete_by_workflow(&self, workflow_id: &WorkflowId) -> Result<u64> {
349        let workflow_id_str = workflow_id.to_string();
350        let result = sqlx::query(
351            r#"
352            DELETE FROM executions WHERE workflow_id = ?
353            "#,
354        )
355        .bind(&workflow_id_str)
356        .execute(self.pool.pool())
357        .await?;
358
359        Ok(result.rows_affected())
360    }
361
362    /// Archive completed executions older than the specified date
363    /// Returns the number of executions archived (deleted)
364    #[tracing::instrument(skip(self), fields(before = %before))]
365    pub async fn archive_completed(&self, before: chrono::DateTime<chrono::Utc>) -> Result<u64> {
366        let before_str = before.to_rfc3339();
367        let result = sqlx::query(
368            r#"
369            DELETE FROM executions
370            WHERE completed_at IS NOT NULL
371            AND completed_at < ?
372            "#,
373        )
374        .bind(&before_str)
375        .execute(self.pool.pool())
376        .await?;
377
378        Ok(result.rows_affected())
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use oxify_model::ExecutionContext;
386
387    async fn setup_test_pool() -> Result<DatabasePool> {
388        let config = crate::DatabaseConfig {
389            database_url: std::env::var("DATABASE_URL")
390                .unwrap_or_else(|_| "sqlite::memory:".to_string()),
391            ..Default::default()
392        };
393        DatabasePool::new(config).await
394    }
395
396    #[tokio::test]
397    #[ignore] // Requires database
398    async fn test_execution_crud() -> Result<()> {
399        let pool = setup_test_pool().await?;
400        pool.migrate().await?;
401
402        let store = ExecutionStore::new(pool);
403
404        // Create test execution
405        let workflow_id = Uuid::new_v4();
406        let mut ctx = ExecutionContext::new(workflow_id);
407
408        // Create
409        let id = store.create(&ctx).await?;
410        assert_eq!(id, ctx.execution_id);
411
412        // Get
413        let fetched = store.get(&id).await?;
414        assert!(fetched.is_some());
415
416        // Update
417        ctx.state = ExecutionState::Completed;
418        ctx.mark_completed();
419        let result = store.update(&id, &ctx).await?;
420        assert!(result);
421
422        // Delete
423        let result = store.delete(&id).await?;
424        assert!(result);
425
426        Ok(())
427    }
428}