Skip to main content

mqdb_core/
constraint.rs

1// Copyright 2025-2026 LabOverWire. All rights reserved.
2// SPDX-License-Identifier: AGPL-3.0-only
3
4use crate::entity::Entity;
5use crate::error::{Error, Result};
6use crate::keys;
7use crate::storage::{BatchWriter, Storage};
8use crate::types::OwnershipConfig;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub enum OnDeleteAction {
15    Restrict,
16    Cascade,
17    SetNull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct UniqueConstraint {
22    pub entity: String,
23    pub fields: Vec<String>,
24    pub name: String,
25}
26
27impl UniqueConstraint {
28    pub fn new(entity: impl Into<String>, fields: Vec<String>) -> Self {
29        let entity = entity.into();
30        let name = format!("{}_{}_unique", entity, fields.join("_"));
31        Self {
32            entity,
33            fields,
34            name,
35        }
36    }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ForeignKeyConstraint {
41    pub source_entity: String,
42    pub source_field: String,
43    pub target_entity: String,
44    pub target_field: String,
45    pub on_delete: OnDeleteAction,
46    pub name: String,
47}
48
49impl ForeignKeyConstraint {
50    pub fn new(
51        source_entity: impl Into<String>,
52        source_field: impl Into<String>,
53        target_entity: impl Into<String>,
54        target_field: impl Into<String>,
55        on_delete: OnDeleteAction,
56    ) -> Self {
57        let source_entity = source_entity.into();
58        let source_field = source_field.into();
59        let target_entity = target_entity.into();
60        let target_field = target_field.into();
61        let name = format!("{source_entity}_{source_field}_{target_entity}_fk");
62        Self {
63            source_entity,
64            source_field,
65            target_entity,
66            target_field,
67            on_delete,
68            name,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct NotNullConstraint {
75    pub entity: String,
76    pub field: String,
77    pub name: String,
78}
79
80impl NotNullConstraint {
81    pub fn new(entity: impl Into<String>, field: impl Into<String>) -> Self {
82        let entity = entity.into();
83        let field = field.into();
84        let name = format!("{entity}_{field}_notnull");
85        Self {
86            entity,
87            field,
88            name,
89        }
90    }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum Constraint {
95    Unique(UniqueConstraint),
96    ForeignKey(ForeignKeyConstraint),
97    NotNull(NotNullConstraint),
98}
99
100impl Constraint {
101    #[must_use]
102    pub fn name(&self) -> &str {
103        match self {
104            Constraint::Unique(c) => &c.name,
105            Constraint::ForeignKey(c) => &c.name,
106            Constraint::NotNull(c) => &c.name,
107        }
108    }
109
110    #[must_use]
111    pub fn entity(&self) -> &str {
112        match self {
113            Constraint::Unique(c) => &c.entity,
114            Constraint::ForeignKey(c) => &c.source_entity,
115            Constraint::NotNull(c) => &c.entity,
116        }
117    }
118
119    #[must_use]
120    fn constraint_type(&self) -> &str {
121        match self {
122            Constraint::Unique(_) => "unique",
123            Constraint::ForeignKey(_) => "fk",
124            Constraint::NotNull(_) => "notnull",
125        }
126    }
127
128    #[must_use]
129    pub fn to_api_value(&self) -> Value {
130        match self {
131            Constraint::Unique(c) => json!({
132                "name": c.name,
133                "type": "unique",
134                "fields": c.fields,
135            }),
136            Constraint::ForeignKey(c) => json!({
137                "name": c.name,
138                "type": "fk",
139                "field": c.source_field,
140                "target_entity": c.target_entity,
141                "target_field": c.target_field,
142                "on_delete": match c.on_delete {
143                    OnDeleteAction::Restrict => "restrict",
144                    OnDeleteAction::Cascade => "cascade",
145                    OnDeleteAction::SetNull => "set_null",
146                }
147            }),
148            Constraint::NotNull(c) => json!({
149                "name": c.name,
150                "type": "notnull",
151                "field": c.field,
152            }),
153        }
154    }
155}
156
157pub struct CascadeOperation {
158    pub entity: String,
159    pub id: String,
160}
161
162pub struct SetNullOperation {
163    pub entity: String,
164    pub id: String,
165    pub field: String,
166}
167
168pub enum DeleteOperation {
169    Cascade(CascadeOperation),
170    SetNull(SetNullOperation),
171}
172
173pub struct OwnershipContext<'a> {
174    pub sender: &'a str,
175    pub ownership: &'a OwnershipConfig,
176}
177
178struct CrossOwnedRef {
179    entity: String,
180    id: String,
181    field: String,
182    owner: String,
183}
184
185pub struct ConstraintManager {
186    constraints: HashMap<String, Vec<Constraint>>,
187}
188
189impl ConstraintManager {
190    #[allow(clippy::must_use_candidate)]
191    pub fn new() -> Self {
192        Self {
193            constraints: HashMap::new(),
194        }
195    }
196
197    pub fn add_constraint(&mut self, constraint: Constraint) {
198        let entity = constraint.entity().to_string();
199        self.constraints.entry(entity).or_default().push(constraint);
200    }
201
202    #[must_use]
203    pub fn get_constraints(&self, entity: &str) -> &[Constraint] {
204        self.constraints.get(entity).map_or(&[], Vec::as_slice)
205    }
206
207    /// # Errors
208    /// Returns an error if any constraint is violated.
209    pub fn validate_create(
210        &self,
211        entity: &Entity,
212        _batch: &mut BatchWriter,
213        storage: &Storage,
214    ) -> Result<()> {
215        let constraints = self.get_constraints(&entity.name);
216
217        for constraint in constraints {
218            match constraint {
219                Constraint::NotNull(c) => Self::validate_not_null(entity, c)?,
220                Constraint::Unique(c) => Self::validate_unique(entity, c, storage)?,
221                Constraint::ForeignKey(c) => Self::validate_foreign_key(entity, c, storage)?,
222            }
223        }
224
225        Ok(())
226    }
227
228    /// # Errors
229    /// Returns an error if any constraint is violated.
230    pub fn validate_update(
231        &self,
232        entity: &Entity,
233        old_entity: &Entity,
234        _batch: &mut BatchWriter,
235        storage: &Storage,
236    ) -> Result<()> {
237        let constraints = self.get_constraints(&entity.name);
238
239        for constraint in constraints {
240            match constraint {
241                Constraint::NotNull(c) => Self::validate_not_null(entity, c)?,
242                Constraint::Unique(c) => {
243                    Self::validate_unique_update(entity, old_entity, c, storage)?;
244                }
245                Constraint::ForeignKey(c) => Self::validate_foreign_key(entity, c, storage)?,
246            }
247        }
248
249        Ok(())
250    }
251
252    /// # Errors
253    /// Returns an error if a foreign key constraint prevents deletion.
254    pub fn validate_delete(
255        &self,
256        entity: &Entity,
257        storage: &Storage,
258        ownership_ctx: Option<&OwnershipContext<'_>>,
259    ) -> Result<Vec<DeleteOperation>> {
260        use std::collections::HashSet;
261        let mut all_operations = Vec::new();
262        let mut visited = HashSet::new();
263        let mut cross_owned = Vec::new();
264        self.collect_delete_operations(
265            entity,
266            storage,
267            &mut all_operations,
268            &mut visited,
269            &mut cross_owned,
270            ownership_ctx,
271        )?;
272
273        self.classify_cross_owned_danglers(
274            entity,
275            &cross_owned,
276            &mut all_operations,
277            ownership_ctx,
278        )?;
279
280        Ok(all_operations)
281    }
282
283    #[allow(clippy::too_many_arguments)]
284    fn collect_delete_operations(
285        &self,
286        entity: &Entity,
287        storage: &Storage,
288        all_operations: &mut Vec<DeleteOperation>,
289        visited: &mut std::collections::HashSet<String>,
290        cross_owned: &mut Vec<CrossOwnedRef>,
291        ownership_ctx: Option<&OwnershipContext<'_>>,
292    ) -> Result<()> {
293        let entity_key = format!("{}/{}", entity.name, entity.id);
294        if visited.contains(&entity_key) {
295            return Ok(());
296        }
297        visited.insert(entity_key);
298
299        let all_constraints: Vec<&Constraint> = self.constraints.values().flatten().collect();
300
301        for constraint in all_constraints {
302            if let Constraint::ForeignKey(fk) = constraint
303                && fk.target_entity == entity.name
304            {
305                match fk.on_delete {
306                    OnDeleteAction::Restrict => {
307                        let referencing = Self::find_referencing_entities(
308                            storage,
309                            &fk.source_entity,
310                            &fk.source_field,
311                            &entity.id,
312                        )?;
313                        if !referencing.is_empty() {
314                            return Err(Error::ForeignKeyRestrict {
315                                entity: entity.name.clone(),
316                                id: entity.id.clone(),
317                                referencing_entity: fk.source_entity.clone(),
318                            });
319                        }
320                    }
321                    OnDeleteAction::Cascade => {
322                        let referencing = Self::find_referencing_entities(
323                            storage,
324                            &fk.source_entity,
325                            &fk.source_field,
326                            &entity.id,
327                        )?;
328                        for id in referencing {
329                            if Self::is_cross_owned(
330                                storage,
331                                &fk.source_entity,
332                                &fk.source_field,
333                                &id,
334                                ownership_ctx,
335                                cross_owned,
336                            )? {
337                                continue;
338                            }
339
340                            let cascade_key = keys::encode_data_key(&fk.source_entity, &id);
341                            if let Some(cascade_data) = storage.get(&cascade_key)? {
342                                let cascade_entity = Entity::deserialize(
343                                    fk.source_entity.clone(),
344                                    id.clone(),
345                                    &cascade_data,
346                                )?;
347                                self.collect_delete_operations(
348                                    &cascade_entity,
349                                    storage,
350                                    all_operations,
351                                    visited,
352                                    cross_owned,
353                                    ownership_ctx,
354                                )?;
355                            }
356                            all_operations.push(DeleteOperation::Cascade(CascadeOperation {
357                                entity: fk.source_entity.clone(),
358                                id,
359                            }));
360                        }
361                    }
362                    OnDeleteAction::SetNull => {
363                        let referencing = Self::find_referencing_entities(
364                            storage,
365                            &fk.source_entity,
366                            &fk.source_field,
367                            &entity.id,
368                        )?;
369                        for id in referencing {
370                            all_operations.push(DeleteOperation::SetNull(SetNullOperation {
371                                entity: fk.source_entity.clone(),
372                                id,
373                                field: fk.source_field.clone(),
374                            }));
375                        }
376                    }
377                }
378            }
379        }
380
381        Ok(())
382    }
383
384    fn is_cross_owned(
385        storage: &Storage,
386        source_entity: &str,
387        source_field: &str,
388        ref_id: &str,
389        ownership_ctx: Option<&OwnershipContext<'_>>,
390        cross_owned: &mut Vec<CrossOwnedRef>,
391    ) -> Result<bool> {
392        let Some(ctx) = ownership_ctx else {
393            return Ok(false);
394        };
395
396        if ctx.ownership.is_admin(ctx.sender) {
397            return Ok(false);
398        }
399
400        let Some(owner_field) = ctx.ownership.owner_field(source_entity) else {
401            return Ok(false);
402        };
403
404        let key = keys::encode_data_key(source_entity, ref_id);
405        let Some(data) = storage.get(&key)? else {
406            return Ok(false);
407        };
408        let entity = Entity::deserialize(source_entity.to_string(), ref_id.to_string(), &data)?;
409        let owner = entity
410            .data
411            .get(owner_field)
412            .and_then(|v| v.as_str())
413            .unwrap_or("");
414
415        if owner == ctx.sender {
416            return Ok(false);
417        }
418
419        cross_owned.push(CrossOwnedRef {
420            entity: source_entity.to_string(),
421            id: ref_id.to_string(),
422            field: source_field.to_string(),
423            owner: owner.to_string(),
424        });
425        Ok(true)
426    }
427
428    fn classify_cross_owned_danglers(
429        &self,
430        root_entity: &Entity,
431        cross_owned: &[CrossOwnedRef],
432        all_operations: &mut Vec<DeleteOperation>,
433        ownership_ctx: Option<&OwnershipContext<'_>>,
434    ) -> Result<()> {
435        if ownership_ctx.is_none() || cross_owned.is_empty() {
436            return Ok(());
437        }
438
439        for co in cross_owned {
440            if self.has_not_null_constraint(&co.entity, &co.field) {
441                return Err(Error::CascadeBlocked(Box::new(
442                    crate::error::CascadeBlockedInfo {
443                        entity: root_entity.name.clone(),
444                        id: root_entity.id.clone(),
445                        blocked_entity: co.entity.clone(),
446                        blocked_id: co.id.clone(),
447                        blocked_field: co.field.clone(),
448                        blocked_owner: co.owner.clone(),
449                    },
450                )));
451            }
452
453            all_operations.push(DeleteOperation::SetNull(SetNullOperation {
454                entity: co.entity.clone(),
455                id: co.id.clone(),
456                field: co.field.clone(),
457            }));
458        }
459
460        Ok(())
461    }
462
463    fn has_not_null_constraint(&self, entity: &str, field: &str) -> bool {
464        let Some(constraints) = self.constraints.get(entity) else {
465            return false;
466        };
467        constraints
468            .iter()
469            .any(|c| matches!(c, Constraint::NotNull(nn) if nn.field == field))
470    }
471
472    fn validate_not_null(entity: &Entity, constraint: &NotNullConstraint) -> Result<()> {
473        match entity.get_field(&constraint.field) {
474            Some(value) if !value.is_null() => Ok(()),
475            _ => Err(Error::NotNullViolation {
476                entity: entity.name.clone(),
477                field: constraint.field.clone(),
478            }),
479        }
480    }
481
482    fn validate_unique(
483        entity: &Entity,
484        constraint: &UniqueConstraint,
485        storage: &Storage,
486    ) -> Result<()> {
487        for field in &constraint.fields {
488            if let Some(value) = entity.get_field(field) {
489                let value_bytes = keys::encode_value_for_index(value)?;
490                let prefix = keys::encode_index_prefix(&entity.name, field, Some(&value_bytes));
491                let existing = storage.prefix_scan(&prefix)?;
492
493                for (key, _) in existing {
494                    if let Some(existing_id) = Self::extract_id_from_index_key(&key)
495                        && existing_id != entity.id
496                    {
497                        return Err(Error::UniqueViolation {
498                            entity: entity.name.clone(),
499                            field: field.clone(),
500                            value: String::from_utf8_lossy(&value_bytes).to_string(),
501                        });
502                    }
503                }
504            }
505        }
506
507        Ok(())
508    }
509
510    fn validate_unique_update(
511        entity: &Entity,
512        _old_entity: &Entity,
513        constraint: &UniqueConstraint,
514        storage: &Storage,
515    ) -> Result<()> {
516        Self::validate_unique(entity, constraint, storage)
517    }
518
519    fn validate_foreign_key(
520        entity: &Entity,
521        constraint: &ForeignKeyConstraint,
522        storage: &Storage,
523    ) -> Result<()> {
524        if let Some(fk_value) = entity.get_field(&constraint.source_field)
525            && !fk_value.is_null()
526        {
527            let target_id = fk_value.as_str().ok_or(Error::InvalidForeignKey)?;
528
529            let target_key = keys::encode_data_key(&constraint.target_entity, target_id);
530
531            if storage.get(&target_key)?.is_none() {
532                return Err(Error::ForeignKeyViolation {
533                    entity: entity.name.clone(),
534                    field: constraint.source_field.clone(),
535                    target_entity: constraint.target_entity.clone(),
536                    target_id: target_id.to_string(),
537                });
538            }
539        }
540
541        Ok(())
542    }
543
544    fn find_referencing_entities(
545        storage: &Storage,
546        source_entity: &str,
547        source_field: &str,
548        target_id: &str,
549    ) -> Result<Vec<String>> {
550        use crate::entity::Entity;
551
552        let prefix = keys::encode_data_key(source_entity, "");
553        let items = storage.prefix_scan(&prefix)?;
554
555        let mut referencing_ids = Vec::new();
556
557        for (key, value) in items {
558            if let Ok((_entity, id)) = keys::decode_data_key(&key)
559                && let Ok(entity) =
560                    Entity::deserialize(source_entity.to_string(), id.clone(), &value)
561                && let Some(fk_value) = entity.get_field(source_field)
562                && let Some(fk_str) = fk_value.as_str()
563                && fk_str == target_id
564            {
565                referencing_ids.push(id);
566            }
567        }
568
569        Ok(referencing_ids)
570    }
571
572    fn extract_id_from_index_key(key: &[u8]) -> Option<String> {
573        if let Some(last_slash) = key.iter().rposition(|&b| b == b'/') {
574            String::from_utf8(key[last_slash + 1..].to_vec()).ok()
575        } else {
576            None
577        }
578    }
579
580    /// # Errors
581    /// Returns an error if serialization fails.
582    pub fn persist_constraint(
583        &self,
584        batch: &mut BatchWriter,
585        constraint: &Constraint,
586    ) -> Result<()> {
587        let key = keys::encode_constraint_key(
588            constraint.constraint_type(),
589            constraint.entity(),
590            constraint.name(),
591        );
592        let value = serde_json::to_vec(constraint)?;
593        batch.insert(key, value);
594        Ok(())
595    }
596
597    /// # Errors
598    /// Returns an error if reading or deserializing constraints fails.
599    pub fn load_constraints(&mut self, storage: &Storage) -> Result<()> {
600        let prefix = b"meta/constraint/";
601        let items = storage.prefix_scan(prefix)?;
602
603        for (_key, value) in items {
604            let constraint: Constraint = serde_json::from_slice(&value)?;
605            self.add_constraint(constraint);
606        }
607
608        Ok(())
609    }
610
611    #[must_use]
612    pub fn entity_names(&self) -> Vec<String> {
613        self.constraints.keys().cloned().collect()
614    }
615
616    #[must_use]
617    pub fn all_constraints(&self) -> &HashMap<String, Vec<Constraint>> {
618        &self.constraints
619    }
620
621    pub fn remove_constraint(&mut self, batch: &mut BatchWriter, entity: &str, name: &str) {
622        if let Some(constraints) = self.constraints.get_mut(entity)
623            && let Some(pos) = constraints.iter().position(|c| c.name() == name)
624        {
625            let constraint = constraints.remove(pos);
626            let key = keys::encode_constraint_key(
627                constraint.constraint_type(),
628                constraint.entity(),
629                constraint.name(),
630            );
631            batch.remove(key);
632        }
633    }
634}
635
636impl Default for ConstraintManager {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_constraint_creation() {
648        let unique = UniqueConstraint::new("users", vec!["email".to_string()]);
649        assert_eq!(unique.entity, "users");
650        assert_eq!(unique.fields, vec!["email"]);
651        assert_eq!(unique.name, "users_email_unique");
652
653        let fk =
654            ForeignKeyConstraint::new("posts", "author_id", "users", "id", OnDeleteAction::Cascade);
655        assert_eq!(fk.source_entity, "posts");
656        assert_eq!(fk.source_field, "author_id");
657        assert_eq!(fk.target_entity, "users");
658        assert_eq!(fk.on_delete, OnDeleteAction::Cascade);
659
660        let not_null = NotNullConstraint::new("users", "email");
661        assert_eq!(not_null.entity, "users");
662        assert_eq!(not_null.field, "email");
663    }
664
665    #[test]
666    fn test_constraint_manager() {
667        let mut manager = ConstraintManager::new();
668
669        let constraint =
670            Constraint::Unique(UniqueConstraint::new("users", vec!["email".to_string()]));
671        manager.add_constraint(constraint);
672
673        let constraints = manager.get_constraints("users");
674        assert_eq!(constraints.len(), 1);
675    }
676}