Skip to main content

reifydb_type/value/constraint/
mod.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 ReifyDB
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7	error::{ConstraintKind, Error, TypeError},
8	fragment::Fragment,
9	value::{
10		Value,
11		constraint::{bytes::MaxBytes, precision::Precision, scale::Scale},
12		dictionary::DictionaryId,
13		sumtype::SumTypeId,
14		r#type::Type,
15	},
16};
17
18pub mod bytes;
19pub mod precision;
20pub mod scale;
21
22/// Represents a type with optional constraints
23#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct TypeConstraint {
25	base_type: Type,
26	constraint: Option<Constraint>,
27}
28
29/// Constraint types for different data types
30#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
31pub enum Constraint {
32	/// Maximum number of bytes for UTF8, BLOB, INT, UINT
33	MaxBytes(MaxBytes),
34	/// Precision and scale for DECIMAL
35	PrecisionScale(Precision, Scale),
36	/// Dictionary constraint: (catalog dictionary ID, id_type)
37	Dictionary(DictionaryId, Type),
38	/// Sum type constraint: links a logical column to a catalog SumTypeDef
39	SumType(SumTypeId),
40}
41
42/// FFI-safe representation of a TypeConstraint
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44#[repr(C)]
45pub struct FFITypeConstraint {
46	/// Base type code (Type::to_u8)
47	pub base_type: u8,
48	/// Constraint type: 0=None, 1=MaxBytes, 2=PrecisionScale, 3=Dictionary, 4=SumType
49	pub constraint_type: u8,
50	/// First constraint param: MaxBytes value OR precision OR dictionary_id low 32 bits
51	pub constraint_param1: u32,
52	/// Second constraint param: scale (PrecisionScale) OR id_type (Dictionary)
53	pub constraint_param2: u32,
54}
55
56impl TypeConstraint {
57	/// Create an unconstrained type
58	pub const fn unconstrained(ty: Type) -> Self {
59		Self {
60			base_type: ty,
61			constraint: None,
62		}
63	}
64
65	/// Create a type with a constraint
66	pub fn with_constraint(ty: Type, constraint: Constraint) -> Self {
67		Self {
68			base_type: ty,
69			constraint: Some(constraint),
70		}
71	}
72
73	/// Create a dictionary type constraint
74	pub fn dictionary(dictionary_id: DictionaryId, id_type: Type) -> Self {
75		Self {
76			base_type: Type::DictionaryId,
77			constraint: Some(Constraint::Dictionary(dictionary_id, id_type)),
78		}
79	}
80
81	/// Create a sum type constraint (tag stored as Uint1)
82	pub fn sumtype(id: SumTypeId) -> Self {
83		Self {
84			base_type: Type::Uint1,
85			constraint: Some(Constraint::SumType(id)),
86		}
87	}
88
89	/// Get the base type
90	pub fn get_type(&self) -> Type {
91		self.base_type.clone()
92	}
93
94	/// Get the storage type. For DictionaryId with a Dictionary constraint,
95	/// returns the id_type (e.g. Uint4). For all other types, returns the base_type.
96	pub fn storage_type(&self) -> Type {
97		match (&self.base_type, &self.constraint) {
98			(Type::DictionaryId, Some(Constraint::Dictionary(_, id_type))) => id_type.clone(),
99			_ => self.base_type.clone(),
100		}
101	}
102
103	/// Get the constraint
104	pub fn constraint(&self) -> &Option<Constraint> {
105		&self.constraint
106	}
107
108	/// Convert to FFI representation
109	pub fn to_ffi(&self) -> FFITypeConstraint {
110		let base_type = self.base_type.to_u8();
111		match &self.constraint {
112			None => FFITypeConstraint {
113				base_type,
114				constraint_type: 0,
115				constraint_param1: 0,
116				constraint_param2: 0,
117			},
118			Some(Constraint::MaxBytes(max)) => FFITypeConstraint {
119				base_type,
120				constraint_type: 1,
121				constraint_param1: max.value(),
122				constraint_param2: 0,
123			},
124			Some(Constraint::PrecisionScale(p, s)) => FFITypeConstraint {
125				base_type,
126				constraint_type: 2,
127				constraint_param1: p.value() as u32,
128				constraint_param2: s.value() as u32,
129			},
130			Some(Constraint::Dictionary(dict_id, id_type)) => FFITypeConstraint {
131				base_type,
132				constraint_type: 3,
133				constraint_param1: dict_id.to_u64() as u32,
134				constraint_param2: id_type.to_u8() as u32,
135			},
136			Some(Constraint::SumType(id)) => FFITypeConstraint {
137				base_type,
138				constraint_type: 4,
139				constraint_param1: id.to_u64() as u32,
140				constraint_param2: 0,
141			},
142		}
143	}
144
145	/// Create from FFI representation
146	pub fn from_ffi(ffi: FFITypeConstraint) -> Self {
147		let ty = Type::from_u8(ffi.base_type);
148		match ffi.constraint_type {
149			1 => Self::with_constraint(ty, Constraint::MaxBytes(MaxBytes::new(ffi.constraint_param1))),
150			2 => Self::with_constraint(
151				ty,
152				Constraint::PrecisionScale(
153					Precision::new(ffi.constraint_param1 as u8),
154					Scale::new(ffi.constraint_param2 as u8),
155				),
156			),
157			3 => Self::with_constraint(
158				ty,
159				Constraint::Dictionary(
160					DictionaryId::from(ffi.constraint_param1 as u64),
161					Type::from_u8(ffi.constraint_param2 as u8),
162				),
163			),
164			4 => Self::with_constraint(
165				ty,
166				Constraint::SumType(SumTypeId::from(ffi.constraint_param1 as u64)),
167			),
168			_ => Self::unconstrained(ty),
169		}
170	}
171
172	/// Validate a value against this type constraint
173	pub fn validate(&self, value: &Value) -> Result<(), Error> {
174		// First check type compatibility
175		let value_type = value.get_type();
176		if value_type != self.base_type && !matches!(value, Value::None { .. }) {
177			// For Option types, also accept values matching the inner type
178			if let Type::Option(inner) = &self.base_type {
179				if value_type != **inner {
180					unimplemented!()
181				}
182			} else {
183				unimplemented!()
184			}
185		}
186
187		// If None, only allow for Option types
188		if matches!(value, Value::None { .. }) {
189			if self.base_type.is_option() {
190				return Ok(());
191			} else {
192				return Err(TypeError::ConstraintViolation {
193					kind: ConstraintKind::NoneNotAllowed {
194						column_type: self.base_type.clone(),
195					},
196					message: format!(
197						"Cannot insert none into non-optional column of type {}. Declare the column as Option({}) to allow none values.",
198						self.base_type, self.base_type
199					),
200					fragment: Fragment::None,
201				}
202				.into());
203			}
204		}
205
206		// Check constraints if present
207		match (&self.base_type, &self.constraint) {
208			(Type::Utf8, Some(Constraint::MaxBytes(max))) => {
209				if let Value::Utf8(s) = value {
210					let byte_len = s.as_bytes().len();
211					let max_value: usize = (*max).into();
212					if byte_len > max_value {
213						return Err(TypeError::ConstraintViolation {
214							kind: ConstraintKind::Utf8MaxBytes {
215								actual: byte_len,
216								max: max_value,
217							},
218							message: format!(
219								"UTF8 value exceeds maximum byte length: {} bytes (max: {} bytes)",
220								byte_len, max_value
221							),
222							fragment: Fragment::None,
223						}
224						.into());
225					}
226				}
227			}
228			(Type::Blob, Some(Constraint::MaxBytes(max))) => {
229				if let Value::Blob(blob) = value {
230					let byte_len = blob.len();
231					let max_value: usize = (*max).into();
232					if byte_len > max_value {
233						return Err(TypeError::ConstraintViolation {
234							kind: ConstraintKind::BlobMaxBytes {
235								actual: byte_len,
236								max: max_value,
237							},
238							message: format!(
239								"BLOB value exceeds maximum byte length: {} bytes (max: {} bytes)",
240								byte_len, max_value
241							),
242							fragment: Fragment::None,
243						}
244						.into());
245					}
246				}
247			}
248			(Type::Int, Some(Constraint::MaxBytes(max))) => {
249				if let Value::Int(vi) = value {
250					// Calculate byte size of Int by
251					// converting to string and estimating
252					// This is a rough estimate: each
253					// decimal digit needs ~3.32 bits, so
254					// ~0.415 bytes
255					let str_len = vi.to_string().len();
256					let byte_len = (str_len * 415 / 1000) + 1; // Rough estimate
257					let max_value: usize = (*max).into();
258					if byte_len > max_value {
259						return Err(TypeError::ConstraintViolation {
260							kind: ConstraintKind::IntMaxBytes {
261								actual: byte_len,
262								max: max_value,
263							},
264							message: format!(
265								"INT value exceeds maximum byte length: {} bytes (max: {} bytes)",
266								byte_len, max_value
267							),
268							fragment: Fragment::None,
269						}
270						.into());
271					}
272				}
273			}
274			(Type::Uint, Some(Constraint::MaxBytes(max))) => {
275				if let Value::Uint(vu) = value {
276					// Calculate byte size of Uint by
277					// converting to string and estimating
278					// This is a rough estimate: each
279					// decimal digit needs ~3.32 bits, so
280					// ~0.415 bytes
281					let str_len = vu.to_string().len();
282					let byte_len = (str_len * 415 / 1000) + 1; // Rough estimate
283					let max_value: usize = (*max).into();
284					if byte_len > max_value {
285						return Err(TypeError::ConstraintViolation {
286							kind: ConstraintKind::UintMaxBytes {
287								actual: byte_len,
288								max: max_value,
289							},
290							message: format!(
291								"UINT value exceeds maximum byte length: {} bytes (max: {} bytes)",
292								byte_len, max_value
293							),
294							fragment: Fragment::None,
295						}
296						.into());
297					}
298				}
299			}
300			(Type::Decimal, Some(Constraint::PrecisionScale(precision, scale))) => {
301				if let Value::Decimal(decimal) = value {
302					// Calculate precision and scale from
303					// BigDecimal
304					let decimal_str = decimal.to_string();
305
306					// Calculate scale (digits after decimal
307					// point)
308					let decimal_scale: u8 = if let Some(dot_pos) = decimal_str.find('.') {
309						let after_dot = &decimal_str[dot_pos + 1..];
310						after_dot.len().min(255) as u8
311					} else {
312						0
313					};
314
315					// Calculate precision (total number of
316					// significant digits)
317					let decimal_precision: u8 =
318						decimal_str.chars().filter(|c| c.is_ascii_digit()).count().min(255)
319							as u8;
320
321					let scale_value: u8 = (*scale).into();
322					let precision_value: u8 = (*precision).into();
323
324					if decimal_scale > scale_value {
325						return Err(TypeError::ConstraintViolation {
326							kind: ConstraintKind::DecimalScale {
327								actual: decimal_scale,
328								max: scale_value,
329							},
330							message: format!(
331								"DECIMAL value exceeds maximum scale: {} decimal places (max: {} decimal places)",
332								decimal_scale, scale_value
333							),
334							fragment: Fragment::None,
335						}
336						.into());
337					}
338					if decimal_precision > precision_value {
339						return Err(TypeError::ConstraintViolation {
340							kind: ConstraintKind::DecimalPrecision {
341								actual: decimal_precision,
342								max: precision_value,
343							},
344							message: format!(
345								"DECIMAL value exceeds maximum precision: {} digits (max: {} digits)",
346								decimal_precision, precision_value
347							),
348							fragment: Fragment::None,
349						}
350						.into());
351					}
352				}
353			}
354			// No constraint or non-applicable constraint
355			_ => {}
356		}
357
358		Ok(())
359	}
360
361	/// Check if this type is unconstrained
362	pub fn is_unconstrained(&self) -> bool {
363		self.constraint.is_none()
364	}
365
366	/// Get a human-readable string representation
367	pub fn to_string(&self) -> String {
368		match &self.constraint {
369			None => format!("{}", self.base_type),
370			Some(Constraint::MaxBytes(max)) => {
371				format!("{}({})", self.base_type, max)
372			}
373			Some(Constraint::PrecisionScale(p, s)) => {
374				format!("{}({},{})", self.base_type, p, s)
375			}
376			Some(Constraint::Dictionary(dict_id, id_type)) => {
377				format!("DictionaryId(dict={}, {})", dict_id, id_type)
378			}
379			Some(Constraint::SumType(id)) => {
380				format!("SumType({})", id)
381			}
382		}
383	}
384}
385
386#[cfg(test)]
387pub mod tests {
388	use super::*;
389
390	#[test]
391	fn test_unconstrained_type() {
392		let tc = TypeConstraint::unconstrained(Type::Utf8);
393		assert_eq!(tc.base_type, Type::Utf8);
394		assert_eq!(tc.constraint, None);
395		assert!(tc.is_unconstrained());
396	}
397
398	#[test]
399	fn test_constrained_utf8() {
400		let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(50)));
401		assert_eq!(tc.base_type, Type::Utf8);
402		assert_eq!(tc.constraint, Some(Constraint::MaxBytes(MaxBytes::new(50))));
403		assert!(!tc.is_unconstrained());
404	}
405
406	#[test]
407	fn test_constrained_decimal() {
408		let tc = TypeConstraint::with_constraint(
409			Type::Decimal,
410			Constraint::PrecisionScale(Precision::new(10), Scale::new(2)),
411		);
412		assert_eq!(tc.base_type, Type::Decimal);
413		assert_eq!(tc.constraint, Some(Constraint::PrecisionScale(Precision::new(10), Scale::new(2))));
414	}
415
416	#[test]
417	fn test_validate_utf8_within_limit() {
418		let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(10)));
419		let value = Value::Utf8("hello".to_string());
420		assert!(tc.validate(&value).is_ok());
421	}
422
423	#[test]
424	fn test_validate_utf8_exceeds_limit() {
425		let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(5)));
426		let value = Value::Utf8("hello world".to_string());
427		assert!(tc.validate(&value).is_err());
428	}
429
430	#[test]
431	fn test_validate_unconstrained() {
432		let tc = TypeConstraint::unconstrained(Type::Utf8);
433		let value = Value::Utf8("any length string is fine here".to_string());
434		assert!(tc.validate(&value).is_ok());
435	}
436
437	#[test]
438	fn test_validate_none_rejected_for_non_option() {
439		let tc = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(5)));
440		let value = Value::none();
441		assert!(tc.validate(&value).is_err());
442	}
443
444	#[test]
445	fn test_validate_none_accepted_for_option() {
446		let tc = TypeConstraint::unconstrained(Type::Option(Box::new(Type::Utf8)));
447		let value = Value::none();
448		assert!(tc.validate(&value).is_ok());
449	}
450
451	#[test]
452	fn test_to_string() {
453		let tc1 = TypeConstraint::unconstrained(Type::Utf8);
454		assert_eq!(tc1.to_string(), "Utf8");
455
456		let tc2 = TypeConstraint::with_constraint(Type::Utf8, Constraint::MaxBytes(MaxBytes::new(50)));
457		assert_eq!(tc2.to_string(), "Utf8(50)");
458
459		let tc3 = TypeConstraint::with_constraint(
460			Type::Decimal,
461			Constraint::PrecisionScale(Precision::new(10), Scale::new(2)),
462		);
463		assert_eq!(tc3.to_string(), "Decimal(10,2)");
464	}
465}