Skip to main content

reifydb_function/text/
substring.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::{container::utf8::Utf8Container, r#type::Type};
6
7use crate::{
8	ScalarFunction, ScalarFunctionContext,
9	error::{ScalarFunctionError, ScalarFunctionResult},
10	propagate_options,
11};
12
13pub struct TextSubstring;
14
15impl TextSubstring {
16	pub fn new() -> Self {
17		Self
18	}
19}
20
21impl ScalarFunction for TextSubstring {
22	fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
23		if let Some(result) = propagate_options(self, &ctx) {
24			return result;
25		}
26
27		let columns = ctx.columns;
28		let row_count = ctx.row_count;
29
30		// Validate exactly 3 arguments
31		if columns.len() != 3 {
32			return Err(ScalarFunctionError::ArityMismatch {
33				function: ctx.fragment.clone(),
34				expected: 3,
35				actual: columns.len(),
36			});
37		}
38
39		let text_column = columns.get(0).unwrap();
40		let start_column = columns.get(1).unwrap();
41		let length_column = columns.get(2).unwrap();
42
43		match (text_column.data(), start_column.data(), length_column.data()) {
44			(
45				ColumnData::Utf8 {
46					container: text_container,
47					max_bytes,
48				},
49				ColumnData::Int4(start_container),
50				ColumnData::Int4(length_container),
51			) => {
52				let mut result_data = Vec::with_capacity(text_container.data().len());
53
54				for i in 0..row_count {
55					if text_container.is_defined(i)
56						&& start_container.is_defined(i) && length_container.is_defined(i)
57					{
58						let original_str = &text_container[i];
59						let start_pos = start_container.get(i).copied().unwrap_or(0);
60						let length = length_container.get(i).copied().unwrap_or(0);
61
62						// Get the substring with proper Unicode handling
63						let chars: Vec<char> = original_str.chars().collect();
64						let chars_len = chars.len();
65
66						// Convert negative start to positive index from end
67						let start_idx = if start_pos < 0 {
68							chars_len.saturating_sub((-start_pos) as usize)
69						} else {
70							start_pos as usize
71						};
72						let length_usize = if length < 0 {
73							0
74						} else {
75							length as usize
76						};
77
78						let substring = if start_idx >= chars_len {
79							// Start position is beyond string length
80							String::new()
81						} else {
82							let end_idx = (start_idx + length_usize).min(chars_len);
83							chars[start_idx..end_idx].iter().collect()
84						};
85
86						result_data.push(substring);
87					} else {
88						result_data.push(String::new());
89					}
90				}
91
92				Ok(ColumnData::Utf8 {
93					container: Utf8Container::new(result_data),
94					max_bytes: *max_bytes,
95				})
96			}
97			// Handle cases where start/length are different integer types
98			(
99				ColumnData::Utf8 {
100					container: text_container,
101					max_bytes,
102				},
103				start_data,
104				length_data,
105			) => {
106				let mut result_data = Vec::with_capacity(text_container.data().len());
107
108				for i in 0..row_count {
109					if text_container.is_defined(i) {
110						let original_str = &text_container[i];
111
112						// Extract start position from various integer types
113						let start_pos = match start_data {
114							ColumnData::Int1(container) => {
115								container.get(i).map(|&v| v as i32).unwrap_or(0)
116							}
117							ColumnData::Int2(container) => {
118								container.get(i).map(|&v| v as i32).unwrap_or(0)
119							}
120							ColumnData::Int4(container) => {
121								container.get(i).copied().unwrap_or(0)
122							}
123							ColumnData::Int8(container) => {
124								container.get(i).map(|&v| v as i32).unwrap_or(0)
125							}
126							_ => 0,
127						};
128
129						// Extract length from various integer types
130						let length = match length_data {
131							ColumnData::Int1(container) => {
132								container.get(i).map(|&v| v as i32).unwrap_or(0)
133							}
134							ColumnData::Int2(container) => {
135								container.get(i).map(|&v| v as i32).unwrap_or(0)
136							}
137							ColumnData::Int4(container) => {
138								container.get(i).copied().unwrap_or(0)
139							}
140							ColumnData::Int8(container) => {
141								container.get(i).map(|&v| v as i32).unwrap_or(0)
142							}
143							_ => 0,
144						};
145
146						// Get the substring with proper Unicode handling
147						let chars: Vec<char> = original_str.chars().collect();
148						let chars_len = chars.len();
149
150						// Convert negative start to positive index from end
151						let start_idx = if start_pos < 0 {
152							chars_len.saturating_sub((-start_pos) as usize)
153						} else {
154							start_pos as usize
155						};
156						let length_usize = if length < 0 {
157							0
158						} else {
159							length as usize
160						};
161
162						let substring = if start_idx >= chars_len {
163							// Start position is beyond string length
164							String::new()
165						} else {
166							let end_idx = (start_idx + length_usize).min(chars_len);
167							chars[start_idx..end_idx].iter().collect()
168						};
169
170						result_data.push(substring);
171					} else {
172						result_data.push(String::new());
173					}
174				}
175
176				Ok(ColumnData::Utf8 {
177					container: Utf8Container::new(result_data),
178					max_bytes: *max_bytes,
179				})
180			}
181			(other, _, _) => Err(ScalarFunctionError::InvalidArgumentType {
182				function: ctx.fragment.clone(),
183				argument_index: 0,
184				expected: vec![Type::Utf8],
185				actual: other.get_type(),
186			}),
187		}
188	}
189
190	fn return_type(&self, _input_types: &[Type]) -> Type {
191		Type::Utf8
192	}
193}