Skip to main content

reifydb_function/text/
substring.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::{container::utf8::Utf8Container, r#type::Type};
6
7use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
8
9pub struct TextSubstring;
10
11impl TextSubstring {
12	pub fn new() -> Self {
13		Self
14	}
15}
16
17impl ScalarFunction for TextSubstring {
18	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
19		if let Some(result) = propagate_options(self, &ctx) {
20			return result;
21		}
22
23		let columns = ctx.columns;
24		let row_count = ctx.row_count;
25
26		// Validate exactly 3 arguments
27		if columns.len() != 3 {
28			return Err(ScalarFunctionError::ArityMismatch {
29				function: ctx.fragment.clone(),
30				expected: 3,
31				actual: columns.len(),
32			});
33		}
34
35		let text_column = columns.get(0).unwrap();
36		let start_column = columns.get(1).unwrap();
37		let length_column = columns.get(2).unwrap();
38
39		match (text_column.data(), start_column.data(), length_column.data()) {
40			(
41				ColumnData::Utf8 {
42					container: text_container,
43					max_bytes,
44				},
45				ColumnData::Int4(start_container),
46				ColumnData::Int4(length_container),
47			) => {
48				let mut result_data = Vec::with_capacity(text_container.data().len());
49
50				for i in 0..row_count {
51					if text_container.is_defined(i)
52						&& start_container.is_defined(i) && length_container.is_defined(i)
53					{
54						let original_str = &text_container[i];
55						let start_pos = start_container.get(i).copied().unwrap_or(0);
56						let length = length_container.get(i).copied().unwrap_or(0);
57
58						// Get the substring with proper Unicode handling
59						let chars: Vec<char> = original_str.chars().collect();
60						let chars_len = chars.len();
61
62						// Convert negative start to positive index from end
63						let start_idx = if start_pos < 0 {
64							chars_len.saturating_sub((-start_pos) as usize)
65						} else {
66							start_pos as usize
67						};
68						let length_usize = if length < 0 {
69							0
70						} else {
71							length as usize
72						};
73
74						let substring = if start_idx >= chars_len {
75							// Start position is beyond string length
76							String::new()
77						} else {
78							let end_idx = (start_idx + length_usize).min(chars_len);
79							chars[start_idx..end_idx].iter().collect()
80						};
81
82						result_data.push(substring);
83					} else {
84						result_data.push(String::new());
85					}
86				}
87
88				Ok(ColumnData::Utf8 {
89					container: Utf8Container::new(result_data),
90					max_bytes: *max_bytes,
91				})
92			}
93			// Handle cases where start/length are different integer types
94			(
95				ColumnData::Utf8 {
96					container: text_container,
97					max_bytes,
98				},
99				start_data,
100				length_data,
101			) => {
102				let mut result_data = Vec::with_capacity(text_container.data().len());
103
104				for i in 0..row_count {
105					if text_container.is_defined(i) {
106						let original_str = &text_container[i];
107
108						// Extract start position from various integer types
109						let start_pos = match start_data {
110							ColumnData::Int1(container) => {
111								container.get(i).map(|&v| v as i32).unwrap_or(0)
112							}
113							ColumnData::Int2(container) => {
114								container.get(i).map(|&v| v as i32).unwrap_or(0)
115							}
116							ColumnData::Int4(container) => {
117								container.get(i).copied().unwrap_or(0)
118							}
119							ColumnData::Int8(container) => {
120								container.get(i).map(|&v| v as i32).unwrap_or(0)
121							}
122							_ => 0,
123						};
124
125						// Extract length from various integer types
126						let length = match length_data {
127							ColumnData::Int1(container) => {
128								container.get(i).map(|&v| v as i32).unwrap_or(0)
129							}
130							ColumnData::Int2(container) => {
131								container.get(i).map(|&v| v as i32).unwrap_or(0)
132							}
133							ColumnData::Int4(container) => {
134								container.get(i).copied().unwrap_or(0)
135							}
136							ColumnData::Int8(container) => {
137								container.get(i).map(|&v| v as i32).unwrap_or(0)
138							}
139							_ => 0,
140						};
141
142						// Get the substring with proper Unicode handling
143						let chars: Vec<char> = original_str.chars().collect();
144						let chars_len = chars.len();
145
146						// Convert negative start to positive index from end
147						let start_idx = if start_pos < 0 {
148							chars_len.saturating_sub((-start_pos) as usize)
149						} else {
150							start_pos as usize
151						};
152						let length_usize = if length < 0 {
153							0
154						} else {
155							length as usize
156						};
157
158						let substring = if start_idx >= chars_len {
159							// Start position is beyond string length
160							String::new()
161						} else {
162							let end_idx = (start_idx + length_usize).min(chars_len);
163							chars[start_idx..end_idx].iter().collect()
164						};
165
166						result_data.push(substring);
167					} else {
168						result_data.push(String::new());
169					}
170				}
171
172				Ok(ColumnData::Utf8 {
173					container: Utf8Container::new(result_data),
174					max_bytes: *max_bytes,
175				})
176			}
177			(other, _, _) => Err(ScalarFunctionError::InvalidArgumentType {
178				function: ctx.fragment.clone(),
179				argument_index: 0,
180				expected: vec![Type::Utf8],
181				actual: other.get_type(),
182			}),
183		}
184	}
185
186	fn return_type(&self, _input_types: &[Type]) -> Type {
187		Type::Utf8
188	}
189}