Skip to main content

mqdb_core/
constraint.rs

1// Copyright 2025-2026 LabOverWire. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::entity::Entity;
5use crate::error::{Error, Result};
6use crate::keys;
7use crate::storage::{BatchWriter, Storage};
8use crate::types::OwnershipConfig;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub enum OnDeleteAction {
15    Restrict,
16    Cascade,
17    SetNull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct UniqueConstraint {
22    pub entity: String,
23    pub fields: Vec<String>,
24    pub name: String,
25}
26
27impl UniqueConstraint {
28    pub fn new(entity: impl Into<String>, fields: Vec<String>) -> Self {
29        let entity = entity.into();
30        let name = format!("{}_{}_unique", entity, fields.join("_"));
31        Self {
32            entity,
33            fields,
34            name,
35        }
36    }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ForeignKeyConstraint {
41    pub source_entity: String,
42    pub source_field: String,
43    pub target_entity: String,
44    pub target_field: String,
45    pub on_delete: OnDeleteAction,
46    pub name: String,
47}
48
49impl ForeignKeyConstraint {
50    pub fn new(
51        source_entity: impl Into<String>,
52        source_field: impl Into<String>,
53        target_entity: impl Into<String>,
54        target_field: impl Into<String>,
55        on_delete: OnDeleteAction,
56    ) -> Self {
57        let source_entity = source_entity.into();
58        let source_field = source_field.into();
59        let target_entity = target_entity.into();
60        let target_field = target_field.into();
61        let name = format!("{source_entity}_{source_field}_{target_entity}_fk");
62        Self {
63            source_entity,
64            source_field,
65            target_entity,
66            target_field,
67            on_delete,
68            name,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct NotNullConstraint {
75    pub entity: String,
76    pub field: String,
77    pub name: String,
78}
79
80impl NotNullConstraint {
81    pub fn new(entity: impl Into<String>, field: impl Into<String>) -> Self {
82        let entity = entity.into();
83        let field = field.into();
84        let name = format!("{entity}_{field}_notnull");
85        Self {
86            entity,
87            field,
88            name,
89        }
90    }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum Constraint {
95    Unique(UniqueConstraint),
96    ForeignKey(ForeignKeyConstraint),
97    NotNull(NotNullConstraint),
98}
99
100impl Constraint {
101    #[must_use]
102    pub fn name(&self) -> &str {
103        match self {
104            Constraint::Unique(c) => &c.name,
105            Constraint::ForeignKey(c) => &c.name,
106            Constraint::NotNull(c) => &c.name,
107        }
108    }
109
110    #[must_use]
111    pub fn entity(&self) -> &str {
112        match self {
113            Constraint::Unique(c) => &c.entity,
114            Constraint::ForeignKey(c) => &c.source_entity,
115            Constraint::NotNull(c) => &c.entity,
116        }
117    }
118
119    #[must_use]
120    fn constraint_type(&self) -> &str {
121        match self {
122            Constraint::Unique(_) => "unique",
123            Constraint::ForeignKey(_) => "fk",
124            Constraint::NotNull(_) => "notnull",
125        }
126    }
127
128    #[must_use]
129    pub fn to_api_value(&self) -> Value {
130        match self {
131            Constraint::Unique(c) => json!({
132                "name": c.name,
133                "type": "unique",
134                "fields": c.fields,
135            }),
136            Constraint::ForeignKey(c) => json!({
137                "name": c.name,
138                "type": "fk",
139                "field": c.source_field,
140                "target_entity": c.target_entity,
141                "target_field": c.target_field,
142                "on_delete": match c.on_delete {
143                    OnDeleteAction::Restrict => "restrict",
144                    OnDeleteAction::Cascade => "cascade",
145                    OnDeleteAction::SetNull => "set_null",
146                }
147            }),
148            Constraint::NotNull(c) => json!({
149                "name": c.name,
150                "type": "notnull",
151                "field": c.field,
152            }),
153        }
154    }
155}
156
157pub struct CascadeOperation {
158    pub entity: String,
159    pub id: String,
160}
161
162pub struct SetNullOperation {
163    pub entity: String,
164    pub id: String,
165    pub field: String,
166}
167
168pub enum DeleteOperation {
169    Cascade(CascadeOperation),
170    SetNull(SetNullOperation),
171}
172
173pub struct OwnershipContext<'a> {
174    pub sender: &'a str,
175    pub ownership: &'a OwnershipConfig,
176}
177
178struct CrossOwnedRef {
179    entity: String,
180    id: String,
181    field: String,
182    owner: String,
183}
184
185enum CascadeChild {
186    Skip,
187    Missing,
188    Recurse(Entity),
189}
190
191pub struct ConstraintManager {
192    constraints: HashMap<String, Vec<Constraint>>,
193}
194
195impl ConstraintManager {
196    #[allow(clippy::must_use_candidate)]
197    pub fn new() -> Self {
198        Self {
199            constraints: HashMap::new(),
200        }
201    }
202
203    pub fn add_constraint(&mut self, constraint: Constraint) {
204        let entity = constraint.entity().to_string();
205        self.constraints.entry(entity).or_default().push(constraint);
206    }
207
208    #[must_use]
209    pub fn get_constraints(&self, entity: &str) -> &[Constraint] {
210        self.constraints.get(entity).map_or(&[], Vec::as_slice)
211    }
212
213    /// # Errors
214    /// Returns an error if any constraint is violated.
215    pub fn validate_create(
216        &self,
217        entity: &Entity,
218        _batch: &mut BatchWriter,
219        storage: &Storage,
220    ) -> Result<()> {
221        let constraints = self.get_constraints(&entity.name);
222
223        for constraint in constraints {
224            match constraint {
225                Constraint::NotNull(c) => Self::validate_not_null(entity, c)?,
226                Constraint::Unique(c) => Self::validate_unique(entity, c, storage)?,
227                Constraint::ForeignKey(c) => Self::validate_foreign_key(entity, c, storage)?,
228            }
229        }
230
231        Ok(())
232    }
233
234    /// # Errors
235    /// Returns an error if any constraint is violated.
236    pub fn validate_update(
237        &self,
238        entity: &Entity,
239        old_entity: &Entity,
240        _batch: &mut BatchWriter,
241        storage: &Storage,
242    ) -> Result<()> {
243        let constraints = self.get_constraints(&entity.name);
244
245        for constraint in constraints {
246            match constraint {
247                Constraint::NotNull(c) => Self::validate_not_null(entity, c)?,
248                Constraint::Unique(c) => {
249                    Self::validate_unique_update(entity, old_entity, c, storage)?;
250                }
251                Constraint::ForeignKey(c) => Self::validate_foreign_key(entity, c, storage)?,
252            }
253        }
254
255        Ok(())
256    }
257
258    /// # Errors
259    /// Returns an error if a foreign key constraint prevents deletion.
260    pub fn validate_delete(
261        &self,
262        entity: &Entity,
263        storage: &Storage,
264        ownership_ctx: Option<&OwnershipContext<'_>>,
265    ) -> Result<Vec<DeleteOperation>> {
266        use std::collections::HashSet;
267        let mut all_operations = Vec::new();
268        let mut visited = HashSet::new();
269        let mut cross_owned = Vec::new();
270        self.collect_delete_operations(
271            entity,
272            storage,
273            &mut all_operations,
274            &mut visited,
275            &mut cross_owned,
276            ownership_ctx,
277        )?;
278
279        self.classify_cross_owned_danglers(
280            entity,
281            &cross_owned,
282            &mut all_operations,
283            ownership_ctx,
284        )?;
285
286        Ok(all_operations)
287    }
288
289    #[allow(clippy::too_many_arguments)]
290    fn collect_delete_operations(
291        &self,
292        entity: &Entity,
293        storage: &Storage,
294        all_operations: &mut Vec<DeleteOperation>,
295        visited: &mut std::collections::HashSet<String>,
296        cross_owned: &mut Vec<CrossOwnedRef>,
297        ownership_ctx: Option<&OwnershipContext<'_>>,
298    ) -> Result<()> {
299        let entity_key = format!("{}/{}", entity.name, entity.id);
300        if visited.contains(&entity_key) {
301            return Ok(());
302        }
303        visited.insert(entity_key);
304
305        let all_constraints: Vec<&Constraint> = self.constraints.values().flatten().collect();
306
307        for constraint in all_constraints {
308            if let Constraint::ForeignKey(fk) = constraint
309                && fk.target_entity == entity.name
310            {
311                match fk.on_delete {
312                    OnDeleteAction::Restrict => {
313                        let referencing = Self::find_referencing_entities(
314                            storage,
315                            &fk.source_entity,
316                            &fk.source_field,
317                            &entity.id,
318                        )?;
319                        if !referencing.is_empty() {
320                            return Err(Error::ForeignKeyRestrict {
321                                entity: entity.name.clone(),
322                                id: entity.id.clone(),
323                                referencing_entity: fk.source_entity.clone(),
324                            });
325                        }
326                    }
327                    OnDeleteAction::Cascade => {
328                        let referencing = Self::find_referencing_entities(
329                            storage,
330                            &fk.source_entity,
331                            &fk.source_field,
332                            &entity.id,
333                        )?;
334                        for id in referencing {
335                            match Self::classify_cascade_child(
336                                storage,
337                                &fk.source_entity,
338                                &fk.source_field,
339                                &id,
340                                ownership_ctx,
341                                cross_owned,
342                            )? {
343                                CascadeChild::Skip => continue,
344                                CascadeChild::Missing => {}
345                                CascadeChild::Recurse(child_entity) => {
346                                    self.collect_delete_operations(
347                                        &child_entity,
348                                        storage,
349                                        all_operations,
350                                        visited,
351                                        cross_owned,
352                                        ownership_ctx,
353                                    )?;
354                                }
355                            }
356                            all_operations.push(DeleteOperation::Cascade(CascadeOperation {
357                                entity: fk.source_entity.clone(),
358                                id,
359                            }));
360                        }
361                    }
362                    OnDeleteAction::SetNull => {
363                        let referencing = Self::find_referencing_entities(
364                            storage,
365                            &fk.source_entity,
366                            &fk.source_field,
367                            &entity.id,
368                        )?;
369                        for id in referencing {
370                            all_operations.push(DeleteOperation::SetNull(SetNullOperation {
371                                entity: fk.source_entity.clone(),
372                                id,
373                                field: fk.source_field.clone(),
374                            }));
375                        }
376                    }
377                }
378            }
379        }
380
381        Ok(())
382    }
383
384    fn classify_cascade_child(
385        storage: &Storage,
386        source_entity: &str,
387        source_field: &str,
388        ref_id: &str,
389        ownership_ctx: Option<&OwnershipContext<'_>>,
390        cross_owned: &mut Vec<CrossOwnedRef>,
391    ) -> Result<CascadeChild> {
392        let key = keys::encode_data_key(source_entity, ref_id);
393        let Some(data) = storage.get(&key)? else {
394            return Ok(CascadeChild::Missing);
395        };
396        let entity = Entity::deserialize(source_entity.to_string(), ref_id.to_string(), &data)?;
397
398        let Some(ctx) = ownership_ctx else {
399            return Ok(CascadeChild::Recurse(entity));
400        };
401        if ctx.ownership.is_admin(ctx.sender) {
402            return Ok(CascadeChild::Recurse(entity));
403        }
404        let Some(owner_field) = ctx.ownership.owner_field(source_entity) else {
405            return Ok(CascadeChild::Recurse(entity));
406        };
407
408        let owner = entity
409            .data
410            .get(owner_field)
411            .and_then(|v| v.as_str())
412            .unwrap_or("");
413
414        if owner == ctx.sender {
415            return Ok(CascadeChild::Recurse(entity));
416        }
417
418        cross_owned.push(CrossOwnedRef {
419            entity: source_entity.to_string(),
420            id: ref_id.to_string(),
421            field: source_field.to_string(),
422            owner: owner.to_string(),
423        });
424        Ok(CascadeChild::Skip)
425    }
426
427    fn classify_cross_owned_danglers(
428        &self,
429        root_entity: &Entity,
430        cross_owned: &[CrossOwnedRef],
431        all_operations: &mut Vec<DeleteOperation>,
432        ownership_ctx: Option<&OwnershipContext<'_>>,
433    ) -> Result<()> {
434        if ownership_ctx.is_none() || cross_owned.is_empty() {
435            return Ok(());
436        }
437
438        for co in cross_owned {
439            if self.has_not_null_constraint(&co.entity, &co.field) {
440                return Err(Error::CascadeBlocked(Box::new(
441                    crate::error::CascadeBlockedInfo {
442                        entity: root_entity.name.clone(),
443                        id: root_entity.id.clone(),
444                        blocked_entity: co.entity.clone(),
445                        blocked_id: co.id.clone(),
446                        blocked_field: co.field.clone(),
447                        blocked_owner: co.owner.clone(),
448                    },
449                )));
450            }
451
452            all_operations.push(DeleteOperation::SetNull(SetNullOperation {
453                entity: co.entity.clone(),
454                id: co.id.clone(),
455                field: co.field.clone(),
456            }));
457        }
458
459        Ok(())
460    }
461
462    fn has_not_null_constraint(&self, entity: &str, field: &str) -> bool {
463        let Some(constraints) = self.constraints.get(entity) else {
464            return false;
465        };
466        constraints
467            .iter()
468            .any(|c| matches!(c, Constraint::NotNull(nn) if nn.field == field))
469    }
470
471    fn validate_not_null(entity: &Entity, constraint: &NotNullConstraint) -> Result<()> {
472        match entity.get_field(&constraint.field) {
473            Some(value) if !value.is_null() => Ok(()),
474            _ => Err(Error::NotNullViolation {
475                entity: entity.name.clone(),
476                field: constraint.field.clone(),
477            }),
478        }
479    }
480
481    fn validate_unique(
482        entity: &Entity,
483        constraint: &UniqueConstraint,
484        storage: &Storage,
485    ) -> Result<()> {
486        for field in &constraint.fields {
487            if let Some(value) = entity.get_field(field) {
488                let value_bytes = keys::encode_value_for_index(value)?;
489                let prefix = keys::encode_index_prefix(&entity.name, field, Some(&value_bytes));
490                let existing = storage.prefix_scan(&prefix)?;
491
492                for (key, _) in existing {
493                    if let Some(existing_id) = Self::extract_id_from_index_key(&key)
494                        && existing_id != entity.id
495                    {
496                        return Err(Error::UniqueViolation {
497                            entity: entity.name.clone(),
498                            field: field.clone(),
499                            value: String::from_utf8_lossy(&value_bytes).to_string(),
500                        });
501                    }
502                }
503            }
504        }
505
506        Ok(())
507    }
508
509    fn validate_unique_update(
510        entity: &Entity,
511        _old_entity: &Entity,
512        constraint: &UniqueConstraint,
513        storage: &Storage,
514    ) -> Result<()> {
515        Self::validate_unique(entity, constraint, storage)
516    }
517
518    fn validate_foreign_key(
519        entity: &Entity,
520        constraint: &ForeignKeyConstraint,
521        storage: &Storage,
522    ) -> Result<()> {
523        if let Some(fk_value) = entity.get_field(&constraint.source_field)
524            && !fk_value.is_null()
525        {
526            let target_id = fk_value.as_str().ok_or(Error::InvalidForeignKey)?;
527
528            let target_key = keys::encode_data_key(&constraint.target_entity, target_id);
529
530            if storage.get(&target_key)?.is_none() {
531                return Err(Error::ForeignKeyViolation {
532                    entity: entity.name.clone(),
533                    field: constraint.source_field.clone(),
534                    target_entity: constraint.target_entity.clone(),
535                    target_id: target_id.to_string(),
536                });
537            }
538        }
539
540        Ok(())
541    }
542
543    fn find_referencing_entities(
544        storage: &Storage,
545        source_entity: &str,
546        source_field: &str,
547        target_id: &str,
548    ) -> Result<Vec<String>> {
549        use crate::entity::Entity;
550
551        let prefix = keys::encode_data_key(source_entity, "");
552        let items = storage.prefix_scan(&prefix)?;
553
554        let mut referencing_ids = Vec::new();
555
556        for (key, value) in items {
557            if let Ok((_entity, id)) = keys::decode_data_key(&key)
558                && let Ok(entity) =
559                    Entity::deserialize(source_entity.to_string(), id.clone(), &value)
560                && let Some(fk_value) = entity.get_field(source_field)
561                && let Some(fk_str) = fk_value.as_str()
562                && fk_str == target_id
563            {
564                referencing_ids.push(id);
565            }
566        }
567
568        Ok(referencing_ids)
569    }
570
571    fn extract_id_from_index_key(key: &[u8]) -> Option<String> {
572        if let Some(last_slash) = key.iter().rposition(|&b| b == b'/') {
573            String::from_utf8(key[last_slash + 1..].to_vec()).ok()
574        } else {
575            None
576        }
577    }
578
579    /// # Errors
580    /// Returns an error if serialization fails.
581    pub fn persist_constraint(
582        &self,
583        batch: &mut BatchWriter,
584        constraint: &Constraint,
585    ) -> Result<()> {
586        let key = keys::encode_constraint_key(
587            constraint.constraint_type(),
588            constraint.entity(),
589            constraint.name(),
590        );
591        let value = serde_json::to_vec(constraint)?;
592        batch.insert(key, value);
593        Ok(())
594    }
595
596    /// # Errors
597    /// Returns an error if reading or deserializing constraints fails.
598    pub fn load_constraints(&mut self, storage: &Storage) -> Result<()> {
599        let prefix = b"meta/constraint/";
600        let items = storage.prefix_scan(prefix)?;
601
602        for (_key, value) in items {
603            let constraint: Constraint = serde_json::from_slice(&value)?;
604            self.add_constraint(constraint);
605        }
606
607        Ok(())
608    }
609
610    #[must_use]
611    pub fn entity_names(&self) -> Vec<String> {
612        self.constraints.keys().cloned().collect()
613    }
614
615    #[must_use]
616    pub fn all_constraints(&self) -> &HashMap<String, Vec<Constraint>> {
617        &self.constraints
618    }
619
620    pub fn remove_constraint(&mut self, batch: &mut BatchWriter, entity: &str, name: &str) {
621        if let Some(constraints) = self.constraints.get_mut(entity)
622            && let Some(pos) = constraints.iter().position(|c| c.name() == name)
623        {
624            let constraint = constraints.remove(pos);
625            let key = keys::encode_constraint_key(
626                constraint.constraint_type(),
627                constraint.entity(),
628                constraint.name(),
629            );
630            batch.remove(key);
631        }
632    }
633}
634
635impl Default for ConstraintManager {
636    fn default() -> Self {
637        Self::new()
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644
645    #[test]
646    fn test_constraint_creation() {
647        let unique = UniqueConstraint::new("users", vec!["email".to_string()]);
648        assert_eq!(unique.entity, "users");
649        assert_eq!(unique.fields, vec!["email"]);
650        assert_eq!(unique.name, "users_email_unique");
651
652        let fk =
653            ForeignKeyConstraint::new("posts", "author_id", "users", "id", OnDeleteAction::Cascade);
654        assert_eq!(fk.source_entity, "posts");
655        assert_eq!(fk.source_field, "author_id");
656        assert_eq!(fk.target_entity, "users");
657        assert_eq!(fk.on_delete, OnDeleteAction::Cascade);
658
659        let not_null = NotNullConstraint::new("users", "email");
660        assert_eq!(not_null.entity, "users");
661        assert_eq!(not_null.field, "email");
662    }
663
664    #[test]
665    fn test_constraint_manager() {
666        let mut manager = ConstraintManager::new();
667
668        let constraint =
669            Constraint::Unique(UniqueConstraint::new("users", vec!["email".to_string()]));
670        manager.add_constraint(constraint);
671
672        let constraints = manager.get_constraints("users");
673        assert_eq!(constraints.len(), 1);
674    }
675}